Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dissect/cstruct/types/structure.py: 59%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

469 statements  

1from __future__ import annotations 

2 

3import io 

4from collections import ChainMap 

5from collections.abc import MutableMapping 

6from contextlib import contextmanager 

7from enum import Enum 

8from functools import lru_cache 

9from itertools import chain 

10from operator import attrgetter 

11from textwrap import dedent 

12from types import FunctionType 

13from typing import TYPE_CHECKING, Any, BinaryIO, Callable 

14 

15from dissect.cstruct.bitbuffer import BitBuffer 

16from dissect.cstruct.types.base import ( 

17 BaseType, 

18 MetaType, 

19 _is_buffer_type, 

20 _is_readable_type, 

21) 

22from dissect.cstruct.types.enum import EnumMetaType 

23from dissect.cstruct.types.pointer import Pointer 

24 

25if TYPE_CHECKING: 

26 from collections.abc import Iterator, Mapping 

27 from types import FunctionType 

28 

29 from typing_extensions import Self 

30 

31 

32class Field: 

33 """Structure field.""" 

34 

35 def __init__(self, name: str | None, type_: type[BaseType], bits: int | None = None, offset: int | None = None): 

36 self.name = name # The name of the field, or None if anonymous 

37 self._name = name or type_.__name__ # The name of the field, or the type name if anonymous 

38 self.type = type_ 

39 self.bits = bits 

40 self.offset = offset 

41 self.alignment = type_.alignment or 1 

42 

43 def __repr__(self) -> str: 

44 bits_str = f" : {self.bits}" if self.bits else "" 

45 return f"<Field {self.name} {self.type.__name__}{bits_str}>" 

46 

47 

48class StructureMetaType(MetaType): 

49 """Base metaclass for cstruct structure type classes.""" 

50 

51 # TODO: resolve field types in _update_fields, remove resolves elsewhere? 

52 

53 fields: dict[str, Field] 

54 """Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures.""" 

55 lookup: dict[str, Field] 

56 """Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields.""" 

57 __fields__: list[Field] 

58 """List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth.""" 

59 

60 # Internal 

61 __align__: bool 

62 __anonymous__: bool 

63 __updating__ = False 

64 __compiled__ = False 

65 __static_sizes__: dict[str, int] # Cache of static sizes by field name 

66 

67 def __new__(metacls, name: str, bases: tuple[type, ...], classdict: dict[str, Any]) -> Self: # type: ignore 

68 if (fields := classdict.pop("fields", None)) is not None: 

69 metacls._update_fields(metacls, fields, align=classdict.get("__align__", False), classdict=classdict) 

70 

71 return super().__new__(metacls, name, bases, classdict) 

72 

73 def __call__(cls, *args, **kwargs) -> Self: # type: ignore 

74 if ( 

75 cls.__fields__ 

76 and len(args) == len(cls.__fields__) == 1 

77 and isinstance(args[0], bytes) 

78 and issubclass(cls.__fields__[0].type, bytes) 

79 and len(args[0]) == cls.__fields__[0].type.size 

80 ): 

81 # Shortcut for single char/bytes type 

82 return type.__call__(cls, *args, **kwargs) 

83 if not args and not kwargs: 

84 obj = type.__call__(cls) 

85 object.__setattr__(obj, "__dynamic_sizes__", {}) 

86 return obj 

87 

88 return super().__call__(*args, **kwargs) 

89 

90 def _update_fields( 

91 cls, fields: list[Field], align: bool = False, classdict: dict[str, Any] | None = None 

92 ) -> dict[str, Any]: 

93 classdict = classdict or {} 

94 

95 lookup = {} 

96 raw_lookup = {} 

97 field_names = [] 

98 static_sizes = {} 

99 for field in fields: 

100 if field._name in lookup and field._name != "_": 

101 raise ValueError(f"Duplicate field name: {field._name}") 

102 

103 if not field.type.dynamic: 

104 static_sizes[field._name] = field.type.size 

105 

106 if isinstance(field.type, StructureMetaType) and field.name is None: 

