Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/betterproto/__init__.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

918 statements  

1from __future__ import annotations 

2 

3import dataclasses 

4import enum as builtin_enum 

5import json 

6import math 

7import struct 

8import sys 

9import typing 

10import warnings 

11from abc import ABC 

12from base64 import ( 

13 b64decode, 

14 b64encode, 

15) 

16from copy import deepcopy 

17from datetime import ( 

18 datetime, 

19 timedelta, 

20 timezone, 

21) 

22from io import BytesIO 

23from itertools import count 

24from typing import ( 

25 TYPE_CHECKING, 

26 Any, 

27 Callable, 

28 ClassVar, 

29 Dict, 

30 Generator, 

31 Iterable, 

32 Mapping, 

33 Optional, 

34 Set, 

35 Tuple, 

36 Type, 

37 Union, 

38 get_type_hints, 

39) 

40 

41from dateutil.parser import isoparse 

42from typing_extensions import Self 

43 

44from ._types import T 

45from ._version import __version__ 

46from .casing import ( 

47 camel_case, 

48 safe_snake_case, 

49 snake_case, 

50) 

51from .enum import Enum as Enum 

52from .grpc.grpclib_client import ServiceStub as ServiceStub 

53from .utils import ( 

54 classproperty, 

55 hybridmethod, 

56) 

57 

58 

59if TYPE_CHECKING: 

60 from _typeshed import ( 

61 SupportsRead, 

62 SupportsWrite, 

63 ) 

64 

65if sys.version_info >= (3, 10): 

66 from types import UnionType as _types_UnionType 

67else: 

68 

69 class _types_UnionType: 

70 ... 

71 

72 

73# Proto 3 data types 

74TYPE_ENUM = "enum" 

75TYPE_BOOL = "bool" 

76TYPE_INT32 = "int32" 

77TYPE_INT64 = "int64" 

78TYPE_UINT32 = "uint32" 

79TYPE_UINT64 = "uint64" 

80TYPE_SINT32 = "sint32" 

81TYPE_SINT64 = "sint64" 

82TYPE_FLOAT = "float" 

83TYPE_DOUBLE = "double" 

84TYPE_FIXED32 = "fixed32" 

85TYPE_SFIXED32 = "sfixed32" 

86TYPE_FIXED64 = "fixed64" 

87TYPE_SFIXED64 = "sfixed64" 

88TYPE_STRING = "string" 

89TYPE_BYTES = "bytes" 

90TYPE_MESSAGE = "message" 

91TYPE_MAP = "map" 

92 

93# Fields that use a fixed amount of space (4 or 8 bytes) 

94FIXED_TYPES = [ 

95 TYPE_FLOAT, 

96 TYPE_DOUBLE, 

97 TYPE_FIXED32, 

98 TYPE_SFIXED32, 

99 TYPE_FIXED64, 

100 TYPE_SFIXED64, 

101] 

102 

103# Fields that are numerical 64-bit types 

104INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64] 

105 

106# Fields that are efficiently packed when 

107PACKED_TYPES = [ 

108 TYPE_ENUM, 

109 TYPE_BOOL, 

110 TYPE_INT32, 

111 TYPE_INT64, 

112 TYPE_UINT32, 

113 TYPE_UINT64, 

114 TYPE_SINT32, 

115 TYPE_SINT64, 

116 TYPE_FLOAT, 

117 TYPE_DOUBLE, 

118 TYPE_FIXED32, 

119 TYPE_SFIXED32, 

120 TYPE_FIXED64, 

121 TYPE_SFIXED64, 

122] 

123 

124# Wire types 

125# https://developers.google.com/protocol-buffers/docs/encoding#structure 

126WIRE_VARINT = 0 

127WIRE_FIXED_64 = 1 

128WIRE_LEN_DELIM = 2 

129WIRE_FIXED_32 = 5 

130 

131# Mappings of which Proto 3 types correspond to which wire types. 

132WIRE_VARINT_TYPES = [ 

133 TYPE_ENUM, 

134 TYPE_BOOL, 

135 TYPE_INT32, 

136 TYPE_INT64, 

137 TYPE_UINT32, 

138 TYPE_UINT64, 

139 TYPE_SINT32, 

140 TYPE_SINT64, 

141] 

142 

143WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] 

144WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] 

145WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] 

146 

147# Indicator of message delimitation in streams 

148SIZE_DELIMITED = -1 

149 

150 

151# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. 

152def datetime_default_gen() -> datetime: 

153 return datetime(1970, 1, 1, tzinfo=timezone.utc) 

154 

155 

156DATETIME_ZERO = datetime_default_gen() 

157 

158 

159# Special protobuf json doubles 

160INFINITY = "Infinity" 

161NEG_INFINITY = "-Infinity" 

162NAN = "NaN" 

163 

164 

165class Casing(builtin_enum.Enum): 

166 """Casing constants for serialization.""" 

167 

168 CAMEL = camel_case #: A camelCase sterilization function. 

169 SNAKE = snake_case #: A snake_case sterilization function. 

170 

171 

172PLACEHOLDER: Any = object() 

173 

174 

175@dataclasses.dataclass(frozen=True) 

176class FieldMetadata: 

177 """Stores internal metadata used for parsing & serialization.""" 

178 

179 # Protobuf field number 

180 number: int 

181 # Protobuf type name 

182 proto_type: str 

183 # Map information if the proto_type is a map 

184 map_types: Optional[Tuple[str, str]] = None 

185 # Groups several "one-of" fields together 

186 group: Optional[str] = None 

187 # Describes the wrapped type (e.g. when using google.protobuf.BoolValue) 

188 wraps: Optional[str] = None 

189 # Is the field optional 

190 optional: Optional[bool] = False 

191 

192 @staticmethod 

193 def get(field: dataclasses.Field) -> "FieldMetadata": 

194 """Returns the field metadata for a dataclass field.""" 

195 return field.metadata["betterproto"] 

196 

197 

198def dataclass_field( 

199 number: int, 

200 proto_type: str, 

201 *, 

202 map_types: Optional[Tuple[str, str]] = None, 

203 group: Optional[str] = None, 

204 wraps: Optional[str] = None, 

205 optional: bool = False, 

206) -> dataclasses.Field: 

207 """Creates a dataclass field with attached protobuf metadata.""" 

208 return dataclasses.field( 

209 default=None if optional else PLACEHOLDER, 

210 metadata={ 

211 "betterproto": FieldMetadata( 

212 number, proto_type, map_types, group, wraps, optional 

213 ) 

214 }, 

215 ) 

216 

217 

218# Note: the fields below return `Any` to prevent type errors in the generated 

219# data classes since the types won't match with `Field` and they get swapped 

220# out at runtime. The generated dataclass variables are still typed correctly. 

221 

222 

223def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: 

224 return dataclass_field(number, TYPE_ENUM, group=group, optional=optional) 

225 

226 

227def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: 

228 return dataclass_field(number, TYPE_BOOL, group=group, optional=optional) 

229 

230 

231def int32_field( 

232 number: int, group: Optional[str] = None, optional: bool = False 

233) -> Any: 

234 return dataclass_field(number, TYPE_INT32, group=group, optional=optional) 

235 

236 

237def int64_field( 

238 number: int, group: Optional[str] = None, optional: bool = False 

239) -> Any: 

240 return dataclass_field(number, TYPE_INT64, group=group, optional=optional) 

241 

242 

243def uint32_field( 

244 number: int, group: Optional[str] = None, optional: bool = False 

245) -> Any: 

246 return dataclass_field(number, TYPE_UINT32, group=group, optional=optional) 

247 

248 

249def uint64_field( 

250 number: int, group: Optional[str] = None, optional: bool = False 

251) -> Any: 

252 return dataclass_field(number, TYPE_UINT64, group=group, optional=optional) 

253 

254 

255def sint32_field( 

256 number: int, group: Optional[str] = None, optional: bool = False 

257) -> Any: 

258 return dataclass_field(number, TYPE_SINT32, group=group, optional=optional) 

259 

260 

261def sint64_field( 

262 number: int, group: Optional[str] = None, optional: bool = False 

263) -> Any: 

264 return dataclass_field(number, TYPE_SINT64, group=group, optional=optional) 

265 

266 

267def float_field( 

268 number: int, group: Optional[str] = None, optional: bool = False 

269) -> Any: 

270 return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional) 

271 

272 

273def double_field( 

274 number: int, group: Optional[str] = None, optional: bool = False 

275) -> Any: 

276 return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional) 

277 

278 

279def fixed32_field( 

280 number: int, group: Optional[str] = None, optional: bool = False 

281) -> Any: 

282 return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional) 

283 

284 

285def fixed64_field( 

286 number: int, group: Optional[str] = None, optional: bool = False 

287) -> Any: 

288 return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional) 

289 

290 

