1from __future__ import annotations
2
3import io
4from contextlib import contextmanager
5from enum import Enum
6from functools import lru_cache
7from itertools import chain
8from operator import attrgetter
9from textwrap import dedent
10from types import FunctionType
11from typing import TYPE_CHECKING, Any, BinaryIO, Callable
12
13from dissect.cstruct.bitbuffer import BitBuffer
14from dissect.cstruct.types.base import (
15 BaseType,
16 MetaType,
17 _is_buffer_type,
18 _is_readable_type,
19)
20from dissect.cstruct.types.enum import EnumMetaType
21from dissect.cstruct.types.pointer import Pointer
22
23if TYPE_CHECKING:
24 from collections.abc import Iterator
25 from types import FunctionType
26
27 from typing_extensions import Self
28
29
30class Field:
31 """Structure field."""
32
33 def __init__(self, name: str | None, type_: type[BaseType], bits: int | None = None, offset: int | None = None):
34 self.name = name # The name of the field, or None if anonymous
35 self._name = name or type_.__name__ # The name of the field, or the type name if anonymous
36 self.type = type_
37 self.bits = bits
38 self.offset = offset
39 self.alignment = type_.alignment or 1
40
41 def __repr__(self) -> str:
42 bits_str = f" : {self.bits}" if self.bits else ""
43 return f"<Field {self.name} {self.type.__name__}{bits_str}>"
44
45
46class StructureMetaType(MetaType):
47 """Base metaclass for cstruct structure type classes."""
48
49 # TODO: resolve field types in _update_fields, remove resolves elsewhere?
50
51 fields: dict[str, Field]
52 """Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures."""
53 lookup: dict[str, Field]
54 """Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields."""
55 __fields__: list[Field]
56 """List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth."""
57
58 # Internal
59 __align__: bool
60 __anonymous__: bool
61 __updating__ = False
62 __compiled__ = False
63
64 def __new__(metacls, name: str, bases: tuple[type, ...], classdict: dict[str, Any]) -> Self: # type: ignore
65 if (fields := classdict.pop("fields", None)) is not None:
66 metacls._update_fields(metacls, fields, align=classdict.get("__align__", False), classdict=classdict)
67
68 return super().__new__(metacls, name, bases, classdict)
69
70 def __call__(cls, *args, **kwargs) -> Self: # type: ignore
71 if (
72 cls.__fields__
73 and len(args) == len(cls.__fields__) == 1
74 and isinstance(args[0], bytes)
75 and issubclass(cls.__fields__[0].type, bytes)
76 and len(args[0]) == cls.__fields__[0].type.size
77 ):
78 # Shortcut for single char/bytes type
79 return type.__call__(cls, *args, **kwargs)
80 if not args and not kwargs:
81 obj = type.__call__(cls)
82 object.__setattr__(obj, "_values", {})
83 object.__setattr__(obj, "_sizes", {})
84 return obj
85
86 return super().__call__(*args, **kwargs)
87
88 def _update_fields(
89 cls, fields: list[Field], align: bool = False, classdict: dict[str, Any] | None = None
90 ) -> dict[str, Any]:
91 classdict = classdict or {}
92
93 lookup = {}
94 raw_lookup = {}
95 field_names = []
96 for field in fields:
97 if field._name in lookup and field._name != "_":
98 raise ValueError(f"Duplicate field name: {field._name}")
99
100 if isinstance(field.type, StructureMetaType) and field.name is None:
101 for anon_field in field.type.fields.values():
102 attr = f"{field._name}.{anon_field.name}"
103 classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr))
104
105 lookup.update(field.type.fields)
106 else:
107 lookup[field._name] = field
108
109 raw_lookup[field._name] = field
110
111 field_names = lookup.keys()
112 classdict["fields"] = lookup
113 classdict["lookup"] = raw_lookup
114 classdict["__fields__"] = fields
115 classdict["__bool__"] = _generate__bool__(field_names)
116
117 if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType):
118 classdict["__init__"] = _generate_union__init__(raw_lookup.values())
119 # Not a great way to do this but it works for now
120 classdict["__eq__"] = Union.__eq__
121 else:
122 classdict["__init__"] = _generate_structure__init__(raw_lookup.values())
123 classdict["__eq__"] = _generate__eq__(field_names)
124
125 classdict["__hash__"] = _generate__hash__(field_names)
126
127 # If we're calling this as a class method or a function on the metaclass
128 if issubclass(cls, type):
129 size, alignment = cls._calculate_size_and_offsets(cls, fields, align)
130 else:
131 size, alignment = cls._calculate_size_and_offsets(fields, align)
132
133 if cls.__compiled__:
134 # If the previous class was compiled try to compile this too
135 from dissect.cstruct import compiler
136
137 try:
138 classdict["_read"] = compiler.Compiler(cls.cs).compile_read(fields, cls.__name__, align=cls.__align__)
139 classdict["__compiled__"] = True
140 except Exception:
141 # Revert _read to the slower loop based method
142 classdict["_read"] = classmethod(Structure._read.__func__)
143 classdict["__compiled__"] = False
144
145 # TODO: compile _write
146 # TODO: generate cached_property for lazy reading
147
148 classdict["size"] = size
149 classdict["alignment"] = alignment
150 classdict["dynamic"] = size is None
151
152 return classdict
153
154 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]:
155 """Iterate all fields in this structure to calculate the field offsets and total structure size.
156
157 If a structure has a dynamic field, further field offsets will be set to None and self.dynamic
158 will be set to True.
159 """
160 # The current offset, set to None if we become dynamic
161 offset = 0
162 # The current alignment for this structure
163 alignment = 0
164
165 # The current bit field type
166 bits_type = None
167 # The offset of the current bit field, set to None if we become dynamic
168 bits_field_offset = 0
169 # How many bits we have left in the current bit field
170 bits_remaining = 0
171
172 for field in fields:
173 if field.offset is not None:
174 # If a field already has an offset, it's leading
175 offset = field.offset
176
177 if align and offset is not None:
178 # Round to next alignment
179 offset += -offset & (field.alignment - 1)
180
181 # The alignment of this struct is equal to its largest members' alignment
182 alignment = max(alignment, field.alignment)
183
184 if field.bits:
185 field_type = field.type
186
187 if isinstance(field_type, EnumMetaType):
188 field_type = field_type.type
189
190 # Bit fields have special logic
191 if (
192 # Exhausted a bit field
193 bits_remaining == 0
194 # Moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8;
195 or field_type != bits_type
196 # Still processing a bit field, but it's at a different offset due to alignment or a manual offset
197 or (bits_type is not None and offset > bits_field_offset + bits_type.size)
198 ):
199 # ... if any of this is true, we have to move to the next field
200 bits_type = field_type
201 bits_count = bits_type.size * 8
202 bits_remaining = bits_count
203 bits_field_offset = offset
204
205 if offset is not None:
206 # We're not dynamic, update the structure size and current offset
207 offset += bits_type.size
208
209 field.offset = bits_field_offset
210
211 bits_remaining -= field.bits
212
213 if bits_remaining < 0:
214 raise ValueError("Straddled bit fields are unsupported")
215 else:
216 # Reset bits stuff
217 bits_type = None
218 bits_field_offset = bits_remaining = 0
219
220 field.offset = offset
221
222 if offset is not None:
223 # We're not dynamic, update the structure size and current offset
224 try:
225 field_len = len(field.type)
226 except TypeError:
227 # This field is dynamic
228 offset = None
229 continue
230
231 offset += field_len
232
233 if align and offset is not None:
234 # Add "tail padding" if we need to align
235 # This bit magic rounds up to the next alignment boundary
236 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5
237 offset += -offset & (alignment - 1)
238
239 # The structure size is whatever the currently calculated offset is
240 return offset, alignment
241
242 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
243 bit_buffer = BitBuffer(stream, cls.cs.endian)
244 struct_start = stream.tell()
245
246 result = {}
247 sizes = {}
248 for field in cls.__fields__:
249 offset = stream.tell()
250
251 if field.offset is not None and offset != struct_start + field.offset:
252 # Field is at a specific offset, either alligned or added that way
253 offset = struct_start + field.offset
254 stream.seek(offset)
255
256 if cls.__align__ and field.offset is None:
257 # Previous field was dynamically sized and we need to align
258 offset += -offset & (field.alignment - 1)
259 stream.seek(offset)
260
261 if field.bits:
262 if isinstance(field.type, EnumMetaType):
263 value = field.type(bit_buffer.read(field.type.type, field.bits))
264 else:
265 value = bit_buffer.read(field.type, field.bits)
266
267 result[field._name] = value
268 continue
269
270 bit_buffer.reset()
271
272 value = field.type._read(stream, result)
273
274 sizes[field._name] = stream.tell() - offset
275 result[field._name] = value
276
277 if cls.__align__:
278 # Align the stream
279 stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR)
280
281 # Using type.__call__ directly calls the __init__ method of the class
282 # This is faster than calling cls() and bypasses the metaclass __call__ method
283 obj = type.__call__(cls, **result)
284 obj._sizes = sizes
285 obj._values = result
286 return obj
287
288 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore
289 result = []
290
291 while obj := cls._read(stream, context):
292 result.append(obj)
293
294 return result
295
296 def _write(cls, stream: BinaryIO, data: Structure) -> int:
297 bit_buffer = BitBuffer(stream, cls.cs.endian)
298 struct_start = stream.tell()
299 num = 0
300
301 for field in cls.__fields__:
302 field_type = cls.cs.resolve(field.type)
303
304 bit_field_type = (
305 (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None
306 )
307 # Current field is not a bit field, but previous was
308 # Or, moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8;
309 if (not field.bits and bit_buffer._type is not None) or (
310 bit_buffer._type and bit_buffer._type != bit_field_type
311 ):
312 # Flush the current bit buffer so we can process alignment properly
313 bit_buffer.flush()
314
315 offset = stream.tell()
316
317 if field.offset is not None and offset < struct_start + field.offset:
318 # Field is at a specific offset, either alligned or added that way
319 stream.write(b"\x00" * (struct_start + field.offset - offset))
320 offset = struct_start + field.offset
321
322 if cls.__align__ and field.offset is None:
323 is_bitbuffer_boundary = bit_buffer._type and (
324 bit_buffer._remaining == 0 or bit_buffer._type != field_type
325 )
326 if not bit_buffer._type or is_bitbuffer_boundary:
327 # Previous field was dynamically sized and we need to align
328 align_pad = -offset & (field.alignment - 1)
329 stream.write(b"\x00" * align_pad)
330 offset += align_pad
331
332 value = getattr(data, field._name, None)
333 if value is None:
334 value = field_type.__default__()
335
336 if field.bits:
337 if isinstance(field_type, EnumMetaType):
338 bit_buffer.write(field_type.type, value.value, field.bits)
339 else:
340 bit_buffer.write(field_type, value, field.bits)
341 else:
342 field_type._write(stream, value)
343 num += stream.tell() - offset
344
345 if bit_buffer._type is not None:
346 bit_buffer.flush()
347
348 if cls.__align__:
349 # Align the stream
350 stream.write(b"\x00" * (-stream.tell() & (cls.alignment - 1)))
351
352 return num
353
354 def add_field(cls, name: str, type_: type[BaseType], bits: int | None = None, offset: int | None = None) -> None:
355 field = Field(name, type_, bits=bits, offset=offset)
356 cls.__fields__.append(field)
357
358 if not cls.__updating__:
359 cls.commit()
360
361 @contextmanager
362 def start_update(cls) -> Iterator[None]:
363 try:
364 cls.__updating__ = True
365 yield
366 finally:
367 cls.commit()
368 cls.__updating__ = False
369
370 def commit(cls) -> None:
371 classdict = cls._update_fields(cls.__fields__, cls.__align__)
372
373 for key, value in classdict.items():
374 setattr(cls, key, value)
375
376
377class Structure(BaseType, metaclass=StructureMetaType):
378 """Base class for cstruct structure type classes."""
379
380 _values: dict[str, Any]
381 _sizes: dict[str, int]
382
383 def __len__(self) -> int:
384 return len(self.dumps())
385
386 def __bytes__(self) -> bytes:
387 return self.dumps()
388
389 def __getitem__(self, item: str) -> Any:
390 return getattr(self, item)
391
392 def __repr__(self) -> str:
393 values = []
394 for name, field in self.__class__.fields.items():
395 value = self[name]
396 if issubclass(field.type, int) and not issubclass(field.type, (Pointer, Enum)):
397 value = hex(value)
398 else:
399 value = repr(value)
400 values.append(f"{name}={value}")
401
402 return f"<{self.__class__.__name__} {' '.join(values)}>"
403
404
405class UnionMetaType(StructureMetaType):
406 """Base metaclass for cstruct union type classes."""
407
408 def __call__(cls, *args, **kwargs) -> Self: # type: ignore
409 obj: Union = super().__call__(*args, **kwargs)
410
411 # Calling with non-stream args or kwargs means we are initializing with values
412 if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs:
413 # We don't support user initialization of dynamic unions yet
414 if cls.dynamic:
415 raise NotImplementedError("Initializing a dynamic union is not yet supported")
416
417 # User (partial) initialization, rebuild the union
418 # First user-provided field is the one used to rebuild the union
419 arg_fields = (field._name for _, field in zip(args, cls.__fields__))
420 kwarg_fields = (name for name in kwargs if name in cls.lookup)
421 if (first_field := next(chain(arg_fields, kwarg_fields), None)) is not None:
422 obj._rebuild(first_field)
423 elif not args and not kwargs:
424 # Initialized with default values
425 # Note that we proxify here in case we have a default initialization (cls())
426 # We don't proxify in case we read from a stream, as we do that later on in _read at a more appropriate time
427 # Same with (partial) user initialization, we do that after rebuilding the union
428 obj._proxify()
429
430 return obj
431
432 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]:
433 size = 0
434 alignment = 0
435
436 for field in fields:
437 if size is not None:
438 try:
439 size = max(len(field.type), size)
440 except TypeError:
441 size = None
442
443 alignment = max(field.alignment, alignment)
444
445 if align and size is not None:
446 # Add "tail padding" if we need to align
447 # This bit magic rounds up to the next alignment boundary
448 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5
449 size += -size & (alignment - 1)
450
451 return size, alignment
452
453 def _read_fields(
454 cls, stream: BinaryIO, context: dict[str, Any] | None = None
455 ) -> tuple[dict[str, Any], dict[str, int]]:
456 result = {}
457 sizes = {}
458
459 if cls.size is None:
460 offset = stream.tell()
461 buf = stream
462 else:
463 offset = 0
464 buf = io.BytesIO(stream.read(cls.size))
465
466 for field in cls.__fields__:
467 field_type = cls.cs.resolve(field.type)
468
469 start = 0
470 if field.offset is not None:
471 start = field.offset
472
473 buf.seek(offset + start)
474 value = field_type._read(buf, result)
475
476 sizes[field._name] = buf.tell() - start
477 result[field._name] = value
478
479 return result, sizes
480
481 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
482 if cls.size is None:
483 start = stream.tell()
484 result, sizes = cls._read_fields(stream, context)
485 size = stream.tell() - start
486 stream.seek(start)
487 buf = stream.read(size)
488 else:
489 result = {}
490 sizes = {}
491 buf = stream.read(cls.size)
492
493 # Create the object and set the values
494 # Using type.__call__ directly calls the __init__ method of the class
495 # This is faster than calling cls() and bypasses the metaclass __call__ method
496 # It also makes it easier to differentiate between user-initialization of the class
497 # and initialization from a stream read
498 obj: Union = type.__call__(cls, **result)
499 object.__setattr__(obj, "_values", result)
500 object.__setattr__(obj, "_sizes", sizes)
501 object.__setattr__(obj, "_buf", buf)
502
503 if cls.size is not None:
504 obj._update()
505
506 # Proxify any nested structures
507 obj._proxify()
508
509 return obj
510
511 def _write(cls, stream: BinaryIO, data: Union) -> int:
512 if cls.dynamic:
513 raise NotImplementedError("Writing dynamic unions is not yet supported")
514
515 offset = stream.tell()
516 expected_offset = offset + len(cls)
517
518 # Sort by largest field
519 fields = sorted(cls.__fields__, key=lambda e: e.type.size or 0, reverse=True)
520 anonymous_struct = False
521
522 # Try to write by largest field
523 for field in fields:
524 if isinstance(field.type, StructureMetaType) and field.name is None:
525 # Prefer to write regular fields initially
526 anonymous_struct = field.type
527 continue
528
529 # Write the value
530 field.type._write(stream, getattr(data, field._name))
531 break
532
533 # If we haven't written anything yet and we initially skipped an anonymous struct, write it now
534 if stream.tell() == offset and anonymous_struct:
535 anonymous_struct._write(stream, data)
536
537 # If we haven't filled the union size yet, pad it
538 if remaining := expected_offset - stream.tell():
539 stream.write(b"\x00" * remaining)
540
541 return stream.tell() - offset
542
543
544class Union(Structure, metaclass=UnionMetaType):
545 """Base class for cstruct union type classes."""
546
547 _buf: bytes
548
549 def __eq__(self, other: object) -> bool:
550 return self.__class__ is other.__class__ and bytes(self) == bytes(other)
551
552 def __setattr__(self, attr: str, value: Any) -> None:
553 if self.__class__.dynamic:
554 raise NotImplementedError("Modifying a dynamic union is not yet supported")
555
556 super().__setattr__(attr, value)
557 self._rebuild(attr)
558
559 def _rebuild(self, attr: str) -> None:
560 if (cur_buf := getattr(self, "_buf", None)) is None:
561 cur_buf = b"\x00" * self.__class__.size
562
563 buf = io.BytesIO(cur_buf)
564 field = self.__class__.lookup[attr]
565 if field.offset:
566 buf.seek(field.offset)
567
568 if (value := getattr(self, attr)) is None:
569 value = field.type.__default__()
570
571 field.type._write(buf, value)
572
573 object.__setattr__(self, "_buf", buf.getvalue())
574 self._update()
575
576 # (Re-)proxify all values
577 self._proxify()
578
579 def _update(self) -> None:
580 result, sizes = self.__class__._read_fields(io.BytesIO(self._buf))
581 self.__dict__.update(result)
582 object.__setattr__(self, "_values", result)
583 object.__setattr__(self, "_sizes", sizes)
584
585 def _proxify(self) -> None:
586 def _proxy_structure(value: Structure) -> None:
587 for field in value.__class__.__fields__:
588 if issubclass(field.type, Structure):
589 nested_value = getattr(value, field._name)
590 proxy = UnionProxy(self, field._name, nested_value)
591 object.__setattr__(value, field._name, proxy)
592 _proxy_structure(nested_value)
593
594 _proxy_structure(self)
595
596
597class UnionProxy:
598 __union__: Union
599 __attr__: str
600 __target__: Structure
601
602 def __init__(self, union: Union, attr: str, target: Structure):
603 object.__setattr__(self, "__union__", union)
604 object.__setattr__(self, "__attr__", attr)
605 object.__setattr__(self, "__target__", target)
606
607 def __len__(self) -> int:
608 return len(self.__target__.dumps())
609
610 def __bytes__(self) -> bytes:
611 return self.__target__.dumps()
612
613 def __getitem__(self, item: str) -> Any:
614 return getattr(self.__target__, item)
615
616 def __repr__(self) -> str:
617 return repr(self.__target__)
618
619 def __getattr__(self, attr: str) -> Any:
620 return getattr(self.__target__, attr)
621
622 def __setattr__(self, attr: str, value: Any) -> None:
623 setattr(self.__target__, attr, value)
624 self.__union__._rebuild(self.__attr__)
625
626
627def attrsetter(path: str) -> Callable[[Any], Any]:
628 path, _, attr = path.rpartition(".")
629 path = path.split(".")
630
631 def _func(obj: Any, value: Any) -> Any:
632 for name in path:
633 obj = getattr(obj, name)
634 setattr(obj, attr, value)
635
636 return _func
637
638
639def _codegen(func: FunctionType) -> FunctionType:
640 """Decorator that generates a template function with a specified number of fields.
641
642 This code is a little complex but allows use to cache generated functions for a specific number of fields.
643 For example, if we generate a structure with 10 fields, we can cache the generated code for that structure.
644 We can then reuse that code and patch it with the correct field names when we create a new structure with 10 fields.
645
646 The functions that are decorated with this decorator should take a list of field names and return a string of code.
647 The decorated function is needs to be called with the number of fields, instead of the field names.
648 The confusing part is that that the original function takes field names, but you then call it with
649 the number of fields instead.
650
651 Inspired by https://github.com/dabeaz/dataklasses.
652
653 Args:
654 func: The decorated function that takes a list of field names and returns a string of code.
655
656 Returns:
657 A cached function that generates the desired function code, to be called with the number of fields.
658 """
659
660 def make_func_code(num_fields: int) -> FunctionType:
661 exec(func([f"_{n}" for n in range(num_fields)]), {}, d := {})
662 return d.popitem()[1]
663
664 make_func_code.__wrapped__ = func
665 return lru_cache(make_func_code)
666
667
668@_codegen
669def _make_structure__init__(fields: list[str]) -> str:
670 """Generates an ``__init__`` method for a structure with the specified fields.
671
672 Args:
673 fields: List of field names.
674 """
675 field_args = ", ".join(f"{field} = None" for field in fields)
676 field_init = "\n".join(
677 f" self.{name} = {name} if {name} is not None else _{i}_default" for i, name in enumerate(fields)
678 )
679
680 code = f"def __init__(self{', ' + field_args or ''}):\n"
681 return code + (field_init or " pass")
682
683
684@_codegen
685def _make_union__init__(fields: list[str]) -> str:
686 """Generates an ``__init__`` method for a class with the specified fields using setattr.
687
688 Args:
689 fields: List of field names.
690 """
691 field_args = ", ".join(f"{field} = None" for field in fields)
692 field_init = "\n".join(
693 f" object.__setattr__(self, '{name}', {name} if {name} is not None else _{i}_default)"
694 for i, name in enumerate(fields)
695 )
696
697 code = f"def __init__(self{', ' + field_args or ''}):\n"
698 return code + (field_init or " pass")
699
700
701@_codegen
702def _make__eq__(fields: list[str]) -> str:
703 """Generates an ``__eq__`` method for a class with the specified fields.
704
705 Args:
706 fields: List of field names.
707 """
708 self_vals = ",".join(f"self.{name}" for name in fields)
709 other_vals = ",".join(f"other.{name}" for name in fields)
710
711 if self_vals:
712 self_vals += ","
713 if other_vals:
714 other_vals += ","
715
716 # In the future this could be a looser check, e.g. an __eq__ on the classes, which compares the fields
717 code = f"""
718 def __eq__(self, other):
719 if self.__class__ is other.__class__:
720 return ({self_vals}) == ({other_vals})
721 return False
722 """
723
724 return dedent(code)
725
726
727@_codegen
728def _make__bool__(fields: list[str]) -> str:
729 """Generates a ``__bool__`` method for a class with the specified fields.
730
731 Args:
732 fields: List of field names.
733 """
734 vals = ", ".join(f"self.{name}" for name in fields)
735
736 code = f"""
737 def __bool__(self):
738 return any([{vals}])
739 """
740
741 return dedent(code)
742
743
744@_codegen
745def _make__hash__(fields: list[str]) -> str:
746 """Generates a ``__hash__`` method for a class with the specified fields.
747
748 Args:
749 fields: List of field names.
750 """
751 vals = ", ".join(f"self.{name}" for name in fields)
752
753 code = f"""
754 def __hash__(self):
755 return hash(({vals}))
756 """
757
758 return dedent(code)
759
760
761def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType:
762 """Patches a function's attributes.
763
764 Args:
765 func: The function to patch.
766 fields: List of field names to add.
767 start: The starting index for patching. Defaults to 0.
768 """
769 return type(func)(
770 func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)),
771 func.__globals__,
772 )
773
774
775def _generate_structure__init__(fields: list[Field]) -> FunctionType:
776 """Generates an ``__init__`` method for a structure with the specified fields.
777
778 Args:
779 fields: List of field names.
780 """
781 field_names = [field._name for field in fields]
782
783 template: FunctionType = _make_structure__init__(len(field_names))
784 return type(template)(
785 template.__code__.replace(
786 co_names=tuple(chain.from_iterable(zip((f"__{name}_default__" for name in field_names), field_names))),
787 co_varnames=("self", *field_names),
788 ),
789 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
790 argdefs=template.__defaults__,
791 )
792
793
794def _generate_union__init__(fields: list[Field]) -> FunctionType:
795 """Generates an ``__init__`` method for a union with the specified fields.
796
797 Args:
798 fields: List of field names.
799 """
800 field_names = [field._name for field in fields]
801
802 template: FunctionType = _make_union__init__(len(field_names))
803 return type(template)(
804 template.__code__.replace(
805 co_consts=(None, *field_names),
806 co_names=("object", "__setattr__", *(f"__{name}_default__" for name in field_names)),
807 co_varnames=("self", *field_names),
808 ),
809 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
810 argdefs=template.__defaults__,
811 )
812
813
814def _generate__eq__(fields: list[str]) -> FunctionType:
815 """Generates an ``__eq__`` method for a class with the specified fields.
816
817 Args:
818 fields: List of field names.
819 """
820 return _patch_attributes(_make__eq__(len(fields)), fields, 1)
821
822
823def _generate__bool__(fields: list[str]) -> FunctionType:
824 """Generates a ``__bool__`` method for a class with the specified fields.
825
826 Args:
827 fields: List of field names.
828 """
829 return _patch_attributes(_make__bool__(len(fields)), fields, 1)
830
831
832def _generate__hash__(fields: list[str]) -> FunctionType:
833 """Generates a ``__hash__`` method for a class with the specified fields.
834
835 Args:
836 fields: List of field names.
837 """
838 return _patch_attributes(_make__hash__(len(fields)), fields, 1)