107 for anon_field in field.type.fields.values(): 

108 attr = f"{field._name}.{anon_field.name}" 

109 classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr)) 

110 

111 lookup.update(field.type.fields) 

112 else: 

113 lookup[field._name] = field 

114 

115 raw_lookup[field._name] = field 

116 

117 field_names = lookup.keys() 

118 classdict["fields"] = lookup 

119 classdict["lookup"] = raw_lookup 

120 classdict["__fields__"] = fields 

121 classdict["__static_sizes__"] = static_sizes 

122 classdict["__bool__"] = _generate__bool__(field_names) 

123 

124 if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType): 

125 classdict["__init__"] = _generate_union__init__(raw_lookup.values()) 

126 # Not a great way to do this but it works for now 

127 classdict["__eq__"] = Union.__eq__ 

128 else: 

129 classdict["__init__"] = _generate_structure__init__(raw_lookup.values()) 

130 classdict["__eq__"] = _generate__eq__(field_names) 

131 

132 classdict["__hash__"] = _generate__hash__(field_names) 

133 

134 # If we're calling this as a class method or a function on the metaclass 

135 if issubclass(cls, type): 

136 size, alignment = cls._calculate_size_and_offsets(cls, fields, align) 

137 else: 

138 size, alignment = cls._calculate_size_and_offsets(fields, align) 

139 

140 if cls.__compiled__: 

141 # If the previous class was compiled try to compile this too 

142 from dissect.cstruct import compiler # noqa: PLC0415 

143 

144 try: 

145 classdict["_read"] = compiler.Compiler(cls.cs).compile_read(fields, cls.__name__, align=cls.__align__) 

146 classdict["__compiled__"] = True 

147 except Exception: 

148 # Revert _read to the slower loop based method 

149 classdict["_read"] = classmethod(Structure._read.__func__) 

150 classdict["__compiled__"] = False 

151 

152 # TODO: compile _write 

153 # TODO: generate cached_property for lazy reading 

154 

155 classdict["size"] = size 

156 classdict["alignment"] = alignment 

157 classdict["dynamic"] = size is None 

158 

159 return classdict 

160 

161 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: 

162 """Iterate all fields in this structure to calculate the field offsets and total structure size. 

163 

164 If a structure has a dynamic field, further field offsets will be set to None and self.dynamic 

165 will be set to True. 

166 """ 

167 # The current offset, set to None if we become dynamic 

168 offset = 0 

169 # The current alignment for this structure 

170 alignment = 0 

171 

172 # The current bit field type 

173 bits_type = None 

174 # The offset of the current bit field, set to None if we become dynamic 

175 bits_field_offset = 0 

176 # How many bits we have left in the current bit field 

177 bits_remaining = 0 

178 

179 for field in fields: 

180 if field.offset is not None: 

181 # If a field already has an offset, it's leading 

182 offset = field.offset 

183 

184 if align and offset is not None: 

185 # Round to next alignment 

186 offset += -offset & (field.alignment - 1) 

187 

188 # The alignment of this struct is equal to its largest members' alignment 

189 alignment = max(alignment, field.alignment) 

190 

191 if field.bits: 

192 field_type = field.type 

193 

194 if isinstance(field_type, EnumMetaType): 

195 field_type = field_type.type 

196 

197 # Bit fields have special logic 

198 if ( 

199 # Exhausted a bit field 

200 bits_remaining == 0 

201 # Moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8; 

202 or field_type != bits_type 

203 # Still processing a bit field, but it's at a different offset due to alignment or a manual offset 

204 or (bits_type is not None and offset > bits_field_offset + bits_type.size) 

205 ): 

206 # ... if any of this is true, we have to move to the next field 

207 bits_type = field_type 

208 bits_count = bits_type.size * 8 

209 bits_remaining = bits_count 

210 bits_field_offset = offset 

211 

212 if offset is not None: 

213 # We're not dynamic, update the structure size and current offset 

214 offset += bits_type.size 

215 

216 field.offset = bits_field_offset 

217 

218 bits_remaining -= field.bits 