291def sfixed32_field( 

292 number: int, group: Optional[str] = None, optional: bool = False 

293) -> Any: 

294 return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional) 

295 

296 

297def sfixed64_field( 

298 number: int, group: Optional[str] = None, optional: bool = False 

299) -> Any: 

300 return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional) 

301 

302 

303def string_field( 

304 number: int, group: Optional[str] = None, optional: bool = False 

305) -> Any: 

306 return dataclass_field(number, TYPE_STRING, group=group, optional=optional) 

307 

308 

309def bytes_field( 

310 number: int, group: Optional[str] = None, optional: bool = False 

311) -> Any: 

312 return dataclass_field(number, TYPE_BYTES, group=group, optional=optional) 

313 

314 

315def message_field( 

316 number: int, 

317 group: Optional[str] = None, 

318 wraps: Optional[str] = None, 

319 optional: bool = False, 

320) -> Any: 

321 return dataclass_field( 

322 number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional 

323 ) 

324 

325 

326def map_field( 

327 number: int, key_type: str, value_type: str, group: Optional[str] = None 

328) -> Any: 

329 return dataclass_field( 

330 number, TYPE_MAP, map_types=(key_type, value_type), group=group 

331 ) 

332 

333 

334def _pack_fmt(proto_type: str) -> str: 

335 """Returns a little-endian format string for reading/writing binary.""" 

336 return { 

337 TYPE_DOUBLE: "<d", 

338 TYPE_FLOAT: "<f", 

339 TYPE_FIXED32: "<I", 

340 TYPE_FIXED64: "<Q", 

341 TYPE_SFIXED32: "<i", 

342 TYPE_SFIXED64: "<q", 

343 }[proto_type] 

344 

345 

346def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None: 

347 """Encodes a single varint and dumps it into the provided stream.""" 

348 if value < -(1 << 63): 

349 raise ValueError( 

350 "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes." 

351 ) 

352 elif value < 0: 

353 value += 1 << 64 

354 

355 bits = value & 0x7F 

356 value >>= 7 

357 while value: 

358 stream.write((0x80 | bits).to_bytes(1, "little")) 

359 bits = value & 0x7F 

360 value >>= 7 

361 stream.write(bits.to_bytes(1, "little")) 

362 

363 

364def encode_varint(value: int) -> bytes: 

365 """Encodes a single varint value for serialization.""" 

366 with BytesIO() as stream: 

367 dump_varint(value, stream) 

368 return stream.getvalue() 

369 

370 

371def size_varint(value: int) -> int: 

372 """Calculates the size in bytes that a value would take as a varint.""" 

373 if value < -(1 << 63): 

374 raise ValueError( 

375 "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes." 

376 ) 

377 elif value < 0: 

378 return 10 

379 elif value == 0: 

380 return 1 

381 else: 

382 return math.ceil(value.bit_length() / 7) 

383 

384 

385def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: 

386 """Adjusts values before serialization.""" 

387 if proto_type in ( 

388 TYPE_ENUM, 

389 TYPE_BOOL, 

390 TYPE_INT32, 

391 TYPE_INT64, 

392 TYPE_UINT32, 

393 TYPE_UINT64, 

394 ): 

395 return encode_varint(value) 

396 elif proto_type in (TYPE_SINT32, TYPE_SINT64): 

397 # Handle zig-zag encoding. 

398 return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) 

399 elif proto_type in FIXED_TYPES: 

400 return struct.pack(_pack_fmt(proto_type), value) 

401 elif proto_type == TYPE_STRING: 

402 return value.encode("utf-8") 

403 elif proto_type == TYPE_MESSAGE: 

404 if isinstance(value, datetime): 

405 # Convert the `datetime` to a timestamp message. 

406 value = _Timestamp.from_datetime(value) 

407 elif isinstance(value, timedelta): 

408 # Convert the `timedelta` to a duration message. 

409 value = _Duration.from_timedelta(value) 

410 elif wraps: 

411 if value is None: 

412 return b"" 

413 value = _get_wrapper(wraps)(value=value) 

414 

415 return bytes(value) 

416 

417 return value 

418 

419 

420def _len_preprocessed_single(proto_type: str, wraps: str, value: Any) -> int: 

421 """Calculate the size of adjusted values for serialization without fully serializing them.""" 

422 if proto_type in ( 

423 TYPE_ENUM, 

424 TYPE_BOOL, 

425 TYPE_INT32, 

426 TYPE_INT64, 

427 TYPE_UINT32, 

428 TYPE_UINT64, 

429 ): 

430 return size_varint(value) 

431 elif proto_type in (TYPE_SINT32, TYPE_SINT64): 

432 # Handle zig-zag encoding. 

433 return size_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) 

434 elif proto_type in FIXED_TYPES: 

435 return len(struct.pack(_pack_fmt(proto_type), value)) 

436 elif proto_type == TYPE_STRING: 

437 return len(value.encode("utf-8")) 

438 elif proto_type == TYPE_MESSAGE: 

439 if isinstance(value, datetime): 

440 # Convert the `datetime` to a timestamp message. 

441 value = _Timestamp.from_datetime(value) 

442 elif isinstance(value, timedelta): 

443 # Convert the `timedelta` to a duration message. 

444 value = _Duration.from_timedelta(value) 

445 elif wraps: 

446 if value is None: 

447 return 0 

448 value = _get_wrapper(wraps)(value=value) 

449 

450 return len(bytes(value)) 

451 

452 return len(value) 

453 

454 

455def _serialize_single( 

456 field_number: int, 

457 proto_type: str, 

458 value: Any, 

459 *, 

460 serialize_empty: bool = False, 

461 wraps: str = "", 

462) -> bytes: 

463 """Serializes a single field and value.""" 

464 value = _preprocess_single(proto_type, wraps, value) 

465 

466 output = bytearray() 

467 if proto_type in WIRE_VARINT_TYPES: 

468 key = encode_varint(field_number << 3) 

469 output += key + value 

470 elif proto_type in WIRE_FIXED_32_TYPES: 

471 key = encode_varint((field_number << 3) | 5) 

472 output += key + value 

473 elif proto_type in WIRE_FIXED_64_TYPES: 

474 key = encode_varint((field_number << 3) | 1) 

475 output += key + value 

476 elif proto_type in WIRE_LEN_DELIM_TYPES: 

477 if len(value) or serialize_empty or wraps: 

478 key = encode_varint((field_number << 3) | 2) 

479 output += key + encode_varint(len(value)) + value 

480 else: 

481 raise NotImplementedError(proto_type) 

482 

483 return bytes(output) 

484 

485 

486def _len_single( 

487 field_number: int, 

488 proto_type: str, 

489 value: Any, 

490 *, 

491 serialize_empty: bool = False, 

492 wraps: str = "", 

493) -> int: 

494 """Calculates the size of a serialized single field and value.""" 

495 size = _len_preprocessed_single(proto_type, wraps, value) 

496 if proto_type in WIRE_VARINT_TYPES: 

497 size += size_varint(field_number << 3) 

498 elif proto_type in WIRE_FIXED_32_TYPES: 

499 size += size_varint((field_number << 3) | 5) 

500 elif proto_type in WIRE_FIXED_64_TYPES: 

501 size += size_varint((field_number << 3) | 1) 

502 elif proto_type in WIRE_LEN_DELIM_TYPES: 

503 if size or serialize_empty or wraps: 

504 size += size_varint((field_number << 3) | 2) + size_varint(size) 

505 else: 

506 raise NotImplementedError(proto_type) 

507 

508 return size 

509 

510 

511def _parse_float(value: Any) -> float: 

512 """Parse the given value to a float 

513 

514 Parameters 

515 ---------- 

516 value: Any 

517 Value to parse 

518 

519 Returns 

520 ------- 

521 float 

522 Parsed value 

523 """ 

524 if value == INFINITY: 

525 return float("inf") 

526 if value == NEG_INFINITY: 

527 return -float("inf") 

528 if value == NAN: 

529 return float("nan") 

530 return float(value) 

531 

532 

533def _dump_float(value: float) -> Union[float, str]: 

534 """Dump the given float to JSON 

535 

536 Parameters 

537 ---------- 

538 value: float 

539 Value to dump 

540 

541 Returns 

542 ------- 

543 Union[float, str] 

544 Dumped value, either a float or the strings 

545 """ 

546 if value == float("inf"): 

547 return INFINITY 

548 if value == -float("inf"): 

549 return NEG_INFINITY 

550 if isinstance(value, float) and math.isnan(value): 

551 return NAN 

552 return value 

553 

554 

555def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]: 

556 """ 

557 Load a single varint value from a stream. Returns the value and the raw bytes read. 

558 """ 

559 result = 0 

560 raw = b"" 

561 for shift in count(0, 7): 

562 if shift >= 64: 

563 raise ValueError("Too many bytes when decoding varint.") 

564 b = stream.read(1) 

565 if not b: 

