Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dissect/cstruct/types/structure.py: 59%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

434 statements  

1from __future__ import annotations 

2 

3import io 

4from contextlib import contextmanager 

5from enum import Enum 

6from functools import lru_cache 

7from itertools import chain 

8from operator import attrgetter 

9from textwrap import dedent 

10from types import FunctionType 

11from typing import TYPE_CHECKING, Any, BinaryIO, Callable 

12 

13from dissect.cstruct.bitbuffer import BitBuffer 

14from dissect.cstruct.types.base import ( 

15 BaseType, 

16 MetaType, 

17 _is_buffer_type, 

18 _is_readable_type, 

19) 

20from dissect.cstruct.types.enum import EnumMetaType 

21from dissect.cstruct.types.pointer import Pointer 

22 

23if TYPE_CHECKING: 

24 from collections.abc import Iterator 

25 from types import FunctionType 

26 

27 from typing_extensions import Self 

28 

29 

30class Field: 

31 """Structure field.""" 

32 

33 def __init__(self, name: str | None, type_: type[BaseType], bits: int | None = None, offset: int | None = None): 

34 self.name = name # The name of the field, or None if anonymous 

35 self._name = name or type_.__name__ # The name of the field, or the type name if anonymous 

36 self.type = type_ 

37 self.bits = bits 

38 self.offset = offset 

39 self.alignment = type_.alignment or 1 

40 

41 def __repr__(self) -> str: 

42 bits_str = f" : {self.bits}" if self.bits else "" 

43 return f"<Field {self.name} {self.type.__name__}{bits_str}>" 

44 

45 

46class StructureMetaType(MetaType): 

47 """Base metaclass for cstruct structure type classes.""" 

48 

49 # TODO: resolve field types in _update_fields, remove resolves elsewhere? 

50 

51 fields: dict[str, Field] 

52 """Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures.""" 

53 lookup: dict[str, Field] 

54 """Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields.""" 

55 __fields__: list[Field] 

56 """List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth.""" 

57 

58 # Internal 

59 __align__: bool 

60 __anonymous__: bool 

61 __updating__ = False 

62 __compiled__ = False 

63 

64 def __new__(metacls, name: str, bases: tuple[type, ...], classdict: dict[str, Any]) -> Self: # type: ignore 

65 if (fields := classdict.pop("fields", None)) is not None: 

66 metacls._update_fields(metacls, fields, align=classdict.get("__align__", False), classdict=classdict) 

67 

68 return super().__new__(metacls, name, bases, classdict) 

69 

70 def __call__(cls, *args, **kwargs) -> Self: # type: ignore 

71 if ( 

72 cls.__fields__ 

73 and len(args) == len(cls.__fields__) == 1 

74 and isinstance(args[0], bytes) 

75 and issubclass(cls.__fields__[0].type, bytes) 

76 and len(args[0]) == cls.__fields__[0].type.size 

77 ): 

78 # Shortcut for single char/bytes type 

79 return type.__call__(cls, *args, **kwargs) 

80 if not args and not kwargs: 

81 obj = type.__call__(cls) 

82 object.__setattr__(obj, "_values", {}) 

83 object.__setattr__(obj, "_sizes", {}) 

84 return obj 

85 

86 return super().__call__(*args, **kwargs) 

87 

88 def _update_fields( 

89 cls, fields: list[Field], align: bool = False, classdict: dict[str, Any] | None = None 

90 ) -> dict[str, Any]: 

91 classdict = classdict or {} 

92 

93 lookup = {} 

94 raw_lookup = {} 

95 field_names = [] 

96 for field in fields: 

97 if field._name in lookup and field._name != "_": 

98 raise ValueError(f"Duplicate field name: {field._name}") 

99 

100 if isinstance(field.type, StructureMetaType) and field.name is None: 

101 for anon_field in field.type.fields.values(): 

102 attr = f"{field._name}.{anon_field.name}" 

103 classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr)) 

104 

105 lookup.update(field.type.fields) 

106 else: 

107 lookup[field._name] = field 

108 

109 raw_lookup[field._name] = field 

110 

111 field_names = lookup.keys() 

112 classdict["fields"] = lookup 