219 

220 if bits_remaining < 0: 

221 raise ValueError("Straddled bit fields are unsupported") 

222 else: 

223 # Reset bits stuff 

224 bits_type = None 

225 bits_field_offset = bits_remaining = 0 

226 

227 field.offset = offset 

228 

229 if offset is not None: 

230 # We're not dynamic, update the structure size and current offset 

231 try: 

232 field_len = len(field.type) 

233 except TypeError: 

234 # This field is dynamic 

235 offset = None 

236 continue 

237 

238 offset += field_len 

239 

240 if align and offset is not None: 

241 # Add "tail padding" if we need to align 

242 # This bit magic rounds up to the next alignment boundary 

243 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 

244 offset += -offset & (alignment - 1) 

245 

246 # The structure size is whatever the currently calculated offset is 

247 return offset, alignment 

248 

249 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore 

250 bit_buffer = BitBuffer(stream, cls.cs.endian) 

251 struct_start = stream.tell() 

252 

253 result = {} 

254 sizes = {} 

255 for field in cls.__fields__: 

256 offset = stream.tell() 

257 

258 if field.offset is not None and offset != struct_start + field.offset: 

259 # Field is at a specific offset, either alligned or added that way 

260 offset = struct_start + field.offset 

261 stream.seek(offset) 

262 

263 if cls.__align__ and field.offset is None: 

264 # Previous field was dynamically sized and we need to align 

265 offset += -offset & (field.alignment - 1) 

266 stream.seek(offset) 

267 

268 if field.bits: 

269 if isinstance(field.type, EnumMetaType): 

270 value = field.type(bit_buffer.read(field.type.type, field.bits)) 

271 else: 

272 value = bit_buffer.read(field.type, field.bits) 

273 

274 result[field._name] = value 

275 continue 

276 

277 bit_buffer.reset() 

278 

279 value = field.type._read(stream, result) 

280 

281 result[field._name] = value 

282 if field.type.dynamic: 

283 sizes[field._name] = stream.tell() - offset 

284 

285 if cls.__align__: 

286 # Align the stream 

287 stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR) 

288 

289 # Using type.__call__ directly calls the __init__ method of the class 

290 # This is faster than calling cls() and bypasses the metaclass __call__ method 

291 obj = type.__call__(cls, **result) 

292 obj.__dynamic_sizes__ = sizes 

293 return obj 

294 

295 def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore 

296 result = [] 

297 

298 while obj := cls._read(stream, context): 

299 result.append(obj) 

300 

301 return result 

302 

303 def _write(cls, stream: BinaryIO, data: Structure) -> int: 

304 bit_buffer = BitBuffer(stream, cls.cs.endian) 

305 struct_start = stream.tell() 

306 num = 0 

307 

308 for field in cls.__fields__: 

309 field_type = cls.cs.resolve(field.type) 

310 