566 raise EOFError("Stream ended unexpectedly while attempting to load varint.") 

567 raw += b 

568 b_int = int.from_bytes(b, byteorder="little") 

569 result |= (b_int & 0x7F) << shift 

570 if not (b_int & 0x80): 

571 return result, raw 

572 

573 

574def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: 

575 """ 

576 Decode a single varint value from a byte buffer. Returns the value and the 

577 new position in the buffer. 

578 """ 

579 with BytesIO(buffer) as stream: 

580 stream.seek(pos) 

581 value, raw = load_varint(stream) 

582 return value, pos + len(raw) 

583 

584 

585@dataclasses.dataclass(frozen=True) 

586class ParsedField: 

587 number: int 

588 wire_type: int 

589 value: Any 

590 raw: bytes 

591 

592 

593def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]: 

594 while True: 

595 try: 

596 num_wire, raw = load_varint(stream) 

597 except EOFError: 

598 return 

599 number = num_wire >> 3 

600 wire_type = num_wire & 0x7 

601 

602 decoded: Any = None 

603 if wire_type == WIRE_VARINT: 

604 decoded, r = load_varint(stream) 

605 raw += r 

606 elif wire_type == WIRE_FIXED_64: 

607 decoded = stream.read(8) 

608 raw += decoded 

609 elif wire_type == WIRE_LEN_DELIM: 

610 length, r = load_varint(stream) 

611 decoded = stream.read(length) 

612 raw += r 

613 raw += decoded 

614 elif wire_type == WIRE_FIXED_32: 

615 decoded = stream.read(4) 

616 raw += decoded 

617 

618 yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=raw) 

619 

620 

621def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: 

622 i = 0 

623 while i < len(value): 

624 start = i 

625 num_wire, i = decode_varint(value, i) 

626 number = num_wire >> 3 

627 wire_type = num_wire & 0x7 

628 

629 decoded: Any = None 

630 if wire_type == WIRE_VARINT: 

631 decoded, i = decode_varint(value, i) 

632 elif wire_type == WIRE_FIXED_64: 

633 decoded, i = value[i : i + 8], i + 8 

634 elif wire_type == WIRE_LEN_DELIM: 

635 length, i = decode_varint(value, i) 

636 decoded = value[i : i + length] 

637 i += length 

638 elif wire_type == WIRE_FIXED_32: 

639 decoded, i = value[i : i + 4], i + 4 

640 

641 yield ParsedField( 

642 number=number, wire_type=wire_type, value=decoded, raw=value[start:i] 

643 ) 

644 

645 

646class ProtoClassMetadata: 

647 __slots__ = ( 

648 "oneof_group_by_field", 

649 "oneof_field_by_group", 

650 "default_gen", 

651 "cls_by_field", 

652 "field_name_by_number", 

653 "meta_by_field_name", 

654 "sorted_field_names", 

655 ) 

656 

657 oneof_group_by_field: Dict[str, str] 

658 oneof_field_by_group: Dict[str, Set[dataclasses.Field]] 

659 field_name_by_number: Dict[int, str] 

660 meta_by_field_name: Dict[str, FieldMetadata] 

661 sorted_field_names: Tuple[str, ...] 

662 default_gen: Dict[str, Callable[[], Any]] 

663 cls_by_field: Dict[str, Type] 

664 

665 def __init__(self, cls: Type["Message"]): 

666 by_field = {} 

667 by_group: Dict[str, Set] = {} 

668 by_field_name = {} 

669 by_field_number = {} 

670 

671 fields = dataclasses.fields(cls) 

672 for field in fields: 

673 meta = FieldMetadata.get(field) 

674 

675 if meta.group: 

676 # This is part of a one-of group. 

677 by_field[field.name] = meta.group 

678 

679 by_group.setdefault(meta.group, set()).add(field) 

680 

681 by_field_name[field.name] = meta 

682 by_field_number[meta.number] = field.name 

683 

684 self.oneof_group_by_field = by_field 

685 self.oneof_field_by_group = by_group 

686 self.field_name_by_number = by_field_number 

687 self.meta_by_field_name = by_field_name 

688 self.sorted_field_names = tuple( 

689 by_field_number[number] for number in sorted(by_field_number) 

690 ) 

691 self.default_gen = self._get_default_gen(cls, fields) 

692 self.cls_by_field = self._get_cls_by_field(cls, fields) 

693 

694 @staticmethod 

695 def _get_default_gen( 

696 cls: Type["Message"], fields: Iterable[dataclasses.Field] 

697 ) -> Dict[str, Callable[[], Any]]: 

698 return {field.name: cls._get_field_default_gen(field) for field in fields} 

699 

700 @staticmethod 

701 def _get_cls_by_field( 

702 cls: Type["Message"], fields: Iterable[dataclasses.Field] 

703 ) -> Dict[str, Type]: 

704 field_cls = {} 

705 

706 for field in fields: 

707 meta = FieldMetadata.get(field) 

708 if meta.proto_type == TYPE_MAP: 

709 assert meta.map_types 

710 kt = cls._cls_for(field, index=0) 

711 vt = cls._cls_for(field, index=1) 

712 field_cls[field.name] = dataclasses.make_dataclass( 

713 "Entry", 

714 [ 

715 ("key", kt, dataclass_field(1, meta.map_types[0])), 

716 ("value", vt, dataclass_field(2, meta.map_types[1])), 

717 ], 

718 bases=(Message,), 

719 ) 

720 field_cls[f"{field.name}.value"] = vt 

721 else: 

722 field_cls[field.name] = cls._cls_for(field) 

723 

724 return field_cls 

725 

726 

727class Message(ABC): 

728 """ 

729 The base class for protobuf messages, all generated messages will inherit from 

730 this. This class registers the message fields which are used by the serializers and 

731 parsers to go between the Python, binary and JSON representations of the message. 

732 

733 .. container:: operations 

734 

735 .. describe:: bytes(x) 

736 

737 Calls :meth:`__bytes__`. 

738 

739 .. describe:: bool(x) 

740 

741 Calls :meth:`__bool__`. 

742 """ 

743 

744 _serialized_on_wire: bool 

745 _unknown_fields: bytes 

746 _group_current: Dict[str, str] 

747 _betterproto_meta: ClassVar[ProtoClassMetadata] 

748 

749 def __post_init__(self) -> None: 

750 # Keep track of whether every field was default 

751 all_sentinel = True 

752 

753 # Set current field of each group after `__init__` has already been run. 

754 group_current: Dict[str, Optional[str]] = {} 

755 for field_name, meta in self._betterproto.meta_by_field_name.items(): 

756 if meta.group: 

757 group_current.setdefault(meta.group) 

758 

759 value = self.__raw_get(field_name) 

760 if value is not PLACEHOLDER and not (meta.optional and value is None): 

761 # Found a non-sentinel value 

762 all_sentinel = False 

763 

764 if meta.group: 

765 # This was set, so make it the selected value of the one-of. 

766 group_current[meta.group] = field_name 

767 

768 # Now that all the defaults are set, reset it! 

769 self.__dict__["_serialized_on_wire"] = not all_sentinel 

770 self.__dict__["_unknown_fields"] = b"" 

771 self.__dict__["_group_current"] = group_current 

772 

773 def __raw_get(self, name: str) -> Any: 

774 return super().__getattribute__(name) 

775 

776 def __eq__(self, other) -> bool: 

777 if type(self) is not type(other): 

778 return NotImplemented 

779 

780 for field_name in self._betterproto.meta_by_field_name: 

781 self_val = self.__raw_get(field_name) 

782 other_val = other.__raw_get(field_name) 

783 if self_val is PLACEHOLDER: 

784 if other_val is PLACEHOLDER: 

785 continue 

786 self_val = self._get_field_default(field_name) 

787 elif other_val is PLACEHOLDER: 

788 other_val = other._get_field_default(field_name) 

789 

790 if self_val != other_val: 

791 # We consider two nan values to be the same for the 

792 # purposes of comparing messages (otherwise a message 

793 # is not equal to itself) 

794 if ( 

795 isinstance(self_val, float) 

796 and isinstance(other_val, float) 

797 and math.isnan(self_val) 

798 and math.isnan(other_val) 

799 ): 

800 continue 

801 else: 

802 return False 

803 

804 return True 

805 

806 def __repr__(self) -> str: 

807 parts = [ 

808 f"{field_name}={value!r}" 

809 for field_name in self._betterproto.sorted_field_names 

810 for value in (self.__raw_get(field_name),) 

811 if value is not PLACEHOLDER 

812 ] 

813 return f"{self.__class__.__name__}({', '.join(parts)})" 

814 

815 def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]: 

816 for field_name in self._betterproto.sorted_field_names: 

817 yield field_name, self.__raw_get(field_name), PLACEHOLDER 

818 

819 if not TYPE_CHECKING: 

820 

821 def __getattribute__(self, name: str) -> Any: 

