1from __future__ import annotations
2
3import dataclasses
4import enum as builtin_enum
5import json
6import math
7import struct
8import sys
9import typing
10import warnings
11from abc import ABC
12from base64 import (
13 b64decode,
14 b64encode,
15)
16from copy import deepcopy
17from datetime import (
18 datetime,
19 timedelta,
20 timezone,
21)
22from io import BytesIO
23from itertools import count
24from typing import (
25 TYPE_CHECKING,
26 Any,
27 Callable,
28 ClassVar,
29 Dict,
30 Generator,
31 Iterable,
32 Mapping,
33 Optional,
34 Set,
35 Tuple,
36 Type,
37 Union,
38 get_type_hints,
39)
40
41from dateutil.parser import isoparse
42from typing_extensions import Self
43
44from ._types import T
45from ._version import __version__
46from .casing import (
47 camel_case,
48 safe_snake_case,
49 snake_case,
50)
51from .enum import Enum as Enum
52from .grpc.grpclib_client import ServiceStub as ServiceStub
53from .utils import (
54 classproperty,
55 hybridmethod,
56)
57
58
59if TYPE_CHECKING:
60 from _typeshed import (
61 SupportsRead,
62 SupportsWrite,
63 )
64
65if sys.version_info >= (3, 10):
66 from types import UnionType as _types_UnionType
67else:
68
69 class _types_UnionType:
70 ...
71
72
73# Proto 3 data types
74TYPE_ENUM = "enum"
75TYPE_BOOL = "bool"
76TYPE_INT32 = "int32"
77TYPE_INT64 = "int64"
78TYPE_UINT32 = "uint32"
79TYPE_UINT64 = "uint64"
80TYPE_SINT32 = "sint32"
81TYPE_SINT64 = "sint64"
82TYPE_FLOAT = "float"
83TYPE_DOUBLE = "double"
84TYPE_FIXED32 = "fixed32"
85TYPE_SFIXED32 = "sfixed32"
86TYPE_FIXED64 = "fixed64"
87TYPE_SFIXED64 = "sfixed64"
88TYPE_STRING = "string"
89TYPE_BYTES = "bytes"
90TYPE_MESSAGE = "message"
91TYPE_MAP = "map"
92
93# Fields that use a fixed amount of space (4 or 8 bytes)
94FIXED_TYPES = [
95 TYPE_FLOAT,
96 TYPE_DOUBLE,
97 TYPE_FIXED32,
98 TYPE_SFIXED32,
99 TYPE_FIXED64,
100 TYPE_SFIXED64,
101]
102
103# Fields that are numerical 64-bit types
104INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
105
106# Fields that are efficiently packed when
107PACKED_TYPES = [
108 TYPE_ENUM,
109 TYPE_BOOL,
110 TYPE_INT32,
111 TYPE_INT64,
112 TYPE_UINT32,
113 TYPE_UINT64,
114 TYPE_SINT32,
115 TYPE_SINT64,
116 TYPE_FLOAT,
117 TYPE_DOUBLE,
118 TYPE_FIXED32,
119 TYPE_SFIXED32,
120 TYPE_FIXED64,
121 TYPE_SFIXED64,
122]
123
124# Wire types
125# https://developers.google.com/protocol-buffers/docs/encoding#structure
126WIRE_VARINT = 0
127WIRE_FIXED_64 = 1
128WIRE_LEN_DELIM = 2
129WIRE_FIXED_32 = 5
130
131# Mappings of which Proto 3 types correspond to which wire types.
132WIRE_VARINT_TYPES = [
133 TYPE_ENUM,
134 TYPE_BOOL,
135 TYPE_INT32,
136 TYPE_INT64,
137 TYPE_UINT32,
138 TYPE_UINT64,
139 TYPE_SINT32,
140 TYPE_SINT64,
141]
142
143WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
144WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
145WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
146
147# Indicator of message delimitation in streams
148SIZE_DELIMITED = -1
149
150
151# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
152def datetime_default_gen() -> datetime:
153 return datetime(1970, 1, 1, tzinfo=timezone.utc)
154
155
156DATETIME_ZERO = datetime_default_gen()
157
158
159# Special protobuf json doubles
160INFINITY = "Infinity"
161NEG_INFINITY = "-Infinity"
162NAN = "NaN"
163
164
165class Casing(builtin_enum.Enum):
166 """Casing constants for serialization."""
167
168 CAMEL = camel_case #: A camelCase sterilization function.
169 SNAKE = snake_case #: A snake_case sterilization function.
170
171
172PLACEHOLDER: Any = object()
173
174
175@dataclasses.dataclass(frozen=True)
176class FieldMetadata:
177 """Stores internal metadata used for parsing & serialization."""
178
179 # Protobuf field number
180 number: int
181 # Protobuf type name
182 proto_type: str
183 # Map information if the proto_type is a map
184 map_types: Optional[Tuple[str, str]] = None
185 # Groups several "one-of" fields together
186 group: Optional[str] = None
187 # Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
188 wraps: Optional[str] = None
189 # Is the field optional
190 optional: Optional[bool] = False
191
192 @staticmethod
193 def get(field: dataclasses.Field) -> "FieldMetadata":
194 """Returns the field metadata for a dataclass field."""
195 return field.metadata["betterproto"]
196
197
198def dataclass_field(
199 number: int,
200 proto_type: str,
201 *,
202 map_types: Optional[Tuple[str, str]] = None,
203 group: Optional[str] = None,
204 wraps: Optional[str] = None,
205 optional: bool = False,
206) -> dataclasses.Field:
207 """Creates a dataclass field with attached protobuf metadata."""
208 return dataclasses.field(
209 default=None if optional else PLACEHOLDER,
210 metadata={
211 "betterproto": FieldMetadata(
212 number, proto_type, map_types, group, wraps, optional
213 )
214 },
215 )
216
217
218# Note: the fields below return `Any` to prevent type errors in the generated
219# data classes since the types won't match with `Field` and they get swapped
220# out at runtime. The generated dataclass variables are still typed correctly.
221
222
223def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
224 return dataclass_field(number, TYPE_ENUM, group=group, optional=optional)
225
226
227def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
228 return dataclass_field(number, TYPE_BOOL, group=group, optional=optional)
229
230
231def int32_field(
232 number: int, group: Optional[str] = None, optional: bool = False
233) -> Any:
234 return dataclass_field(number, TYPE_INT32, group=group, optional=optional)
235
236
237def int64_field(
238 number: int, group: Optional[str] = None, optional: bool = False
239) -> Any:
240 return dataclass_field(number, TYPE_INT64, group=group, optional=optional)
241
242
243def uint32_field(
244 number: int, group: Optional[str] = None, optional: bool = False
245) -> Any:
246 return dataclass_field(number, TYPE_UINT32, group=group, optional=optional)
247
248
249def uint64_field(
250 number: int, group: Optional[str] = None, optional: bool = False
251) -> Any:
252 return dataclass_field(number, TYPE_UINT64, group=group, optional=optional)
253
254
255def sint32_field(
256 number: int, group: Optional[str] = None, optional: bool = False
257) -> Any:
258 return dataclass_field(number, TYPE_SINT32, group=group, optional=optional)
259
260
261def sint64_field(
262 number: int, group: Optional[str] = None, optional: bool = False
263) -> Any:
264 return dataclass_field(number, TYPE_SINT64, group=group, optional=optional)
265
266
267def float_field(
268 number: int, group: Optional[str] = None, optional: bool = False
269) -> Any:
270 return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional)
271
272
273def double_field(
274 number: int, group: Optional[str] = None, optional: bool = False
275) -> Any:
276 return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional)
277
278
279def fixed32_field(
280 number: int, group: Optional[str] = None, optional: bool = False
281) -> Any:
282 return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional)
283
284
285def fixed64_field(
286 number: int, group: Optional[str] = None, optional: bool = False
287) -> Any:
288 return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional)
289
290
291def sfixed32_field(
292 number: int, group: Optional[str] = None, optional: bool = False
293) -> Any:
294 return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional)
295
296
297def sfixed64_field(
298 number: int, group: Optional[str] = None, optional: bool = False
299) -> Any:
300 return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional)
301
302
303def string_field(
304 number: int, group: Optional[str] = None, optional: bool = False
305) -> Any:
306 return dataclass_field(number, TYPE_STRING, group=group, optional=optional)
307
308
309def bytes_field(
310 number: int, group: Optional[str] = None, optional: bool = False
311) -> Any:
312 return dataclass_field(number, TYPE_BYTES, group=group, optional=optional)
313
314
315def message_field(
316 number: int,
317 group: Optional[str] = None,
318 wraps: Optional[str] = None,
319 optional: bool = False,
320) -> Any:
321 return dataclass_field(
322 number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional
323 )
324
325
326def map_field(
327 number: int, key_type: str, value_type: str, group: Optional[str] = None
328) -> Any:
329 return dataclass_field(
330 number, TYPE_MAP, map_types=(key_type, value_type), group=group
331 )
332
333
334def _pack_fmt(proto_type: str) -> str:
335 """Returns a little-endian format string for reading/writing binary."""
336 return {
337 TYPE_DOUBLE: "<d",
338 TYPE_FLOAT: "<f",
339 TYPE_FIXED32: "<I",
340 TYPE_FIXED64: "<Q",
341 TYPE_SFIXED32: "<i",
342 TYPE_SFIXED64: "<q",
343 }[proto_type]
344
345
346def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None:
347 """Encodes a single varint and dumps it into the provided stream."""
348 if value < -(1 << 63):
349 raise ValueError(
350 "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
351 )
352 elif value < 0:
353 value += 1 << 64
354
355 bits = value & 0x7F
356 value >>= 7
357 while value:
358 stream.write((0x80 | bits).to_bytes(1, "little"))
359 bits = value & 0x7F
360 value >>= 7
361 stream.write(bits.to_bytes(1, "little"))
362
363
364def encode_varint(value: int) -> bytes:
365 """Encodes a single varint value for serialization."""
366 with BytesIO() as stream:
367 dump_varint(value, stream)
368 return stream.getvalue()
369
370
371def size_varint(value: int) -> int:
372 """Calculates the size in bytes that a value would take as a varint."""
373 if value < -(1 << 63):
374 raise ValueError(
375 "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
376 )
377 elif value < 0:
378 return 10
379 elif value == 0:
380 return 1
381 else:
382 return math.ceil(value.bit_length() / 7)
383
384
385def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
386 """Adjusts values before serialization."""
387 if proto_type in (
388 TYPE_ENUM,
389 TYPE_BOOL,
390 TYPE_INT32,
391 TYPE_INT64,
392 TYPE_UINT32,
393 TYPE_UINT64,
394 ):
395 return encode_varint(value)
396 elif proto_type in (TYPE_SINT32, TYPE_SINT64):
397 # Handle zig-zag encoding.
398 return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
399 elif proto_type in FIXED_TYPES:
400 return struct.pack(_pack_fmt(proto_type), value)
401 elif proto_type == TYPE_STRING:
402 return value.encode("utf-8")
403 elif proto_type == TYPE_MESSAGE:
404 if isinstance(value, datetime):
405 # Convert the `datetime` to a timestamp message.
406 value = _Timestamp.from_datetime(value)
407 elif isinstance(value, timedelta):
408 # Convert the `timedelta` to a duration message.
409 value = _Duration.from_timedelta(value)
410 elif wraps:
411 if value is None:
412 return b""
413 value = _get_wrapper(wraps)(value=value)
414
415 return bytes(value)
416
417 return value
418
419
420def _len_preprocessed_single(proto_type: str, wraps: str, value: Any) -> int:
421 """Calculate the size of adjusted values for serialization without fully serializing them."""
422 if proto_type in (
423 TYPE_ENUM,
424 TYPE_BOOL,
425 TYPE_INT32,
426 TYPE_INT64,
427 TYPE_UINT32,
428 TYPE_UINT64,
429 ):
430 return size_varint(value)
431 elif proto_type in (TYPE_SINT32, TYPE_SINT64):
432 # Handle zig-zag encoding.
433 return size_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
434 elif proto_type in FIXED_TYPES:
435 return len(struct.pack(_pack_fmt(proto_type), value))
436 elif proto_type == TYPE_STRING:
437 return len(value.encode("utf-8"))
438 elif proto_type == TYPE_MESSAGE:
439 if isinstance(value, datetime):
440 # Convert the `datetime` to a timestamp message.
441 value = _Timestamp.from_datetime(value)
442 elif isinstance(value, timedelta):
443 # Convert the `timedelta` to a duration message.
444 value = _Duration.from_timedelta(value)
445 elif wraps:
446 if value is None:
447 return 0
448 value = _get_wrapper(wraps)(value=value)
449
450 return len(bytes(value))
451
452 return len(value)
453
454
455def _serialize_single(
456 field_number: int,
457 proto_type: str,
458 value: Any,
459 *,
460 serialize_empty: bool = False,
461 wraps: str = "",
462) -> bytes:
463 """Serializes a single field and value."""
464 value = _preprocess_single(proto_type, wraps, value)
465
466 output = bytearray()
467 if proto_type in WIRE_VARINT_TYPES:
468 key = encode_varint(field_number << 3)
469 output += key + value
470 elif proto_type in WIRE_FIXED_32_TYPES:
471 key = encode_varint((field_number << 3) | 5)
472 output += key + value
473 elif proto_type in WIRE_FIXED_64_TYPES:
474 key = encode_varint((field_number << 3) | 1)
475 output += key + value
476 elif proto_type in WIRE_LEN_DELIM_TYPES:
477 if len(value) or serialize_empty or wraps:
478 key = encode_varint((field_number << 3) | 2)
479 output += key + encode_varint(len(value)) + value
480 else:
481 raise NotImplementedError(proto_type)
482
483 return bytes(output)
484
485
486def _len_single(
487 field_number: int,
488 proto_type: str,
489 value: Any,
490 *,
491 serialize_empty: bool = False,
492 wraps: str = "",
493) -> int:
494 """Calculates the size of a serialized single field and value."""
495 size = _len_preprocessed_single(proto_type, wraps, value)
496 if proto_type in WIRE_VARINT_TYPES:
497 size += size_varint(field_number << 3)
498 elif proto_type in WIRE_FIXED_32_TYPES:
499 size += size_varint((field_number << 3) | 5)
500 elif proto_type in WIRE_FIXED_64_TYPES:
501 size += size_varint((field_number << 3) | 1)
502 elif proto_type in WIRE_LEN_DELIM_TYPES:
503 if size or serialize_empty or wraps:
504 size += size_varint((field_number << 3) | 2) + size_varint(size)
505 else:
506 raise NotImplementedError(proto_type)
507
508 return size
509
510
511def _parse_float(value: Any) -> float:
512 """Parse the given value to a float
513
514 Parameters
515 ----------
516 value: Any
517 Value to parse
518
519 Returns
520 -------
521 float
522 Parsed value
523 """
524 if value == INFINITY:
525 return float("inf")
526 if value == NEG_INFINITY:
527 return -float("inf")
528 if value == NAN:
529 return float("nan")
530 return float(value)
531
532
533def _dump_float(value: float) -> Union[float, str]:
534 """Dump the given float to JSON
535
536 Parameters
537 ----------
538 value: float
539 Value to dump
540
541 Returns
542 -------
543 Union[float, str]
544 Dumped value, either a float or the strings
545 """
546 if value == float("inf"):
547 return INFINITY
548 if value == -float("inf"):
549 return NEG_INFINITY
550 if isinstance(value, float) and math.isnan(value):
551 return NAN
552 return value
553
554
555def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]:
556 """
557 Load a single varint value from a stream. Returns the value and the raw bytes read.
558 """
559 result = 0
560 raw = b""
561 for shift in count(0, 7):
562 if shift >= 64:
563 raise ValueError("Too many bytes when decoding varint.")
564 b = stream.read(1)
565 if not b:
566 raise EOFError("Stream ended unexpectedly while attempting to load varint.")
567 raw += b
568 b_int = int.from_bytes(b, byteorder="little")
569 result |= (b_int & 0x7F) << shift
570 if not (b_int & 0x80):
571 return result, raw
572
573
574def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
575 """
576 Decode a single varint value from a byte buffer. Returns the value and the
577 new position in the buffer.
578 """
579 with BytesIO(buffer) as stream:
580 stream.seek(pos)
581 value, raw = load_varint(stream)
582 return value, pos + len(raw)
583
584
585@dataclasses.dataclass(frozen=True)
586class ParsedField:
587 number: int
588 wire_type: int
589 value: Any
590 raw: bytes
591
592
593def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]:
594 while True:
595 try:
596 num_wire, raw = load_varint(stream)
597 except EOFError:
598 return
599 number = num_wire >> 3
600 wire_type = num_wire & 0x7
601
602 decoded: Any = None
603 if wire_type == WIRE_VARINT:
604 decoded, r = load_varint(stream)
605 raw += r
606 elif wire_type == WIRE_FIXED_64:
607 decoded = stream.read(8)
608 raw += decoded
609 elif wire_type == WIRE_LEN_DELIM:
610 length, r = load_varint(stream)
611 decoded = stream.read(length)
612 raw += r
613 raw += decoded
614 elif wire_type == WIRE_FIXED_32:
615 decoded = stream.read(4)
616 raw += decoded
617
618 yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=raw)
619
620
621def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
622 i = 0
623 while i < len(value):
624 start = i
625 num_wire, i = decode_varint(value, i)
626 number = num_wire >> 3
627 wire_type = num_wire & 0x7
628
629 decoded: Any = None
630 if wire_type == WIRE_VARINT:
631 decoded, i = decode_varint(value, i)
632 elif wire_type == WIRE_FIXED_64:
633 decoded, i = value[i : i + 8], i + 8
634 elif wire_type == WIRE_LEN_DELIM:
635 length, i = decode_varint(value, i)
636 decoded = value[i : i + length]
637 i += length
638 elif wire_type == WIRE_FIXED_32:
639 decoded, i = value[i : i + 4], i + 4
640
641 yield ParsedField(
642 number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
643 )
644
645
646class ProtoClassMetadata:
647 __slots__ = (
648 "oneof_group_by_field",
649 "oneof_field_by_group",
650 "default_gen",
651 "cls_by_field",
652 "field_name_by_number",
653 "meta_by_field_name",
654 "sorted_field_names",
655 )
656
657 oneof_group_by_field: Dict[str, str]
658 oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
659 field_name_by_number: Dict[int, str]
660 meta_by_field_name: Dict[str, FieldMetadata]
661 sorted_field_names: Tuple[str, ...]
662 default_gen: Dict[str, Callable[[], Any]]
663 cls_by_field: Dict[str, Type]
664
665 def __init__(self, cls: Type["Message"]):
666 by_field = {}
667 by_group: Dict[str, Set] = {}
668 by_field_name = {}
669 by_field_number = {}
670
671 fields = dataclasses.fields(cls)
672 for field in fields:
673 meta = FieldMetadata.get(field)
674
675 if meta.group:
676 # This is part of a one-of group.
677 by_field[field.name] = meta.group
678
679 by_group.setdefault(meta.group, set()).add(field)
680
681 by_field_name[field.name] = meta
682 by_field_number[meta.number] = field.name
683
684 self.oneof_group_by_field = by_field
685 self.oneof_field_by_group = by_group
686 self.field_name_by_number = by_field_number
687 self.meta_by_field_name = by_field_name
688 self.sorted_field_names = tuple(
689 by_field_number[number] for number in sorted(by_field_number)
690 )
691 self.default_gen = self._get_default_gen(cls, fields)
692 self.cls_by_field = self._get_cls_by_field(cls, fields)
693
694 @staticmethod
695 def _get_default_gen(
696 cls: Type["Message"], fields: Iterable[dataclasses.Field]
697 ) -> Dict[str, Callable[[], Any]]:
698 return {field.name: cls._get_field_default_gen(field) for field in fields}
699
700 @staticmethod
701 def _get_cls_by_field(
702 cls: Type["Message"], fields: Iterable[dataclasses.Field]
703 ) -> Dict[str, Type]:
704 field_cls = {}
705
706 for field in fields:
707 meta = FieldMetadata.get(field)
708 if meta.proto_type == TYPE_MAP:
709 assert meta.map_types
710 kt = cls._cls_for(field, index=0)
711 vt = cls._cls_for(field, index=1)
712 field_cls[field.name] = dataclasses.make_dataclass(
713 "Entry",
714 [
715 ("key", kt, dataclass_field(1, meta.map_types[0])),
716 ("value", vt, dataclass_field(2, meta.map_types[1])),
717 ],
718 bases=(Message,),
719 )
720 field_cls[f"{field.name}.value"] = vt
721 else:
722 field_cls[field.name] = cls._cls_for(field)
723
724 return field_cls
725
726
727class Message(ABC):
728 """
729 The base class for protobuf messages, all generated messages will inherit from
730 this. This class registers the message fields which are used by the serializers and
731 parsers to go between the Python, binary and JSON representations of the message.
732
733 .. container:: operations
734
735 .. describe:: bytes(x)
736
737 Calls :meth:`__bytes__`.
738
739 .. describe:: bool(x)
740
741 Calls :meth:`__bool__`.
742 """
743
744 _serialized_on_wire: bool
745 _unknown_fields: bytes
746 _group_current: Dict[str, str]
747 _betterproto_meta: ClassVar[ProtoClassMetadata]
748
749 def __post_init__(self) -> None:
750 # Keep track of whether every field was default
751 all_sentinel = True
752
753 # Set current field of each group after `__init__` has already been run.
754 group_current: Dict[str, Optional[str]] = {}
755 for field_name, meta in self._betterproto.meta_by_field_name.items():
756 if meta.group:
757 group_current.setdefault(meta.group)
758
759 value = self.__raw_get(field_name)
760 if value is not PLACEHOLDER and not (meta.optional and value is None):
761 # Found a non-sentinel value
762 all_sentinel = False
763
764 if meta.group:
765 # This was set, so make it the selected value of the one-of.
766 group_current[meta.group] = field_name
767
768 # Now that all the defaults are set, reset it!
769 self.__dict__["_serialized_on_wire"] = not all_sentinel
770 self.__dict__["_unknown_fields"] = b""
771 self.__dict__["_group_current"] = group_current
772
773 def __raw_get(self, name: str) -> Any:
774 return super().__getattribute__(name)
775
776 def __eq__(self, other) -> bool:
777 if type(self) is not type(other):
778 return NotImplemented
779
780 for field_name in self._betterproto.meta_by_field_name:
781 self_val = self.__raw_get(field_name)
782 other_val = other.__raw_get(field_name)
783 if self_val is PLACEHOLDER:
784 if other_val is PLACEHOLDER:
785 continue
786 self_val = self._get_field_default(field_name)
787 elif other_val is PLACEHOLDER:
788 other_val = other._get_field_default(field_name)
789
790 if self_val != other_val:
791 # We consider two nan values to be the same for the
792 # purposes of comparing messages (otherwise a message
793 # is not equal to itself)
794 if (
795 isinstance(self_val, float)
796 and isinstance(other_val, float)
797 and math.isnan(self_val)
798 and math.isnan(other_val)
799 ):
800 continue
801 else:
802 return False
803
804 return True
805
806 def __repr__(self) -> str:
807 parts = [
808 f"{field_name}={value!r}"
809 for field_name in self._betterproto.sorted_field_names
810 for value in (self.__raw_get(field_name),)
811 if value is not PLACEHOLDER
812 ]
813 return f"{self.__class__.__name__}({', '.join(parts)})"
814
815 def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
816 for field_name in self._betterproto.sorted_field_names:
817 yield field_name, self.__raw_get(field_name), PLACEHOLDER
818
819 if not TYPE_CHECKING:
820
821 def __getattribute__(self, name: str) -> Any:
822 """
823 Lazily initialize default values to avoid infinite recursion for recursive
824 message types.
825 Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
826 """
827 try:
828 group_current = super().__getattribute__("_group_current")
829 except AttributeError:
830 pass
831 else:
832 if name not in {"__class__", "_betterproto"}:
833 group = self._betterproto.oneof_group_by_field.get(name)
834 if group is not None and group_current[group] != name:
835 if sys.version_info < (3, 10):
836 raise AttributeError(
837 f"{group!r} is set to {group_current[group]!r}, not {name!r}"
838 )
839 else:
840 raise AttributeError(
841 f"{group!r} is set to {group_current[group]!r}, not {name!r}",
842 name=name,
843 obj=self,
844 )
845
846 value = super().__getattribute__(name)
847 if value is not PLACEHOLDER:
848 return value
849
850 value = self._get_field_default(name)
851 super().__setattr__(name, value)
852 return value
853
854 def __setattr__(self, attr: str, value: Any) -> None:
855 if (
856 isinstance(value, Message)
857 and hasattr(value, "_betterproto")
858 and not value._betterproto.meta_by_field_name
859 ):
860 value._serialized_on_wire = True
861
862 if attr != "_serialized_on_wire":
863 # Track when a field has been set.
864 self.__dict__["_serialized_on_wire"] = True
865
866 if hasattr(self, "_group_current"): # __post_init__ had already run
867 if attr in self._betterproto.oneof_group_by_field:
868 group = self._betterproto.oneof_group_by_field[attr]
869 for field in self._betterproto.oneof_field_by_group[group]:
870 if field.name == attr:
871 self._group_current[group] = field.name
872 else:
873 super().__setattr__(field.name, PLACEHOLDER)
874
875 super().__setattr__(attr, value)
876
877 def __bool__(self) -> bool:
878 """True if the Message has any fields with non-default values."""
879 return any(
880 self.__raw_get(field_name)
881 not in (PLACEHOLDER, self._get_field_default(field_name))
882 for field_name in self._betterproto.meta_by_field_name
883 )
884
885 def __deepcopy__(self: T, _: Any = {}) -> T:
886 kwargs = {}
887 for name in self._betterproto.sorted_field_names:
888 value = self.__raw_get(name)
889 if value is not PLACEHOLDER:
890 kwargs[name] = deepcopy(value)
891 return self.__class__(**kwargs) # type: ignore
892
893 def __copy__(self: T, _: Any = {}) -> T:
894 kwargs = {}
895 for name in self._betterproto.sorted_field_names:
896 value = self.__raw_get(name)
897 if value is not PLACEHOLDER:
898 kwargs[name] = value
899 return self.__class__(**kwargs) # type: ignore
900
901 @classproperty
902 def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
903 """
904 Lazy initialize metadata for each protobuf class.
905 It may be initialized multiple times in a multi-threaded environment,
906 but that won't affect the correctness.
907 """
908 try:
909 return cls._betterproto_meta
910 except AttributeError:
911 cls._betterproto_meta = meta = ProtoClassMetadata(cls)
912 return meta
913
914 def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
915 """
916 Dumps the binary encoded Protobuf message to the stream.
917
918 Parameters
919 -----------
920 stream: :class:`BinaryIO`
921 The stream to dump the message to.
922 delimit:
923 Whether to prefix the message with a varint declaring its size.
924 """
925 if delimit == SIZE_DELIMITED:
926 dump_varint(len(self), stream)
927
928 for field_name, meta in self._betterproto.meta_by_field_name.items():
929 try:
930 value = getattr(self, field_name)
931 except AttributeError:
932 continue
933
934 if value is None:
935 # Optional items should be skipped. This is used for the Google
936 # wrapper types and proto3 field presence/optional fields.
937 continue
938
939 # Being selected in a a group means this field is the one that is
940 # currently set in a `oneof` group, so it must be serialized even
941 # if the value is the default zero value.
942 #
943 # Note that proto3 field presence/optional fields are put in a
944 # synthetic single-item oneof by protoc, which helps us ensure we
945 # send the value even if the value is the default zero value.
946 selected_in_group = bool(meta.group) or meta.optional
947
948 # Empty messages can still be sent on the wire if they were
949 # set (or received empty).
950 serialize_empty = isinstance(value, Message) and value._serialized_on_wire
951
952 include_default_value_for_oneof = self._include_default_value_for_oneof(
953 field_name=field_name, meta=meta
954 )
955
956 if value == self._get_field_default(field_name) and not (
957 selected_in_group or serialize_empty or include_default_value_for_oneof
958 ):
959 # Default (zero) values are not serialized. Two exceptions are
960 # if this is the selected oneof item or if we know we have to
961 # serialize an empty message (i.e. zero value was explicitly
962 # set by the user).
963 continue
964
965 if isinstance(value, list):
966 if meta.proto_type in PACKED_TYPES:
967 # Packed lists look like a length-delimited field. First,
968 # preprocess/encode each value into a buffer and then
969 # treat it like a field of raw bytes.
970 buf = bytearray()
971 for item in value:
972 buf += _preprocess_single(meta.proto_type, "", item)
973 stream.write(_serialize_single(meta.number, TYPE_BYTES, buf))
974 else:
975 for item in value:
976 stream.write(
977 _serialize_single(
978 meta.number,
979 meta.proto_type,
980 item,
981 wraps=meta.wraps or "",
982 serialize_empty=True,
983 )
984 # if it's an empty message it still needs to be represented
985 # as an item in the repeated list
986 or b"\n\x00"
987 )
988
989 elif isinstance(value, dict):
990 for k, v in value.items():
991 assert meta.map_types
992 sk = _serialize_single(1, meta.map_types[0], k)
993 sv = _serialize_single(2, meta.map_types[1], v)
994 stream.write(
995 _serialize_single(meta.number, meta.proto_type, sk + sv)
996 )
997 else:
998 # If we have an empty string and we're including the default value for
999 # a oneof, make sure we serialize it. This ensures that the byte string
1000 # output isn't simply an empty string. This also ensures that round trip
1001 # serialization will keep `which_one_of` calls consistent.
1002 if (
1003 isinstance(value, str)
1004 and value == ""
1005 and include_default_value_for_oneof
1006 ):
1007 serialize_empty = True
1008
1009 stream.write(
1010 _serialize_single(
1011 meta.number,
1012 meta.proto_type,
1013 value,
1014 serialize_empty=serialize_empty or bool(selected_in_group),
1015 wraps=meta.wraps or "",
1016 )
1017 )
1018
1019 stream.write(self._unknown_fields)
1020
1021 def __bytes__(self) -> bytes:
1022 """
1023 Get the binary encoded Protobuf representation of this message instance.
1024 """
1025 with BytesIO() as stream:
1026 self.dump(stream)
1027 return stream.getvalue()
1028
1029 def __len__(self) -> int:
1030 """
1031 Get the size of the encoded Protobuf representation of this message instance.
1032 """
1033 size = 0
1034 for field_name, meta in self._betterproto.meta_by_field_name.items():
1035 try:
1036 value = getattr(self, field_name)
1037 except AttributeError:
1038 continue
1039
1040 if value is None:
1041 # Optional items should be skipped. This is used for the Google
1042 # wrapper types and proto3 field presence/optional fields.
1043 continue
1044
1045 # Being selected in a group means this field is the one that is
1046 # currently set in a `oneof` group, so it must be serialized even
1047 # if the value is the default zero value.
1048 #
1049 # Note that proto3 field presence/optional fields are put in a
1050 # synthetic single-item oneof by protoc, which helps us ensure we
1051 # send the value even if the value is the default zero value.
1052 selected_in_group = bool(meta.group)
1053
1054 # Empty messages can still be sent on the wire if they were
1055 # set (or received empty).
1056 serialize_empty = isinstance(value, Message) and value._serialized_on_wire
1057
1058 include_default_value_for_oneof = self._include_default_value_for_oneof(
1059 field_name=field_name, meta=meta
1060 )
1061
1062 if value == self._get_field_default(field_name) and not (
1063 selected_in_group or serialize_empty or include_default_value_for_oneof
1064 ):
1065 # Default (zero) values are not serialized. Two exceptions are
1066 # if this is the selected oneof item or if we know we have to
1067 # serialize an empty message (i.e. zero value was explicitly
1068 # set by the user).
1069 continue
1070
1071 if isinstance(value, list):
1072 if meta.proto_type in PACKED_TYPES:
1073 # Packed lists look like a length-delimited field. First,
1074 # preprocess/encode each value into a buffer and then
1075 # treat it like a field of raw bytes.
1076 buf = bytearray()
1077 for item in value:
1078 buf += _preprocess_single(meta.proto_type, "", item)
1079 size += _len_single(meta.number, TYPE_BYTES, buf)
1080 else:
1081 for item in value:
1082 size += (
1083 _len_single(
1084 meta.number,
1085 meta.proto_type,
1086 item,
1087 wraps=meta.wraps or "",
1088 serialize_empty=True,
1089 )
1090 # if it's an empty message it still needs to be represented
1091 # as an item in the repeated list
1092 or 2
1093 )
1094
1095 elif isinstance(value, dict):
1096 for k, v in value.items():
1097 assert meta.map_types
1098 sk = _serialize_single(1, meta.map_types[0], k)
1099 sv = _serialize_single(2, meta.map_types[1], v)
1100 size += _len_single(meta.number, meta.proto_type, sk + sv)
1101 else:
1102 # If we have an empty string and we're including the default value for
1103 # a oneof, make sure we serialize it. This ensures that the byte string
1104 # output isn't simply an empty string. This also ensures that round trip
1105 # serialization will keep `which_one_of` calls consistent.
1106 if (
1107 isinstance(value, str)
1108 and value == ""
1109 and include_default_value_for_oneof
1110 ):
1111 serialize_empty = True
1112
1113 size += _len_single(
1114 meta.number,
1115 meta.proto_type,
1116 value,
1117 serialize_empty=serialize_empty or bool(selected_in_group),
1118 wraps=meta.wraps or "",
1119 )
1120
1121 size += len(self._unknown_fields)
1122 return size
1123
1124 # For compatibility with other libraries
1125 def SerializeToString(self: T) -> bytes:
1126 """
1127 Get the binary encoded Protobuf representation of this message instance.
1128
1129 .. note::
1130 This is a method for compatibility with other libraries,
1131 you should really use ``bytes(x)``.
1132
1133 Returns
1134 --------
1135 :class:`bytes`
1136 The binary encoded Protobuf representation of this message instance
1137 """
1138 return bytes(self)
1139
1140 def __getstate__(self) -> bytes:
1141 return bytes(self)
1142
1143 def __setstate__(self: T, pickled_bytes: bytes) -> T:
1144 return self.parse(pickled_bytes)
1145
1146 def __reduce__(self) -> Tuple[Any, ...]:
1147 return (self.__class__.FromString, (bytes(self),))
1148
1149 @classmethod
1150 def _type_hint(cls, field_name: str) -> Type:
1151 return cls._type_hints()[field_name]
1152
1153 @classmethod
1154 def _type_hints(cls) -> Dict[str, Type]:
1155 module = sys.modules[cls.__module__]
1156 return get_type_hints(cls, module.__dict__, {})
1157
1158 @classmethod
1159 def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
1160 """Get the message class for a field from the type hints."""
1161 field_cls = cls._type_hint(field.name)
1162 if hasattr(field_cls, "__args__") and index >= 0:
1163 if field_cls.__args__ is not None:
1164 field_cls = field_cls.__args__[index]
1165 return field_cls
1166
1167 def _get_field_default(self, field_name: str) -> Any:
1168 with warnings.catch_warnings():
1169 # ignore warnings when initialising deprecated field defaults
1170 warnings.filterwarnings("ignore", category=DeprecationWarning)
1171 return self._betterproto.default_gen[field_name]()
1172
1173 @classmethod
1174 def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
1175 t = cls._type_hint(field.name)
1176
1177 is_310_union = isinstance(t, _types_UnionType)
1178 if hasattr(t, "__origin__") or is_310_union:
1179 if is_310_union or t.__origin__ is Union:
1180 # This is an optional field (either wrapped, or using proto3
1181 # field presence). For setting the default we really don't care
1182 # what kind of field it is.
1183 return type(None)
1184 if t.__origin__ is list:
1185 # This is some kind of list (repeated) field.
1186 return list
1187 if t.__origin__ is dict:
1188 # This is some kind of map (dict in Python).
1189 return dict
1190 return t
1191 if issubclass(t, Enum):
1192 # Enums always default to zero.
1193 return t.try_value
1194 if t is datetime:
1195 # Offsets are relative to 1970-01-01T00:00:00Z
1196 return datetime_default_gen
1197 # This is either a primitive scalar or another message type. Calling
1198 # it should result in its zero value.
1199 return t
1200
1201 def _postprocess_single(
1202 self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
1203 ) -> Any:
1204 """Adjusts values after parsing."""
1205 if wire_type == WIRE_VARINT:
1206 if meta.proto_type in (TYPE_INT32, TYPE_INT64):
1207 bits = int(meta.proto_type[3:])
1208 value = value & ((1 << bits) - 1)
1209 signbit = 1 << (bits - 1)
1210 value = int((value ^ signbit) - signbit)
1211 elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64):
1212 # Undo zig-zag encoding
1213 value = (value >> 1) ^ (-(value & 1))
1214 elif meta.proto_type == TYPE_BOOL:
1215 # Booleans use a varint encoding, so convert it to true/false.
1216 value = value > 0
1217 elif meta.proto_type == TYPE_ENUM:
1218 # Convert enum ints to python enum instances
1219 value = self._betterproto.cls_by_field[field_name].try_value(value)
1220 elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
1221 fmt = _pack_fmt(meta.proto_type)
1222 value = struct.unpack(fmt, value)[0]
1223 elif wire_type == WIRE_LEN_DELIM:
1224 if meta.proto_type == TYPE_STRING:
1225 value = str(value, "utf-8")
1226 elif meta.proto_type == TYPE_MESSAGE:
1227 cls = self._betterproto.cls_by_field[field_name]
1228
1229 if cls == datetime:
1230 value = _Timestamp().parse(value).to_datetime()
1231 elif cls == timedelta:
1232 value = _Duration().parse(value).to_timedelta()
1233 elif meta.wraps:
1234 # This is a Google wrapper value message around a single
1235 # scalar type.
1236 value = _get_wrapper(meta.wraps)().parse(value).value
1237 else:
1238 value = cls().parse(value)
1239 value._serialized_on_wire = True
1240 elif meta.proto_type == TYPE_MAP:
1241 value = self._betterproto.cls_by_field[field_name]().parse(value)
1242
1243 return value
1244
1245 def _include_default_value_for_oneof(
1246 self, field_name: str, meta: FieldMetadata
1247 ) -> bool:
1248 return (
1249 meta.group is not None and self._group_current.get(meta.group) == field_name
1250 )
1251
1252 def load(
1253 self: T,
1254 stream: "SupportsRead[bytes]",
1255 size: Optional[int] = None,
1256 ) -> T:
1257 """
1258 Load the binary encoded Protobuf from a stream into this message instance. This
1259 returns the instance itself and is therefore assignable and chainable.
1260
1261 Parameters
1262 -----------
1263 stream: :class:`bytes`
1264 The stream to load the message from.
1265 size: :class:`Optional[int]`
1266 The size of the message in the stream.
1267 Reads stream until EOF if ``None`` is given.
1268 Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
1269
1270 Returns
1271 --------
1272 :class:`Message`
1273 The initialized message.
1274 """
1275 # If the message is delimited, parse the message delimiter
1276 if size == SIZE_DELIMITED:
1277 size, _ = load_varint(stream)
1278
1279 # Got some data over the wire
1280 self._serialized_on_wire = True
1281 proto_meta = self._betterproto
1282 read = 0
1283 for parsed in load_fields(stream):
1284 field_name = proto_meta.field_name_by_number.get(parsed.number)
1285 if not field_name:
1286 self._unknown_fields += parsed.raw
1287 continue
1288
1289 meta = proto_meta.meta_by_field_name[field_name]
1290
1291 value: Any
1292 if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
1293 # This is a packed repeated field.
1294 pos = 0
1295 value = []
1296 while pos < len(parsed.value):
1297 if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32):
1298 decoded, pos = parsed.value[pos : pos + 4], pos + 4
1299 wire_type = WIRE_FIXED_32
1300 elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64):
1301 decoded, pos = parsed.value[pos : pos + 8], pos + 8
1302 wire_type = WIRE_FIXED_64
1303 else:
1304 decoded, pos = decode_varint(parsed.value, pos)
1305 wire_type = WIRE_VARINT
1306 decoded = self._postprocess_single(
1307 wire_type, meta, field_name, decoded
1308 )
1309 value.append(decoded)
1310 else:
1311 value = self._postprocess_single(
1312 parsed.wire_type, meta, field_name, parsed.value
1313 )
1314
1315 try:
1316 current = getattr(self, field_name)
1317 except AttributeError:
1318 current = self._get_field_default(field_name)
1319 setattr(self, field_name, current)
1320
1321 if meta.proto_type == TYPE_MAP:
1322 # Value represents a single key/value pair entry in the map.
1323 current[value.key] = value.value
1324 elif isinstance(current, list) and not isinstance(value, list):
1325 current.append(value)
1326 else:
1327 setattr(self, field_name, value)
1328
1329 # If we have now loaded the expected length of the message, stop
1330 if size is not None:
1331 prev = read
1332 read += len(parsed.raw)
1333 if read == size:
1334 break
1335 elif read > size:
1336 raise ValueError(
1337 f"Expected message of size {size}, can only read "
1338 f"either {prev} or {read} bytes - there is no "
1339 "message of the expected size in the stream."
1340 )
1341
1342 if size is not None and read < size:
1343 raise ValueError(
1344 f"Expected message of size {size}, but was only able to "
1345 f"read {read} bytes - the stream may have ended too soon,"
1346 " or the expected size may have been incorrect."
1347 )
1348
1349 return self
1350
1351 def parse(self: T, data: bytes) -> T:
1352 """
1353 Parse the binary encoded Protobuf into this message instance. This
1354 returns the instance itself and is therefore assignable and chainable.
1355
1356 Parameters
1357 -----------
1358 data: :class:`bytes`
1359 The data to parse the message from.
1360
1361 Returns
1362 --------
1363 :class:`Message`
1364 The initialized message.
1365 """
1366 with BytesIO(data) as stream:
1367 return self.load(stream)
1368
1369 # For compatibility with other libraries.
1370 @classmethod
1371 def FromString(cls: Type[T], data: bytes) -> T:
1372 """
1373 Parse the binary encoded Protobuf into this message instance. This
1374 returns the instance itself and is therefore assignable and chainable.
1375
1376 .. note::
1377 This is a method for compatibility with other libraries,
1378 you should really use :meth:`parse`.
1379
1380
1381 Parameters
1382 -----------
1383 data: :class:`bytes`
1384 The data to parse the protobuf from.
1385
1386 Returns
1387 --------
1388 :class:`Message`
1389 The initialized message.
1390 """
1391 return cls().parse(data)
1392
1393 def to_dict(
1394 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
1395 ) -> Dict[str, Any]:
1396 """
1397 Returns a JSON serializable dict representation of this object.
1398
1399 Parameters
1400 -----------
1401 casing: :class:`Casing`
1402 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1403 compatibility purposes.
1404 include_default_values: :class:`bool`
1405 If ``True`` will include the default values of fields. Default is ``False``.
1406 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1407 set to ``True``, otherwise this would be ignored.
1408
1409 Returns
1410 --------
1411 Dict[:class:`str`, Any]
1412 The JSON serializable dict representation of this object.
1413 """
1414 output: Dict[str, Any] = {}
1415 field_types = self._type_hints()
1416 defaults = self._betterproto.default_gen
1417 for field_name, meta in self._betterproto.meta_by_field_name.items():
1418 field_is_repeated = defaults[field_name] is list
1419 try:
1420 value = getattr(self, field_name)
1421 except AttributeError:
1422 value = self._get_field_default(field_name)
1423 cased_name = casing(field_name).rstrip("_") # type: ignore
1424 if meta.proto_type == TYPE_MESSAGE:
1425 if isinstance(value, datetime):
1426 if (
1427 value != DATETIME_ZERO
1428 or include_default_values
1429 or self._include_default_value_for_oneof(
1430 field_name=field_name, meta=meta
1431 )
1432 ):
1433 output[cased_name] = _Timestamp.timestamp_to_json(value)
1434 elif isinstance(value, timedelta):
1435 if (
1436 value != timedelta(0)
1437 or include_default_values
1438 or self._include_default_value_for_oneof(
1439 field_name=field_name, meta=meta
1440 )
1441 ):
1442 output[cased_name] = _Duration.delta_to_json(value)
1443 elif meta.wraps:
1444 if value is not None or include_default_values:
1445 output[cased_name] = value
1446 elif field_is_repeated:
1447 # Convert each item.
1448 cls = self._betterproto.cls_by_field[field_name]
1449 if cls == datetime:
1450 value = [_Timestamp.timestamp_to_json(i) for i in value]
1451 elif cls == timedelta:
1452 value = [_Duration.delta_to_json(i) for i in value]
1453 else:
1454 value = [
1455 i.to_dict(casing, include_default_values) for i in value
1456 ]
1457 if value or include_default_values:
1458 output[cased_name] = value
1459 elif value is None:
1460 if include_default_values:
1461 output[cased_name] = value
1462 elif (
1463 value._serialized_on_wire
1464 or include_default_values
1465 or self._include_default_value_for_oneof(
1466 field_name=field_name, meta=meta
1467 )
1468 ):
1469 output[cased_name] = value.to_dict(casing, include_default_values)
1470 elif meta.proto_type == TYPE_MAP:
1471 output_map = {**value}
1472 for k in value:
1473 if hasattr(value[k], "to_dict"):
1474 output_map[k] = value[k].to_dict(casing, include_default_values)
1475
1476 if value or include_default_values:
1477 output[cased_name] = output_map
1478 elif (
1479 value != self._get_field_default(field_name)
1480 or include_default_values
1481 or self._include_default_value_for_oneof(
1482 field_name=field_name, meta=meta
1483 )
1484 ):
1485 if meta.proto_type in INT_64_TYPES:
1486 if field_is_repeated:
1487 output[cased_name] = [str(n) for n in value]
1488 elif value is None:
1489 if include_default_values:
1490 output[cased_name] = value
1491 else:
1492 output[cased_name] = str(value)
1493 elif meta.proto_type == TYPE_BYTES:
1494 if field_is_repeated:
1495 output[cased_name] = [
1496 b64encode(b).decode("utf8") for b in value
1497 ]
1498 elif value is None and include_default_values:
1499 output[cased_name] = value
1500 else:
1501 output[cased_name] = b64encode(value).decode("utf8")
1502 elif meta.proto_type == TYPE_ENUM:
1503 if field_is_repeated:
1504 enum_class = field_types[field_name].__args__[0]
1505 if isinstance(value, typing.Iterable) and not isinstance(
1506 value, str
1507 ):
1508 output[cased_name] = [enum_class(el).name for el in value]
1509 else:
1510 # transparently upgrade single value to repeated
1511 output[cased_name] = [enum_class(value).name]
1512 elif value is None:
1513 if include_default_values:
1514 output[cased_name] = value
1515 elif meta.optional:
1516 enum_class = field_types[field_name].__args__[0]
1517 output[cased_name] = enum_class(value).name
1518 else:
1519 enum_class = field_types[field_name] # noqa
1520 output[cased_name] = enum_class(value).name
1521 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1522 if field_is_repeated:
1523 output[cased_name] = [_dump_float(n) for n in value]
1524 else:
1525 output[cased_name] = _dump_float(value)
1526 else:
1527 output[cased_name] = value
1528 return output
1529
1530 @classmethod
1531 def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
1532 init_kwargs: Dict[str, Any] = {}
1533 for key, value in mapping.items():
1534 field_name = safe_snake_case(key)
1535 try:
1536 meta = cls._betterproto.meta_by_field_name[field_name]
1537 except KeyError:
1538 continue
1539 if value is None:
1540 continue
1541
1542 if meta.proto_type == TYPE_MESSAGE:
1543 sub_cls = cls._betterproto.cls_by_field[field_name]
1544 if sub_cls == datetime:
1545 value = (
1546 [isoparse(item) for item in value]
1547 if isinstance(value, list)
1548 else isoparse(value)
1549 )
1550 elif sub_cls == timedelta:
1551 value = (
1552 [timedelta(seconds=float(item[:-1])) for item in value]
1553 if isinstance(value, list)
1554 else timedelta(seconds=float(value[:-1]))
1555 )
1556 elif not meta.wraps:
1557 value = (
1558 [sub_cls.from_dict(item) for item in value]
1559 if isinstance(value, list)
1560 else sub_cls.from_dict(value)
1561 )
1562 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1563 sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
1564 value = {k: sub_cls.from_dict(v) for k, v in value.items()}
1565 else:
1566 if meta.proto_type in INT_64_TYPES:
1567 value = (
1568 [int(n) for n in value]
1569 if isinstance(value, list)
1570 else int(value)
1571 )
1572 elif meta.proto_type == TYPE_BYTES:
1573 value = (
1574 [b64decode(n) for n in value]
1575 if isinstance(value, list)
1576 else b64decode(value)
1577 )
1578 elif meta.proto_type == TYPE_ENUM:
1579 enum_cls = cls._betterproto.cls_by_field[field_name]
1580 if isinstance(value, list):
1581 value = [enum_cls.from_string(e) for e in value]
1582 elif isinstance(value, str):
1583 value = enum_cls.from_string(value)
1584 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1585 value = (
1586 [_parse_float(n) for n in value]
1587 if isinstance(value, list)
1588 else _parse_float(value)
1589 )
1590
1591 init_kwargs[field_name] = value
1592 return init_kwargs
1593
1594 @hybridmethod
1595 def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
1596 """
1597 Parse the key/value pairs into the a new message instance.
1598
1599 Parameters
1600 -----------
1601 value: Dict[:class:`str`, Any]
1602 The dictionary to parse from.
1603
1604 Returns
1605 --------
1606 :class:`Message`
1607 The initialized message.
1608 """
1609 self = cls(**cls._from_dict_init(value))
1610 self._serialized_on_wire = True
1611 return self
1612
1613 @from_dict.instancemethod
1614 def from_dict(self, value: Mapping[str, Any]) -> Self:
1615 """
1616 Parse the key/value pairs into the current message instance. This returns the
1617 instance itself and is therefore assignable and chainable.
1618
1619 Parameters
1620 -----------
1621 value: Dict[:class:`str`, Any]
1622 The dictionary to parse from.
1623
1624 Returns
1625 --------
1626 :class:`Message`
1627 The initialized message.
1628 """
1629 self._serialized_on_wire = True
1630 for field, value in self._from_dict_init(value).items():
1631 setattr(self, field, value)
1632 return self
1633
1634 def to_json(
1635 self,
1636 indent: Union[None, int, str] = None,
1637 include_default_values: bool = False,
1638 casing: Casing = Casing.CAMEL,
1639 ) -> str:
1640 """A helper function to parse the message instance into its JSON
1641 representation.
1642
1643 This is equivalent to::
1644
1645 json.dumps(message.to_dict(), indent=indent)
1646
1647 Parameters
1648 -----------
1649 indent: Optional[Union[:class:`int`, :class:`str`]]
1650 The indent to pass to :func:`json.dumps`.
1651
1652 include_default_values: :class:`bool`
1653 If ``True`` will include the default values of fields. Default is ``False``.
1654 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1655 set to ``True``, otherwise this would be ignored.
1656
1657 casing: :class:`Casing`
1658 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1659 compatibility purposes.
1660
1661 Returns
1662 --------
1663 :class:`str`
1664 The JSON representation of the message.
1665 """
1666 return json.dumps(
1667 self.to_dict(include_default_values=include_default_values, casing=casing),
1668 indent=indent,
1669 )
1670
1671 def from_json(self: T, value: Union[str, bytes]) -> T:
1672 """A helper function to return the message instance from its JSON
1673 representation. This returns the instance itself and is therefore assignable
1674 and chainable.
1675
1676 This is equivalent to::
1677
1678 return message.from_dict(json.loads(value))
1679
1680 Parameters
1681 -----------
1682 value: Union[:class:`str`, :class:`bytes`]
1683 The value to pass to :func:`json.loads`.
1684
1685 Returns
1686 --------
1687 :class:`Message`
1688 The initialized message.
1689 """
1690 return self.from_dict(json.loads(value))
1691
1692 def to_pydict(
1693 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
1694 ) -> Dict[str, Any]:
1695 """
1696 Returns a python dict representation of this object.
1697
1698 Parameters
1699 -----------
1700 casing: :class:`Casing`
1701 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1702 compatibility purposes.
1703 include_default_values: :class:`bool`
1704 If ``True`` will include the default values of fields. Default is ``False``.
1705 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1706 set to ``True``, otherwise this would be ignored.
1707
1708 Returns
1709 --------
1710 Dict[:class:`str`, Any]
1711 The python dict representation of this object.
1712 """
1713 output: Dict[str, Any] = {}
1714 defaults = self._betterproto.default_gen
1715 for field_name, meta in self._betterproto.meta_by_field_name.items():
1716 field_is_repeated = defaults[field_name] is list
1717 value = getattr(self, field_name)
1718 cased_name = casing(field_name).rstrip("_") # type: ignore
1719 if meta.proto_type == TYPE_MESSAGE:
1720 if isinstance(value, datetime):
1721 if (
1722 value != DATETIME_ZERO
1723 or include_default_values
1724 or self._include_default_value_for_oneof(
1725 field_name=field_name, meta=meta
1726 )
1727 ):
1728 output[cased_name] = value
1729 elif isinstance(value, timedelta):
1730 if (
1731 value != timedelta(0)
1732 or include_default_values
1733 or self._include_default_value_for_oneof(
1734 field_name=field_name, meta=meta
1735 )
1736 ):
1737 output[cased_name] = value
1738 elif meta.wraps:
1739 if value is not None or include_default_values:
1740 output[cased_name] = value
1741 elif field_is_repeated:
1742 # Convert each item.
1743 value = [i.to_pydict(casing, include_default_values) for i in value]
1744 if value or include_default_values:
1745 output[cased_name] = value
1746 elif value is None:
1747 if include_default_values:
1748 output[cased_name] = None
1749 elif (
1750 value._serialized_on_wire
1751 or include_default_values
1752 or self._include_default_value_for_oneof(
1753 field_name=field_name, meta=meta
1754 )
1755 ):
1756 output[cased_name] = value.to_pydict(casing, include_default_values)
1757 elif meta.proto_type == TYPE_MAP:
1758 for k in value:
1759 if hasattr(value[k], "to_pydict"):
1760 value[k] = value[k].to_pydict(casing, include_default_values)
1761
1762 if value or include_default_values:
1763 output[cased_name] = value
1764 elif (
1765 value != self._get_field_default(field_name)
1766 or include_default_values
1767 or self._include_default_value_for_oneof(
1768 field_name=field_name, meta=meta
1769 )
1770 ):
1771 output[cased_name] = value
1772 return output
1773
1774 def from_pydict(self: T, value: Mapping[str, Any]) -> T:
1775 """
1776 Parse the key/value pairs into the current message instance. This returns the
1777 instance itself and is therefore assignable and chainable.
1778
1779 Parameters
1780 -----------
1781 value: Dict[:class:`str`, Any]
1782 The dictionary to parse from.
1783
1784 Returns
1785 --------
1786 :class:`Message`
1787 The initialized message.
1788 """
1789 self._serialized_on_wire = True
1790 for key in value:
1791 field_name = safe_snake_case(key)
1792 meta = self._betterproto.meta_by_field_name.get(field_name)
1793 if not meta:
1794 continue
1795
1796 if value[key] is not None:
1797 if meta.proto_type == TYPE_MESSAGE:
1798 v = getattr(self, field_name)
1799 if isinstance(v, list):
1800 cls = self._betterproto.cls_by_field[field_name]
1801 for item in value[key]:
1802 v.append(cls().from_pydict(item))
1803 elif isinstance(v, datetime):
1804 v = value[key]
1805 elif isinstance(v, timedelta):
1806 v = value[key]
1807 elif meta.wraps:
1808 v = value[key]
1809 else:
1810 # NOTE: `from_pydict` mutates the underlying message, so no
1811 # assignment here is necessary.
1812 v.from_pydict(value[key])
1813 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1814 v = getattr(self, field_name)
1815 cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1816 for k in value[key]:
1817 v[k] = cls().from_pydict(value[key][k])
1818 else:
1819 v = value[key]
1820
1821 if v is not None:
1822 setattr(self, field_name, v)
1823 return self
1824
1825 def is_set(self, name: str) -> bool:
1826 """
1827 Check if field with the given name has been set.
1828
1829 Parameters
1830 -----------
1831 name: :class:`str`
1832 The name of the field to check for.
1833
1834 Returns
1835 --------
1836 :class:`bool`
1837 `True` if field has been set, otherwise `False`.
1838 """
1839 default = (
1840 PLACEHOLDER
1841 if not self._betterproto.meta_by_field_name[name].optional
1842 else None
1843 )
1844 return self.__raw_get(name) is not default
1845
1846 @classmethod
1847 def _validate_field_groups(cls, values):
1848 group_to_one_ofs = cls._betterproto.oneof_field_by_group
1849 field_name_to_meta = cls._betterproto.meta_by_field_name
1850
1851 for group, field_set in group_to_one_ofs.items():
1852 if len(field_set) == 1:
1853 (field,) = field_set
1854 field_name = field.name
1855 meta = field_name_to_meta[field_name]
1856
1857 # This is a synthetic oneof; we should ignore it's presence and not consider it as a oneof.
1858 if meta.optional:
1859 continue
1860
1861 set_fields = [
1862 field.name
1863 for field in field_set
1864 if getattr(values, field.name, None) is not None
1865 ]
1866
1867 if not set_fields:
1868 raise ValueError(f"Group {group} has no value; all fields are None")
1869 elif len(set_fields) > 1:
1870 set_fields_str = ", ".join(set_fields)
1871 raise ValueError(
1872 f"Group {group} has more than one value; fields {set_fields_str} are not None"
1873 )
1874
1875 return values
1876
1877
1878Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
1879
1880# monkey patch (de-)serialization functions of class `Message`
1881# with functions from `betterproto-rust-codec` if available
1882try:
1883 import betterproto_rust_codec
1884
1885 def __parse_patch(self: T, data: bytes) -> T:
1886 betterproto_rust_codec.deserialize(self, data)
1887 return self
1888
1889 def __bytes_patch(self) -> bytes:
1890 return betterproto_rust_codec.serialize(self)
1891
1892 Message.parse = __parse_patch
1893 Message.__bytes__ = __bytes_patch
1894except ModuleNotFoundError:
1895 pass
1896
1897
1898def serialized_on_wire(message: Message) -> bool:
1899 """
1900 If this message was or should be serialized on the wire. This can be used to detect
1901 presence (e.g. optional wrapper message) and is used internally during
1902 parsing/serialization.
1903
1904 Returns
1905 --------
1906 :class:`bool`
1907 Whether this message was or should be serialized on the wire.
1908 """
1909 return message._serialized_on_wire
1910
1911
1912def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
1913 """
1914 Return the name and value of a message's one-of field group.
1915
1916 Returns
1917 --------
1918 Tuple[:class:`str`, Any]
1919 The field name and the value for that field.
1920 """
1921 field_name = message._group_current.get(group_name)
1922 if not field_name:
1923 return "", None
1924 return field_name, getattr(message, field_name)
1925
1926
1927# Circular import workaround: google.protobuf depends on base classes defined above.
1928from .lib.google.protobuf import ( # noqa
1929 BoolValue,
1930 BytesValue,
1931 DoubleValue,
1932 Duration,
1933 EnumValue,
1934 FloatValue,
1935 Int32Value,
1936 Int64Value,
1937 StringValue,
1938 Timestamp,
1939 UInt32Value,
1940 UInt64Value,
1941)
1942
1943
1944class _Duration(Duration):
1945 @classmethod
1946 def from_timedelta(
1947 cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
1948 ) -> "_Duration":
1949 total_ms = delta // _1_microsecond
1950 seconds = int(total_ms / 1e6)
1951 nanos = int((total_ms % 1e6) * 1e3)
1952 return cls(seconds, nanos)
1953
1954 def to_timedelta(self) -> timedelta:
1955 return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
1956
1957 @staticmethod
1958 def delta_to_json(delta: timedelta) -> str:
1959 parts = str(delta.total_seconds()).split(".")
1960 if len(parts) > 1:
1961 while len(parts[1]) not in (3, 6, 9):
1962 parts[1] = f"{parts[1]}0"
1963 return f"{'.'.join(parts)}s"
1964
1965
1966class _Timestamp(Timestamp):
1967 @classmethod
1968 def from_datetime(cls, dt: datetime) -> "_Timestamp":
1969 # manual epoch offset calulation to avoid rounding errors,
1970 # to support negative timestamps (before 1970) and skirt
1971 # around datetime bugs (apparently 0 isn't a year in [0, 9999]??)
1972 offset = dt - DATETIME_ZERO
1973 # below is the same as timedelta.total_seconds() but without dividing by 1e6
1974 # so we end up with microseconds as integers instead of seconds as float
1975 offset_us = (
1976 offset.days * 24 * 60 * 60 + offset.seconds
1977 ) * 10**6 + offset.microseconds
1978 seconds, us = divmod(offset_us, 10**6)
1979 return cls(seconds, us * 1000)
1980
1981 def to_datetime(self) -> datetime:
1982 # datetime.fromtimestamp() expects a timestamp in seconds, not microseconds
1983 # if we pass it as a floating point number, we will run into rounding errors
1984 # see also #407
1985 offset = timedelta(seconds=self.seconds, microseconds=self.nanos // 1000)
1986 return DATETIME_ZERO + offset
1987
1988 @staticmethod
1989 def timestamp_to_json(dt: datetime) -> str:
1990 nanos = dt.microsecond * 1e3
1991 if dt.tzinfo is not None:
1992 # change timezone aware datetime objects to utc
1993 dt = dt.astimezone(timezone.utc)
1994 copy = dt.replace(microsecond=0, tzinfo=None)
1995 result = copy.isoformat()
1996 if (nanos % 1e9) == 0:
1997 # If there are 0 fractional digits, the fractional
1998 # point '.' should be omitted when serializing.
1999 return f"{result}Z"
2000 if (nanos % 1e6) == 0:
2001 # Serialize 3 fractional digits.
2002 return f"{result}.{int(nanos // 1e6) :03d}Z"
2003 if (nanos % 1e3) == 0:
2004 # Serialize 6 fractional digits.
2005 return f"{result}.{int(nanos // 1e3) :06d}Z"
2006 # Serialize 9 fractional digits.
2007 return f"{result}.{nanos:09d}"
2008
2009
2010def _get_wrapper(proto_type: str) -> Type:
2011 """Get the wrapper message class for a wrapped type."""
2012
2013 # TODO: include ListValue and NullValue?
2014 return {
2015 TYPE_BOOL: BoolValue,
2016 TYPE_BYTES: BytesValue,
2017 TYPE_DOUBLE: DoubleValue,
2018 TYPE_FLOAT: FloatValue,
2019 TYPE_ENUM: EnumValue,
2020 TYPE_INT32: Int32Value,
2021 TYPE_INT64: Int64Value,
2022 TYPE_STRING: StringValue,
2023 TYPE_UINT32: UInt32Value,
2024 TYPE_UINT64: UInt64Value,
2025 }[proto_type]