311 bit_field_type = ( 

312 (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None 

313 ) 

314 # Current field is not a bit field, but previous was 

315 # Or, moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8; 

316 if (not field.bits and bit_buffer._type is not None) or ( 

317 bit_buffer._type and bit_buffer._type != bit_field_type 

318 ): 

319 # Flush the current bit buffer so we can process alignment properly 

320 bit_buffer.flush() 

321 

322 offset = stream.tell() 

323 

324 if field.offset is not None and offset < struct_start + field.offset: 

325 # Field is at a specific offset, either alligned or added that way 

326 stream.write(b"\x00" * (struct_start + field.offset - offset)) 

327 offset = struct_start + field.offset 

328 

329 if cls.__align__ and field.offset is None: 

330 is_bitbuffer_boundary = bit_buffer._type and ( 

331 bit_buffer._remaining == 0 or bit_buffer._type != field_type 

332 ) 

333 if not bit_buffer._type or is_bitbuffer_boundary: 

334 # Previous field was dynamically sized and we need to align 

335 align_pad = -offset & (field.alignment - 1) 

336 stream.write(b"\x00" * align_pad) 

337 offset += align_pad 

338 

339 value = getattr(data, field._name, None) 

340 if value is None: 

341 value = field_type.__default__() 

342 

343 if field.bits: 

344 if isinstance(field_type, EnumMetaType): 

345 bit_buffer.write(field_type.type, value.value, field.bits) 

346 else: 

347 bit_buffer.write(field_type, value, field.bits) 

348 else: 

349 field_type._write(stream, value) 

350 num += stream.tell() - offset 

351 

352 if bit_buffer._type is not None: 

353 bit_buffer.flush() 

354 

355 if cls.__align__: 

356 # Align the stream 

357 stream.write(b"\x00" * (-stream.tell() & (cls.alignment - 1))) 

358 

359 return num 

360 

361 def add_field(cls, name: str, type_: type[BaseType], bits: int | None = None, offset: int | None = None) -> None: 

362 field = Field(name, type_, bits=bits, offset=offset) 

363 cls.__fields__.append(field) 

364 

365 if not cls.__updating__: 

366 cls.commit() 

367 

368 @contextmanager 

369 def start_update(cls) -> Iterator[None]: 

370 try: 

371 cls.__updating__ = True 

372 yield 

373 finally: 

374 cls.commit() 

375 cls.__updating__ = False 

376 

377 def commit(cls) -> None: 

378 classdict = cls._update_fields(cls.__fields__, cls.__align__) 

379 

380 for key, value in classdict.items(): 

381 setattr(cls, key, value) 

382 

383 

384class Structure(BaseType, metaclass=StructureMetaType): 

385 """Base class for cstruct structure type classes. 

386 

387 Note that setting attributes which do not correspond to a field in the structure results in undefined behavior. 

388 For performance reasons, the structure does not check if the field exists when writing to an attribute. 

389 """ 

390 

391 __dynamic_sizes__: dict[str, int] 

392 

393 def __len__(self) -> int: 

394 return len(self.dumps()) 

395 

396 def __bytes__(self) -> bytes: 

397 return self.dumps() 

398 

399 def __getitem__(self, item: str) -> Any: 

400 return getattr(self, item) 

401 

402 def __repr__(self) -> str: 

403 values = [] 

404 for name, field in self.__class__.fields.items(): 

405 value = self[name] 

406 if issubclass(field.type, int) and not issubclass(field.type, (Pointer, Enum)): 

407 value = hex(value) 

408 else: 

409 value = repr(value) 

410 values.append(f"{name}={value}") 

411 

412 return f"<{self.__class__.__name__} {' '.join(values)}>" 

413 

414 @property 

415 def __values__(self) -> MutableMapping[str, Any]: 

416 return StructureValuesProxy(self) 

417 

418 @property 

419 def __sizes__(self) -> Mapping[str, int | None]: 

420 return ChainMap(self.__class__.__static_sizes__, self.__dynamic_sizes__) 

421 

422 

423class StructureValuesProxy(MutableMapping): 

424 """A proxy for the values of fields of a Structure.""" 

425 

426 def __init__(self, struct: Structure): 

427 self._struct: Structure = struct 

428 

429 def __getitem__(self, key: str) -> Any: 

430 if key in self: 

431 return getattr(self._struct, key) 

432 raise KeyError(key) 

433 

434 def __setitem__(self, key: str, value: Any) -> None: 

435 if key in self: 

436 return setattr(self._struct, key, value) 

437 raise KeyError(key) 

438 

439 def __contains__(self, key: str) -> bool: 

440 return key in self._struct.__class__.fields 

441 

442 def __iter__(self) -> Iterator[str]: 

443 return iter(self._struct.__class__.fields) 

444 

445 def __len__(self) -> int: 

446 return len(self._struct.__class__.fields) 

447 

448 def __repr__(self) -> str: 

449 return repr(dict(self)) 

450 

451 def __delitem__(self, _: str): 

452 # Is abstract in base, but deleting is not supported. 

453 raise NotImplementedError("Cannot delete fields from a Structure") 

454 

455 

456class UnionMetaType(StructureMetaType): 

457 """Base metaclass for cstruct union type classes.""" 

458 

459 def __call__(cls, *args, **kwargs) -> Self: # type: ignore 

460 obj: Union = super().__call__(*args, **kwargs) 

461 

462 # Calling with non-stream args or kwargs means we are initializing with values 

463 if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs: 

464 # We don't support user initialization of dynamic unions yet 

465 if cls.dynamic: 

466 raise NotImplementedError("Initializing a dynamic union is not yet supported") 

467 

468 # User (partial) initialization, rebuild the union 

469 # First user-provided field is the one used to rebuild the union 

470 arg_fields = (field._name for _, field in zip(args, cls.__fields__)) 

471 kwarg_fields = (name for name in kwargs if name in cls.lookup) 

472 if (first_field := next(chain(arg_fields, kwarg_fields), None)) is not None: 

473 obj._rebuild(first_field) 

474 elif not args and not kwargs: 

475 # Initialized with default values 

476 # Note that we proxify here in case we have a default initialization (cls()) 

477 # We don't proxify in case we read from a stream, as we do that later on in _read at a more appropriate time 

478 # Same with (partial) user initialization, we do that after rebuilding the union 

479 obj._proxify() 

480 

481 return obj 

482 

483 def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: 

484 size = 0 

485 alignment = 0 

486 

487 for field in fields: 

488 if size is not None: 

489 try: 

490 size = max(len(field.type), size) 

491 except TypeError: 

492 size = None 

493 

494 alignment = max(field.alignment, alignment) 

495 

496 if align and size is not None: 

497 # Add "tail padding" if we need to align 

498 # This bit magic rounds up to the next alignment boundary 

499 # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 

500 size += -size & (alignment - 1) 

501 

502 return size, alignment 

503 

504 def _read_fields( 

505 cls, stream: BinaryIO, context: dict[str, Any] | None = None 

506 ) -> tuple[dict[str, Any], dict[str, int]]: 

507 result = {} 

508 sizes = {} 

509 

510 if cls.size is None: 

511 offset = stream.tell() 

512 buf = stream 

513 else: 

514 offset = 0 

515 buf = io.BytesIO(stream.read(cls.size)) 

516 

517 for field in cls.__fields__: 

518 field_type = cls.cs.resolve(field.type) 

519 

520 start = 0 

521 if field.offset is not None: 

522 start = field.offset 

523 

524 buf.seek(offset + start) 

525 value = field_type._read(buf, result) 

526 

527 result[field._name] = value 

528 if field.type.dynamic: 

529 sizes[field._name] = buf.tell() - start 

530 

531 return result, sizes 

532 

533 def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore 

534 if cls.size is None: 

535 start = stream.tell() 

536 result, sizes = cls._read_fields(stream, context) 

537 size = stream.tell() - start 

538 stream.seek(start) 

539 buf = stream.read(size) 

540 else: 

541 result = {} 

542 sizes = {} 

543 buf = stream.read(cls.size) 

544 

545 # Create the object and set the values 

546 # Using type.__call__ directly calls the __init__ method of the class 

547 # This is faster than calling cls() and bypasses the metaclass __call__ method 

548 # It also makes it easier to differentiate between user-initialization of the class 

549 # and initialization from a stream read 

550 obj: Union = type.__call__(cls, **result) 

551 object.__setattr__(obj, "__dynamic_sizes__", sizes) 

552 object.__setattr__(obj, "_buf", buf) 

553 

554 if cls.size is not None: 

555 obj._update() 

556 

557 # Proxify any nested structures 

558 obj._proxify() 

559 

560 return obj 

561 

562 def _write(cls, stream: BinaryIO, data: Union) -> int: 

563 if cls.dynamic: 

564 raise NotImplementedError("Writing dynamic unions is not yet supported") 

565 

566 offset = stream.tell() 

567 expected_offset = offset + len(cls) 

568 

569 # Sort by largest field 

570 fields = sorted(cls.__fields__, key=lambda e: e.type.size or 0, reverse=True) 

571 anonymous_struct = False 

572 

573 # Try to write by largest field 

574 for field in fields: 

575 if isinstance(field.type, StructureMetaType) and field.name is None: 

576 # Prefer to write regular fields initially 

577 anonymous_struct = field.type 

578 continue 

579 

580 # Write the value 

581 field.type._write(stream, getattr(data, field._name)) 

582 break 

583 

584 # If we haven't written anything yet and we initially skipped an anonymous struct, write it now 

585 if stream.tell() == offset and anonymous_struct: 

586 anonymous_struct._write(stream, data) 

587 

588 # If we haven't filled the union size yet, pad it 

589 if remaining := expected_offset - stream.tell(): 

590 stream.write(b"\x00" * remaining) 

591 

592 return stream.tell() - offset 

593 

594 

595class Union(Structure, metaclass=UnionMetaType): 

596 """Base class for cstruct union type classes.""" 

597 

598 _buf: bytes 

599 

600 def __eq__(self, other: object) -> bool: 

601 return self.__class__ is other.__class__ and bytes(self) == bytes(other) 

602 

603 def __setattr__(self, attr: str, value: Any) -> None: 

604 if self.__class__.dynamic: 

605 raise NotImplementedError("Modifying a dynamic union is not yet supported") 

606 

607 super().__setattr__(attr, value) 

608 self._rebuild(attr) 

609 

610 def _rebuild(self, attr: str) -> None: 

611 if (cur_buf := getattr(self, "_buf", None)) is None: 

612 cur_buf = b"\x00" * self.__class__.size 

613 

614 buf = io.BytesIO(cur_buf) 

615 field = self.__class__.lookup[attr] 

616 if field.offset: 

617 buf.seek(field.offset) 

618 

619 if (value := getattr(self, attr)) is None: 

620 value = field.type.__default__() 

621 

622 field.type._write(buf, value) 

623 

624 object.__setattr__(self, "_buf", buf.getvalue()) 

625 self._update() 

626 

627 # (Re-)proxify all values 

628 self._proxify() 

629 

630 def _update(self) -> None: 

631 result, sizes = self.__class__._read_fields(io.BytesIO(self._buf)) 

632 self.__dict__.update(result) 

633 object.__setattr__(self, "__dynamic_sizes__", sizes) 

634 

635 def _proxify(self) -> None: 

636 def _proxy_structure(value: Structure) -> None: 

637 for field in value.__class__.__fields__: 

638 if issubclass(field.type, Structure): 

639 nested_value = getattr(value, field._name) 

640 proxy = UnionProxy(self, field._name, nested_value) 

641 object.__setattr__(value, field._name, proxy) 

642 _proxy_structure(nested_value) 

643 

644 _proxy_structure(self) 

645 

646 

647class UnionProxy: 

648 __union__: Union 

649 __attr__: str 

650 __target__: Structure 

651 

652 def __init__(self, union: Union, attr: str, target: Structure): 

653 object.__setattr__(self, "__union__", union) 

654 object.__setattr__(self, "__attr__", attr) 

655 object.__setattr__(self, "__target__", target) 

656 

657 def __len__(self) -> int: 

658 return len(self.__target__.dumps()) 

659 

660 def __bytes__(self) -> bytes: 

661 return self.__target__.dumps() 

662 

663 def __getitem__(self, item: str) -> Any: 

664 return getattr(self.__target__, item) 

665 

666 def __repr__(self) -> str: 

667 return repr(self.__target__) 

668 

669 def __getattr__(self, attr: str) -> Any: 

670 return getattr(self.__target__, attr) 

671 

672 def __setattr__(self, attr: str, value: Any) -> None: 

673 setattr(self.__target__, attr, value) 

674 self.__union__._rebuild(self.__attr__) 

675 

676 

677def attrsetter(path: str) -> Callable[[Any], Any]: 

678 path, _, attr = path.rpartition(".") 

679 path = path.split(".") 

680 

681 def _func(obj: Any, value: Any) -> Any: 

682 for name in path: 

683 obj = getattr(obj, name) 

684 setattr(obj, attr, value) 

685 

686 return _func 

687 

688 

689def _codegen(func: FunctionType) -> FunctionType: 

690 """Decorator that generates a template function with a specified number of fields. 

691 

692 This code is a little complex but allows use to cache generated functions for a specific number of fields. 

693 For example, if we generate a structure with 10 fields, we can cache the generated code for that structure. 

694 We can then reuse that code and patch it with the correct field names when we create a new structure with 10 fields. 

695 

696 The functions that are decorated with this decorator should take a list of field names and return a string of code. 

697 The decorated function is needs to be called with the number of fields, instead of the field names. 

698 The confusing part is that that the original function takes field names, but you then call it with 

699 the number of fields instead. 

700 

701 Inspired by https://github.com/dabeaz/dataklasses. 

702 

703 Args: 

704 func: The decorated function that takes a list of field names and returns a string of code. 

705 

706 Returns: 

707 A cached function that generates the desired function code, to be called with the number of fields. 

708 """ 

709 

710 def make_func_code(num_fields: int) -> FunctionType: 

711 exec(func([f"_{n}" for n in range(num_fields)]), {}, d := {}) 

712 return d.popitem()[1] 

713 

714 make_func_code.__wrapped__ = func 

715 return lru_cache(make_func_code) 

716 

717 

718@_codegen 

719def _make_structure__init__(fields: list[str]) -> str: 

720 """Generates an ``__init__`` method for a structure with the specified fields. 

721 

722 Args: 

723 fields: List of field names. 

724 """ 

725 field_args = ", ".join(f"{field} = None" for field in fields) 

726 field_init = "\n".join( 

727 f" self.{name} = {name} if {name} is not None else _{i}_default" for i, name in enumerate(fields) 

728 ) 

729 

730 code = f"def __init__(self{', ' + field_args or ''}):\n" 

731 return code + (field_init or " pass") 

732 

733 

734@_codegen 

735def _make_union__init__(fields: list[str]) -> str: 

736 """Generates an ``__init__`` method for a class with the specified fields using setattr. 

737 

738 Args: 

739 fields: List of field names. 

740 """ 

741 field_args = ", ".join(f"{field} = None" for field in fields) 

742 field_init = "\n".join( 

743 f" object.__setattr__(self, '{name}', {name} if {name} is not None else _{i}_default)" 

744 for i, name in enumerate(fields) 

745 ) 

746 

747 code = f"def __init__(self{', ' + field_args or ''}):\n" 

748 return code + (field_init or " pass") 

749 

750 

751@_codegen 

752def _make__eq__(fields: list[str]) -> str: 

753 """Generates an ``__eq__`` method for a class with the specified fields. 

754 

755 Args: 

756 fields: List of field names. 

757 """ 

758 self_vals = ",".join(f"self.{name}" for name in fields) 

759 other_vals = ",".join(f"other.{name}" for name in fields) 

760 

761 if self_vals: 

762 self_vals += "," 

763 if other_vals: 

764 other_vals += "," 

765 

766 # In the future this could be a looser check, e.g. an __eq__ on the classes, which compares the fields 

767 code = f""" 

768 def __eq__(self, other): 

769 if self.__class__ is other.__class__: 

770 return ({self_vals}) == ({other_vals}) 

771 return False 

772 """ 

773 

774 return dedent(code) 

775 

776 

777@_codegen 

778def _make__bool__(fields: list[str]) -> str: 

779 """Generates a ``__bool__`` method for a class with the specified fields. 

780 

781 Args: 

782 fields: List of field names. 

783 """ 

784 vals = ", ".join(f"self.{name}" for name in fields) 

785 

786 code = f""" 

787 def __bool__(self): 

788 return any([{vals}]) 

789 """ 

790 

791 return dedent(code) 

792 

793 

794@_codegen 

795def _make__hash__(fields: list[str]) -> str: 

796 """Generates a ``__hash__`` method for a class with the specified fields. 

797 

798 Args: 

799 fields: List of field names. 

800 """ 

801 vals = ", ".join(f"self.{name}" for name in fields) 

802 

803 code = f""" 

804 def __hash__(self): 

805 return hash(({vals})) 

806 """ 

807 

808 return dedent(code) 

809 

810 

811def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: 

812 """Patches a function's attributes. 

813 

814 Args: 

815 func: The function to patch. 

816 fields: List of field names to add. 

817 start: The starting index for patching. Defaults to 0. 

818 """ 

819 return type(func)( 

820 func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)), 

