1from __future__ import annotations
2
3import io
4from collections import ChainMap
5from collections.abc import MutableMapping
6from contextlib import contextmanager
7from enum import Enum
8from functools import lru_cache
9from itertools import chain
10from operator import attrgetter
11from textwrap import dedent
12from types import FunctionType
13from typing import TYPE_CHECKING, Any, BinaryIO, Callable
14
15from dissect.cstruct.bitbuffer import BitBuffer
16from dissect.cstruct.types.base import (
17 BaseType,
18 MetaType,
19 _is_buffer_type,
20 _is_readable_type,
21)
22from dissect.cstruct.types.enum import EnumMetaType
23from dissect.cstruct.types.pointer import Pointer
24
25if TYPE_CHECKING:
26 from collections.abc import Iterator, Mapping
27 from types import FunctionType
28
29 from typing_extensions import Self
30
31
32class Field:
33 """Structure field."""
34
35 def __init__(self, name: str | None, type_: type[BaseType], bits: int | None = None, offset: int | None = None):
36 self.name = name # The name of the field, or None if anonymous
37 self._name = name or type_.__name__ # The name of the field, or the type name if anonymous
38 self.type = type_
39 self.bits = bits
40 self.offset = offset
41 self.alignment = type_.alignment or 1
42
43 def __repr__(self) -> str:
44 bits_str = f" : {self.bits}" if self.bits else ""
45 return f"<Field {self.name} {self.type.__name__}{bits_str}>"
46
47
48class StructureMetaType(MetaType):
49 """Base metaclass for cstruct structure type classes."""
50
51 # TODO: resolve field types in _update_fields, remove resolves elsewhere?
52
53 fields: dict[str, Field]
54 """Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures."""
55 lookup: dict[str, Field]
56 """Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields."""
57 __fields__: list[Field]
58 """List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth."""
59
60 # Internal
61 __align__: bool
62 __anonymous__: bool
63 __updating__ = False
64 __compiled__ = False
65 __static_sizes__: dict[str, int] # Cache of static sizes by field name
66
67 def __new__(metacls, name: str, bases: tuple[type, ...], classdict: dict[str, Any]) -> Self: # type: ignore
68 if (fields := classdict.pop("fields", None)) is not None:
69 metacls._update_fields(metacls, fields, align=classdict.get("__align__", False), classdict=classdict)
70
71 return super().__new__(metacls, name, bases, classdict)
72
73 def __call__(cls, *args, **kwargs) -> Self: # type: ignore
74 if (
75 cls.__fields__
76 and len(args) == len(cls.__fields__) == 1
77 and isinstance(args[0], bytes)
78 and issubclass(cls.__fields__[0].type, bytes)
79 and len(args[0]) == cls.__fields__[0].type.size
80 ):
81 # Shortcut for single char/bytes type
82 return type.__call__(cls, *args, **kwargs)
83 if not args and not kwargs:
84 obj = type.__call__(cls)
85 object.__setattr__(obj, "__dynamic_sizes__", {})
86 return obj
87
88 return super().__call__(*args, **kwargs)
89
90 def _update_fields(
91 cls, fields: list[Field], align: bool = False, classdict: dict[str, Any] | None = None
92 ) -> dict[str, Any]:
93 classdict = classdict or {}
94
95 lookup = {}
96 raw_lookup = {}
97 field_names = []
98 static_sizes = {}
99 for field in fields:
100 if field._name in lookup and field._name != "_":
101 raise ValueError(f"Duplicate field name: {field._name}")
102
103 if not field.type.dynamic:
104 static_sizes[field._name] = field.type.size
105
106 if isinstance(field.type, StructureMetaType) and field.name is None:
107 for anon_field in field.type.fields.values():
108 attr = f"{field._name}.{anon_field.name}"
109 classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr))
110
111 lookup.update(field.type.fields)
112 else:
113 lookup[field._name] = field
114
115 raw_lookup[field._name] = field
116
117 field_names = lookup.keys()
118 classdict["fields"] = lookup
119 classdict["lookup"] = raw_lookup
120 classdict["__fields__"] = fields
121 classdict["__static_sizes__"] = static_sizes
122 classdict["__bool__"] = _generate__bool__(field_names)
123
124 if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType):
125 classdict["__init__"] = _generate_union__init__(raw_lookup.values())
126 # Not a great way to do this but it works for now
127 classdict["__eq__"] = Union.__eq__
128 else:
129 classdict["__init__"] = _generate_structure__init__(raw_lookup.values())
130 classdict["__eq__"] = _generate__eq__(field_names)
131
132 classdict["__hash__"] = _generate__hash__(field_names)
133
134 # If we're calling this as a class method or a function on the metaclass
135 if issubclass(cls, type):
136 size, alignment = cls._calculate_size_and_offsets(cls, fields, align)
137 else:
138 size, alignment = cls._calculate_size_and_offsets(fields, align)
139
140 if cls.__compiled__:
141 # If the previous class was compiled try to compile this too
142 from dissect.cstruct import compiler # noqa: PLC0415
143
144 try:
145 classdict["_read"] = compiler.Compiler(cls.cs).compile_read(fields, cls.__name__, align=cls.__align__)
146 classdict["__compiled__"] = True
147 except Exception:
148 # Revert _read to the slower loop based method
149 classdict["_read"] = classmethod(Structure._read.__func__)
150 classdict["__compiled__"] = False
151
152 # TODO: compile _write
153 # TODO: generate cached_property for lazy reading
154
155 classdict["size"] = size
156 classdict["alignment"] = alignment
157 classdict["dynamic"] = size is None
158
159 return classdict
160
161 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]:
162 """Iterate all fields in this structure to calculate the field offsets and total structure size.
163
164 If a structure has a dynamic field, further field offsets will be set to None and self.dynamic
165 will be set to True.
166 """
167 # The current offset, set to None if we become dynamic
168 offset = 0
169 # The current alignment for this structure
170 alignment = 0
171
172 # The current bit field type
173 bits_type = None
174 # The offset of the current bit field, set to None if we become dynamic
175 bits_field_offset = 0
176 # How many bits we have left in the current bit field
177 bits_remaining = 0
178
179 for field in fields:
180 if field.offset is not None:
181 # If a field already has an offset, it's leading
182 offset = field.offset
183
184 if align and offset is not None:
185 # Round to next alignment
186 offset += -offset & (field.alignment - 1)
187
188 # The alignment of this struct is equal to its largest members' alignment
189 alignment = max(alignment, field.alignment)
190
191 if field.bits:
192 field_type = field.type
193
194 if isinstance(field_type, EnumMetaType):
195 field_type = field_type.type
196
197 # Bit fields have special logic
198 if (
199 # Exhausted a bit field
200 bits_remaining == 0
201 # Moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8;
202 or field_type != bits_type
203 # Still processing a bit field, but it's at a different offset due to alignment or a manual offset
204 or (bits_type is not None and offset > bits_field_offset + bits_type.size)
205 ):
206 # ... if any of this is true, we have to move to the next field
207 bits_type = field_type
208 bits_count = bits_type.size * 8
209 bits_remaining = bits_count
210 bits_field_offset = offset
211
212 if offset is not None:
213 # We're not dynamic, update the structure size and current offset
214 offset += bits_type.size
215
216 field.offset = bits_field_offset
217
218 bits_remaining -= field.bits
219
220 if bits_remaining < 0:
221 raise ValueError("Straddled bit fields are unsupported")
222 else:
223 # Reset bits stuff
224 bits_type = None
225 bits_field_offset = bits_remaining = 0
226
227 field.offset = offset
228
229 if offset is not None:
230 # We're not dynamic, update the structure size and current offset
231 try:
232 field_len = len(field.type)
233 except TypeError:
234 # This field is dynamic
235 offset = None
236 continue
237
238 offset += field_len
239
240 if align and offset is not None:
241 # Add "tail padding" if we need to align
242 # This bit magic rounds up to the next alignment boundary
243 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5
244 offset += -offset & (alignment - 1)
245
246 # The structure size is whatever the currently calculated offset is
247 return offset, alignment
248
249 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
250 bit_buffer = BitBuffer(stream, cls.cs.endian)
251 struct_start = stream.tell()
252
253 result = {}
254 sizes = {}
255 for field in cls.__fields__:
256 offset = stream.tell()
257
258 if field.offset is not None and offset != struct_start + field.offset:
259 # Field is at a specific offset, either alligned or added that way
260 offset = struct_start + field.offset
261 stream.seek(offset)
262
263 if cls.__align__ and field.offset is None:
264 # Previous field was dynamically sized and we need to align
265 offset += -offset & (field.alignment - 1)
266 stream.seek(offset)
267
268 if field.bits:
269 if isinstance(field.type, EnumMetaType):
270 value = field.type(bit_buffer.read(field.type.type, field.bits))
271 else:
272 value = bit_buffer.read(field.type, field.bits)
273
274 result[field._name] = value
275 continue
276
277 bit_buffer.reset()
278
279 value = field.type._read(stream, result)
280
281 result[field._name] = value
282 if field.type.dynamic:
283 sizes[field._name] = stream.tell() - offset
284
285 if cls.__align__:
286 # Align the stream
287 stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR)
288
289 # Using type.__call__ directly calls the __init__ method of the class
290 # This is faster than calling cls() and bypasses the metaclass __call__ method
291 obj = type.__call__(cls, **result)
292 obj.__dynamic_sizes__ = sizes
293 return obj
294
295 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore
296 result = []
297
298 while obj := cls._read(stream, context):
299 result.append(obj)
300
301 return result
302
303 def _write(cls, stream: BinaryIO, data: Structure) -> int:
304 bit_buffer = BitBuffer(stream, cls.cs.endian)
305 struct_start = stream.tell()
306 num = 0
307
308 for field in cls.__fields__:
309 field_type = cls.cs.resolve(field.type)
310
311 bit_field_type = (
312 (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None
313 )
314 # Current field is not a bit field, but previous was
315 # Or, moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8;
316 if (not field.bits and bit_buffer._type is not None) or (
317 bit_buffer._type and bit_buffer._type != bit_field_type
318 ):
319 # Flush the current bit buffer so we can process alignment properly
320 bit_buffer.flush()
321
322 offset = stream.tell()
323
324 if field.offset is not None and offset < struct_start + field.offset:
325 # Field is at a specific offset, either alligned or added that way
326 stream.write(b"\x00" * (struct_start + field.offset - offset))
327 offset = struct_start + field.offset
328
329 if cls.__align__ and field.offset is None:
330 is_bitbuffer_boundary = bit_buffer._type and (
331 bit_buffer._remaining == 0 or bit_buffer._type != field_type
332 )
333 if not bit_buffer._type or is_bitbuffer_boundary:
334 # Previous field was dynamically sized and we need to align
335 align_pad = -offset & (field.alignment - 1)
336 stream.write(b"\x00" * align_pad)
337 offset += align_pad
338
339 value = getattr(data, field._name, None)
340 if value is None:
341 value = field_type.__default__()
342
343 if field.bits:
344 if isinstance(field_type, EnumMetaType):
345 bit_buffer.write(field_type.type, value.value, field.bits)
346 else:
347 bit_buffer.write(field_type, value, field.bits)
348 else:
349 field_type._write(stream, value)
350 num += stream.tell() - offset
351
352 if bit_buffer._type is not None:
353 bit_buffer.flush()
354
355 if cls.__align__:
356 # Align the stream
357 stream.write(b"\x00" * (-stream.tell() & (cls.alignment - 1)))
358
359 return num
360
361 def add_field(cls, name: str, type_: type[BaseType], bits: int | None = None, offset: int | None = None) -> None:
362 field = Field(name, type_, bits=bits, offset=offset)
363 cls.__fields__.append(field)
364
365 if not cls.__updating__:
366 cls.commit()
367
368 @contextmanager
369 def start_update(cls) -> Iterator[None]:
370 try:
371 cls.__updating__ = True
372 yield
373 finally:
374 cls.commit()
375 cls.__updating__ = False
376
377 def commit(cls) -> None:
378 classdict = cls._update_fields(cls.__fields__, cls.__align__)
379
380 for key, value in classdict.items():
381 setattr(cls, key, value)
382
383
384class Structure(BaseType, metaclass=StructureMetaType):
385 """Base class for cstruct structure type classes.
386
387 Note that setting attributes which do not correspond to a field in the structure results in undefined behavior.
388 For performance reasons, the structure does not check if the field exists when writing to an attribute.
389 """
390
391 __dynamic_sizes__: dict[str, int]
392
393 def __len__(self) -> int:
394 return len(self.dumps())
395
396 def __bytes__(self) -> bytes:
397 return self.dumps()
398
399 def __getitem__(self, item: str) -> Any:
400 return getattr(self, item)
401
402 def __repr__(self) -> str:
403 values = []
404 for name, field in self.__class__.fields.items():
405 value = self[name]
406 if issubclass(field.type, int) and not issubclass(field.type, (Pointer, Enum)):
407 value = hex(value)
408 else:
409 value = repr(value)
410 values.append(f"{name}={value}")
411
412 return f"<{self.__class__.__name__} {' '.join(values)}>"
413
414 @property
415 def __values__(self) -> MutableMapping[str, Any]:
416 return StructureValuesProxy(self)
417
418 @property
419 def __sizes__(self) -> Mapping[str, int | None]:
420 return ChainMap(self.__class__.__static_sizes__, self.__dynamic_sizes__)
421
422
423class StructureValuesProxy(MutableMapping):
424 """A proxy for the values of fields of a Structure."""
425
426 def __init__(self, struct: Structure):
427 self._struct: Structure = struct
428
429 def __getitem__(self, key: str) -> Any:
430 if key in self:
431 return getattr(self._struct, key)
432 raise KeyError(key)
433
434 def __setitem__(self, key: str, value: Any) -> None:
435 if key in self:
436 return setattr(self._struct, key, value)
437 raise KeyError(key)
438
439 def __contains__(self, key: str) -> bool:
440 return key in self._struct.__class__.fields
441
442 def __iter__(self) -> Iterator[str]:
443 return iter(self._struct.__class__.fields)
444
445 def __len__(self) -> int:
446 return len(self._struct.__class__.fields)
447
448 def __repr__(self) -> str:
449 return repr(dict(self))
450
451 def __delitem__(self, _: str):
452 # Is abstract in base, but deleting is not supported.
453 raise NotImplementedError("Cannot delete fields from a Structure")
454
455
456class UnionMetaType(StructureMetaType):
457 """Base metaclass for cstruct union type classes."""
458
459 def __call__(cls, *args, **kwargs) -> Self: # type: ignore
460 obj: Union = super().__call__(*args, **kwargs)
461
462 # Calling with non-stream args or kwargs means we are initializing with values
463 if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs:
464 # We don't support user initialization of dynamic unions yet
465 if cls.dynamic:
466 raise NotImplementedError("Initializing a dynamic union is not yet supported")
467
468 # User (partial) initialization, rebuild the union
469 # First user-provided field is the one used to rebuild the union
470 arg_fields = (field._name for _, field in zip(args, cls.__fields__))
471 kwarg_fields = (name for name in kwargs if name in cls.lookup)
472 if (first_field := next(chain(arg_fields, kwarg_fields), None)) is not None:
473 obj._rebuild(first_field)
474 elif not args and not kwargs:
475 # Initialized with default values
476 # Note that we proxify here in case we have a default initialization (cls())
477 # We don't proxify in case we read from a stream, as we do that later on in _read at a more appropriate time
478 # Same with (partial) user initialization, we do that after rebuilding the union
479 obj._proxify()
480
481 return obj
482
483 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]:
484 size = 0
485 alignment = 0
486
487 for field in fields:
488 if size is not None:
489 try:
490 size = max(len(field.type), size)
491 except TypeError:
492 size = None
493
494 alignment = max(field.alignment, alignment)
495
496 if align and size is not None:
497 # Add "tail padding" if we need to align
498 # This bit magic rounds up to the next alignment boundary
499 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5
500 size += -size & (alignment - 1)
501
502 return size, alignment
503
504 def _read_fields(
505 cls, stream: BinaryIO, context: dict[str, Any] | None = None
506 ) -> tuple[dict[str, Any], dict[str, int]]:
507 result = {}
508 sizes = {}
509
510 if cls.size is None:
511 offset = stream.tell()
512 buf = stream
513 else:
514 offset = 0
515 buf = io.BytesIO(stream.read(cls.size))
516
517 for field in cls.__fields__:
518 field_type = cls.cs.resolve(field.type)
519
520 start = 0
521 if field.offset is not None:
522 start = field.offset
523
524 buf.seek(offset + start)
525 value = field_type._read(buf, result)
526
527 result[field._name] = value
528 if field.type.dynamic:
529 sizes[field._name] = buf.tell() - start
530
531 return result, sizes
532
533 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
534 if cls.size is None:
535 start = stream.tell()
536 result, sizes = cls._read_fields(stream, context)
537 size = stream.tell() - start
538 stream.seek(start)
539 buf = stream.read(size)
540 else:
541 result = {}
542 sizes = {}
543 buf = stream.read(cls.size)
544
545 # Create the object and set the values
546 # Using type.__call__ directly calls the __init__ method of the class
547 # This is faster than calling cls() and bypasses the metaclass __call__ method
548 # It also makes it easier to differentiate between user-initialization of the class
549 # and initialization from a stream read
550 obj: Union = type.__call__(cls, **result)
551 object.__setattr__(obj, "__dynamic_sizes__", sizes)
552 object.__setattr__(obj, "_buf", buf)
553
554 if cls.size is not None:
555 obj._update()
556
557 # Proxify any nested structures
558 obj._proxify()
559
560 return obj
561
562 def _write(cls, stream: BinaryIO, data: Union) -> int:
563 if cls.dynamic:
564 raise NotImplementedError("Writing dynamic unions is not yet supported")
565
566 offset = stream.tell()
567 expected_offset = offset + len(cls)
568
569 # Sort by largest field
570 fields = sorted(cls.__fields__, key=lambda e: e.type.size or 0, reverse=True)
571 anonymous_struct = False
572
573 # Try to write by largest field
574 for field in fields:
575 if isinstance(field.type, StructureMetaType) and field.name is None:
576 # Prefer to write regular fields initially
577 anonymous_struct = field.type
578 continue
579
580 # Write the value
581 field.type._write(stream, getattr(data, field._name))
582 break
583
584 # If we haven't written anything yet and we initially skipped an anonymous struct, write it now
585 if stream.tell() == offset and anonymous_struct:
586 anonymous_struct._write(stream, data)
587
588 # If we haven't filled the union size yet, pad it
589 if remaining := expected_offset - stream.tell():
590 stream.write(b"\x00" * remaining)
591
592 return stream.tell() - offset
593
594
595class Union(Structure, metaclass=UnionMetaType):
596 """Base class for cstruct union type classes."""
597
598 _buf: bytes
599
600 def __eq__(self, other: object) -> bool:
601 return self.__class__ is other.__class__ and bytes(self) == bytes(other)
602
603 def __setattr__(self, attr: str, value: Any) -> None:
604 if self.__class__.dynamic:
605 raise NotImplementedError("Modifying a dynamic union is not yet supported")
606
607 super().__setattr__(attr, value)
608 self._rebuild(attr)
609
610 def _rebuild(self, attr: str) -> None:
611 if (cur_buf := getattr(self, "_buf", None)) is None:
612 cur_buf = b"\x00" * self.__class__.size
613
614 buf = io.BytesIO(cur_buf)
615 field = self.__class__.lookup[attr]
616 if field.offset:
617 buf.seek(field.offset)
618
619 if (value := getattr(self, attr)) is None:
620 value = field.type.__default__()
621
622 field.type._write(buf, value)
623
624 object.__setattr__(self, "_buf", buf.getvalue())
625 self._update()
626
627 # (Re-)proxify all values
628 self._proxify()
629
630 def _update(self) -> None:
631 result, sizes = self.__class__._read_fields(io.BytesIO(self._buf))
632 self.__dict__.update(result)
633 object.__setattr__(self, "__dynamic_sizes__", sizes)
634
635 def _proxify(self) -> None:
636 def _proxy_structure(value: Structure) -> None:
637 for field in value.__class__.__fields__:
638 if issubclass(field.type, Structure):
639 nested_value = getattr(value, field._name)
640 proxy = UnionProxy(self, field._name, nested_value)
641 object.__setattr__(value, field._name, proxy)
642 _proxy_structure(nested_value)
643
644 _proxy_structure(self)
645
646
647class UnionProxy:
648 __union__: Union
649 __attr__: str
650 __target__: Structure
651
652 def __init__(self, union: Union, attr: str, target: Structure):
653 object.__setattr__(self, "__union__", union)
654 object.__setattr__(self, "__attr__", attr)
655 object.__setattr__(self, "__target__", target)
656
657 def __len__(self) -> int:
658 return len(self.__target__.dumps())
659
660 def __bytes__(self) -> bytes:
661 return self.__target__.dumps()
662
663 def __getitem__(self, item: str) -> Any:
664 return getattr(self.__target__, item)
665
666 def __repr__(self) -> str:
667 return repr(self.__target__)
668
669 def __getattr__(self, attr: str) -> Any:
670 return getattr(self.__target__, attr)
671
672 def __setattr__(self, attr: str, value: Any) -> None:
673 setattr(self.__target__, attr, value)
674 self.__union__._rebuild(self.__attr__)
675
676
677def attrsetter(path: str) -> Callable[[Any], Any]:
678 path, _, attr = path.rpartition(".")
679 path = path.split(".")
680
681 def _func(obj: Any, value: Any) -> Any:
682 for name in path:
683 obj = getattr(obj, name)
684 setattr(obj, attr, value)
685
686 return _func
687
688
689def _codegen(func: FunctionType) -> FunctionType:
690 """Decorator that generates a template function with a specified number of fields.
691
692 This code is a little complex but allows use to cache generated functions for a specific number of fields.
693 For example, if we generate a structure with 10 fields, we can cache the generated code for that structure.
694 We can then reuse that code and patch it with the correct field names when we create a new structure with 10 fields.
695
696 The functions that are decorated with this decorator should take a list of field names and return a string of code.
697 The decorated function is needs to be called with the number of fields, instead of the field names.
698 The confusing part is that that the original function takes field names, but you then call it with
699 the number of fields instead.
700
701 Inspired by https://github.com/dabeaz/dataklasses.
702
703 Args:
704 func: The decorated function that takes a list of field names and returns a string of code.
705
706 Returns:
707 A cached function that generates the desired function code, to be called with the number of fields.
708 """
709
710 def make_func_code(num_fields: int) -> FunctionType:
711 exec(func([f"_{n}" for n in range(num_fields)]), {}, d := {})
712 return d.popitem()[1]
713
714 make_func_code.__wrapped__ = func
715 return lru_cache(make_func_code)
716
717
718@_codegen
719def _make_structure__init__(fields: list[str]) -> str:
720 """Generates an ``__init__`` method for a structure with the specified fields.
721
722 Args:
723 fields: List of field names.
724 """
725 field_args = ", ".join(f"{field} = None" for field in fields)
726 field_init = "\n".join(
727 f" self.{name} = {name} if {name} is not None else _{i}_default" for i, name in enumerate(fields)
728 )
729
730 code = f"def __init__(self{', ' + field_args or ''}):\n"
731 return code + (field_init or " pass")
732
733
734@_codegen
735def _make_union__init__(fields: list[str]) -> str:
736 """Generates an ``__init__`` method for a class with the specified fields using setattr.
737
738 Args:
739 fields: List of field names.
740 """
741 field_args = ", ".join(f"{field} = None" for field in fields)
742 field_init = "\n".join(
743 f" object.__setattr__(self, '{name}', {name} if {name} is not None else _{i}_default)"
744 for i, name in enumerate(fields)
745 )
746
747 code = f"def __init__(self{', ' + field_args or ''}):\n"
748 return code + (field_init or " pass")
749
750
751@_codegen
752def _make__eq__(fields: list[str]) -> str:
753 """Generates an ``__eq__`` method for a class with the specified fields.
754
755 Args:
756 fields: List of field names.
757 """
758 self_vals = ",".join(f"self.{name}" for name in fields)
759 other_vals = ",".join(f"other.{name}" for name in fields)
760
761 if self_vals:
762 self_vals += ","
763 if other_vals:
764 other_vals += ","
765
766 # In the future this could be a looser check, e.g. an __eq__ on the classes, which compares the fields
767 code = f"""
768 def __eq__(self, other):
769 if self.__class__ is other.__class__:
770 return ({self_vals}) == ({other_vals})
771 return False
772 """
773
774 return dedent(code)
775
776
777@_codegen
778def _make__bool__(fields: list[str]) -> str:
779 """Generates a ``__bool__`` method for a class with the specified fields.
780
781 Args:
782 fields: List of field names.
783 """
784 vals = ", ".join(f"self.{name}" for name in fields)
785
786 code = f"""
787 def __bool__(self):
788 return any([{vals}])
789 """
790
791 return dedent(code)
792
793
794@_codegen
795def _make__hash__(fields: list[str]) -> str:
796 """Generates a ``__hash__`` method for a class with the specified fields.
797
798 Args:
799 fields: List of field names.
800 """
801 vals = ", ".join(f"self.{name}" for name in fields)
802
803 code = f"""
804 def __hash__(self):
805 return hash(({vals}))
806 """
807
808 return dedent(code)
809
810
811def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType:
812 """Patches a function's attributes.
813
814 Args:
815 func: The function to patch.
816 fields: List of field names to add.
817 start: The starting index for patching. Defaults to 0.
818 """
819 return type(func)(
820 func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)),
821 func.__globals__,
822 )
823
824
825def _generate_structure__init__(fields: list[Field]) -> FunctionType:
826 """Generates an ``__init__`` method for a structure with the specified fields.
827
828 Args:
829 fields: List of field names.
830 """
831 mapping = _generate_co_mapping(fields)
832
833 template: FunctionType = _make_structure__init__(len(fields))
834 return type(template)(
835 template.__code__.replace(
836 co_names=_remap_co_values(template.__code__.co_names, mapping),
837 co_varnames=_remap_co_values(template.__code__.co_varnames, mapping),
838 ),
839 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
840 argdefs=template.__defaults__,
841 )
842
843
844def _generate_union__init__(fields: list[Field]) -> FunctionType:
845 """Generates an ``__init__`` method for a union with the specified fields.
846
847 Args:
848 fields: List of field names.
849 """
850 mapping = _generate_co_mapping(fields)
851
852 template: FunctionType = _make_union__init__(len(fields))
853 return type(template)(
854 template.__code__.replace(
855 co_consts=_remap_co_values(template.__code__.co_consts, mapping),
856 co_names=_remap_co_values(template.__code__.co_names, mapping),
857 co_varnames=_remap_co_values(template.__code__.co_varnames, mapping),
858 ),
859 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
860 argdefs=template.__defaults__,
861 )
862
863
864def _generate_co_mapping(fields: list[Field]) -> dict[str, str]:
865 """Generates a mapping of generated code object names to field names.
866
867 The generated code uses names like ``_0``, ``_1``, etc. for fields, and ``_0_default``, ``_1_default``, etc.
868 for default initializer values. Return a mapping of these names to the actual field names.
869
870 Args:
871 fields: List of field names.
872 """
873 return {
874 key: value
875 for i, field in enumerate(fields)
876 for key, value in [(f"_{i}", field._name), (f"_{i}_default", f"__{field._name}_default__")]
877 }
878
879
880def _remap_co_values(value: tuple[Any, ...], mapping: dict[str, str]) -> tuple[Any, ...]:
881 """Remaps code object values using a mapping.
882
883 This is used to replace generated code object names with actual field names.
884
885 Args:
886 value: The original code object values.
887 mapping: A mapping of generated code object names to field names.
888 """
889 # Only attempt to remap if the value is a string, otherwise return it as is
890 # This is to avoid issues with trying to remap non-hashable types, and we only need to replace strings anyway
891 return tuple(mapping.get(v, v) if isinstance(v, str) else v for v in value)
892
893
894def _generate__eq__(fields: list[str]) -> FunctionType:
895 """Generates an ``__eq__`` method for a class with the specified fields.
896
897 Args:
898 fields: List of field names.
899 """
900 return _patch_attributes(_make__eq__(len(fields)), fields, 1)
901
902
903def _generate__bool__(fields: list[str]) -> FunctionType:
904 """Generates a ``__bool__`` method for a class with the specified fields.
905
906 Args:
907 fields: List of field names.
908 """
909 return _patch_attributes(_make__bool__(len(fields)), fields, 1)
910
911
912def _generate__hash__(fields: list[str]) -> FunctionType:
913 """Generates a ``__hash__`` method for a class with the specified fields.
914
915 Args:
916 fields: List of field names.
917 """
918 return _patch_attributes(_make__hash__(len(fields)), fields, 1)