822 """ 

823 Lazily initialize default values to avoid infinite recursion for recursive 

824 message types. 

825 Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields. 

826 """ 

827 try: 

828 group_current = super().__getattribute__("_group_current") 

829 except AttributeError: 

830 pass 

831 else: 

832 if name not in {"__class__", "_betterproto"}: 

833 group = self._betterproto.oneof_group_by_field.get(name) 

834 if group is not None and group_current[group] != name: 

835 if sys.version_info < (3, 10): 

836 raise AttributeError( 

837 f"{group!r} is set to {group_current[group]!r}, not {name!r}" 

838 ) 

839 else: 

840 raise AttributeError( 

841 f"{group!r} is set to {group_current[group]!r}, not {name!r}", 

842 name=name, 

843 obj=self, 

844 ) 

845 

846 value = super().__getattribute__(name) 

847 if value is not PLACEHOLDER: 

848 return value 

849 

850 value = self._get_field_default(name) 

851 super().__setattr__(name, value) 

852 return value 

853 

854 def __setattr__(self, attr: str, value: Any) -> None: 

855 if ( 

856 isinstance(value, Message) 

857 and hasattr(value, "_betterproto") 

858 and not value._betterproto.meta_by_field_name 

859 ): 

860 value._serialized_on_wire = True 

861 

862 if attr != "_serialized_on_wire": 

863 # Track when a field has been set. 

864 self.__dict__["_serialized_on_wire"] = True 

865 

866 if hasattr(self, "_group_current"): # __post_init__ had already run 

867 if attr in self._betterproto.oneof_group_by_field: 

868 group = self._betterproto.oneof_group_by_field[attr] 

869 for field in self._betterproto.oneof_field_by_group[group]: 

870 if field.name == attr: 

871 self._group_current[group] = field.name 

872 else: 

873 super().__setattr__(field.name, PLACEHOLDER) 

874 

875 super().__setattr__(attr, value) 

876 

877 def __bool__(self) -> bool: 

878 """True if the Message has any fields with non-default values.""" 

879 return any( 

880 self.__raw_get(field_name) 

881 not in (PLACEHOLDER, self._get_field_default(field_name)) 

882 for field_name in self._betterproto.meta_by_field_name 

883 ) 

884 

885 def __deepcopy__(self: T, _: Any = {}) -> T: 

886 kwargs = {} 

887 for name in self._betterproto.sorted_field_names: 

888 value = self.__raw_get(name) 

889 if value is not PLACEHOLDER: 

890 kwargs[name] = deepcopy(value) 

891 return self.__class__(**kwargs) # type: ignore 

892 

893 def __copy__(self: T, _: Any = {}) -> T: 

894 kwargs = {} 

895 for name in self._betterproto.sorted_field_names: 

896 value = self.__raw_get(name) 

897 if value is not PLACEHOLDER: 

898 kwargs[name] = value 

899 return self.__class__(**kwargs) # type: ignore 

900 

901 @classproperty 

902 def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore 

903 """ 

904 Lazy initialize metadata for each protobuf class. 

905 It may be initialized multiple times in a multi-threaded environment, 

906 but that won't affect the correctness. 

907 """ 

908 try: 

909 return cls._betterproto_meta 

910 except AttributeError: 

911 cls._betterproto_meta = meta = ProtoClassMetadata(cls) 

912 return meta 

913 

914 def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: 

915 """ 

916 Dumps the binary encoded Protobuf message to the stream. 

917 

918 Parameters 

919 ----------- 

920 stream: :class:`BinaryIO` 

921 The stream to dump the message to. 

922 delimit: 

923 Whether to prefix the message with a varint declaring its size. 

924 """ 

925 if delimit == SIZE_DELIMITED: 

926 dump_varint(len(self), stream) 

927 

928 for field_name, meta in self._betterproto.meta_by_field_name.items(): 

929 try: 

930 value = getattr(self, field_name) 

931 except AttributeError: 

932 continue 

933 

934 if value is None: 

935 # Optional items should be skipped. This is used for the Google 

936 # wrapper types and proto3 field presence/optional fields. 

937 continue 

938 

939 # Being selected in a a group means this field is the one that is 

940 # currently set in a `oneof` group, so it must be serialized even 

941 # if the value is the default zero value. 

942 # 

943 # Note that proto3 field presence/optional fields are put in a 

944 # synthetic single-item oneof by protoc, which helps us ensure we 

945 # send the value even if the value is the default zero value. 

946 selected_in_group = bool(meta.group) or meta.optional 

947 

948 # Empty messages can still be sent on the wire if they were 

949 # set (or received empty). 

950 serialize_empty = isinstance(value, Message) and value._serialized_on_wire 

951 

952 include_default_value_for_oneof = self._include_default_value_for_oneof( 

953 field_name=field_name, meta=meta 

954 ) 

955 

956 if value == self._get_field_default(field_name) and not ( 

957 selected_in_group or serialize_empty or include_default_value_for_oneof 

958 ): 

959 # Default (zero) values are not serialized. Two exceptions are 

960 # if this is the selected oneof item or if we know we have to 

961 # serialize an empty message (i.e. zero value was explicitly 

962 # set by the user). 

963 continue 

964 

965 if isinstance(value, list): 

966 if meta.proto_type in PACKED_TYPES: 

967 # Packed lists look like a length-delimited field. First, 

968 # preprocess/encode each value into a buffer and then 

969 # treat it like a field of raw bytes. 

970 buf = bytearray() 

971 for item in value: 

972 buf += _preprocess_single(meta.proto_type, "", item) 

973 stream.write(_serialize_single(meta.number, TYPE_BYTES, buf)) 

974 else: 

975 for item in value: 

976 stream.write( 

977 _serialize_single( 

978 meta.number, 

979 meta.proto_type, 

980 item, 

981 wraps=meta.wraps or "", 

982 serialize_empty=True, 

983 ) 

984 # if it's an empty message it still needs to be represented 

985 # as an item in the repeated list 

986 or b"\n\x00" 

987 ) 

988 

989 elif isinstance(value, dict): 

990 for k, v in value.items(): 

991 assert meta.map_types 

992 sk = _serialize_single(1, meta.map_types[0], k) 

993 sv = _serialize_single(2, meta.map_types[1], v) 

994 stream.write( 

995 _serialize_single(meta.number, meta.proto_type, sk + sv) 

996 ) 

997 else: 

998 # If we have an empty string and we're including the default value for 

999 # a oneof, make sure we serialize it. This ensures that the byte string 

1000 # output isn't simply an empty string. This also ensures that round trip 

1001 # serialization will keep `which_one_of` calls consistent. 

1002 if ( 

1003 isinstance(value, str) 

1004 and value == "" 

1005 and include_default_value_for_oneof 

1006 ): 

1007 serialize_empty = True 

1008 

1009 stream.write( 

1010 _serialize_single( 

1011 meta.number, 

1012 meta.proto_type, 

1013 value, 

1014 serialize_empty=serialize_empty or bool(selected_in_group), 

1015 wraps=meta.wraps or "", 

1016 ) 

1017 ) 

1018 

1019 stream.write(self._unknown_fields) 

1020 

1021 def __bytes__(self) -> bytes: 

1022 """ 

1023 Get the binary encoded Protobuf representation of this message instance. 

1024 """ 

1025 with BytesIO() as stream: 

1026 self.dump(stream) 

1027 return stream.getvalue() 

1028 

1029 def __len__(self) -> int: 

1030 """ 

1031 Get the size of the encoded Protobuf representation of this message instance. 

1032 """ 

1033 size = 0 

1034 for field_name, meta in self._betterproto.meta_by_field_name.items(): 

1035 try: 

1036 value = getattr(self, field_name) 

1037 except AttributeError: 

1038 continue 

1039 

1040 if value is None: 

1041 # Optional items should be skipped. This is used for the Google 

1042 # wrapper types and proto3 field presence/optional fields. 

1043 continue 

1044 

1045 # Being selected in a group means this field is the one that is 

1046 # currently set in a `oneof` group, so it must be serialized even 

1047 # if the value is the default zero value. 

1048 # 

1049 # Note that proto3 field presence/optional fields are put in a 

1050 # synthetic single-item oneof by protoc, which helps us ensure we 

1051 # send the value even if the value is the default zero value. 

1052 selected_in_group = bool(meta.group) 

1053 

1054 # Empty messages can still be sent on the wire if they were 

1055 # set (or received empty). 

1056 serialize_empty = isinstance(value, Message) and value._serialized_on_wire 

1057 

1058 include_default_value_for_oneof = self._include_default_value_for_oneof( 

1059 field_name=field_name, meta=meta 

1060 ) 

1061 

1062 if value == self._get_field_default(field_name) and not ( 

1063 selected_in_group or serialize_empty or include_default_value_for_oneof 

1064 ): 