821 func.__globals__, 

822 ) 

823 

824 

825def _generate_structure__init__(fields: list[Field]) -> FunctionType: 

826 """Generates an ``__init__`` method for a structure with the specified fields. 

827 

828 Args: 

829 fields: List of field names. 

830 """ 

831 mapping = _generate_co_mapping(fields) 

832 

833 template: FunctionType = _make_structure__init__(len(fields)) 

834 return type(template)( 

835 template.__code__.replace( 

836 co_names=_remap_co_values(template.__code__.co_names, mapping), 

837 co_varnames=_remap_co_values(template.__code__.co_varnames, mapping), 

838 ), 

839 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields}, 

840 argdefs=template.__defaults__, 

841 ) 

842 

843 

844def _generate_union__init__(fields: list[Field]) -> FunctionType: 

845 """Generates an ``__init__`` method for a union with the specified fields. 

846 

847 Args: 

848 fields: List of field names. 

849 """ 

850 mapping = _generate_co_mapping(fields) 

851 

852 template: FunctionType = _make_union__init__(len(fields)) 

853 return type(template)( 

854 template.__code__.replace( 

855 co_consts=_remap_co_values(template.__code__.co_consts, mapping), 

856 co_names=_remap_co_values(template.__code__.co_names, mapping), 

857 co_varnames=_remap_co_values(template.__code__.co_varnames, mapping), 

858 ), 

859 template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields}, 