113 classdict["lookup"] = raw_lookup 

114 classdict["__fields__"] = fields 

115 classdict["__bool__"] = _generate__bool__(field_names) 

116 

117 if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType): 

118 classdict["__init__"] = _generate_union__init__(raw_lookup.values()) 

119 # Not a great way to do this but it works for now 

120 classdict["__eq__"] = Union.__eq__ 

121 else: 

122 classdict["__init__"] = _generate_structure__init__(raw_lookup.values()) 

123 classdict["__eq__"] = _generate__eq__(field_names) 

124 

125 classdict["__hash__"] = _generate__hash__(field_names) 

126 

127 # If we're calling this as a class method or a function on the metaclass 

128 if issubclass(cls, type): 

129 size, alignment = cls._calculate_size_and_offsets(cls, fields, align) 

130 else: 

131 size, alignment = cls._calculate_size_and_offsets(fields, align) 

132 

133 if cls.__compiled__: 

134 # If the previous class was compiled try to compile this too 

135 from dissect.cstruct import compiler 

136 

137 try: 

138 classdict["_read"] = compiler.Compiler(cls.cs).compile_read(fields, cls.__name__, align=cls.__align__) 

139 classdict["__compiled__"] = True 

140 except Exception: 

141 # Revert _read to the slower loop based method 

142 classdict["_read"] = classmethod(Structure._read.__func__) 

143 classdict["__compiled__"] = False 

144 

145 # TODO: compile _write 

146 # TODO: generate cached_property for lazy reading 

147 

148 classdict["size"] = size 

149 classdict["alignment"] = alignment 

150 classdict["dynamic"] = size is None 

151 

152 return classdict 

153 

154 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: 

155 """Iterate all fields in this structure to calculate the field offsets and total structure size. 

156 

157 If a structure has a dynamic field, further field offsets will be set to None and self.dynamic 

158 will be set to True. 

159 """ 

160 # The current offset, set to None if we become dynamic 

161 offset = 0 

162 # The current alignment for this structure 

163 alignment = 0 

164 

165 # The current bit field type 

166 bits_type = None 

167 # The offset of the current bit field, set to None if we become dynamic 

168 bits_field_offset = 0 

169 # How many bits we have left in the current bit field 

170 bits_remaining = 0 

171 

172 for field in fields: 

173 if field.offset is not None: 

174 # If a field already has an offset, it's leading 

175 offset = field.offset 

176 

177 if align and offset is not None: 

178 # Round to next alignment 

179 offset += -offset & (field.alignment - 1) 

180 

181 # The alignment of this struct is equal to its largest members' alignment 

182 alignment = max(alignment, field.alignment) 

183 

184 if field.bits: 

185 field_type = field.type 

186 

187 if isinstance(field_type, EnumMetaType): 

188 field_type = field_type.type 

189 

190 # Bit fields have special logic 

191 if ( 

192 # Exhausted a bit field 

193 bits_remaining == 0 

194 # Moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8; 

195 or field_type != bits_type 

196 # Still processing a bit field, but it's at a different offset due to alignment or a manual offset 

197 or (bits_type is not None and offset > bits_field_offset + bits_type.size) 

198 ): 

199 # ... if any of this is true, we have to move to the next field 

200 bits_type = field_type 

201 bits_count = bits_type.size * 8 

202 bits_remaining = bits_count 

203 bits_field_offset = offset 

204 

205 if offset is not None: 

206 # We're not dynamic, update the structure size and current offset 

207 offset += bits_type.size 

208 

209 field.offset = bits_field_offset 

210 

211 bits_remaining -= field.bits 

212 

213 if bits_remaining < 0: 

214 raise ValueError("Straddled bit fields are unsupported") 

215 else: 

216 # Reset bits stuff 

217 bits_type = None 

218 bits_field_offset = bits_remaining = 0 

219 

220 field.offset = offset 

221 

222 if offset is not None: 

223 # We're not dynamic, update the structure size and current offset 

224 try: 

225 field_len = len(field.type) 

226 except TypeError: 

227 # This field is dynamic 

228 offset = None 

229 continue 

230 

231 offset += field_len 

232 