1065 # Default (zero) values are not serialized. Two exceptions are 

1066 # if this is the selected oneof item or if we know we have to 

1067 # serialize an empty message (i.e. zero value was explicitly 

1068 # set by the user). 

1069 continue 

1070 

1071 if isinstance(value, list): 

1072 if meta.proto_type in PACKED_TYPES: 

1073 # Packed lists look like a length-delimited field. First, 

1074 # preprocess/encode each value into a buffer and then 

1075 # treat it like a field of raw bytes. 

1076 buf = bytearray() 

1077 for item in value: 

1078 buf += _preprocess_single(meta.proto_type, "", item) 

1079 size += _len_single(meta.number, TYPE_BYTES, buf) 

1080 else: 

1081 for item in value: 

1082 size += ( 

1083 _len_single( 

1084 meta.number, 

1085 meta.proto_type, 

1086 item, 

1087 wraps=meta.wraps or "", 

1088 serialize_empty=True, 

1089 ) 

1090 # if it's an empty message it still needs to be represented 

1091 # as an item in the repeated list 

1092 or 2 

1093 ) 

1094 

1095 elif isinstance(value, dict): 

1096 for k, v in value.items(): 

1097 assert meta.map_types 

1098 sk = _serialize_single(1, meta.map_types[0], k) 

1099 sv = _serialize_single(2, meta.map_types[1], v) 

1100 size += _len_single(meta.number, meta.proto_type, sk + sv) 

1101 else: 

1102 # If we have an empty string and we're including the default value for 

1103 # a oneof, make sure we serialize it. This ensures that the byte string 

1104 # output isn't simply an empty string. This also ensures that round trip 

1105 # serialization will keep `which_one_of` calls consistent. 

1106 if ( 

1107 isinstance(value, str) 

1108 and value == "" 

1109 and include_default_value_for_oneof 

1110 ): 

1111 serialize_empty = True 

1112 

1113 size += _len_single( 

1114 meta.number, 

1115 meta.proto_type, 

1116 value, 

1117 serialize_empty=serialize_empty or bool(selected_in_group), 

1118 wraps=meta.wraps or "", 

1119 ) 

1120 

1121 size += len(self._unknown_fields) 

1122 return size 

1123 

1124 # For compatibility with other libraries 

1125 def SerializeToString(self: T) -> bytes: 

1126 """ 

1127 Get the binary encoded Protobuf representation of this message instance. 

1128 

1129 .. note:: 

1130 This is a method for compatibility with other libraries, 

1131 you should really use ``bytes(x)``. 

1132 

1133 Returns 

1134 -------- 

1135 :class:`bytes` 

1136 The binary encoded Protobuf representation of this message instance 

1137 """ 

1138 return bytes(self) 

1139 

1140 def __getstate__(self) -> bytes: 

1141 return bytes(self) 

1142 

1143 def __setstate__(self: T, pickled_bytes: bytes) -> T: 

1144 return self.parse(pickled_bytes) 

1145 

1146 def __reduce__(self) -> Tuple[Any, ...]: 

1147 return (self.__class__.FromString, (bytes(self),)) 

1148 

1149 @classmethod 

1150 def _type_hint(cls, field_name: str) -> Type: 

1151 return cls._type_hints()[field_name] 

1152 

1153 @classmethod 

1154 def _type_hints(cls) -> Dict[str, Type]: 

1155 module = sys.modules[cls.__module__] 

1156 return get_type_hints(cls, module.__dict__, {}) 

1157 

1158 @classmethod 

1159 def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: 

1160 """Get the message class for a field from the type hints.""" 

1161 field_cls = cls._type_hint(field.name) 

1162 if hasattr(field_cls, "__args__") and index >= 0: 

1163 if field_cls.__args__ is not None: 

1164 field_cls = field_cls.__args__[index] 

1165 return field_cls 

1166 

1167 def _get_field_default(self, field_name: str) -> Any: 

1168 with warnings.catch_warnings(): 

1169 # ignore warnings when initialising deprecated field defaults 

1170 warnings.filterwarnings("ignore", category=DeprecationWarning) 

1171 return self._betterproto.default_gen[field_name]() 

1172 

1173 @classmethod 

1174 def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: 

1175 t = cls._type_hint(field.name) 

1176 

1177 is_310_union = isinstance(t, _types_UnionType) 

1178 if hasattr(t, "__origin__") or is_310_union: 

1179 if is_310_union or t.__origin__ is Union: 

1180 # This is an optional field (either wrapped, or using proto3 

1181 # field presence). For setting the default we really don't care 

1182 # what kind of field it is. 

1183 return type(None) 

1184 if t.__origin__ is list: 

1185 # This is some kind of list (repeated) field. 

1186 return list 

1187 if t.__origin__ is dict: 

1188 # This is some kind of map (dict in Python). 

1189 return dict 

1190 return t 

1191 if issubclass(t, Enum): 

1192 # Enums always default to zero. 

1193 return t.try_value 

1194 if t is datetime: 

1195 # Offsets are relative to 1970-01-01T00:00:00Z 

1196 return datetime_default_gen 

1197 # This is either a primitive scalar or another message type. Calling 

1198 # it should result in its zero value. 

1199 return t 

1200 

1201 def _postprocess_single( 

1202 self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any 

1203 ) -> Any: 

1204 """Adjusts values after parsing.""" 

1205 if wire_type == WIRE_VARINT: 

1206 if meta.proto_type in (TYPE_INT32, TYPE_INT64): 

1207 bits = int(meta.proto_type[3:]) 

1208 value = value & ((1 << bits) - 1) 

1209 signbit = 1 << (bits - 1) 

1210 value = int((value ^ signbit) - signbit) 

1211 elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64): 

1212 # Undo zig-zag encoding 

1213 value = (value >> 1) ^ (-(value & 1)) 

1214 elif meta.proto_type == TYPE_BOOL: 

1215 # Booleans use a varint encoding, so convert it to true/false. 

1216 value = value > 0 

1217 elif meta.proto_type == TYPE_ENUM: 

1218 # Convert enum ints to python enum instances 

1219 value = self._betterproto.cls_by_field[field_name].try_value(value) 

1220 elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): 

1221 fmt = _pack_fmt(meta.proto_type) 

1222 value = struct.unpack(fmt, value)[0] 

1223 elif wire_type == WIRE_LEN_DELIM: 

1224 if meta.proto_type == TYPE_STRING: 

1225 value = str(value, "utf-8") 

1226 elif meta.proto_type == TYPE_MESSAGE: 

1227 cls = self._betterproto.cls_by_field[field_name] 

1228 

1229 if cls == datetime: 

1230 value = _Timestamp().parse(value).to_datetime() 

1231 elif cls == timedelta: 

1232 value = _Duration().parse(value).to_timedelta() 

1233 elif meta.wraps: 

1234 # This is a Google wrapper value message around a single 

1235 # scalar type. 

1236 value = _get_wrapper(meta.wraps)().parse(value).value 

1237 else: 

1238 value = cls().parse(value) 

1239 value._serialized_on_wire = True 

1240 elif meta.proto_type == TYPE_MAP: 

1241 value = self._betterproto.cls_by_field[field_name]().parse(value) 

1242 

1243 return value 

1244 

1245 def _include_default_value_for_oneof( 

1246 self, field_name: str, meta: FieldMetadata 

1247 ) -> bool: 

1248 return ( 

1249 meta.group is not None and self._group_current.get(meta.group) == field_name 

1250 ) 

1251 

1252 def load( 

1253 self: T, 

1254 stream: "SupportsRead[bytes]", 

1255 size: Optional[int] = None, 

1256 ) -> T: 

1257 """ 

1258 Load the binary encoded Protobuf from a stream into this message instance. This 

1259 returns the instance itself and is therefore assignable and chainable. 

1260 

1261 Parameters 

1262 ----------- 

1263 stream: :class:`bytes` 

1264 The stream to load the message from. 

1265 size: :class:`Optional[int]` 

1266 The size of the message in the stream. 

1267 Reads stream until EOF if ``None`` is given. 

1268 Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given. 

1269 

1270 Returns 

1271 -------- 

1272 :class:`Message` 

1273 The initialized message. 

1274 """ 

1275 # If the message is delimited, parse the message delimiter 

1276 if size == SIZE_DELIMITED: 

1277 size, _ = load_varint(stream) 

1278 

1279 # Got some data over the wire 

1280 self._serialized_on_wire = True 

1281 proto_meta = self._betterproto 

1282 read = 0 

1283 for parsed in load_fields(stream): 

1284 field_name = proto_meta.field_name_by_number.get(parsed.number) 

1285 if not field_name: 

1286 self._unknown_fields += parsed.raw 

1287 continue 

1288 

1289 meta = proto_meta.meta_by_field_name[field_name] 

1290 

1291 value: Any 

1292 if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: 

1293 # This is a packed repeated field. 