860 argdefs=template.__defaults__, 

861 ) 

862 

863 

864def _generate_co_mapping(fields: list[Field]) -> dict[str, str]: 

865 """Generates a mapping of generated code object names to field names. 

866 

867 The generated code uses names like ``_0``, ``_1``, etc. for fields, and ``_0_default``, ``_1_default``, etc. 

868 for default initializer values. Return a mapping of these names to the actual field names. 

869 

870 Args: 

871 fields: List of field names. 

872 """ 

873 return { 

874 key: value 

875 for i, field in enumerate(fields) 

876 for key, value in [(f"_{i}", field._name), (f"_{i}_default", f"__{field._name}_default__")] 

877 } 

878 

879 

880def _remap_co_values(value: tuple[Any, ...], mapping: dict[str, str]) -> tuple[Any, ...]: 

881 """Remaps code object values using a mapping. 

882 

883 This is used to replace generated code object names with actual field names. 

884 

885 Args: 

886 value: The original code object values. 

887 mapping: A mapping of generated code object names to field names. 

888 """ 

889 # Only attempt to remap if the value is a string, otherwise return it as is 

890 # This is to avoid issues with trying to remap non-hashable types, and we only need to replace strings anyway 

891 return tuple(mapping.get(v, v) if isinstance(v, str) else v for v in value) 

892 

893 

894def _generate__eq__(fields: list[str]) -> FunctionType: 

895 """Generates an ``__eq__`` method for a class with the specified fields. 

896 

897 Args: 

898 fields: List of field names. 

899 """ 

900 return _patch_attributes(_make__eq__(len(fields)), fields, 1) 

901 

902 

903def _generate__bool__(fields: list[str]) -> FunctionType: 

904 """Generates a ``__bool__`` method for a class with the specified fields. 

905 

906 Args: 

907 fields: List of field names. 

908 """ 

909 return _patch_attributes(_make__bool__(len(fields)), fields, 1) 

910 

911 

912def _generate__hash__(fields: list[str]) -> FunctionType: 

913 """Generates a ``__hash__`` method for a class with the specified fields. 

914 

915 Args: 

916 fields: List of field names. 

917 """ 

918 return _patch_attributes(_make__hash__(len(fields)), fields, 1)