233 if align and offset is not None: 

234 # Add "tail padding" if we need to align 

235 # This bit magic rounds up to the next alignment boundary 

236 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 

237 offset += -offset & (alignment - 1) 

238 

239 # The structure size is whatever the currently calculated offset is 

240 return offset, alignment 

241 

242 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore 

243 bit_buffer = BitBuffer(stream, cls.cs.endian) 

244 struct_start = stream.tell() 

245 

246 result = {} 

247 sizes = {} 

248 for field in cls.__fields__: 

249 offset = stream.tell() 

250 

251 if field.offset is not None and offset != struct_start + field.offset: 

252 # Field is at a specific offset, either alligned or added that way 

253 offset = struct_start + field.offset 

254 stream.seek(offset) 

255 

256 if cls.__align__ and field.offset is None: 

257 # Previous field was dynamically sized and we need to align 

258 offset += -offset & (field.alignment - 1) 

259 stream.seek(offset) 

260 

261 if field.bits: 

262 if isinstance(field.type, EnumMetaType): 

263 value = field.type(bit_buffer.read(field.type.type, field.bits)) 

264 else: 

265 value = bit_buffer.read(field.type, field.bits) 

266 

267 result[field._name] = value 

268 continue 

269 

270 bit_buffer.reset() 

271 

272 value = field.type._read(stream, result) 

273 

274 sizes[field._name] = stream.tell() - offset 

275 result[field._name] = value 

276 

277 if cls.__align__: 

278 # Align the stream 

279 stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR) 

280 

281 # Using type.__call__ directly calls the __init__ method of the class 

282 # This is faster than calling cls() and bypasses the metaclass __call__ method 

283 obj = type.__call__(cls, **result) 

284 obj._sizes = sizes 

285 obj._values = result 

286 return obj 

287 

288 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore 

289 result = [] 

290 

291 while obj := cls._read(stream, context): 

292 result.append(obj) 

293 

294 return result 

295 

296 def _write(cls, stream: BinaryIO, data: Structure) -> int: 

297 bit_buffer = BitBuffer(stream, cls.cs.endian) 

298 struct_start = stream.tell() 

299 num = 0 

300 

301 for field in cls.__fields__: 

302 field_type = cls.cs.resolve(field.type) 

303 