1294 pos = 0 

1295 value = [] 

1296 while pos < len(parsed.value): 

1297 if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): 

1298 decoded, pos = parsed.value[pos : pos + 4], pos + 4 

1299 wire_type = WIRE_FIXED_32 

1300 elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): 

1301 decoded, pos = parsed.value[pos : pos + 8], pos + 8 

1302 wire_type = WIRE_FIXED_64 

1303 else: 

1304 decoded, pos = decode_varint(parsed.value, pos) 

1305 wire_type = WIRE_VARINT 

1306 decoded = self._postprocess_single( 

1307 wire_type, meta, field_name, decoded 

1308 ) 

1309 value.append(decoded) 

1310 else: 

1311 value = self._postprocess_single( 

1312 parsed.wire_type, meta, field_name, parsed.value 

1313 ) 

1314 

1315 try: 

1316 current = getattr(self, field_name) 

1317 except AttributeError: 

1318 current = self._get_field_default(field_name) 

1319 setattr(self, field_name, current) 

1320 

1321 if meta.proto_type == TYPE_MAP: 

1322 # Value represents a single key/value pair entry in the map. 

1323 current[value.key] = value.value 

1324 elif isinstance(current, list) and not isinstance(value, list): 

1325 current.append(value) 

1326 else: 

1327 setattr(self, field_name, value) 

1328 

1329 # If we have now loaded the expected length of the message, stop 

1330 if size is not None: 

1331 prev = read 

1332 read += len(parsed.raw) 

1333 if read == size: 

1334 break 

1335 elif read > size: 

1336 raise ValueError( 

1337 f"Expected message of size {size}, can only read " 

1338 f"either {prev} or {read} bytes - there is no " 

1339 "message of the expected size in the stream." 

1340 ) 

1341 

1342 if size is not None and read < size: 

1343 raise ValueError( 

1344 f"Expected message of size {size}, but was only able to " 

1345 f"read {read} bytes - the stream may have ended too soon," 

1346 " or the expected size may have been incorrect." 

1347 ) 

1348 

1349 return self 

1350 

1351 def parse(self: T, data: bytes) -> T: 

1352 """ 

1353 Parse the binary encoded Protobuf into this message instance. This 

1354 returns the instance itself and is therefore assignable and chainable. 

1355 

1356 Parameters 

1357 ----------- 

1358 data: :class:`bytes` 

1359 The data to parse the message from. 

1360 

1361 Returns 

1362 -------- 

1363 :class:`Message` 

1364 The initialized message. 

1365 """ 

1366 with BytesIO(data) as stream: 

1367 return self.load(stream) 

1368 

1369 # For compatibility with other libraries. 

1370 @classmethod 

1371 def FromString(cls: Type[T], data: bytes) -> T: 

1372 """ 

1373 Parse the binary encoded Protobuf into this message instance. This 

1374 returns the instance itself and is therefore assignable and chainable. 

1375 

1376 .. note:: 

1377 This is a method for compatibility with other libraries, 

1378 you should really use :meth:`parse`. 

1379 

1380 

1381 Parameters 

1382 ----------- 

1383 data: :class:`bytes` 

1384 The data to parse the protobuf from. 

1385 

1386 Returns 

1387 -------- 

1388 :class:`Message` 

1389 The initialized message. 

1390 """ 

1391 return cls().parse(data) 

1392 

1393 def to_dict( 

1394 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False 

1395 ) -> Dict[str, Any]: 

1396 """ 

1397 Returns a JSON serializable dict representation of this object. 

1398 

1399 Parameters 

1400 ----------- 

1401 casing: :class:`Casing` 

1402 The casing to use for key values. Default is :attr:`Casing.CAMEL` for 

1403 compatibility purposes. 

1404 include_default_values: :class:`bool` 

1405 If ``True`` will include the default values of fields. Default is ``False``. 

1406 E.g. an ``int32`` field will be included with a value of ``0`` if this is 

1407 set to ``True``, otherwise this would be ignored. 

1408 

1409 Returns 

1410 -------- 

1411 Dict[:class:`str`, Any] 

1412 The JSON serializable dict representation of this object. 

1413 """ 

1414 output: Dict[str, Any] = {} 

1415 field_types = self._type_hints() 

1416 defaults = self._betterproto.default_gen 

1417 for field_name, meta in self._betterproto.meta_by_field_name.items(): 

1418 field_is_repeated = defaults[field_name] is list 

1419 try: 

1420 value = getattr(self, field_name) 

1421 except AttributeError: 

1422 value = self._get_field_default(field_name) 

1423 cased_name = casing(field_name).rstrip("_") # type: ignore 

1424 if meta.proto_type == TYPE_MESSAGE: 

1425 if isinstance(value, datetime): 

1426 if ( 

1427 value != DATETIME_ZERO 

1428 or include_default_values 

1429 or self._include_default_value_for_oneof( 

1430 field_name=field_name, meta=meta 

1431 ) 

1432 ): 

1433 output[cased_name] = _Timestamp.timestamp_to_json(value) 

1434 elif isinstance(value, timedelta): 

1435 if ( 

1436 value != timedelta(0) 

1437 or include_default_values 

1438 or self._include_default_value_for_oneof( 

1439 field_name=field_name, meta=meta 

1440 ) 

1441 ): 

1442 output[cased_name] = _Duration.delta_to_json(value) 

1443 elif meta.wraps: 

1444 if value is not None or include_default_values: 

1445 output[cased_name] = value 

1446 elif field_is_repeated: 

1447 # Convert each item. 

1448 cls = self._betterproto.cls_by_field[field_name] 

1449 if cls == datetime: 

1450 value = [_Timestamp.timestamp_to_json(i) for i in value] 

1451 elif cls == timedelta: 

1452 value = [_Duration.delta_to_json(i) for i in value] 

1453 else: 

1454 value = [ 

1455 i.to_dict(casing, include_default_values) for i in value 

1456 ] 

1457 if value or include_default_values: 

1458 output[cased_name] = value 

1459 elif value is None: 

1460 if include_default_values: 

1461 output[cased_name] = value 

1462 elif ( 

1463 value._serialized_on_wire 

1464 or include_default_values 

1465 or self._include_default_value_for_oneof( 

1466 field_name=field_name, meta=meta 

1467 ) 

1468 ): 

1469 output[cased_name] = value.to_dict(casing, include_default_values) 

1470 elif meta.proto_type == TYPE_MAP: 

1471 output_map = {**value} 

1472 for k in value: 

1473 if hasattr(value[k], "to_dict"): 

1474 output_map[k] = value[k].to_dict(casing, include_default_values) 

1475 

1476 if value or include_default_values: 

1477 output[cased_name] = output_map 

1478 elif ( 

1479 value != self._get_field_default(field_name) 

1480 or include_default_values 

1481 or self._include_default_value_for_oneof( 

1482 field_name=field_name, meta=meta 

1483 ) 

1484 ): 

1485 if meta.proto_type in INT_64_TYPES: 

1486 if field_is_repeated: 

1487 output[cased_name] = [str(n) for n in value] 

1488 elif value is None: 

1489 if include_default_values: 

1490 output[cased_name] = value 

1491 else: 

1492 output[cased_name] = str(value) 

1493 elif meta.proto_type == TYPE_BYTES: 

1494 if field_is_repeated: 

1495 output[cased_name] = [ 

1496 b64encode(b).decode("utf8") for b in value 

1497 ] 

1498 elif value is None and include_default_values: 

1499 output[cased_name] = value 

1500 else: 

1501 output[cased_name] = b64encode(value).decode("utf8") 

1502 elif meta.proto_type == TYPE_ENUM: 

1503 if field_is_repeated: 

1504 enum_class = field_types[field_name].__args__[0] 

1505 if isinstance(value, typing.Iterable) and not isinstance( 

1506 value, str 

1507 ): 

1508 output[cased_name] = [enum_class(el).name for el in value] 

1509 else: 

1510 # transparently upgrade single value to repeated 

1511 output[cased_name] = [enum_class(value).name] 

1512 elif value is None: 

1513 if include_default_values: 

1514 output[cased_name] = value 

1515 elif meta.optional: 

1516 enum_class = field_types[field_name].__args__[0] 

1517 output[cased_name] = enum_class(value).name 

1518 else: 

1519 enum_class = field_types[field_name] # noqa 

1520 output[cased_name] = enum_class(value).name 

1521 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): 

1522 if field_is_repeated: 

1523 output[cased_name] = [_dump_float(n) for n in value] 

1524 else: 

1525 output[cased_name] = _dump_float(value) 

1526 else: 

1527 output[cased_name] = value 

1528 return output 

1529 

1530 @classmethod 

1531 def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: 

1532 init_kwargs: Dict[str, Any] = {} 

1533 for key, value in mapping.items(): 

1534 field_name = safe_snake_case(key) 

1535 try: 

