1import dataclasses
2import enum
3import json
4import math
5import struct
6import sys
7import typing
8import warnings
9from abc import ABC
10from base64 import (
11 b64decode,
12 b64encode,
13)
14from copy import deepcopy
15from datetime import (
16 datetime,
17 timedelta,
18 timezone,
19)
20from typing import (
21 TYPE_CHECKING,
22 Any,
23 Callable,
24 Dict,
25 Generator,
26 Iterable,
27 List,
28 Mapping,
29 Optional,
30 Set,
31 Tuple,
32 Type,
33 Union,
34 get_type_hints,
35)
36
37from dateutil.parser import isoparse
38
39from ._types import T
40from ._version import __version__
41from .casing import (
42 camel_case,
43 safe_snake_case,
44 snake_case,
45)
46from .grpc.grpclib_client import ServiceStub
47
48
49# Proto 3 data types
50TYPE_ENUM = "enum"
51TYPE_BOOL = "bool"
52TYPE_INT32 = "int32"
53TYPE_INT64 = "int64"
54TYPE_UINT32 = "uint32"
55TYPE_UINT64 = "uint64"
56TYPE_SINT32 = "sint32"
57TYPE_SINT64 = "sint64"
58TYPE_FLOAT = "float"
59TYPE_DOUBLE = "double"
60TYPE_FIXED32 = "fixed32"
61TYPE_SFIXED32 = "sfixed32"
62TYPE_FIXED64 = "fixed64"
63TYPE_SFIXED64 = "sfixed64"
64TYPE_STRING = "string"
65TYPE_BYTES = "bytes"
66TYPE_MESSAGE = "message"
67TYPE_MAP = "map"
68
69
70# Fields that use a fixed amount of space (4 or 8 bytes)
71FIXED_TYPES = [
72 TYPE_FLOAT,
73 TYPE_DOUBLE,
74 TYPE_FIXED32,
75 TYPE_SFIXED32,
76 TYPE_FIXED64,
77 TYPE_SFIXED64,
78]
79
80# Fields that are numerical 64-bit types
81INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
82
83# Fields that are efficiently packed when
84PACKED_TYPES = [
85 TYPE_ENUM,
86 TYPE_BOOL,
87 TYPE_INT32,
88 TYPE_INT64,
89 TYPE_UINT32,
90 TYPE_UINT64,
91 TYPE_SINT32,
92 TYPE_SINT64,
93 TYPE_FLOAT,
94 TYPE_DOUBLE,
95 TYPE_FIXED32,
96 TYPE_SFIXED32,
97 TYPE_FIXED64,
98 TYPE_SFIXED64,
99]
100
101# Wire types
102# https://developers.google.com/protocol-buffers/docs/encoding#structure
103WIRE_VARINT = 0
104WIRE_FIXED_64 = 1
105WIRE_LEN_DELIM = 2
106WIRE_FIXED_32 = 5
107
108# Mappings of which Proto 3 types correspond to which wire types.
109WIRE_VARINT_TYPES = [
110 TYPE_ENUM,
111 TYPE_BOOL,
112 TYPE_INT32,
113 TYPE_INT64,
114 TYPE_UINT32,
115 TYPE_UINT64,
116 TYPE_SINT32,
117 TYPE_SINT64,
118]
119
120WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
121WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
122WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
123
124
125# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
126def datetime_default_gen() -> datetime:
127 return datetime(1970, 1, 1, tzinfo=timezone.utc)
128
129
130DATETIME_ZERO = datetime_default_gen()
131
132
133# Special protobuf json doubles
134INFINITY = "Infinity"
135NEG_INFINITY = "-Infinity"
136NAN = "NaN"
137
138
139class Casing(enum.Enum):
140 """Casing constants for serialization."""
141
142 CAMEL = camel_case #: A camelCase sterilization function.
143 SNAKE = snake_case #: A snake_case sterilization function.
144
145
146PLACEHOLDER: Any = object()
147
148
149@dataclasses.dataclass(frozen=True)
150class FieldMetadata:
151 """Stores internal metadata used for parsing & serialization."""
152
153 # Protobuf field number
154 number: int
155 # Protobuf type name
156 proto_type: str
157 # Map information if the proto_type is a map
158 map_types: Optional[Tuple[str, str]] = None
159 # Groups several "one-of" fields together
160 group: Optional[str] = None
161 # Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
162 wraps: Optional[str] = None
163 # Is the field optional
164 optional: Optional[bool] = False
165
166 @staticmethod
167 def get(field: dataclasses.Field) -> "FieldMetadata":
168 """Returns the field metadata for a dataclass field."""
169 return field.metadata["betterproto"]
170
171
172def dataclass_field(
173 number: int,
174 proto_type: str,
175 *,
176 map_types: Optional[Tuple[str, str]] = None,
177 group: Optional[str] = None,
178 wraps: Optional[str] = None,
179 optional: bool = False,
180) -> dataclasses.Field:
181 """Creates a dataclass field with attached protobuf metadata."""
182 return dataclasses.field(
183 default=None if optional else PLACEHOLDER,
184 metadata={
185 "betterproto": FieldMetadata(
186 number, proto_type, map_types, group, wraps, optional
187 )
188 },
189 )
190
191
192# Note: the fields below return `Any` to prevent type errors in the generated
193# data classes since the types won't match with `Field` and they get swapped
194# out at runtime. The generated dataclass variables are still typed correctly.
195
196
197def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
198 return dataclass_field(number, TYPE_ENUM, group=group, optional=optional)
199
200
201def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any:
202 return dataclass_field(number, TYPE_BOOL, group=group, optional=optional)
203
204
205def int32_field(
206 number: int, group: Optional[str] = None, optional: bool = False
207) -> Any:
208 return dataclass_field(number, TYPE_INT32, group=group, optional=optional)
209
210
211def int64_field(
212 number: int, group: Optional[str] = None, optional: bool = False
213) -> Any:
214 return dataclass_field(number, TYPE_INT64, group=group, optional=optional)
215
216
217def uint32_field(
218 number: int, group: Optional[str] = None, optional: bool = False
219) -> Any:
220 return dataclass_field(number, TYPE_UINT32, group=group, optional=optional)
221
222
223def uint64_field(
224 number: int, group: Optional[str] = None, optional: bool = False
225) -> Any:
226 return dataclass_field(number, TYPE_UINT64, group=group, optional=optional)
227
228
229def sint32_field(
230 number: int, group: Optional[str] = None, optional: bool = False
231) -> Any:
232 return dataclass_field(number, TYPE_SINT32, group=group, optional=optional)
233
234
235def sint64_field(
236 number: int, group: Optional[str] = None, optional: bool = False
237) -> Any:
238 return dataclass_field(number, TYPE_SINT64, group=group, optional=optional)
239
240
241def float_field(
242 number: int, group: Optional[str] = None, optional: bool = False
243) -> Any:
244 return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional)
245
246
247def double_field(
248 number: int, group: Optional[str] = None, optional: bool = False
249) -> Any:
250 return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional)
251
252
253def fixed32_field(
254 number: int, group: Optional[str] = None, optional: bool = False
255) -> Any:
256 return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional)
257
258
259def fixed64_field(
260 number: int, group: Optional[str] = None, optional: bool = False
261) -> Any:
262 return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional)
263
264
265def sfixed32_field(
266 number: int, group: Optional[str] = None, optional: bool = False
267) -> Any:
268 return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional)
269
270
271def sfixed64_field(
272 number: int, group: Optional[str] = None, optional: bool = False
273) -> Any:
274 return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional)
275
276
277def string_field(
278 number: int, group: Optional[str] = None, optional: bool = False
279) -> Any:
280 return dataclass_field(number, TYPE_STRING, group=group, optional=optional)
281
282
283def bytes_field(
284 number: int, group: Optional[str] = None, optional: bool = False
285) -> Any:
286 return dataclass_field(number, TYPE_BYTES, group=group, optional=optional)
287
288
289def message_field(
290 number: int,
291 group: Optional[str] = None,
292 wraps: Optional[str] = None,
293 optional: bool = False,
294) -> Any:
295 return dataclass_field(
296 number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional
297 )
298
299
300def map_field(
301 number: int, key_type: str, value_type: str, group: Optional[str] = None
302) -> Any:
303 return dataclass_field(
304 number, TYPE_MAP, map_types=(key_type, value_type), group=group
305 )
306
307
308class Enum(enum.IntEnum):
309 """
310 The base class for protobuf enumerations, all generated enumerations will inherit
311 from this. Bases :class:`enum.IntEnum`.
312 """
313
314 @classmethod
315 def from_string(cls, name: str) -> "Enum":
316 """Return the value which corresponds to the string name.
317
318 Parameters
319 -----------
320 name: :class:`str`
321 The name of the enum member to get
322
323 Raises
324 -------
325 :exc:`ValueError`
326 The member was not found in the Enum.
327 """
328 try:
329 return cls._member_map_[name] # type: ignore
330 except KeyError as e:
331 raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
332
333
334def _pack_fmt(proto_type: str) -> str:
335 """Returns a little-endian format string for reading/writing binary."""
336 return {
337 TYPE_DOUBLE: "<d",
338 TYPE_FLOAT: "<f",
339 TYPE_FIXED32: "<I",
340 TYPE_FIXED64: "<Q",
341 TYPE_SFIXED32: "<i",
342 TYPE_SFIXED64: "<q",
343 }[proto_type]
344
345
346def encode_varint(value: int) -> bytes:
347 """Encodes a single varint value for serialization."""
348 b: List[int] = []
349
350 if value < 0:
351 value += 1 << 64
352
353 bits = value & 0x7F
354 value >>= 7
355 while value:
356 b.append(0x80 | bits)
357 bits = value & 0x7F
358 value >>= 7
359 return bytes(b + [bits])
360
361
362def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
363 """Adjusts values before serialization."""
364 if proto_type in (
365 TYPE_ENUM,
366 TYPE_BOOL,
367 TYPE_INT32,
368 TYPE_INT64,
369 TYPE_UINT32,
370 TYPE_UINT64,
371 ):
372 return encode_varint(value)
373 elif proto_type in (TYPE_SINT32, TYPE_SINT64):
374 # Handle zig-zag encoding.
375 return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
376 elif proto_type in FIXED_TYPES:
377 return struct.pack(_pack_fmt(proto_type), value)
378 elif proto_type == TYPE_STRING:
379 return value.encode("utf-8")
380 elif proto_type == TYPE_MESSAGE:
381 if isinstance(value, datetime):
382 # Convert the `datetime` to a timestamp message.
383 value = _Timestamp.from_datetime(value)
384 elif isinstance(value, timedelta):
385 # Convert the `timedelta` to a duration message.
386 value = _Duration.from_timedelta(value)
387 elif wraps:
388 if value is None:
389 return b""
390 value = _get_wrapper(wraps)(value=value)
391
392 return bytes(value)
393
394 return value
395
396
397def _serialize_single(
398 field_number: int,
399 proto_type: str,
400 value: Any,
401 *,
402 serialize_empty: bool = False,
403 wraps: str = "",
404) -> bytes:
405 """Serializes a single field and value."""
406 value = _preprocess_single(proto_type, wraps, value)
407
408 output = bytearray()
409 if proto_type in WIRE_VARINT_TYPES:
410 key = encode_varint(field_number << 3)
411 output += key + value
412 elif proto_type in WIRE_FIXED_32_TYPES:
413 key = encode_varint((field_number << 3) | 5)
414 output += key + value
415 elif proto_type in WIRE_FIXED_64_TYPES:
416 key = encode_varint((field_number << 3) | 1)
417 output += key + value
418 elif proto_type in WIRE_LEN_DELIM_TYPES:
419 if len(value) or serialize_empty or wraps:
420 key = encode_varint((field_number << 3) | 2)
421 output += key + encode_varint(len(value)) + value
422 else:
423 raise NotImplementedError(proto_type)
424
425 return bytes(output)
426
427
428def _parse_float(value: Any) -> float:
429 """Parse the given value to a float
430
431 Parameters
432 ----------
433 value: Any
434 Value to parse
435
436 Returns
437 -------
438 float
439 Parsed value
440 """
441 if value == INFINITY:
442 return float("inf")
443 if value == NEG_INFINITY:
444 return -float("inf")
445 if value == NAN:
446 return float("nan")
447 return float(value)
448
449
450def _dump_float(value: float) -> Union[float, str]:
451 """Dump the given float to JSON
452
453 Parameters
454 ----------
455 value: float
456 Value to dump
457
458 Returns
459 -------
460 Union[float, str]
461 Dumped value, either a float or the strings
462 """
463 if value == float("inf"):
464 return INFINITY
465 if value == -float("inf"):
466 return NEG_INFINITY
467 if isinstance(value, float) and math.isnan(value):
468 return NAN
469 return value
470
471
472def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
473 """
474 Decode a single varint value from a byte buffer. Returns the value and the
475 new position in the buffer.
476 """
477 result = 0
478 shift = 0
479 while 1:
480 b = buffer[pos]
481 result |= (b & 0x7F) << shift
482 pos += 1
483 if not (b & 0x80):
484 return result, pos
485 shift += 7
486 if shift >= 64:
487 raise ValueError("Too many bytes when decoding varint.")
488
489
490@dataclasses.dataclass(frozen=True)
491class ParsedField:
492 number: int
493 wire_type: int
494 value: Any
495 raw: bytes
496
497
498def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
499 i = 0
500 while i < len(value):
501 start = i
502 num_wire, i = decode_varint(value, i)
503 number = num_wire >> 3
504 wire_type = num_wire & 0x7
505
506 decoded: Any = None
507 if wire_type == WIRE_VARINT:
508 decoded, i = decode_varint(value, i)
509 elif wire_type == WIRE_FIXED_64:
510 decoded, i = value[i : i + 8], i + 8
511 elif wire_type == WIRE_LEN_DELIM:
512 length, i = decode_varint(value, i)
513 decoded = value[i : i + length]
514 i += length
515 elif wire_type == WIRE_FIXED_32:
516 decoded, i = value[i : i + 4], i + 4
517
518 yield ParsedField(
519 number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
520 )
521
522
523class ProtoClassMetadata:
524 __slots__ = (
525 "oneof_group_by_field",
526 "oneof_field_by_group",
527 "default_gen",
528 "cls_by_field",
529 "field_name_by_number",
530 "meta_by_field_name",
531 "sorted_field_names",
532 )
533
534 oneof_group_by_field: Dict[str, str]
535 oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
536 field_name_by_number: Dict[int, str]
537 meta_by_field_name: Dict[str, FieldMetadata]
538 sorted_field_names: Tuple[str, ...]
539 default_gen: Dict[str, Callable[[], Any]]
540 cls_by_field: Dict[str, Type]
541
542 def __init__(self, cls: Type["Message"]):
543 by_field = {}
544 by_group: Dict[str, Set] = {}
545 by_field_name = {}
546 by_field_number = {}
547
548 fields = dataclasses.fields(cls)
549 for field in fields:
550 meta = FieldMetadata.get(field)
551
552 if meta.group:
553 # This is part of a one-of group.
554 by_field[field.name] = meta.group
555
556 by_group.setdefault(meta.group, set()).add(field)
557
558 by_field_name[field.name] = meta
559 by_field_number[meta.number] = field.name
560
561 self.oneof_group_by_field = by_field
562 self.oneof_field_by_group = by_group
563 self.field_name_by_number = by_field_number
564 self.meta_by_field_name = by_field_name
565 self.sorted_field_names = tuple(
566 by_field_number[number] for number in sorted(by_field_number)
567 )
568 self.default_gen = self._get_default_gen(cls, fields)
569 self.cls_by_field = self._get_cls_by_field(cls, fields)
570
571 @staticmethod
572 def _get_default_gen(
573 cls: Type["Message"], fields: Iterable[dataclasses.Field]
574 ) -> Dict[str, Callable[[], Any]]:
575 return {field.name: cls._get_field_default_gen(field) for field in fields}
576
577 @staticmethod
578 def _get_cls_by_field(
579 cls: Type["Message"], fields: Iterable[dataclasses.Field]
580 ) -> Dict[str, Type]:
581 field_cls = {}
582
583 for field in fields:
584 meta = FieldMetadata.get(field)
585 if meta.proto_type == TYPE_MAP:
586 assert meta.map_types
587 kt = cls._cls_for(field, index=0)
588 vt = cls._cls_for(field, index=1)
589 field_cls[field.name] = dataclasses.make_dataclass(
590 "Entry",
591 [
592 ("key", kt, dataclass_field(1, meta.map_types[0])),
593 ("value", vt, dataclass_field(2, meta.map_types[1])),
594 ],
595 bases=(Message,),
596 )
597 field_cls[f"{field.name}.value"] = vt
598 else:
599 field_cls[field.name] = cls._cls_for(field)
600
601 return field_cls
602
603
604class Message(ABC):
605 """
606 The base class for protobuf messages, all generated messages will inherit from
607 this. This class registers the message fields which are used by the serializers and
608 parsers to go between the Python, binary and JSON representations of the message.
609
610 .. container:: operations
611
612 .. describe:: bytes(x)
613
614 Calls :meth:`__bytes__`.
615
616 .. describe:: bool(x)
617
618 Calls :meth:`__bool__`.
619 """
620
621 _serialized_on_wire: bool
622 _unknown_fields: bytes
623 _group_current: Dict[str, str]
624
625 def __post_init__(self) -> None:
626 # Keep track of whether every field was default
627 all_sentinel = True
628
629 # Set current field of each group after `__init__` has already been run.
630 group_current: Dict[str, Optional[str]] = {}
631 for field_name, meta in self._betterproto.meta_by_field_name.items():
632 if meta.group:
633 group_current.setdefault(meta.group)
634
635 value = self.__raw_get(field_name)
636 if value != PLACEHOLDER and not (meta.optional and value is None):
637 # Found a non-sentinel value
638 all_sentinel = False
639
640 if meta.group:
641 # This was set, so make it the selected value of the one-of.
642 group_current[meta.group] = field_name
643
644 # Now that all the defaults are set, reset it!
645 self.__dict__["_serialized_on_wire"] = not all_sentinel
646 self.__dict__["_unknown_fields"] = b""
647 self.__dict__["_group_current"] = group_current
648
649 def __raw_get(self, name: str) -> Any:
650 return super().__getattribute__(name)
651
652 def __eq__(self, other) -> bool:
653 if type(self) is not type(other):
654 return False
655
656 for field_name in self._betterproto.meta_by_field_name:
657 self_val = self.__raw_get(field_name)
658 other_val = other.__raw_get(field_name)
659 if self_val is PLACEHOLDER:
660 if other_val is PLACEHOLDER:
661 continue
662 self_val = self._get_field_default(field_name)
663 elif other_val is PLACEHOLDER:
664 other_val = other._get_field_default(field_name)
665
666 if self_val != other_val:
667 # We consider two nan values to be the same for the
668 # purposes of comparing messages (otherwise a message
669 # is not equal to itself)
670 if (
671 isinstance(self_val, float)
672 and isinstance(other_val, float)
673 and math.isnan(self_val)
674 and math.isnan(other_val)
675 ):
676 continue
677 else:
678 return False
679
680 return True
681
682 def __repr__(self) -> str:
683 parts = [
684 f"{field_name}={value!r}"
685 for field_name in self._betterproto.sorted_field_names
686 for value in (self.__raw_get(field_name),)
687 if value is not PLACEHOLDER
688 ]
689 return f"{self.__class__.__name__}({', '.join(parts)})"
690
691 if not TYPE_CHECKING:
692
693 def __getattribute__(self, name: str) -> Any:
694 """
695 Lazily initialize default values to avoid infinite recursion for recursive
696 message types
697 """
698 value = super().__getattribute__(name)
699 if value is not PLACEHOLDER:
700 return value
701
702 value = self._get_field_default(name)
703 super().__setattr__(name, value)
704 return value
705
706 def __setattr__(self, attr: str, value: Any) -> None:
707 if (
708 isinstance(value, Message)
709 and hasattr(value, "_betterproto")
710 and not value._betterproto.meta_by_field_name
711 ):
712 value._serialized_on_wire = True
713
714 if attr != "_serialized_on_wire":
715 # Track when a field has been set.
716 self.__dict__["_serialized_on_wire"] = True
717
718 if hasattr(self, "_group_current"): # __post_init__ had already run
719 if attr in self._betterproto.oneof_group_by_field:
720 group = self._betterproto.oneof_group_by_field[attr]
721 for field in self._betterproto.oneof_field_by_group[group]:
722 if field.name == attr:
723 self._group_current[group] = field.name
724 else:
725 super().__setattr__(field.name, PLACEHOLDER)
726
727 super().__setattr__(attr, value)
728
729 def __bool__(self) -> bool:
730 """True if the Message has any fields with non-default values."""
731 return any(
732 self.__raw_get(field_name)
733 not in (PLACEHOLDER, self._get_field_default(field_name))
734 for field_name in self._betterproto.meta_by_field_name
735 )
736
737 def __deepcopy__(self: T, _: Any = {}) -> T:
738 kwargs = {}
739 for name in self._betterproto.sorted_field_names:
740 value = self.__raw_get(name)
741 if value is not PLACEHOLDER:
742 kwargs[name] = deepcopy(value)
743 return self.__class__(**kwargs) # type: ignore
744
745 @property
746 def _betterproto(self) -> ProtoClassMetadata:
747 """
748 Lazy initialize metadata for each protobuf class.
749 It may be initialized multiple times in a multi-threaded environment,
750 but that won't affect the correctness.
751 """
752 meta = getattr(self.__class__, "_betterproto_meta", None)
753 if not meta:
754 meta = ProtoClassMetadata(self.__class__)
755 self.__class__._betterproto_meta = meta # type: ignore
756 return meta
757
758 def __bytes__(self) -> bytes:
759 """
760 Get the binary encoded Protobuf representation of this message instance.
761 """
762 output = bytearray()
763 for field_name, meta in self._betterproto.meta_by_field_name.items():
764 value = getattr(self, field_name)
765
766 if value is None:
767 # Optional items should be skipped. This is used for the Google
768 # wrapper types and proto3 field presence/optional fields.
769 continue
770
771 # Being selected in a a group means this field is the one that is
772 # currently set in a `oneof` group, so it must be serialized even
773 # if the value is the default zero value.
774 #
775 # Note that proto3 field presence/optional fields are put in a
776 # synthetic single-item oneof by protoc, which helps us ensure we
777 # send the value even if the value is the default zero value.
778 selected_in_group = (
779 meta.group and self._group_current[meta.group] == field_name
780 )
781
782 # Empty messages can still be sent on the wire if they were
783 # set (or received empty).
784 serialize_empty = isinstance(value, Message) and value._serialized_on_wire
785
786 include_default_value_for_oneof = self._include_default_value_for_oneof(
787 field_name=field_name, meta=meta
788 )
789
790 if value == self._get_field_default(field_name) and not (
791 selected_in_group or serialize_empty or include_default_value_for_oneof
792 ):
793 # Default (zero) values are not serialized. Two exceptions are
794 # if this is the selected oneof item or if we know we have to
795 # serialize an empty message (i.e. zero value was explicitly
796 # set by the user).
797 continue
798
799 if isinstance(value, list):
800 if meta.proto_type in PACKED_TYPES:
801 # Packed lists look like a length-delimited field. First,
802 # preprocess/encode each value into a buffer and then
803 # treat it like a field of raw bytes.
804 buf = bytearray()
805 for item in value:
806 buf += _preprocess_single(meta.proto_type, "", item)
807 output += _serialize_single(meta.number, TYPE_BYTES, buf)
808 else:
809 for item in value:
810 output += (
811 _serialize_single(
812 meta.number,
813 meta.proto_type,
814 item,
815 wraps=meta.wraps or "",
816 serialize_empty=True,
817 )
818 # if it's an empty message it still needs to be represented
819 # as an item in the repeated list
820 or b"\n\x00"
821 )
822
823 elif isinstance(value, dict):
824 for k, v in value.items():
825 assert meta.map_types
826 sk = _serialize_single(1, meta.map_types[0], k)
827 sv = _serialize_single(2, meta.map_types[1], v)
828 output += _serialize_single(meta.number, meta.proto_type, sk + sv)
829 else:
830 # If we have an empty string and we're including the default value for
831 # a oneof, make sure we serialize it. This ensures that the byte string
832 # output isn't simply an empty string. This also ensures that round trip
833 # serialization will keep `which_one_of` calls consistent.
834 if (
835 isinstance(value, str)
836 and value == ""
837 and include_default_value_for_oneof
838 ):
839 serialize_empty = True
840
841 output += _serialize_single(
842 meta.number,
843 meta.proto_type,
844 value,
845 serialize_empty=serialize_empty or bool(selected_in_group),
846 wraps=meta.wraps or "",
847 )
848
849 output += self._unknown_fields
850 return bytes(output)
851
852 # For compatibility with other libraries
853 def SerializeToString(self: T) -> bytes:
854 """
855 Get the binary encoded Protobuf representation of this message instance.
856
857 .. note::
858 This is a method for compatibility with other libraries,
859 you should really use ``bytes(x)``.
860
861 Returns
862 --------
863 :class:`bytes`
864 The binary encoded Protobuf representation of this message instance
865 """
866 return bytes(self)
867
868 @classmethod
869 def _type_hint(cls, field_name: str) -> Type:
870 return cls._type_hints()[field_name]
871
872 @classmethod
873 def _type_hints(cls) -> Dict[str, Type]:
874 module = sys.modules[cls.__module__]
875 return get_type_hints(cls, module.__dict__, {})
876
877 @classmethod
878 def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
879 """Get the message class for a field from the type hints."""
880 field_cls = cls._type_hint(field.name)
881 if hasattr(field_cls, "__args__") and index >= 0:
882 if field_cls.__args__ is not None:
883 field_cls = field_cls.__args__[index]
884 return field_cls
885
886 def _get_field_default(self, field_name: str) -> Any:
887 with warnings.catch_warnings():
888 # ignore warnings when initialising deprecated field defaults
889 warnings.filterwarnings("ignore", category=DeprecationWarning)
890 return self._betterproto.default_gen[field_name]()
891
892 @classmethod
893 def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
894 t = cls._type_hint(field.name)
895
896 if hasattr(t, "__origin__"):
897 if t.__origin__ is dict:
898 # This is some kind of map (dict in Python).
899 return dict
900 elif t.__origin__ is list:
901 # This is some kind of list (repeated) field.
902 return list
903 elif t.__origin__ is Union and t.__args__[1] is type(None):
904 # This is an optional field (either wrapped, or using proto3
905 # field presence). For setting the default we really don't care
906 # what kind of field it is.
907 return type(None)
908 else:
909 return t
910 elif issubclass(t, Enum):
911 # Enums always default to zero.
912 return int
913 elif t is datetime:
914 # Offsets are relative to 1970-01-01T00:00:00Z
915 return datetime_default_gen
916 else:
917 # This is either a primitive scalar or another message type. Calling
918 # it should result in its zero value.
919 return t
920
921 def _postprocess_single(
922 self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
923 ) -> Any:
924 """Adjusts values after parsing."""
925 if wire_type == WIRE_VARINT:
926 if meta.proto_type in (TYPE_INT32, TYPE_INT64):
927 bits = int(meta.proto_type[3:])
928 value = value & ((1 << bits) - 1)
929 signbit = 1 << (bits - 1)
930 value = int((value ^ signbit) - signbit)
931 elif meta.proto_type in (TYPE_SINT32, TYPE_SINT64):
932 # Undo zig-zag encoding
933 value = (value >> 1) ^ (-(value & 1))
934 elif meta.proto_type == TYPE_BOOL:
935 # Booleans use a varint encoding, so convert it to true/false.
936 value = value > 0
937 elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
938 fmt = _pack_fmt(meta.proto_type)
939 value = struct.unpack(fmt, value)[0]
940 elif wire_type == WIRE_LEN_DELIM:
941 if meta.proto_type == TYPE_STRING:
942 value = str(value, "utf-8")
943 elif meta.proto_type == TYPE_MESSAGE:
944 cls = self._betterproto.cls_by_field[field_name]
945
946 if cls == datetime:
947 value = _Timestamp().parse(value).to_datetime()
948 elif cls == timedelta:
949 value = _Duration().parse(value).to_timedelta()
950 elif meta.wraps:
951 # This is a Google wrapper value message around a single
952 # scalar type.
953 value = _get_wrapper(meta.wraps)().parse(value).value
954 else:
955 value = cls().parse(value)
956 value._serialized_on_wire = True
957 elif meta.proto_type == TYPE_MAP:
958 value = self._betterproto.cls_by_field[field_name]().parse(value)
959
960 return value
961
962 def _include_default_value_for_oneof(
963 self, field_name: str, meta: FieldMetadata
964 ) -> bool:
965 return (
966 meta.group is not None and self._group_current.get(meta.group) == field_name
967 )
968
969 def parse(self: T, data: bytes) -> T:
970 """
971 Parse the binary encoded Protobuf into this message instance. This
972 returns the instance itself and is therefore assignable and chainable.
973
974 Parameters
975 -----------
976 data: :class:`bytes`
977 The data to parse the protobuf from.
978
979 Returns
980 --------
981 :class:`Message`
982 The initialized message.
983 """
984 # Got some data over the wire
985 self._serialized_on_wire = True
986 proto_meta = self._betterproto
987 for parsed in parse_fields(data):
988 field_name = proto_meta.field_name_by_number.get(parsed.number)
989 if not field_name:
990 self._unknown_fields += parsed.raw
991 continue
992
993 meta = proto_meta.meta_by_field_name[field_name]
994
995 value: Any
996 if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
997 # This is a packed repeated field.
998 pos = 0
999 value = []
1000 while pos < len(parsed.value):
1001 if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32):
1002 decoded, pos = parsed.value[pos : pos + 4], pos + 4
1003 wire_type = WIRE_FIXED_32
1004 elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64):
1005 decoded, pos = parsed.value[pos : pos + 8], pos + 8
1006 wire_type = WIRE_FIXED_64
1007 else:
1008 decoded, pos = decode_varint(parsed.value, pos)
1009 wire_type = WIRE_VARINT
1010 decoded = self._postprocess_single(
1011 wire_type, meta, field_name, decoded
1012 )
1013 value.append(decoded)
1014 else:
1015 value = self._postprocess_single(
1016 parsed.wire_type, meta, field_name, parsed.value
1017 )
1018
1019 current = getattr(self, field_name)
1020 if meta.proto_type == TYPE_MAP:
1021 # Value represents a single key/value pair entry in the map.
1022 current[value.key] = value.value
1023 elif isinstance(current, list) and not isinstance(value, list):
1024 current.append(value)
1025 else:
1026 setattr(self, field_name, value)
1027
1028 return self
1029
1030 # For compatibility with other libraries.
1031 @classmethod
1032 def FromString(cls: Type[T], data: bytes) -> T:
1033 """
1034 Parse the binary encoded Protobuf into this message instance. This
1035 returns the instance itself and is therefore assignable and chainable.
1036
1037 .. note::
1038 This is a method for compatibility with other libraries,
1039 you should really use :meth:`parse`.
1040
1041
1042 Parameters
1043 -----------
1044 data: :class:`bytes`
1045 The data to parse the protobuf from.
1046
1047 Returns
1048 --------
1049 :class:`Message`
1050 The initialized message.
1051 """
1052 return cls().parse(data)
1053
1054 def to_dict(
1055 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
1056 ) -> Dict[str, Any]:
1057 """
1058 Returns a JSON serializable dict representation of this object.
1059
1060 Parameters
1061 -----------
1062 casing: :class:`Casing`
1063 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1064 compatibility purposes.
1065 include_default_values: :class:`bool`
1066 If ``True`` will include the default values of fields. Default is ``False``.
1067 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1068 set to ``True``, otherwise this would be ignored.
1069
1070 Returns
1071 --------
1072 Dict[:class:`str`, Any]
1073 The JSON serializable dict representation of this object.
1074 """
1075 output: Dict[str, Any] = {}
1076 field_types = self._type_hints()
1077 defaults = self._betterproto.default_gen
1078 for field_name, meta in self._betterproto.meta_by_field_name.items():
1079 field_is_repeated = defaults[field_name] is list
1080 value = getattr(self, field_name)
1081 cased_name = casing(field_name).rstrip("_") # type: ignore
1082 if meta.proto_type == TYPE_MESSAGE:
1083 if isinstance(value, datetime):
1084 if (
1085 value != DATETIME_ZERO
1086 or include_default_values
1087 or self._include_default_value_for_oneof(
1088 field_name=field_name, meta=meta
1089 )
1090 ):
1091 output[cased_name] = _Timestamp.timestamp_to_json(value)
1092 elif isinstance(value, timedelta):
1093 if (
1094 value != timedelta(0)
1095 or include_default_values
1096 or self._include_default_value_for_oneof(
1097 field_name=field_name, meta=meta
1098 )
1099 ):
1100 output[cased_name] = _Duration.delta_to_json(value)
1101 elif meta.wraps:
1102 if value is not None or include_default_values:
1103 output[cased_name] = value
1104 elif field_is_repeated:
1105 # Convert each item.
1106 cls = self._betterproto.cls_by_field[field_name]
1107 if cls == datetime:
1108 value = [_Timestamp.timestamp_to_json(i) for i in value]
1109 elif cls == timedelta:
1110 value = [_Duration.delta_to_json(i) for i in value]
1111 else:
1112 value = [
1113 i.to_dict(casing, include_default_values) for i in value
1114 ]
1115 if value or include_default_values:
1116 output[cased_name] = value
1117 elif value is None:
1118 if include_default_values:
1119 output[cased_name] = value
1120 elif (
1121 value._serialized_on_wire
1122 or include_default_values
1123 or self._include_default_value_for_oneof(
1124 field_name=field_name, meta=meta
1125 )
1126 ):
1127 output[cased_name] = value.to_dict(casing, include_default_values)
1128 elif meta.proto_type == TYPE_MAP:
1129 output_map = {**value}
1130 for k in value:
1131 if hasattr(value[k], "to_dict"):
1132 output_map[k] = value[k].to_dict(casing, include_default_values)
1133
1134 if value or include_default_values:
1135 output[cased_name] = output_map
1136 elif (
1137 value != self._get_field_default(field_name)
1138 or include_default_values
1139 or self._include_default_value_for_oneof(
1140 field_name=field_name, meta=meta
1141 )
1142 ):
1143 if meta.proto_type in INT_64_TYPES:
1144 if field_is_repeated:
1145 output[cased_name] = [str(n) for n in value]
1146 elif value is None:
1147 if include_default_values:
1148 output[cased_name] = value
1149 else:
1150 output[cased_name] = str(value)
1151 elif meta.proto_type == TYPE_BYTES:
1152 if field_is_repeated:
1153 output[cased_name] = [
1154 b64encode(b).decode("utf8") for b in value
1155 ]
1156 elif value is None and include_default_values:
1157 output[cased_name] = value
1158 else:
1159 output[cased_name] = b64encode(value).decode("utf8")
1160 elif meta.proto_type == TYPE_ENUM:
1161 if field_is_repeated:
1162 enum_class = field_types[field_name].__args__[0]
1163 if isinstance(value, typing.Iterable) and not isinstance(
1164 value, str
1165 ):
1166 output[cased_name] = [enum_class(el).name for el in value]
1167 else:
1168 # transparently upgrade single value to repeated
1169 output[cased_name] = [enum_class(value).name]
1170 elif value is None:
1171 if include_default_values:
1172 output[cased_name] = value
1173 elif meta.optional:
1174 enum_class = field_types[field_name].__args__[0]
1175 output[cased_name] = enum_class(value).name
1176 else:
1177 enum_class = field_types[field_name] # noqa
1178 output[cased_name] = enum_class(value).name
1179 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1180 if field_is_repeated:
1181 output[cased_name] = [_dump_float(n) for n in value]
1182 else:
1183 output[cased_name] = _dump_float(value)
1184 else:
1185 output[cased_name] = value
1186 return output
1187
1188 def from_dict(self: T, value: Mapping[str, Any]) -> T:
1189 """
1190 Parse the key/value pairs into the current message instance. This returns the
1191 instance itself and is therefore assignable and chainable.
1192
1193 Parameters
1194 -----------
1195 value: Dict[:class:`str`, Any]
1196 The dictionary to parse from.
1197
1198 Returns
1199 --------
1200 :class:`Message`
1201 The initialized message.
1202 """
1203 self._serialized_on_wire = True
1204 for key in value:
1205 field_name = safe_snake_case(key)
1206 meta = self._betterproto.meta_by_field_name.get(field_name)
1207 if not meta:
1208 continue
1209
1210 if value[key] is not None:
1211 if meta.proto_type == TYPE_MESSAGE:
1212 v = getattr(self, field_name)
1213 cls = self._betterproto.cls_by_field[field_name]
1214 if isinstance(v, list):
1215 if cls == datetime:
1216 v = [isoparse(item) for item in value[key]]
1217 elif cls == timedelta:
1218 v = [
1219 timedelta(seconds=float(item[:-1]))
1220 for item in value[key]
1221 ]
1222 else:
1223 v = [cls().from_dict(item) for item in value[key]]
1224 elif cls == datetime:
1225 v = isoparse(value[key])
1226 setattr(self, field_name, v)
1227 elif cls == timedelta:
1228 v = timedelta(seconds=float(value[key][:-1]))
1229 setattr(self, field_name, v)
1230 elif meta.wraps:
1231 setattr(self, field_name, value[key])
1232 elif v is None:
1233 setattr(self, field_name, cls().from_dict(value[key]))
1234 else:
1235 # NOTE: `from_dict` mutates the underlying message, so no
1236 # assignment here is necessary.
1237 v.from_dict(value[key])
1238 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1239 v = getattr(self, field_name)
1240 cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1241 for k in value[key]:
1242 v[k] = cls().from_dict(value[key][k])
1243 else:
1244 v = value[key]
1245 if meta.proto_type in INT_64_TYPES:
1246 if isinstance(value[key], list):
1247 v = [int(n) for n in value[key]]
1248 else:
1249 v = int(value[key])
1250 elif meta.proto_type == TYPE_BYTES:
1251 if isinstance(value[key], list):
1252 v = [b64decode(n) for n in value[key]]
1253 else:
1254 v = b64decode(value[key])
1255 elif meta.proto_type == TYPE_ENUM:
1256 enum_cls = self._betterproto.cls_by_field[field_name]
1257 if isinstance(v, list):
1258 v = [enum_cls.from_string(e) for e in v]
1259 elif isinstance(v, str):
1260 v = enum_cls.from_string(v)
1261 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1262 if isinstance(value[key], list):
1263 v = [_parse_float(n) for n in value[key]]
1264 else:
1265 v = _parse_float(value[key])
1266
1267 if v is not None:
1268 setattr(self, field_name, v)
1269 return self
1270
1271 def to_json(
1272 self,
1273 indent: Union[None, int, str] = None,
1274 include_default_values: bool = False,
1275 casing: Casing = Casing.CAMEL,
1276 ) -> str:
1277 """A helper function to parse the message instance into its JSON
1278 representation.
1279
1280 This is equivalent to::
1281
1282 json.dumps(message.to_dict(), indent=indent)
1283
1284 Parameters
1285 -----------
1286 indent: Optional[Union[:class:`int`, :class:`str`]]
1287 The indent to pass to :func:`json.dumps`.
1288
1289 include_default_values: :class:`bool`
1290 If ``True`` will include the default values of fields. Default is ``False``.
1291 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1292 set to ``True``, otherwise this would be ignored.
1293
1294 casing: :class:`Casing`
1295 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1296 compatibility purposes.
1297
1298 Returns
1299 --------
1300 :class:`str`
1301 The JSON representation of the message.
1302 """
1303 return json.dumps(
1304 self.to_dict(include_default_values=include_default_values, casing=casing),
1305 indent=indent,
1306 )
1307
1308 def from_json(self: T, value: Union[str, bytes]) -> T:
1309 """A helper function to return the message instance from its JSON
1310 representation. This returns the instance itself and is therefore assignable
1311 and chainable.
1312
1313 This is equivalent to::
1314
1315 return message.from_dict(json.loads(value))
1316
1317 Parameters
1318 -----------
1319 value: Union[:class:`str`, :class:`bytes`]
1320 The value to pass to :func:`json.loads`.
1321
1322 Returns
1323 --------
1324 :class:`Message`
1325 The initialized message.
1326 """
1327 return self.from_dict(json.loads(value))
1328
1329 def to_pydict(
1330 self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
1331 ) -> Dict[str, Any]:
1332 """
1333 Returns a python dict representation of this object.
1334
1335 Parameters
1336 -----------
1337 casing: :class:`Casing`
1338 The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1339 compatibility purposes.
1340 include_default_values: :class:`bool`
1341 If ``True`` will include the default values of fields. Default is ``False``.
1342 E.g. an ``int32`` field will be included with a value of ``0`` if this is
1343 set to ``True``, otherwise this would be ignored.
1344
1345 Returns
1346 --------
1347 Dict[:class:`str`, Any]
1348 The python dict representation of this object.
1349 """
1350 output: Dict[str, Any] = {}
1351 defaults = self._betterproto.default_gen
1352 for field_name, meta in self._betterproto.meta_by_field_name.items():
1353 field_is_repeated = defaults[field_name] is list
1354 value = getattr(self, field_name)
1355 cased_name = casing(field_name).rstrip("_") # type: ignore
1356 if meta.proto_type == TYPE_MESSAGE:
1357 if isinstance(value, datetime):
1358 if (
1359 value != DATETIME_ZERO
1360 or include_default_values
1361 or self._include_default_value_for_oneof(
1362 field_name=field_name, meta=meta
1363 )
1364 ):
1365 output[cased_name] = value
1366 elif isinstance(value, timedelta):
1367 if (
1368 value != timedelta(0)
1369 or include_default_values
1370 or self._include_default_value_for_oneof(
1371 field_name=field_name, meta=meta
1372 )
1373 ):
1374 output[cased_name] = value
1375 elif meta.wraps:
1376 if value is not None or include_default_values:
1377 output[cased_name] = value
1378 elif field_is_repeated:
1379 # Convert each item.
1380 value = [i.to_pydict(casing, include_default_values) for i in value]
1381 if value or include_default_values:
1382 output[cased_name] = value
1383 elif value is None:
1384 if include_default_values:
1385 output[cased_name] = None
1386 elif (
1387 value._serialized_on_wire
1388 or include_default_values
1389 or self._include_default_value_for_oneof(
1390 field_name=field_name, meta=meta
1391 )
1392 ):
1393 output[cased_name] = value.to_pydict(casing, include_default_values)
1394 elif meta.proto_type == TYPE_MAP:
1395 for k in value:
1396 if hasattr(value[k], "to_pydict"):
1397 value[k] = value[k].to_pydict(casing, include_default_values)
1398
1399 if value or include_default_values:
1400 output[cased_name] = value
1401 elif (
1402 value != self._get_field_default(field_name)
1403 or include_default_values
1404 or self._include_default_value_for_oneof(
1405 field_name=field_name, meta=meta
1406 )
1407 ):
1408 output[cased_name] = value
1409 return output
1410
1411 def from_pydict(self: T, value: Mapping[str, Any]) -> T:
1412 """
1413 Parse the key/value pairs into the current message instance. This returns the
1414 instance itself and is therefore assignable and chainable.
1415
1416 Parameters
1417 -----------
1418 value: Dict[:class:`str`, Any]
1419 The dictionary to parse from.
1420
1421 Returns
1422 --------
1423 :class:`Message`
1424 The initialized message.
1425 """
1426 self._serialized_on_wire = True
1427 for key in value:
1428 field_name = safe_snake_case(key)
1429 meta = self._betterproto.meta_by_field_name.get(field_name)
1430 if not meta:
1431 continue
1432
1433 if value[key] is not None:
1434 if meta.proto_type == TYPE_MESSAGE:
1435 v = getattr(self, field_name)
1436 if isinstance(v, list):
1437 cls = self._betterproto.cls_by_field[field_name]
1438 for item in value[key]:
1439 v.append(cls().from_pydict(item))
1440 elif isinstance(v, datetime):
1441 v = value[key]
1442 elif isinstance(v, timedelta):
1443 v = value[key]
1444 elif meta.wraps:
1445 v = value[key]
1446 else:
1447 # NOTE: `from_pydict` mutates the underlying message, so no
1448 # assignment here is necessary.
1449 v.from_pydict(value[key])
1450 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1451 v = getattr(self, field_name)
1452 cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1453 for k in value[key]:
1454 v[k] = cls().from_pydict(value[key][k])
1455 else:
1456 v = value[key]
1457
1458 if v is not None:
1459 setattr(self, field_name, v)
1460 return self
1461
1462 def is_set(self, name: str) -> bool:
1463 """
1464 Check if field with the given name has been set.
1465
1466 Parameters
1467 -----------
1468 name: :class:`str`
1469 The name of the field to check for.
1470
1471 Returns
1472 --------
1473 :class:`bool`
1474 `True` if field has been set, otherwise `False`.
1475 """
1476 default = (
1477 PLACEHOLDER
1478 if not self._betterproto.meta_by_field_name[name].optional
1479 else None
1480 )
1481 return self.__raw_get(name) is not default
1482
1483 @classmethod
1484 def _validate_field_groups(cls, values):
1485 group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
1486 field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
1487
1488 for group, field_set in group_to_one_ofs.items():
1489
1490 if len(field_set) == 1:
1491 (field,) = field_set
1492 field_name = field.name
1493 meta = field_name_to_meta[field_name]
1494
1495 # This is a synthetic oneof; we should ignore it's presence and not consider it as a oneof.
1496 if meta.optional:
1497 continue
1498
1499 set_fields = [
1500 field.name for field in field_set if values[field.name] is not None
1501 ]
1502
1503 if not set_fields:
1504 raise ValueError(f"Group {group} has no value; all fields are None")
1505 elif len(set_fields) > 1:
1506 set_fields_str = ", ".join(set_fields)
1507 raise ValueError(
1508 f"Group {group} has more than one value; fields {set_fields_str} are not None"
1509 )
1510
1511 return values
1512
1513
1514def serialized_on_wire(message: Message) -> bool:
1515 """
1516 If this message was or should be serialized on the wire. This can be used to detect
1517 presence (e.g. optional wrapper message) and is used internally during
1518 parsing/serialization.
1519
1520 Returns
1521 --------
1522 :class:`bool`
1523 Whether this message was or should be serialized on the wire.
1524 """
1525 return message._serialized_on_wire
1526
1527
1528def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
1529 """
1530 Return the name and value of a message's one-of field group.
1531
1532 Returns
1533 --------
1534 Tuple[:class:`str`, Any]
1535 The field name and the value for that field.
1536 """
1537 field_name = message._group_current.get(group_name)
1538 if not field_name:
1539 return "", None
1540 return field_name, getattr(message, field_name)
1541
1542
1543# Circular import workaround: google.protobuf depends on base classes defined above.
1544from .lib.google.protobuf import ( # noqa
1545 BoolValue,
1546 BytesValue,
1547 DoubleValue,
1548 Duration,
1549 EnumValue,
1550 FloatValue,
1551 Int32Value,
1552 Int64Value,
1553 StringValue,
1554 Timestamp,
1555 UInt32Value,
1556 UInt64Value,
1557)
1558
1559
1560class _Duration(Duration):
1561 @classmethod
1562 def from_timedelta(
1563 cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
1564 ) -> "_Duration":
1565 total_ms = delta // _1_microsecond
1566 seconds = int(total_ms / 1e6)
1567 nanos = int((total_ms % 1e6) * 1e3)
1568 return cls(seconds, nanos)
1569
1570 def to_timedelta(self) -> timedelta:
1571 return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
1572
1573 @staticmethod
1574 def delta_to_json(delta: timedelta) -> str:
1575 parts = str(delta.total_seconds()).split(".")
1576 if len(parts) > 1:
1577 while len(parts[1]) not in (3, 6, 9):
1578 parts[1] = f"{parts[1]}0"
1579 return f"{'.'.join(parts)}s"
1580
1581
1582class _Timestamp(Timestamp):
1583 @classmethod
1584 def from_datetime(cls, dt: datetime) -> "_Timestamp":
1585 seconds = int(dt.timestamp())
1586 nanos = int(dt.microsecond * 1e3)
1587 return cls(seconds, nanos)
1588
1589 def to_datetime(self) -> datetime:
1590 ts = self.seconds + (self.nanos / 1e9)
1591 return datetime.fromtimestamp(ts, tz=timezone.utc)
1592
1593 @staticmethod
1594 def timestamp_to_json(dt: datetime) -> str:
1595 nanos = dt.microsecond * 1e3
1596 if dt.tzinfo is not None:
1597 # change timezone aware datetime objects to utc
1598 dt = dt.astimezone(timezone.utc)
1599 copy = dt.replace(microsecond=0, tzinfo=None)
1600 result = copy.isoformat()
1601 if (nanos % 1e9) == 0:
1602 # If there are 0 fractional digits, the fractional
1603 # point '.' should be omitted when serializing.
1604 return f"{result}Z"
1605 if (nanos % 1e6) == 0:
1606 # Serialize 3 fractional digits.
1607 return f"{result}.{int(nanos // 1e6) :03d}Z"
1608 if (nanos % 1e3) == 0:
1609 # Serialize 6 fractional digits.
1610 return f"{result}.{int(nanos // 1e3) :06d}Z"
1611 # Serialize 9 fractional digits.
1612 return f"{result}.{nanos:09d}"
1613
1614
1615def _get_wrapper(proto_type: str) -> Type:
1616 """Get the wrapper message class for a wrapped type."""
1617
1618 # TODO: include ListValue and NullValue?
1619 return {
1620 TYPE_BOOL: BoolValue,
1621 TYPE_BYTES: BytesValue,
1622 TYPE_DOUBLE: DoubleValue,
1623 TYPE_FLOAT: FloatValue,
1624 TYPE_ENUM: EnumValue,
1625 TYPE_INT32: Int32Value,
1626 TYPE_INT64: Int64Value,
1627 TYPE_STRING: StringValue,
1628 TYPE_UINT32: UInt32Value,
1629 TYPE_UINT64: UInt64Value,
1630 }[proto_type]