304 bit_field_type = ( 

305 (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None 

306 ) 

307 # Current field is not a bit field, but previous was 

308 # Or, moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8; 

309 if (not field.bits and bit_buffer._type is not None) or ( 

310 bit_buffer._type and bit_buffer._type != bit_field_type 

311 ): 

312 # Flush the current bit buffer so we can process alignment properly 

313 bit_buffer.flush() 

314 

315 offset = stream.tell() 

316 

317 if field.offset is not None and offset < struct_start + field.offset: 

318 # Field is at a specific offset, either alligned or added that way 

319 stream.write(b"\x00" * (struct_start + field.offset - offset)) 

320 offset = struct_start + field.offset 

321 

322 if cls.__align__ and field.offset is None: 

323 is_bitbuffer_boundary = bit_buffer._type and ( 

324 bit_buffer._remaining == 0 or bit_buffer._type != field_type 

325 ) 

326 if not bit_buffer._type or is_bitbuffer_boundary: 

327 # Previous field was dynamically sized and we need to align 

328 align_pad = -offset & (field.alignment - 1) 

329 stream.write(b"\x00" * align_pad) 

330 offset += align_pad 

331 

332 value = getattr(data, field._name, None) 

333 if value is None: 

334 value = field_type.__default__() 

335 

336 if field.bits: 

337 if isinstance(field_type, EnumMetaType): 

338 bit_buffer.write(field_type.type, value.value, field.bits) 

339 else: 

340 bit_buffer.write(field_type, value, field.bits) 

341 else: 

342 field_type._write(stream, value) 

343 num += stream.tell() - offset 

344 

345 if bit_buffer._type is not None: 

346 bit_buffer.flush() 

347 

348 if cls.__align__: 

349 # Align the stream 

350 stream.write(b"\x00" * (-stream.tell() & (cls.alignment - 1))) 

351 

352 return num 

353 

354 def add_field(cls, name: str, type_: type[BaseType], bits: int | None = None, offset: int | None = None) -> None: 

355 field = Field(name, type_, bits=bits, offset=offset) 

356 cls.__fields__.append(field) 

357 

358 if not cls.__updating__: 

359 cls.commit() 

360 

361 @contextmanager 

362 def start_update(cls) -> Iterator[None]: 

363 try: 

364 cls.__updating__ = True 

365 yield 

366 finally: 

367 cls.commit() 

368 cls.__updating__ = False 

369 

370 def commit(cls) -> None: 

371 classdict = cls._update_fields(cls.__fields__, cls.__align__) 

372 

373 for key, value in classdict.items(): 

374 setattr(cls, key, value) 

375 

376 

377class Structure(BaseType, metaclass=StructureMetaType): 

378 """Base class for cstruct structure type classes.""" 

379 

380 _values: dict[str, Any] 

381 _sizes: dict[str, int] 

382 

383 def __len__(self) -> int: 

384 return len(self.dumps()) 

385 

386 def __bytes__(self) -> bytes: 

387 return self.dumps() 

388 

389 def __getitem__(self, item: str) -> Any: 

390 return getattr(self, item) 

391 

392 def __repr__(self) -> str: 

393 values = [] 

394 for name, field in self.__class__.fields.items(): 

395 value = self[name] 

396 if issubclass(field.type, int) and not issubclass(field.type, (Pointer, Enum)): 

397 value = hex(value) 

398 else: 

399 value = repr(value) 

400 values.append(f"{name}={value}") 

401 

402 return f"<{self.__class__.__name__} {' '.join(values)}>" 

403 

404 

405class UnionMetaType(StructureMetaType): 

406 """Base metaclass for cstruct union type classes.""" 

407 

408 def __call__(cls, *args, **kwargs) -> Self: # type: ignore 

409 obj: Union = super().__call__(*args, **kwargs) 

410 

411 # Calling with non-stream args or kwargs means we are initializing with values 

412 if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs: 

413 # We don't support user initialization of dynamic unions yet 

414 if cls.dynamic: 

415 raise NotImplementedError("Initializing a dynamic union is not yet supported") 

416 

417 # User (partial) initialization, rebuild the union 

418 # First user-provided field is the one used to rebuild the union 

419 arg_fields = (field._name for _, field in zip(args, cls.__fields__)) 

420 kwarg_fields = (name for name in kwargs if name in cls.lookup) 

421 if (first_field := next(chain(arg_fields, kwarg_fields), None)) is not None: 

422 obj._rebuild(first_field) 

423 elif not args and not kwargs: 

424 # Initialized with default values 

425 # Note that we proxify here in case we have a default initialization (cls()) 

426 # We don't proxify in case we read from a stream, as we do that later on in _read at a more appropriate time 

427 # Same with (partial) user initialization, we do that after rebuilding the union 

428 obj._proxify() 

429 

430 return obj 

431 

432 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: 

433 size = 0 

434 alignment = 0 

435 

436 for field in fields: 

437 if size is not None: 

438 try: 

439 size = max(len(field.type), size) 

440 except TypeError: 

441 size = None 

442 

443 alignment = max(field.alignment, alignment) 

444 

445 if align and size is not None: 

446 # Add "tail padding" if we need to align 

447 # This bit magic rounds up to the next alignment boundary 

448 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 

449 size += -size & (alignment - 1) 

450 

451 return size, alignment 

452 

453 def _read_fields( 

454 cls, stream: BinaryIO, context: dict[str, Any] | None = None 

455 ) -> tuple[dict[str, Any], dict[str, int]]: 

456 result = {} 

457 sizes = {} 

458 

459 if cls.size is None: 

460 offset = stream.tell() 

461 buf = stream 

462 else: 

463 offset = 0 

464 buf = io.BytesIO(stream.read(cls.size)) 

465 

466 for field in cls.__fields__: 

467 field_type = cls.cs.resolve(field.type) 

468 

469 start = 0 

470 if field.offset is not None: 

471 start = field.offset 

472 

473 buf.seek(offset + start) 

474 value = field_type._read(buf, result) 

475 

476 sizes[field._name] = buf.tell() - start 

477 result[field._name] = value 

478 

479 return result, sizes 

480 

481 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore 

482 if cls.size is None: 

483 start = stream.tell() 

484 result, sizes = cls._read_fields(stream, context) 

485 size = stream.tell() - start 

486 stream.seek(start) 

487 buf = stream.read(size) 

488 else: 

489 result = {} 

490 sizes = {} 

491 buf = stream.read(cls.size) 

492 

493 # Create the object and set the values 

494 # Using type.__call__ directly calls the __init__ method of the class 

495 # This is faster than calling cls() and bypasses the metaclass __call__ method 

496 # It also makes it easier to differentiate between user-initialization of the class 

497 # and initialization from a stream read 

498 obj: Union = type.__call__(cls, **result) 

499 object.__setattr__(obj, "_values", result) 

500 object.__setattr__(obj, "_sizes", sizes) 

501 object.__setattr__(obj, "_buf", buf) 

502 

503 if cls.size is not None: 

504 obj._update() 

505 

506 # Proxify any nested structures 

507 obj._proxify() 

508 

509 return obj 

510 

511 def _write(cls, stream: BinaryIO, data: Union) -> int: 

512 if cls.dynamic: 

513 raise NotImplementedError("Writing dynamic unions is not yet supported") 

514 

515 offset = stream.tell() 

516 expected_offset = offset + len(cls) 

517 

518 # Sort by largest field 

519 fields = sorted(cls.__fields__, key=lambda e: e.type.size or 0, reverse=True) 

520 anonymous_struct = False 

521 

522 # Try to write by largest field 

523 for field in fields: 

524 if isinstance(field.type, StructureMetaType) and field.name is None: 

525 # Prefer to write regular fields initially 

526 anonymous_struct = field.type 

527 continue 

528 

529 # Write the value 

530 field.type._write(stream, getattr(data, field._name)) 

531 break 

532 

533 # If we haven't written anything yet and we initially skipped an anonymous struct, write it now 

534 if stream.tell() == offset and anonymous_struct: 

535 anonymous_struct._write(stream, data) 

536 

537 # If we haven't filled the union size yet, pad it 

538 if remaining := expected_offset - stream.tell(): 

539 stream.write(b"\x00" * remaining) 

540 

541 return stream.tell() - offset 

542 

543 

544class Union(Structure, metaclass=UnionMetaType): 

545 """Base class for cstruct union type classes.""" 

546 

547 _buf: bytes 

548 

549 def __eq__(self, other: object) -> bool: 

550 return self.__class__ is other.__class__ and bytes(self) == bytes(other) 

551 

552 def __setattr__(self, attr: str, value: Any) -> None: 

553 if self.__class__.dynamic: 

554 raise NotImplementedError("Modifying a dynamic union is not yet supported") 

555 

556 super().__setattr__(attr, value) 

557 self._rebuild(attr) 

558 

559 def _rebuild(self, attr: str) -> None: 

560 if (cur_buf := getattr(self, "_buf", None)) is None: 

561 cur_buf = b"\x00" * self.__class__.size 

562 

563 buf = io.BytesIO(cur_buf) 

564 field = self.__class__.lookup[attr] 

565 if field.offset: 

566 buf.seek(field.offset) 

567 

568 if (value := getattr(self, attr)) is None: 

569 value = field.type.__default__() 

570 

571 field.type._write(buf, value) 

572 

573 object.__setattr__(self, "_buf", buf.getvalue()) 

574 self._update() 

575 

576 # (Re-)proxify all values 

577 self._proxify() 

578 

579 def _update(self) -> None: 

580 result, sizes = self.__class__._read_fields(io.BytesIO(self._buf)) 

581 self.__dict__.update(result) 

582 object.__setattr__(self, "_values", result) 

583 object.__setattr__(self, "_sizes", sizes) 

584 

585 def _proxify(self) -> None: 

586 def _proxy_structure(value: Structure) -> None: 

587 for field in value.__class__.__fields__: 

588 if issubclass(field.type, Structure): 

589 nested_value = getattr(value, field._name) 

590 proxy = UnionProxy(self, field._name, nested_value) 

591 object.__setattr__(value, field._name, proxy) 

592 _proxy_structure(nested_value) 

593 

594 _proxy_structure(self) 

595 

596 

597class UnionProxy: 

598 __union__: Union 

599 __attr__: str 

600 __target__: Structure 

601 

602 def __init__(self, union: Union, attr: str, target: Structure): 

603 object.__setattr__(self, "__union__", union) 

604 object.__setattr__(self, "__attr__", attr) 

605 object.__setattr__(self, "__target__", target) 

606 

607 def __len__(self) -> int: 

608 return len(self.__target__.dumps()) 

609 

610 def __bytes__(self) -> bytes: 

611 return self.__target__.dumps() 

612 

613 def __getitem__(self, item: str) -> Any: 

614 return getattr(self.__target__, item) 

615 

616 def __repr__(self) -> str: 

617 return repr(self.__target__) 

618 

619 def __getattr__(self, attr: str) -> Any: 

620 return getattr(self.__target__, attr) 

621 

622 def __setattr__(self, attr: str, value: Any) -> None: 

623 setattr(self.__target__, attr, value) 

624 self.__union__._rebuild(self.__attr__) 

625 

626 

627def attrsetter(path: str) -> Callable[[Any], Any]: 

628 path, _, attr = path.rpartition(".") 

629 path = path.split(".") 

630 

631 def _func(obj: Any, value: Any) -> Any: 

632 for name in path: 

633 obj = getattr(obj, name) 

634 setattr(obj, attr, value) 

635 

636 return _func 

637 

638 

639def _codegen(func: FunctionType) -> FunctionType: 

640 """Decorator that generates a template function with a specified number of fields. 

641 

642 This code is a little complex but allows use to cache generated functions for a specific number of fields. 

643 For example, if we generate a structure with 10 fields, we can cache the generated code for that structure. 

644 We can then reuse that code and patch it with the correct field names when we create a new structure with 10 fields. 

645 

646 The functions that are decorated with this decorator should take a list of field names and return a string of code. 

647 The decorated function is needs to be called with the number of fields, instead of the field names. 

648 The confusing part is that that the original function takes field names, but you then call it with 

649 the number of fields instead. 

650 

651 Inspired by https://github.com/dabeaz/dataklasses. 

652 

653 Args: 

654 func: The decorated function that takes a list of field names and returns a string of code. 

655 

656 Returns: 

657 A cached function that generates the desired function code, to be called with the number of fields. 

658 """ 

659 

660 def make_func_code(num_fields: int) -> FunctionType: 

661 exec(func([f"_{n}" for n in range(num_fields)]), {}, d := {}) 

662 return d.popitem()[1] 

663 

664 make_func_code.__wrapped__ = func 

665 return lru_cache(make_func_code) 

666 

667 

668@_codegen 

669def _make_structure__init__(fields: list[str]) -> str: 

670 """Generates an ``__init__`` method for a structure with the specified fields. 

671 

672 Args: 

673 fields: List of field names. 

674 """ 

675 field_args = ", ".join(f"{field} = None" for field in fields) 

676 field_init = "\n".join( 

677 f" self.{name} = {name} if {name} is not None else _{i}_default" for i, name in enumerate(fields) 

678 ) 

679 

680 code = f"def __init__(self{', ' + field_args or ''}):\n" 

681 return code + (field_init or " pass") 

682 

683 

684@_codegen 

685def _make_union__init__(fields: list[str]) -> str: 

686 """Generates an ``__init__`` method for a class with the specified fields using setattr. 

687 

688 Args: 

689 fields: List of field names. 

690 """ 

691 field_args = ", ".join(f"{field} = None" for field in fields) 

692 field_init = "\n".join( 

693 f" object.__setattr__(self, '{name}', {name} if {name} is not None else _{i}_default)" 

694 for i, name in enumerate(fields) 

695 ) 

696 

697 code = f"def __init__(self{', ' + field_args or ''}):\n" 

698 return code + (field_init or " pass") 

699 

700 

701@_codegen 

702def _make__eq__(fields: list[str]) -> str: 

703 """Generates an ``__eq__`` method for a class with the specified fields. 

704 

705 Args: 

706 fields: List of field names. 

707 """ 

708 self_vals = ",".join(f"self.{name}" for name in fields) 

709 other_vals = ",".join(f"other.{name}" for name in fields) 

710 

711 if self_vals: 

712 self_vals += "," 

713 if other_vals: 

714 other_vals += "," 

715 

716 # In the future this could be a looser check, e.g. an __eq__ on the classes, which compares the fields 

717 code = f""" 

718 def __eq__(self, other): 

719 if self.__class__ is other.__class__: 

720 return ({self_vals}) == ({other_vals}) 

721 return False 

722 """ 

723 

724 return dedent(code) 

725 

726 

727@_codegen 

728def _make__bool__(fields: list[str]) -> str: 

729 """Generates a ``__bool__`` method for a class with the specified fields. 

730 

731 Args: 

732 fields: List of field names. 

733 """ 

734 vals = ", ".join(f"self.{name}" for name in fields) 

735 

736 code = f""" 

737 def __bool__(self): 

738 return any([{vals}]) 

739 """ 

740 

741 return dedent(code) 

742 

743 

744@_codegen 

745def _make__hash__(fields: list[str]) -> str: 

746 """Generates a ``__hash__`` method for a class with the specified fields. 

747 

748 Args: 

749 fields: List of field names. 

750 """ 

751 vals = ", ".join(f"self.{name}" for name in fields) 

752 

753 code = f""" 

754 def __hash__(self): 

755 return hash(({vals})) 

756 """ 

757 

758 return dedent(code) 

759 

760 

761def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: 

762 """Patches a function's attributes. 

763 

764 Args: 

765 func: The function to patch. 

766 fields: List of field names to add. 

767 start: The starting index for patching. Defaults to 0. 

768 """ 

769 return type(func)( 

770 func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)), 