1536 meta = cls._betterproto.meta_by_field_name[field_name] 

1537 except KeyError: 

1538 continue 

1539 if value is None: 

1540 continue 

1541 

1542 if meta.proto_type == TYPE_MESSAGE: 

1543 sub_cls = cls._betterproto.cls_by_field[field_name] 

1544 if sub_cls == datetime: 

1545 value = ( 

1546 [isoparse(item) for item in value] 

1547 if isinstance(value, list) 

1548 else isoparse(value) 

1549 ) 

1550 elif sub_cls == timedelta: 

1551 value = ( 

1552 [timedelta(seconds=float(item[:-1])) for item in value] 

1553 if isinstance(value, list) 

1554 else timedelta(seconds=float(value[:-1])) 

1555 ) 

1556 elif not meta.wraps: 

1557 value = ( 

1558 [sub_cls.from_dict(item) for item in value] 

1559 if isinstance(value, list) 

1560 else sub_cls.from_dict(value) 

1561 ) 

1562 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: 

1563 sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"] 

1564 value = {k: sub_cls.from_dict(v) for k, v in value.items()} 

1565 else: 

1566 if meta.proto_type in INT_64_TYPES: 

1567 value = ( 

1568 [int(n) for n in value] 

1569 if isinstance(value, list) 

1570 else int(value) 

1571 ) 

1572 elif meta.proto_type == TYPE_BYTES: 

1573 value = ( 

1574 [b64decode(n) for n in value] 

1575 if isinstance(value, list) 

1576 else b64decode(value) 

1577 ) 

1578 elif meta.proto_type == TYPE_ENUM: 

1579 enum_cls = cls._betterproto.cls_by_field[field_name] 

1580 if isinstance(value, list): 

1581 value = [enum_cls.from_string(e) for e in value] 

1582 elif isinstance(value, str): 

1583 value = enum_cls.from_string(value) 

1584 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): 

1585 value = ( 

1586 [_parse_float(n) for n in value] 

1587 if isinstance(value, list) 

1588 else _parse_float(value) 

1589 ) 

1590 

1591 init_kwargs[field_name] = value 

1592 return init_kwargs 

1593 

1594 @hybridmethod 

1595 def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore 

1596 """ 

1597 Parse the key/value pairs into the a new message instance. 

1598 

1599 Parameters 

1600 ----------- 

1601 value: Dict[:class:`str`, Any] 

1602 The dictionary to parse from. 

1603 

1604 Returns 

1605 -------- 

1606 :class:`Message` 

1607 The initialized message. 

1608 """ 

1609 self = cls(**cls._from_dict_init(value)) 

1610 self._serialized_on_wire = True 

1611 return self 

1612 

1613 @from_dict.instancemethod 

1614 def from_dict(self, value: Mapping[str, Any]) -> Self: 

1615 """ 

1616 Parse the key/value pairs into the current message instance. This returns the 

1617 instance itself and is therefore assignable and chainable. 

1618 

1619 Parameters 

1620 ----------- 

1621 value: Dict[:class:`str`, Any] 

1622 The dictionary to parse from. 

1623 

1624 Returns 

1625 -------- 

1626 :class:`Message` 

1627 The initialized message. 

1628 """ 

1629 self._serialized_on_wire = True 

1630 for field, value in self._from_dict_init(value).items(): 

1631 setattr(self, field, value) 

1632 return self 

1633 

1634 def to_json( 

1635 self, 

1636 indent: Union[None, int, str] = None, 

1637 include_default_values: bool = False, 

1638 casing: Casing = Casing.CAMEL, 

1639 ) -> str: 

1640 """A helper function to parse the message instance into its JSON 

1641 representation. 

1642 

1643 This is equivalent to:: 

1644 

1645 json.dumps(message.to_dict(), indent=indent) 

1646 

1647 Parameters 

1648 ----------- 

1649 indent: Optional[Union[:class:`int`, :class:`str`]] 

1650 The indent to pass to :func:`json.dumps`. 

1651 

1652 include_default_values: :class:`bool` 

1653 If ``True`` will include the default values of fields. Default is ``False``. 

1654 E.g. an ``int32`` field will be included with a value of ``0`` if this is 

1655 set to ``True``, otherwise this would be ignored. 

1656 

1657 casing: :class:`Casing` 

1658 The casing to use for key values. Default is :attr:`Casing.CAMEL` for 

1659 compatibility purposes. 

1660 

1661 Returns 

1662 -------- 

1663 :class:`str` 

1664 The JSON representation of the message. 

1665 """ 

1666 return json.dumps( 

1667 self.to_dict(include_default_values=include_default_values, casing=casing), 

1668 indent=indent, 

1669 ) 

1670 

1671 def from_json(self: T, value: Union[str, bytes]) -> T: 

1672 """A helper function to return the message instance from its JSON 

1673 representation. This returns the instance itself and is therefore assignable 

1674 and chainable. 

1675 

1676 This is equivalent to:: 

1677 

1678 return message.from_dict(json.loads(value)) 

1679 

1680 Parameters 

1681 ----------- 

1682 value: Union[:class:`str`, :class:`bytes`] 

1683 The value to pass to :func:`json.loads`. 

1684 

1685 Returns 

1686 -------- 

1687 :class:`Message` 

1688 The initialized message. 

1689 """ 

1690 return self.from_dict(json.loads(value)) 

1691 

1692 def to_pydict( 

1693 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False 

1694 ) -> Dict[str, Any]: 

1695 """ 

1696 Returns a python dict representation of this object. 

1697 

1698 Parameters 

1699 ----------- 

1700 casing: :class:`Casing` 

1701 The casing to use for key values. Default is :attr:`Casing.CAMEL` for 

1702 compatibility purposes. 

1703 include_default_values: :class:`bool` 

1704 If ``True`` will include the default values of fields. Default is ``False``. 

1705 E.g. an ``int32`` field will be included with a value of ``0`` if this is 

1706 set to ``True``, otherwise this would be ignored. 

1707 

1708 Returns 

1709 -------- 

1710 Dict[:class:`str`, Any] 

1711 The python dict representation of this object. 

1712 """ 

1713 output: Dict[str, Any] = {} 

1714 defaults = self._betterproto.default_gen 

1715 for field_name, meta in self._betterproto.meta_by_field_name.items(): 

1716 field_is_repeated = defaults[field_name] is list 

1717 value = getattr(self, field_name) 

1718 cased_name = casing(field_name).rstrip("_") # type: ignore 

1719 if meta.proto_type == TYPE_MESSAGE: 

1720 if isinstance(value, datetime): 

1721 if ( 

1722 value != DATETIME_ZERO 

1723 or include_default_values 

1724 or self._include_default_value_for_oneof( 

1725 field_name=field_name, meta=meta 

1726 ) 

1727 ): 

1728 output[cased_name] = value 

1729 elif isinstance(value, timedelta): 

1730 if ( 

1731 value != timedelta(0) 

1732 or include_default_values 

1733 or self._include_default_value_for_oneof( 

1734 field_name=field_name, meta=meta 

1735 ) 

1736 ): 

1737 output[cased_name] = value 

1738 elif meta.wraps: 

1739 if value is not None or include_default_values: 

1740 output[cased_name] = value 

1741 elif field_is_repeated: 

1742 # Convert each item. 

1743 value = [i.to_pydict(casing, include_default_values) for i in value] 

1744 if value or include_default_values: 

1745 output[cased_name] = value 

1746 elif value is None: 

1747 if include_default_values: 

1748 output[cased_name] = None 

1749 elif ( 

1750 value._serialized_on_wire 

1751 or include_default_values 

1752 or self._include_default_value_for_oneof( 

1753 field_name=field_name, meta=meta 

1754 ) 

1755 ): 

1756 output[cased_name] = value.to_pydict(casing, include_default_values) 

1757 elif meta.proto_type == TYPE_MAP: 

1758 for k in value: 

1759 if hasattr(value[k], "to_pydict"): 

1760 value[k] = value[k].to_pydict(casing, include_default_values) 

1761 

1762 if value or include_default_values: 

1763 output[cased_name] = value 

1764 elif ( 

1765 value != self._get_field_default(field_name) 

1766 or include_default_values 

1767 or self._include_default_value_for_oneof( 

1768 field_name=field_name, meta=meta 

1769 ) 

1770 ): 

1771 output[cased_name] = value 

1772 return output 

1773 

1774 def from_pydict(self: T, value: Mapping[str, Any]) -> T: 

1775 """ 

1776 Parse the key/value pairs into the current message instance. This returns the 

1777 instance itself and is therefore assignable and chainable. 

1778 

1779 Parameters 

1780 ----------- 

1781 value: Dict[:class:`str`, Any] 

1782 The dictionary to parse from. 

1783 

1784 Returns 

1785 -------- 

1786 :class:`Message` 

1787 The initialized message. 

1788 """ 

1789 self._serialized_on_wire = True 

1790 for key in value: 

1791 field_name = safe_snake_case(key) 

1792 meta = self._betterproto.meta_by_field_name.get(field_name) 

1793 if not meta: 

1794 continue 

1795 

1796 if value[key] is not None: 

1797 if meta.proto_type == TYPE_MESSAGE: 

1798 v = getattr(self, field_name) 

1799 if isinstance(v, list): 

1800 cls = self._betterproto.cls_by_field[field_name] 

1801 for item in value[key]: 

1802 v.append(cls().from_pydict(item)) 

1803 elif isinstance(v, datetime): 

1804 v = value[key] 

1805 elif isinstance(v, timedelta): 

1806 v = value[key] 

1807 elif meta.wraps: 

1808 v = value[key] 

1809 else: 

1810 # NOTE: `from_pydict` mutates the underlying message, so no 

1811 # assignment here is necessary. 

1812 v.from_pydict(value[key]) 

1813 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: 

1814 v = getattr(self, field_name) 

1815 cls = self._betterproto.cls_by_field[f"{field_name}.value"] 

1816 for k in value[key]: 

1817 v[k] = cls().from_pydict(value[key][k]) 

1818 else: 

1819 v = value[key] 

1820 

1821 if v is not None: 

1822 setattr(self, field_name, v) 

1823 return self 

1824 

1825 def is_set(self, name: str) -> bool: 

1826 """ 

1827 Check if field with the given name has been set. 

1828 

1829 Parameters 

1830 ----------- 

1831 name: :class:`str` 

1832 The name of the field to check for. 

1833 

1834 Returns 

1835 -------- 

1836 :class:`bool` 

1837 `True` if field has been set, otherwise `False`. 

1838 """ 

1839 default = ( 

1840 PLACEHOLDER 

1841 if not self._betterproto.meta_by_field_name[name].optional 

1842 else None 

1843 ) 

1844 return self.__raw_get(name) is not default 

1845 

1846 @classmethod 

1847 def _validate_field_groups(cls, values): 

1848 group_to_one_ofs = cls._betterproto.oneof_field_by_group 

1849 field_name_to_meta = cls._betterproto.meta_by_field_name 

1850 

1851 for group, field_set in group_to_one_ofs.items(): 

1852 if len(field_set) == 1: 

1853 (field,) = field_set 

1854 field_name = field.name 

1855 meta = field_name_to_meta[field_name] 

1856 

1857 # This is a synthetic oneof; we should ignore it's presence and not consider it as a oneof. 

1858 if meta.optional: 

1859 continue 

1860 

1861 set_fields = [ 

1862 field.name 

1863 for field in field_set 

1864 if getattr(values, field.name, None) is not None 

1865 ] 

1866 

1867 if not set_fields: 

1868 raise ValueError(f"Group {group} has no value; all fields are None") 

1869 elif len(set_fields) > 1: 

1870 set_fields_str = ", ".join(set_fields) 

1871 raise ValueError( 

1872 f"Group {group} has more than one value; fields {set_fields_str} are not None" 

1873 ) 

1874 

1875 return values 

1876 

1877 

1878Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :) 

1879 

1880# monkey patch (de-)serialization functions of class `Message` 

1881# with functions from `betterproto-rust-codec` if available 

1882try: 

1883 import betterproto_rust_codec 

1884 

1885 def __parse_patch(self: T, data: bytes) -> T: 

1886 betterproto_rust_codec.deserialize(self, data) 

1887 return self 

1888 

1889 def __bytes_patch(self) -> bytes: 

1890 return betterproto_rust_codec.serialize(self) 

1891 

1892 Message.parse = __parse_patch 

1893 Message.__bytes__ = __bytes_patch 

1894except ModuleNotFoundError: 

1895 pass 

1896 

1897 

1898def serialized_on_wire(message: Message) -> bool: 

1899 """ 

1900 If this message was or should be serialized on the wire. This can be used to detect 

1901 presence (e.g. optional wrapper message) and is used internally during 

1902 parsing/serialization. 

1903 

1904 Returns 

1905 -------- 

1906 :class:`bool` 

1907 Whether this message was or should be serialized on the wire. 

1908 """ 

1909 return message._serialized_on_wire 

1910 

1911 

1912def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: 

1913 """ 

1914 Return the name and value of a message's one-of field group. 

1915 

1916 Returns 

1917 -------- 

1918 Tuple[:class:`str`, Any] 

1919 The field name and the value for that field. 

1920 """ 

1921 field_name = message._group_current.get(group_name) 

1922 if not field_name: 

1923 return "", None 

1924 return field_name, getattr(message, field_name) 

1925 

1926 

1927# Circular import workaround: google.protobuf depends on base classes defined above. 

1928from .lib.google.protobuf import ( # noqa 

1929 BoolValue, 

1930 BytesValue, 

1931 DoubleValue, 

1932 Duration, 

1933 EnumValue, 

1934 FloatValue, 

1935 Int32Value, 

1936 Int64Value, 

1937 StringValue, 

1938 Timestamp, 

1939 UInt32Value, 

1940 UInt64Value, 

1941) 

1942 

1943 

1944class _Duration(Duration): 

1945 @classmethod 

1946 def from_timedelta( 

1947 cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1) 

1948 ) -> "_Duration": 

1949 total_ms = delta // _1_microsecond 

1950 seconds = int(total_ms / 1e6) 

1951 nanos = int((total_ms % 1e6) * 1e3) 

1952 return cls(seconds, nanos) 

1953 

1954 def to_timedelta(self) -> timedelta: 

1955 return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) 

1956 

1957 @staticmethod 

1958 def delta_to_json(delta: timedelta) -> str: 

1959 parts = str(delta.total_seconds()).split(".") 

1960 if len(parts) > 1: 

1961 while len(parts[1]) not in (3, 6, 9): 

1962 parts[1] = f"{parts[1]}0" 

1963 return f"{'.'.join(parts)}s" 

1964 

1965 

1966class _Timestamp(Timestamp): 

1967 @classmethod 

1968 def from_datetime(cls, dt: datetime) -> "_Timestamp": 

1969 # manual epoch offset calulation to avoid rounding errors, 

1970 # to support negative timestamps (before 1970) and skirt 

1971 # around datetime bugs (apparently 0 isn't a year in [0, 9999]??) 

1972 offset = dt - DATETIME_ZERO 

1973 # below is the same as timedelta.total_seconds() but without dividing by 1e6 

1974 # so we end up with microseconds as integers instead of seconds as float 

1975 offset_us = ( 

1976 offset.days * 24 * 60 * 60 + offset.seconds 

1977 ) * 10**6 + offset.microseconds 

1978 seconds, us = divmod(offset_us, 10**6) 

1979 return cls(seconds, us * 1000) 

1980 

1981 def to_datetime(self) -> datetime: 

1982 # datetime.fromtimestamp() expects a timestamp in seconds, not microseconds 

1983 # if we pass it as a floating point number, we will run into rounding errors 

1984 # see also #407 

1985 offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000) 

1986 return DATETIME_ZERO + offset 

1987 

1988 @staticmethod 

1989 def timestamp_to_json(dt: datetime) -> str: 

1990 nanos = dt.microsecond * 1e3 

1991 if dt.tzinfo is not None: 

1992 # change timezone aware datetime objects to utc 

1993 dt = dt.astimezone(timezone.utc) 

1994 copy = dt.replace(microsecond=0, tzinfo=None) 

1995 result = copy.isoformat() 

1996 if (nanos % 1e9) == 0: 

1997 # If there are 0 fractional digits, the fractional 

1998 # point '.' should be omitted when serializing. 

1999 return f"{result}Z" 

2000 if (nanos % 1e6) == 0: 

2001 # Serialize 3 fractional digits. 

2002 return f"{result}.{int(nanos // 1e6) :03d}Z" 

2003 if (nanos % 1e3) == 0: 

2004 # Serialize 6 fractional digits. 

2005 return f"{result}.{int(nanos // 1e3) :06d}Z" 

2006 # Serialize 9 fractional digits. 

2007 return f"{result}.{nanos:09d}" 

2008 

2009 

2010def _get_wrapper(proto_type: str) -> Type: 

2011 """Get the wrapper message class for a wrapped type.""" 

2012 

2013 # TODO: include ListValue and NullValue? 

2014 return { 

2015 TYPE_BOOL: BoolValue, 

2016 TYPE_BYTES: BytesValue, 

2017 TYPE_DOUBLE: DoubleValue, 

2018 TYPE_FLOAT: FloatValue, 

2019 TYPE_ENUM: EnumValue, 

2020 TYPE_INT32: Int32Value, 

2021 TYPE_INT64: Int64Value, 

2022 TYPE_STRING: StringValue, 

2023 TYPE_UINT32: UInt32Value, 

2024 TYPE_UINT64: UInt64Value, 

2025 }[proto_type]