771 func.__globals__, 

772 ) 

773 

774 

775def _generate_structure__init__(fields: list[Field]) -> FunctionType: 

776 """Generates an ``__init__`` method for a structure with the specified fields. 

777 

778 Args: 

779 fields: List of field names. 

780 """ 

781 field_names = [field._name for field in fields] 

782 

783 template: FunctionType = _make_structure__init__(len(field_names)) 

784 return type(template)( 

785 template.__code__.replace( 

786 co_names=tuple(chain.from_iterable(zip((f"__{name}_default__" for name in field_names), field_names))), 

787 co_varnames=("self", *field_names), 

788 ), 

789 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields}, 

790 argdefs=template.__defaults__, 

791 ) 

792 

793 

794def _generate_union__init__(fields: list[Field]) -> FunctionType: 

795 """Generates an ``__init__`` method for a union with the specified fields. 

796 

797 Args: 

798 fields: List of field names. 

799 """ 

800 field_names = [field._name for field in fields] 

801 

802 template: FunctionType = _make_union__init__(len(field_names)) 

803 return type(template)( 

804 template.__code__.replace( 

805 co_consts=(None, *field_names), 

806 co_names=("object", "__setattr__", *(f"__{name}_default__" for name in field_names)), 

807 co_varnames=("self", *field_names), 

808 ), 

809 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields}, 

810 argdefs=template.__defaults__, 

811 ) 

812 

813 

814def _generate__eq__(fields: list[str]) -> FunctionType: 

815 """Generates an ``__eq__`` method for a class with the specified fields. 

816 

817 Args: 

818 fields: List of field names. 

819 """ 

820 return _patch_attributes(_make__eq__(len(fields)), fields, 1) 

821 

822 

823def _generate__bool__(fields: list[str]) -> FunctionType: 

824 """Generates a ``__bool__`` method for a class with the specified fields. 

825 

826 Args: 

827 fields: List of field names. 

828 """ 

829 return _patch_attributes(_make__bool__(len(fields)), fields, 1) 

830 

831 

832def _generate__hash__(fields: list[str]) -> FunctionType: 

833 """Generates a ``__hash__`` method for a class with the specified fields. 

834 

835 Args: 

836 fields: List of field names. 

837 """ 

838 return _patch_attributes(_make__hash__(len(fields)), fields, 1)