Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 11%
772 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
28__author__ = 'robinson@google.com (Will Robinson)'
30from io import BytesIO
31import struct
32import sys
33import warnings
34import weakref
36from google.protobuf import descriptor as descriptor_mod
37from google.protobuf import message as message_mod
38from google.protobuf import text_format
39# We use "as" to avoid name collisions with variables.
40from google.protobuf.internal import api_implementation
41from google.protobuf.internal import containers
42from google.protobuf.internal import decoder
43from google.protobuf.internal import encoder
44from google.protobuf.internal import enum_type_wrapper
45from google.protobuf.internal import extension_dict
46from google.protobuf.internal import message_listener as message_listener_mod
47from google.protobuf.internal import type_checkers
48from google.protobuf.internal import well_known_types
49from google.protobuf.internal import wire_format
51_FieldDescriptor = descriptor_mod.FieldDescriptor
52_AnyFullTypeName = 'google.protobuf.Any'
53_ExtensionDict = extension_dict._ExtensionDict
55class GeneratedProtocolMessageType(type):
57 """Metaclass for protocol message classes created at runtime from Descriptors.
59 We add implementations for all methods described in the Message class. We
60 also create properties to allow getting/setting all fields in the protocol
61 message. Finally, we create slots to prevent users from accidentally
62 "setting" nonexistent fields in the protocol message, which then wouldn't get
63 serialized / deserialized properly.
65 The protocol compiler currently uses this metaclass to create protocol
66 message classes at runtime. Clients can also manually create their own
67 classes at runtime, as in this example:
69 mydescriptor = Descriptor(.....)
70 factory = symbol_database.Default()
71 factory.pool.AddDescriptor(mydescriptor)
72 MyProtoClass = factory.GetPrototype(mydescriptor)
73 myproto_instance = MyProtoClass()
74 myproto.foo_field = 23
75 ...
76 """
78 # Must be consistent with the protocol-compiler code in
79 # proto2/compiler/internal/generator.*.
80 _DESCRIPTOR_KEY = 'DESCRIPTOR'
82 def __new__(cls, name, bases, dictionary):
83 """Custom allocation for runtime-generated class types.
85 We override __new__ because this is apparently the only place
86 where we can meaningfully set __slots__ on the class we're creating(?).
87 (The interplay between metaclasses and slots is not very well-documented).
89 Args:
90 name: Name of the class (ignored, but required by the
91 metaclass protocol).
92 bases: Base classes of the class we're constructing.
93 (Should be message.Message). We ignore this field, but
94 it's required by the metaclass protocol
95 dictionary: The class dictionary of the class we're
96 constructing. dictionary[_DESCRIPTOR_KEY] must contain
97 a Descriptor object describing this protocol message
98 type.
100 Returns:
101 Newly-allocated class.
103 Raises:
104 RuntimeError: Generated code only work with python cpp extension.
105 """
106 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
108 if isinstance(descriptor, str):
109 raise RuntimeError('The generated code only work with python cpp '
110 'extension, but it is using pure python runtime.')
112 # If a concrete class already exists for this descriptor, don't try to
113 # create another. Doing so will break any messages that already exist with
114 # the existing class.
115 #
116 # The C++ implementation appears to have its own internal `PyMessageFactory`
117 # to achieve similar results.
118 #
119 # This most commonly happens in `text_format.py` when using descriptors from
120 # a custom pool; it calls symbol_database.Global().getPrototype() on a
121 # descriptor which already has an existing concrete class.
122 new_class = getattr(descriptor, '_concrete_class', None)
123 if new_class:
124 return new_class
126 if descriptor.full_name in well_known_types.WKTBASES:
127 bases += (well_known_types.WKTBASES[descriptor.full_name],)
128 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
129 _AddSlots(descriptor, dictionary)
131 superclass = super(GeneratedProtocolMessageType, cls)
132 new_class = superclass.__new__(cls, name, bases, dictionary)
133 return new_class
135 def __init__(cls, name, bases, dictionary):
136 """Here we perform the majority of our work on the class.
137 We add enum getters, an __init__ method, implementations
138 of all Message methods, and properties for all fields
139 in the protocol type.
141 Args:
142 name: Name of the class (ignored, but required by the
143 metaclass protocol).
144 bases: Base classes of the class we're constructing.
145 (Should be message.Message). We ignore this field, but
146 it's required by the metaclass protocol
147 dictionary: The class dictionary of the class we're
148 constructing. dictionary[_DESCRIPTOR_KEY] must contain
149 a Descriptor object describing this protocol message
150 type.
151 """
152 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
154 # If this is an _existing_ class looked up via `_concrete_class` in the
155 # __new__ method above, then we don't need to re-initialize anything.
156 existing_class = getattr(descriptor, '_concrete_class', None)
157 if existing_class:
158 assert existing_class is cls, (
159 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
160 % (descriptor.full_name))
161 return
163 cls._message_set_decoders_by_tag = {}
164 cls._fields_by_tag = {}
165 if (descriptor.has_options and
166 descriptor.GetOptions().message_set_wire_format):
167 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
168 decoder.MessageSetItemDecoder(descriptor),
169 None,
170 )
172 # Attach stuff to each FieldDescriptor for quick lookup later on.
173 for field in descriptor.fields:
174 _AttachFieldHelpers(cls, field)
176 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
177 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
178 for ext in extensions:
179 _AttachFieldHelpers(cls, ext)
181 descriptor._concrete_class = cls # pylint: disable=protected-access
182 _AddEnumValues(descriptor, cls)
183 _AddInitMethod(descriptor, cls)
184 _AddPropertiesForFields(descriptor, cls)
185 _AddPropertiesForExtensions(descriptor, cls)
186 _AddStaticMethods(cls)
187 _AddMessageMethods(descriptor, cls)
188 _AddPrivateHelperMethods(descriptor, cls)
190 superclass = super(GeneratedProtocolMessageType, cls)
191 superclass.__init__(name, bases, dictionary)
194# Stateless helpers for GeneratedProtocolMessageType below.
195# Outside clients should not access these directly.
196#
197# I opted not to make any of these methods on the metaclass, to make it more
198# clear that I'm not really using any state there and to keep clients from
199# thinking that they have direct access to these construction helpers.
202def _PropertyName(proto_field_name):
203 """Returns the name of the public property attribute which
204 clients can use to get and (in some cases) set the value
205 of a protocol message field.
207 Args:
208 proto_field_name: The protocol message field name, exactly
209 as it appears (or would appear) in a .proto file.
210 """
211 # TODO: Escape Python keywords (e.g., yield), and test this support.
212 # nnorwitz makes my day by writing:
213 # """
214 # FYI. See the keyword module in the stdlib. This could be as simple as:
215 #
216 # if keyword.iskeyword(proto_field_name):
217 # return proto_field_name + "_"
218 # return proto_field_name
219 # """
220 # Kenton says: The above is a BAD IDEA. People rely on being able to use
221 # getattr() and setattr() to reflectively manipulate field values. If we
222 # rename the properties, then every such user has to also make sure to apply
223 # the same transformation. Note that currently if you name a field "yield",
224 # you can still access it just fine using getattr/setattr -- it's not even
225 # that cumbersome to do so.
226 # TODO: Remove this method entirely if/when everyone agrees with my
227 # position.
228 return proto_field_name
231def _AddSlots(message_descriptor, dictionary):
232 """Adds a __slots__ entry to dictionary, containing the names of all valid
233 attributes for this message type.
235 Args:
236 message_descriptor: A Descriptor instance describing this message type.
237 dictionary: Class dictionary to which we'll add a '__slots__' entry.
238 """
239 dictionary['__slots__'] = ['_cached_byte_size',
240 '_cached_byte_size_dirty',
241 '_fields',
242 '_unknown_fields',
243 '_unknown_field_set',
244 '_is_present_in_parent',
245 '_listener',
246 '_listener_for_children',
247 '__weakref__',
248 '_oneofs']
251def _IsMessageSetExtension(field):
252 return (field.is_extension and
253 field.containing_type.has_options and
254 field.containing_type.GetOptions().message_set_wire_format and
255 field.type == _FieldDescriptor.TYPE_MESSAGE and
256 field.label == _FieldDescriptor.LABEL_OPTIONAL)
259def _IsMapField(field):
260 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
261 field.message_type._is_map_entry)
264def _IsMessageMapField(field):
265 value_type = field.message_type.fields_by_name['value']
266 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
268def _AttachFieldHelpers(cls, field_descriptor):
269 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
270 field_descriptor._default_constructor = _DefaultValueConstructorForField(
271 field_descriptor
272 )
274 def AddFieldByTag(wiretype, is_packed):
275 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
276 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
278 AddFieldByTag(
279 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
280 )
282 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
283 # To support wire compatibility of adding packed = true, add a decoder for
284 # packed values regardless of the field's options.
285 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
288def _MaybeAddEncoder(cls, field_descriptor):
289 if hasattr(field_descriptor, '_encoder'):
290 return
291 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
292 is_map_entry = _IsMapField(field_descriptor)
293 is_packed = field_descriptor.is_packed
295 if is_map_entry:
296 field_encoder = encoder.MapEncoder(field_descriptor)
297 sizer = encoder.MapSizer(field_descriptor,
298 _IsMessageMapField(field_descriptor))
299 elif _IsMessageSetExtension(field_descriptor):
300 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
301 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
302 else:
303 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
304 field_descriptor.number, is_repeated, is_packed)
305 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
306 field_descriptor.number, is_repeated, is_packed)
308 field_descriptor._sizer = sizer
309 field_descriptor._encoder = field_encoder
312def _MaybeAddDecoder(cls, field_descriptor):
313 if hasattr(field_descriptor, '_decoders'):
314 return
316 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
317 is_map_entry = _IsMapField(field_descriptor)
318 helper_decoders = {}
320 def AddDecoder(is_packed):
321 decode_type = field_descriptor.type
322 if (decode_type == _FieldDescriptor.TYPE_ENUM and
323 not field_descriptor.enum_type.is_closed):
324 decode_type = _FieldDescriptor.TYPE_INT32
326 oneof_descriptor = None
327 if field_descriptor.containing_oneof is not None:
328 oneof_descriptor = field_descriptor
330 if is_map_entry:
331 is_message_map = _IsMessageMapField(field_descriptor)
333 field_decoder = decoder.MapDecoder(
334 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
335 is_message_map)
336 elif decode_type == _FieldDescriptor.TYPE_STRING:
337 field_decoder = decoder.StringDecoder(
338 field_descriptor.number, is_repeated, is_packed,
339 field_descriptor, field_descriptor._default_constructor,
340 not field_descriptor.has_presence)
341 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
342 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
343 field_descriptor.number, is_repeated, is_packed,
344 field_descriptor, field_descriptor._default_constructor)
345 else:
346 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
347 field_descriptor.number, is_repeated, is_packed,
348 # pylint: disable=protected-access
349 field_descriptor, field_descriptor._default_constructor,
350 not field_descriptor.has_presence)
352 helper_decoders[is_packed] = field_decoder
354 AddDecoder(False)
356 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
357 # To support wire compatibility of adding packed = true, add a decoder for
358 # packed values regardless of the field's options.
359 AddDecoder(True)
361 field_descriptor._decoders = helper_decoders
364def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
365 extensions = descriptor.extensions_by_name
366 for extension_name, extension_field in extensions.items():
367 assert extension_name not in dictionary
368 dictionary[extension_name] = extension_field
371def _AddEnumValues(descriptor, cls):
372 """Sets class-level attributes for all enum fields defined in this message.
374 Also exporting a class-level object that can name enum values.
376 Args:
377 descriptor: Descriptor object for this message type.
378 cls: Class we're constructing for this message type.
379 """
380 for enum_type in descriptor.enum_types:
381 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
382 for enum_value in enum_type.values:
383 setattr(cls, enum_value.name, enum_value.number)
386def _GetInitializeDefaultForMap(field):
387 if field.label != _FieldDescriptor.LABEL_REPEATED:
388 raise ValueError('map_entry set on non-repeated field %s' % (
389 field.name))
390 fields_by_name = field.message_type.fields_by_name
391 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
393 value_field = fields_by_name['value']
394 if _IsMessageMapField(field):
395 def MakeMessageMapDefault(message):
396 return containers.MessageMap(
397 message._listener_for_children, value_field.message_type, key_checker,
398 field.message_type)
399 return MakeMessageMapDefault
400 else:
401 value_checker = type_checkers.GetTypeChecker(value_field)
402 def MakePrimitiveMapDefault(message):
403 return containers.ScalarMap(
404 message._listener_for_children, key_checker, value_checker,
405 field.message_type)
406 return MakePrimitiveMapDefault
408def _DefaultValueConstructorForField(field):
409 """Returns a function which returns a default value for a field.
411 Args:
412 field: FieldDescriptor object for this field.
414 The returned function has one argument:
415 message: Message instance containing this field, or a weakref proxy
416 of same.
418 That function in turn returns a default value for this field. The default
419 value may refer back to |message| via a weak reference.
420 """
422 if _IsMapField(field):
423 return _GetInitializeDefaultForMap(field)
425 if field.label == _FieldDescriptor.LABEL_REPEATED:
426 if field.has_default_value and field.default_value != []:
427 raise ValueError('Repeated field default value not empty list: %s' % (
428 field.default_value))
429 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
430 # We can't look at _concrete_class yet since it might not have
431 # been set. (Depends on order in which we initialize the classes).
432 message_type = field.message_type
433 def MakeRepeatedMessageDefault(message):
434 return containers.RepeatedCompositeFieldContainer(
435 message._listener_for_children, field.message_type)
436 return MakeRepeatedMessageDefault
437 else:
438 type_checker = type_checkers.GetTypeChecker(field)
439 def MakeRepeatedScalarDefault(message):
440 return containers.RepeatedScalarFieldContainer(
441 message._listener_for_children, type_checker)
442 return MakeRepeatedScalarDefault
444 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
445 message_type = field.message_type
446 def MakeSubMessageDefault(message):
447 # _concrete_class may not yet be initialized.
448 if not hasattr(message_type, '_concrete_class'):
449 from google.protobuf import message_factory
450 message_factory.GetMessageClass(message_type)
451 result = message_type._concrete_class()
452 result._SetListener(
453 _OneofListener(message, field)
454 if field.containing_oneof is not None
455 else message._listener_for_children)
456 return result
457 return MakeSubMessageDefault
459 def MakeScalarDefault(message):
460 # TODO: This may be broken since there may not be
461 # default_value. Combine with has_default_value somehow.
462 return field.default_value
463 return MakeScalarDefault
466def _ReraiseTypeErrorWithFieldName(message_name, field_name):
467 """Re-raise the currently-handled TypeError with the field name added."""
468 exc = sys.exc_info()[1]
469 if len(exc.args) == 1 and type(exc) is TypeError:
470 # simple TypeError; add field name to exception message
471 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
473 # re-raise possibly-amended exception with original traceback:
474 raise exc.with_traceback(sys.exc_info()[2])
477def _AddInitMethod(message_descriptor, cls):
478 """Adds an __init__ method to cls."""
480 def _GetIntegerEnumValue(enum_type, value):
481 """Convert a string or integer enum value to an integer.
483 If the value is a string, it is converted to the enum value in
484 enum_type with the same name. If the value is not a string, it's
485 returned as-is. (No conversion or bounds-checking is done.)
486 """
487 if isinstance(value, str):
488 try:
489 return enum_type.values_by_name[value].number
490 except KeyError:
491 raise ValueError('Enum type %s: unknown label "%s"' % (
492 enum_type.full_name, value))
493 return value
495 def init(self, **kwargs):
496 self._cached_byte_size = 0
497 self._cached_byte_size_dirty = len(kwargs) > 0
498 self._fields = {}
499 # Contains a mapping from oneof field descriptors to the descriptor
500 # of the currently set field in that oneof field.
501 self._oneofs = {}
503 # _unknown_fields is () when empty for efficiency, and will be turned into
504 # a list if fields are added.
505 self._unknown_fields = ()
506 # _unknown_field_set is None when empty for efficiency, and will be
507 # turned into UnknownFieldSet struct if fields are added.
508 self._unknown_field_set = None # pylint: disable=protected-access
509 self._is_present_in_parent = False
510 self._listener = message_listener_mod.NullMessageListener()
511 self._listener_for_children = _Listener(self)
512 for field_name, field_value in kwargs.items():
513 field = _GetFieldByName(message_descriptor, field_name)
514 if field is None:
515 raise TypeError('%s() got an unexpected keyword argument "%s"' %
516 (message_descriptor.name, field_name))
517 if field_value is None:
518 # field=None is the same as no field at all.
519 continue
520 if field.label == _FieldDescriptor.LABEL_REPEATED:
521 copy = field._default_constructor(self)
522 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
523 if _IsMapField(field):
524 if _IsMessageMapField(field):
525 for key in field_value:
526 copy[key].MergeFrom(field_value[key])
527 else:
528 copy.update(field_value)
529 else:
530 for val in field_value:
531 if isinstance(val, dict):
532 copy.add(**val)
533 else:
534 copy.add().MergeFrom(val)
535 else: # Scalar
536 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
537 field_value = [_GetIntegerEnumValue(field.enum_type, val)
538 for val in field_value]
539 copy.extend(field_value)
540 self._fields[field] = copy
541 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
542 copy = field._default_constructor(self)
543 new_val = field_value
544 if isinstance(field_value, dict):
545 new_val = field.message_type._concrete_class(**field_value)
546 try:
547 copy.MergeFrom(new_val)
548 except TypeError:
549 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
550 self._fields[field] = copy
551 else:
552 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
553 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
554 try:
555 setattr(self, field_name, field_value)
556 except TypeError:
557 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
559 init.__module__ = None
560 init.__doc__ = None
561 cls.__init__ = init
564def _GetFieldByName(message_descriptor, field_name):
565 """Returns a field descriptor by field name.
567 Args:
568 message_descriptor: A Descriptor describing all fields in message.
569 field_name: The name of the field to retrieve.
570 Returns:
571 The field descriptor associated with the field name.
572 """
573 try:
574 return message_descriptor.fields_by_name[field_name]
575 except KeyError:
576 raise ValueError('Protocol message %s has no "%s" field.' %
577 (message_descriptor.name, field_name))
580def _AddPropertiesForFields(descriptor, cls):
581 """Adds properties for all fields in this protocol message type."""
582 for field in descriptor.fields:
583 _AddPropertiesForField(field, cls)
585 if descriptor.is_extendable:
586 # _ExtensionDict is just an adaptor with no state so we allocate a new one
587 # every time it is accessed.
588 cls.Extensions = property(lambda self: _ExtensionDict(self))
591def _AddPropertiesForField(field, cls):
592 """Adds a public property for a protocol message field.
593 Clients can use this property to get and (in the case
594 of non-repeated scalar fields) directly set the value
595 of a protocol message field.
597 Args:
598 field: A FieldDescriptor for this field.
599 cls: The class we're constructing.
600 """
601 # Catch it if we add other types that we should
602 # handle specially here.
603 assert _FieldDescriptor.MAX_CPPTYPE == 10
605 constant_name = field.name.upper() + '_FIELD_NUMBER'
606 setattr(cls, constant_name, field.number)
608 if field.label == _FieldDescriptor.LABEL_REPEATED:
609 _AddPropertiesForRepeatedField(field, cls)
610 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
611 _AddPropertiesForNonRepeatedCompositeField(field, cls)
612 else:
613 _AddPropertiesForNonRepeatedScalarField(field, cls)
616class _FieldProperty(property):
617 __slots__ = ('DESCRIPTOR',)
619 def __init__(self, descriptor, getter, setter, doc):
620 property.__init__(self, getter, setter, doc=doc)
621 self.DESCRIPTOR = descriptor
624def _AddPropertiesForRepeatedField(field, cls):
625 """Adds a public property for a "repeated" protocol message field. Clients
626 can use this property to get the value of the field, which will be either a
627 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
628 below).
630 Note that when clients add values to these containers, we perform
631 type-checking in the case of repeated scalar fields, and we also set any
632 necessary "has" bits as a side-effect.
634 Args:
635 field: A FieldDescriptor for this field.
636 cls: The class we're constructing.
637 """
638 proto_field_name = field.name
639 property_name = _PropertyName(proto_field_name)
641 def getter(self):
642 field_value = self._fields.get(field)
643 if field_value is None:
644 # Construct a new object to represent this field.
645 field_value = field._default_constructor(self)
647 # Atomically check if another thread has preempted us and, if not, swap
648 # in the new object we just created. If someone has preempted us, we
649 # take that object and discard ours.
650 # WARNING: We are relying on setdefault() being atomic. This is true
651 # in CPython but we haven't investigated others. This warning appears
652 # in several other locations in this file.
653 field_value = self._fields.setdefault(field, field_value)
654 return field_value
655 getter.__module__ = None
656 getter.__doc__ = 'Getter for %s.' % proto_field_name
658 # We define a setter just so we can throw an exception with a more
659 # helpful error message.
660 def setter(self, new_value):
661 raise AttributeError('Assignment not allowed to repeated field '
662 '"%s" in protocol message object.' % proto_field_name)
664 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
665 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
668def _AddPropertiesForNonRepeatedScalarField(field, cls):
669 """Adds a public property for a nonrepeated, scalar protocol message field.
670 Clients can use this property to get and directly set the value of the field.
671 Note that when the client sets the value of a field by using this property,
672 all necessary "has" bits are set as a side-effect, and we also perform
673 type-checking.
675 Args:
676 field: A FieldDescriptor for this field.
677 cls: The class we're constructing.
678 """
679 proto_field_name = field.name
680 property_name = _PropertyName(proto_field_name)
681 type_checker = type_checkers.GetTypeChecker(field)
682 default_value = field.default_value
684 def getter(self):
685 # TODO: This may be broken since there may not be
686 # default_value. Combine with has_default_value somehow.
687 return self._fields.get(field, default_value)
688 getter.__module__ = None
689 getter.__doc__ = 'Getter for %s.' % proto_field_name
691 def field_setter(self, new_value):
692 # pylint: disable=protected-access
693 # Testing the value for truthiness captures all of the proto3 defaults
694 # (0, 0.0, enum 0, and False).
695 try:
696 new_value = type_checker.CheckValue(new_value)
697 except TypeError as e:
698 raise TypeError(
699 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
700 if not field.has_presence and not new_value:
701 self._fields.pop(field, None)
702 else:
703 self._fields[field] = new_value
704 # Check _cached_byte_size_dirty inline to improve performance, since scalar
705 # setters are called frequently.
706 if not self._cached_byte_size_dirty:
707 self._Modified()
709 if field.containing_oneof:
710 def setter(self, new_value):
711 field_setter(self, new_value)
712 self._UpdateOneofState(field)
713 else:
714 setter = field_setter
716 setter.__module__ = None
717 setter.__doc__ = 'Setter for %s.' % proto_field_name
719 # Add a property to encapsulate the getter/setter.
720 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
721 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
724def _AddPropertiesForNonRepeatedCompositeField(field, cls):
725 """Adds a public property for a nonrepeated, composite protocol message field.
726 A composite field is a "group" or "message" field.
728 Clients can use this property to get the value of the field, but cannot
729 assign to the property directly.
731 Args:
732 field: A FieldDescriptor for this field.
733 cls: The class we're constructing.
734 """
735 # TODO: Remove duplication with similar method
736 # for non-repeated scalars.
737 proto_field_name = field.name
738 property_name = _PropertyName(proto_field_name)
740 def getter(self):
741 field_value = self._fields.get(field)
742 if field_value is None:
743 # Construct a new object to represent this field.
744 field_value = field._default_constructor(self)
746 # Atomically check if another thread has preempted us and, if not, swap
747 # in the new object we just created. If someone has preempted us, we
748 # take that object and discard ours.
749 # WARNING: We are relying on setdefault() being atomic. This is true
750 # in CPython but we haven't investigated others. This warning appears
751 # in several other locations in this file.
752 field_value = self._fields.setdefault(field, field_value)
753 return field_value
754 getter.__module__ = None
755 getter.__doc__ = 'Getter for %s.' % proto_field_name
757 # We define a setter just so we can throw an exception with a more
758 # helpful error message.
759 def setter(self, new_value):
760 raise AttributeError('Assignment not allowed to composite field '
761 '"%s" in protocol message object.' % proto_field_name)
763 # Add a property to encapsulate the getter.
764 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
765 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
768def _AddPropertiesForExtensions(descriptor, cls):
769 """Adds properties for all fields in this protocol message type."""
770 extensions = descriptor.extensions_by_name
771 for extension_name, extension_field in extensions.items():
772 constant_name = extension_name.upper() + '_FIELD_NUMBER'
773 setattr(cls, constant_name, extension_field.number)
775 # TODO: Migrate all users of these attributes to functions like
776 # pool.FindExtensionByNumber(descriptor).
777 if descriptor.file is not None:
778 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
779 pool = descriptor.file.pool
781def _AddStaticMethods(cls):
782 # TODO: This probably needs to be thread-safe(?)
783 def RegisterExtension(field_descriptor):
784 field_descriptor.containing_type = cls.DESCRIPTOR
785 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
786 # pylint: disable=protected-access
787 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor)
788 _AttachFieldHelpers(cls, field_descriptor)
789 cls.RegisterExtension = staticmethod(RegisterExtension)
791 def FromString(s):
792 message = cls()
793 message.MergeFromString(s)
794 return message
795 cls.FromString = staticmethod(FromString)
798def _IsPresent(item):
799 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
800 value should be included in the list returned by ListFields()."""
802 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
803 return bool(item[1])
804 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
805 return item[1]._is_present_in_parent
806 else:
807 return True
810def _AddListFieldsMethod(message_descriptor, cls):
811 """Helper for _AddMessageMethods()."""
813 def ListFields(self):
814 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
815 all_fields.sort(key = lambda item: item[0].number)
816 return all_fields
818 cls.ListFields = ListFields
821def _AddHasFieldMethod(message_descriptor, cls):
822 """Helper for _AddMessageMethods()."""
824 hassable_fields = {}
825 for field in message_descriptor.fields:
826 if field.label == _FieldDescriptor.LABEL_REPEATED:
827 continue
828 # For proto3, only submessages and fields inside a oneof have presence.
829 if not field.has_presence:
830 continue
831 hassable_fields[field.name] = field
833 # Has methods are supported for oneof descriptors.
834 for oneof in message_descriptor.oneofs:
835 hassable_fields[oneof.name] = oneof
837 def HasField(self, field_name):
838 try:
839 field = hassable_fields[field_name]
840 except KeyError as exc:
841 raise ValueError('Protocol message %s has no non-repeated field "%s" '
842 'nor has presence is not available for this field.' % (
843 message_descriptor.full_name, field_name)) from exc
845 if isinstance(field, descriptor_mod.OneofDescriptor):
846 try:
847 return HasField(self, self._oneofs[field].name)
848 except KeyError:
849 return False
850 else:
851 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
852 value = self._fields.get(field)
853 return value is not None and value._is_present_in_parent
854 else:
855 return field in self._fields
857 cls.HasField = HasField
860def _AddClearFieldMethod(message_descriptor, cls):
861 """Helper for _AddMessageMethods()."""
862 def ClearField(self, field_name):
863 try:
864 field = message_descriptor.fields_by_name[field_name]
865 except KeyError:
866 try:
867 field = message_descriptor.oneofs_by_name[field_name]
868 if field in self._oneofs:
869 field = self._oneofs[field]
870 else:
871 return
872 except KeyError:
873 raise ValueError('Protocol message %s has no "%s" field.' %
874 (message_descriptor.name, field_name))
876 if field in self._fields:
877 # To match the C++ implementation, we need to invalidate iterators
878 # for map fields when ClearField() happens.
879 if hasattr(self._fields[field], 'InvalidateIterators'):
880 self._fields[field].InvalidateIterators()
882 # Note: If the field is a sub-message, its listener will still point
883 # at us. That's fine, because the worst than can happen is that it
884 # will call _Modified() and invalidate our byte size. Big deal.
885 del self._fields[field]
887 if self._oneofs.get(field.containing_oneof, None) is field:
888 del self._oneofs[field.containing_oneof]
890 # Always call _Modified() -- even if nothing was changed, this is
891 # a mutating method, and thus calling it should cause the field to become
892 # present in the parent message.
893 self._Modified()
895 cls.ClearField = ClearField
898def _AddClearExtensionMethod(cls):
899 """Helper for _AddMessageMethods()."""
900 def ClearExtension(self, field_descriptor):
901 extension_dict._VerifyExtensionHandle(self, field_descriptor)
903 # Similar to ClearField(), above.
904 if field_descriptor in self._fields:
905 del self._fields[field_descriptor]
906 self._Modified()
907 cls.ClearExtension = ClearExtension
910def _AddHasExtensionMethod(cls):
911 """Helper for _AddMessageMethods()."""
912 def HasExtension(self, field_descriptor):
913 extension_dict._VerifyExtensionHandle(self, field_descriptor)
914 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
915 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
917 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
918 value = self._fields.get(field_descriptor)
919 return value is not None and value._is_present_in_parent
920 else:
921 return field_descriptor in self._fields
922 cls.HasExtension = HasExtension
924def _InternalUnpackAny(msg):
925 """Unpacks Any message and returns the unpacked message.
927 This internal method is different from public Any Unpack method which takes
928 the target message as argument. _InternalUnpackAny method does not have
929 target message type and need to find the message type in descriptor pool.
931 Args:
932 msg: An Any message to be unpacked.
934 Returns:
935 The unpacked message.
936 """
937 # TODO: Don't use the factory of generated messages.
938 # To make Any work with custom factories, use the message factory of the
939 # parent message.
940 # pylint: disable=g-import-not-at-top
941 from google.protobuf import symbol_database
942 factory = symbol_database.Default()
944 type_url = msg.type_url
946 if not type_url:
947 return None
949 # TODO: For now we just strip the hostname. Better logic will be
950 # required.
951 type_name = type_url.split('/')[-1]
952 descriptor = factory.pool.FindMessageTypeByName(type_name)
954 if descriptor is None:
955 return None
957 message_class = factory.GetPrototype(descriptor)
958 message = message_class()
960 message.ParseFromString(msg.value)
961 return message
964def _AddEqualsMethod(message_descriptor, cls):
965 """Helper for _AddMessageMethods()."""
966 def __eq__(self, other):
967 if (not isinstance(other, message_mod.Message) or
968 other.DESCRIPTOR != self.DESCRIPTOR):
969 return NotImplemented
971 if self is other:
972 return True
974 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
975 any_a = _InternalUnpackAny(self)
976 any_b = _InternalUnpackAny(other)
977 if any_a and any_b:
978 return any_a == any_b
980 if not self.ListFields() == other.ListFields():
981 return False
983 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
984 # then use it for the comparison.
985 unknown_fields = list(self._unknown_fields)
986 unknown_fields.sort()
987 other_unknown_fields = list(other._unknown_fields)
988 other_unknown_fields.sort()
989 return unknown_fields == other_unknown_fields
991 cls.__eq__ = __eq__
994def _AddStrMethod(message_descriptor, cls):
995 """Helper for _AddMessageMethods()."""
996 def __str__(self):
997 return text_format.MessageToString(self)
998 cls.__str__ = __str__
1001def _AddReprMethod(message_descriptor, cls):
1002 """Helper for _AddMessageMethods()."""
1003 def __repr__(self):
1004 return text_format.MessageToString(self)
1005 cls.__repr__ = __repr__
1008def _AddUnicodeMethod(unused_message_descriptor, cls):
1009 """Helper for _AddMessageMethods()."""
1011 def __unicode__(self):
1012 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1013 cls.__unicode__ = __unicode__
1016def _BytesForNonRepeatedElement(value, field_number, field_type):
1017 """Returns the number of bytes needed to serialize a non-repeated element.
1018 The returned byte count includes space for tag information and any
1019 other additional space associated with serializing value.
1021 Args:
1022 value: Value we're serializing.
1023 field_number: Field number of this value. (Since the field number
1024 is stored as part of a varint-encoded tag, this has an impact
1025 on the total bytes required to serialize the value).
1026 field_type: The type of the field. One of the TYPE_* constants
1027 within FieldDescriptor.
1028 """
1029 try:
1030 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1031 return fn(field_number, value)
1032 except KeyError:
1033 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1036def _AddByteSizeMethod(message_descriptor, cls):
1037 """Helper for _AddMessageMethods()."""
1039 def ByteSize(self):
1040 if not self._cached_byte_size_dirty:
1041 return self._cached_byte_size
1043 size = 0
1044 descriptor = self.DESCRIPTOR
1045 if descriptor._is_map_entry:
1046 # Fields of map entry should always be serialized.
1047 key_field = descriptor.fields_by_name['key']
1048 _MaybeAddEncoder(cls, key_field)
1049 size = key_field._sizer(self.key)
1050 value_field = descriptor.fields_by_name['value']
1051 _MaybeAddEncoder(cls, value_field)
1052 size += value_field._sizer(self.value)
1053 else:
1054 for field_descriptor, field_value in self.ListFields():
1055 _MaybeAddEncoder(cls, field_descriptor)
1056 size += field_descriptor._sizer(field_value)
1057 for tag_bytes, value_bytes in self._unknown_fields:
1058 size += len(tag_bytes) + len(value_bytes)
1060 self._cached_byte_size = size
1061 self._cached_byte_size_dirty = False
1062 self._listener_for_children.dirty = False
1063 return size
1065 cls.ByteSize = ByteSize
1068def _AddSerializeToStringMethod(message_descriptor, cls):
1069 """Helper for _AddMessageMethods()."""
1071 def SerializeToString(self, **kwargs):
1072 # Check if the message has all of its required fields set.
1073 if not self.IsInitialized():
1074 raise message_mod.EncodeError(
1075 'Message %s is missing required fields: %s' % (
1076 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1077 return self.SerializePartialToString(**kwargs)
1078 cls.SerializeToString = SerializeToString
1081def _AddSerializePartialToStringMethod(message_descriptor, cls):
1082 """Helper for _AddMessageMethods()."""
1084 def SerializePartialToString(self, **kwargs):
1085 out = BytesIO()
1086 self._InternalSerialize(out.write, **kwargs)
1087 return out.getvalue()
1088 cls.SerializePartialToString = SerializePartialToString
1090 def InternalSerialize(self, write_bytes, deterministic=None):
1091 if deterministic is None:
1092 deterministic = (
1093 api_implementation.IsPythonDefaultSerializationDeterministic())
1094 else:
1095 deterministic = bool(deterministic)
1097 descriptor = self.DESCRIPTOR
1098 if descriptor._is_map_entry:
1099 # Fields of map entry should always be serialized.
1100 key_field = descriptor.fields_by_name['key']
1101 _MaybeAddEncoder(cls, key_field)
1102 key_field._encoder(write_bytes, self.key, deterministic)
1103 value_field = descriptor.fields_by_name['value']
1104 _MaybeAddEncoder(cls, value_field)
1105 value_field._encoder(write_bytes, self.value, deterministic)
1106 else:
1107 for field_descriptor, field_value in self.ListFields():
1108 _MaybeAddEncoder(cls, field_descriptor)
1109 field_descriptor._encoder(write_bytes, field_value, deterministic)
1110 for tag_bytes, value_bytes in self._unknown_fields:
1111 write_bytes(tag_bytes)
1112 write_bytes(value_bytes)
1113 cls._InternalSerialize = InternalSerialize
1116def _AddMergeFromStringMethod(message_descriptor, cls):
1117 """Helper for _AddMessageMethods()."""
1118 def MergeFromString(self, serialized):
1119 serialized = memoryview(serialized)
1120 length = len(serialized)
1121 try:
1122 if self._InternalParse(serialized, 0, length) != length:
1123 # The only reason _InternalParse would return early is if it
1124 # encountered an end-group tag.
1125 raise message_mod.DecodeError('Unexpected end-group tag.')
1126 except (IndexError, TypeError):
1127 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1128 raise message_mod.DecodeError('Truncated message.')
1129 except struct.error as e:
1130 raise message_mod.DecodeError(e)
1131 return length # Return this for legacy reasons.
1132 cls.MergeFromString = MergeFromString
1134 local_ReadTag = decoder.ReadTag
1135 local_SkipField = decoder.SkipField
1136 fields_by_tag = cls._fields_by_tag
1137 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1139 def InternalParse(self, buffer, pos, end):
1140 """Create a message from serialized bytes.
1142 Args:
1143 self: Message, instance of the proto message object.
1144 buffer: memoryview of the serialized data.
1145 pos: int, position to start in the serialized data.
1146 end: int, end position of the serialized data.
1148 Returns:
1149 Message object.
1150 """
1151 # Guard against internal misuse, since this function is called internally
1152 # quite extensively, and its easy to accidentally pass bytes.
1153 assert isinstance(buffer, memoryview)
1154 self._Modified()
1155 field_dict = self._fields
1156 # pylint: disable=protected-access
1157 unknown_field_set = self._unknown_field_set
1158 while pos != end:
1159 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1160 field_decoder, field_des = message_set_decoders_by_tag.get(
1161 tag_bytes, (None, None)
1162 )
1163 if field_decoder:
1164 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1165 continue
1166 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1167 if field_des is None:
1168 if not self._unknown_fields: # pylint: disable=protected-access
1169 self._unknown_fields = [] # pylint: disable=protected-access
1170 if unknown_field_set is None:
1171 # pylint: disable=protected-access
1172 self._unknown_field_set = containers.UnknownFieldSet()
1173 # pylint: disable=protected-access
1174 unknown_field_set = self._unknown_field_set
1175 # pylint: disable=protected-access
1176 (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1177 field_number, wire_type = wire_format.UnpackTag(tag)
1178 if field_number == 0:
1179 raise message_mod.DecodeError('Field number 0 is illegal.')
1180 # TODO: remove old_pos.
1181 old_pos = new_pos
1182 (data, new_pos) = decoder._DecodeUnknownField(
1183 buffer, new_pos, wire_type) # pylint: disable=protected-access
1184 if new_pos == -1:
1185 return pos
1186 # pylint: disable=protected-access
1187 unknown_field_set._add(field_number, wire_type, data)
1188 # TODO: remove _unknown_fields.
1189 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1190 if new_pos == -1:
1191 return pos
1192 self._unknown_fields.append(
1193 (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1194 pos = new_pos
1195 else:
1196 _MaybeAddDecoder(cls, field_des)
1197 field_decoder = field_des._decoders[is_packed]
1198 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1199 if field_des.containing_oneof:
1200 self._UpdateOneofState(field_des)
1201 return pos
1202 cls._InternalParse = InternalParse
1205def _AddIsInitializedMethod(message_descriptor, cls):
1206 """Adds the IsInitialized and FindInitializationError methods to the
1207 protocol message class."""
1209 required_fields = [field for field in message_descriptor.fields
1210 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1212 def IsInitialized(self, errors=None):
1213 """Checks if all required fields of a message are set.
1215 Args:
1216 errors: A list which, if provided, will be populated with the field
1217 paths of all missing required fields.
1219 Returns:
1220 True iff the specified message has all required fields set.
1221 """
1223 # Performance is critical so we avoid HasField() and ListFields().
1225 for field in required_fields:
1226 if (field not in self._fields or
1227 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1228 not self._fields[field]._is_present_in_parent)):
1229 if errors is not None:
1230 errors.extend(self.FindInitializationErrors())
1231 return False
1233 for field, value in list(self._fields.items()): # dict can change size!
1234 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1235 if field.label == _FieldDescriptor.LABEL_REPEATED:
1236 if (field.message_type._is_map_entry):
1237 continue
1238 for element in value:
1239 if not element.IsInitialized():
1240 if errors is not None:
1241 errors.extend(self.FindInitializationErrors())
1242 return False
1243 elif value._is_present_in_parent and not value.IsInitialized():
1244 if errors is not None:
1245 errors.extend(self.FindInitializationErrors())
1246 return False
1248 return True
1250 cls.IsInitialized = IsInitialized
1252 def FindInitializationErrors(self):
1253 """Finds required fields which are not initialized.
1255 Returns:
1256 A list of strings. Each string is a path to an uninitialized field from
1257 the top-level message, e.g. "foo.bar[5].baz".
1258 """
1260 errors = [] # simplify things
1262 for field in required_fields:
1263 if not self.HasField(field.name):
1264 errors.append(field.name)
1266 for field, value in self.ListFields():
1267 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1268 if field.is_extension:
1269 name = '(%s)' % field.full_name
1270 else:
1271 name = field.name
1273 if _IsMapField(field):
1274 if _IsMessageMapField(field):
1275 for key in value:
1276 element = value[key]
1277 prefix = '%s[%s].' % (name, key)
1278 sub_errors = element.FindInitializationErrors()
1279 errors += [prefix + error for error in sub_errors]
1280 else:
1281 # ScalarMaps can't have any initialization errors.
1282 pass
1283 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1284 for i in range(len(value)):
1285 element = value[i]
1286 prefix = '%s[%d].' % (name, i)
1287 sub_errors = element.FindInitializationErrors()
1288 errors += [prefix + error for error in sub_errors]
1289 else:
1290 prefix = name + '.'
1291 sub_errors = value.FindInitializationErrors()
1292 errors += [prefix + error for error in sub_errors]
1294 return errors
1296 cls.FindInitializationErrors = FindInitializationErrors
1299def _FullyQualifiedClassName(klass):
1300 module = klass.__module__
1301 name = getattr(klass, '__qualname__', klass.__name__)
1302 if module in (None, 'builtins', '__builtin__'):
1303 return name
1304 return module + '.' + name
1307def _AddMergeFromMethod(cls):
1308 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1309 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1311 def MergeFrom(self, msg):
1312 if not isinstance(msg, cls):
1313 raise TypeError(
1314 'Parameter to MergeFrom() must be instance of same class: '
1315 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1316 _FullyQualifiedClassName(msg.__class__)))
1318 assert msg is not self
1319 self._Modified()
1321 fields = self._fields
1323 for field, value in msg._fields.items():
1324 if field.label == LABEL_REPEATED:
1325 field_value = fields.get(field)
1326 if field_value is None:
1327 # Construct a new object to represent this field.
1328 field_value = field._default_constructor(self)
1329 fields[field] = field_value
1330 field_value.MergeFrom(value)
1331 elif field.cpp_type == CPPTYPE_MESSAGE:
1332 if value._is_present_in_parent:
1333 field_value = fields.get(field)
1334 if field_value is None:
1335 # Construct a new object to represent this field.
1336 field_value = field._default_constructor(self)
1337 fields[field] = field_value
1338 field_value.MergeFrom(value)
1339 else:
1340 self._fields[field] = value
1341 if field.containing_oneof:
1342 self._UpdateOneofState(field)
1344 if msg._unknown_fields:
1345 if not self._unknown_fields:
1346 self._unknown_fields = []
1347 self._unknown_fields.extend(msg._unknown_fields)
1348 # pylint: disable=protected-access
1349 if self._unknown_field_set is None:
1350 self._unknown_field_set = containers.UnknownFieldSet()
1351 self._unknown_field_set._extend(msg._unknown_field_set)
1353 cls.MergeFrom = MergeFrom
1356def _AddWhichOneofMethod(message_descriptor, cls):
1357 def WhichOneof(self, oneof_name):
1358 """Returns the name of the currently set field inside a oneof, or None."""
1359 try:
1360 field = message_descriptor.oneofs_by_name[oneof_name]
1361 except KeyError:
1362 raise ValueError(
1363 'Protocol message has no oneof "%s" field.' % oneof_name)
1365 nested_field = self._oneofs.get(field, None)
1366 if nested_field is not None and self.HasField(nested_field.name):
1367 return nested_field.name
1368 else:
1369 return None
1371 cls.WhichOneof = WhichOneof
1374def _Clear(self):
1375 # Clear fields.
1376 self._fields = {}
1377 self._unknown_fields = ()
1378 # pylint: disable=protected-access
1379 if self._unknown_field_set is not None:
1380 self._unknown_field_set._clear()
1381 self._unknown_field_set = None
1383 self._oneofs = {}
1384 self._Modified()
1387def _UnknownFields(self):
1388 warnings.warn(
1389 'message.UnknownFields() is deprecated. Please use the add one '
1390 'feature unknown_fields.UnknownFieldSet(message) in '
1391 'unknown_fields.py instead.'
1392 )
1393 if self._unknown_field_set is None: # pylint: disable=protected-access
1394 # pylint: disable=protected-access
1395 self._unknown_field_set = containers.UnknownFieldSet()
1396 return self._unknown_field_set # pylint: disable=protected-access
1399def _DiscardUnknownFields(self):
1400 self._unknown_fields = []
1401 self._unknown_field_set = None # pylint: disable=protected-access
1402 for field, value in self.ListFields():
1403 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1404 if _IsMapField(field):
1405 if _IsMessageMapField(field):
1406 for key in value:
1407 value[key].DiscardUnknownFields()
1408 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1409 for sub_message in value:
1410 sub_message.DiscardUnknownFields()
1411 else:
1412 value.DiscardUnknownFields()
1415def _SetListener(self, listener):
1416 if listener is None:
1417 self._listener = message_listener_mod.NullMessageListener()
1418 else:
1419 self._listener = listener
1422def _AddMessageMethods(message_descriptor, cls):
1423 """Adds implementations of all Message methods to cls."""
1424 _AddListFieldsMethod(message_descriptor, cls)
1425 _AddHasFieldMethod(message_descriptor, cls)
1426 _AddClearFieldMethod(message_descriptor, cls)
1427 if message_descriptor.is_extendable:
1428 _AddClearExtensionMethod(cls)
1429 _AddHasExtensionMethod(cls)
1430 _AddEqualsMethod(message_descriptor, cls)
1431 _AddStrMethod(message_descriptor, cls)
1432 _AddReprMethod(message_descriptor, cls)
1433 _AddUnicodeMethod(message_descriptor, cls)
1434 _AddByteSizeMethod(message_descriptor, cls)
1435 _AddSerializeToStringMethod(message_descriptor, cls)
1436 _AddSerializePartialToStringMethod(message_descriptor, cls)
1437 _AddMergeFromStringMethod(message_descriptor, cls)
1438 _AddIsInitializedMethod(message_descriptor, cls)
1439 _AddMergeFromMethod(cls)
1440 _AddWhichOneofMethod(message_descriptor, cls)
1441 # Adds methods which do not depend on cls.
1442 cls.Clear = _Clear
1443 cls.UnknownFields = _UnknownFields
1444 cls.DiscardUnknownFields = _DiscardUnknownFields
1445 cls._SetListener = _SetListener
1448def _AddPrivateHelperMethods(message_descriptor, cls):
1449 """Adds implementation of private helper methods to cls."""
1451 def Modified(self):
1452 """Sets the _cached_byte_size_dirty bit to true,
1453 and propagates this to our listener iff this was a state change.
1454 """
1456 # Note: Some callers check _cached_byte_size_dirty before calling
1457 # _Modified() as an extra optimization. So, if this method is ever
1458 # changed such that it does stuff even when _cached_byte_size_dirty is
1459 # already true, the callers need to be updated.
1460 if not self._cached_byte_size_dirty:
1461 self._cached_byte_size_dirty = True
1462 self._listener_for_children.dirty = True
1463 self._is_present_in_parent = True
1464 self._listener.Modified()
1466 def _UpdateOneofState(self, field):
1467 """Sets field as the active field in its containing oneof.
1469 Will also delete currently active field in the oneof, if it is different
1470 from the argument. Does not mark the message as modified.
1471 """
1472 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1473 if other_field is not field:
1474 del self._fields[other_field]
1475 self._oneofs[field.containing_oneof] = field
1477 cls._Modified = Modified
1478 cls.SetInParent = Modified
1479 cls._UpdateOneofState = _UpdateOneofState
1482class _Listener(object):
1484 """MessageListener implementation that a parent message registers with its
1485 child message.
1487 In order to support semantics like:
1489 foo.bar.baz.moo = 23
1490 assert foo.HasField('bar')
1492 ...child objects must have back references to their parents.
1493 This helper class is at the heart of this support.
1494 """
1496 def __init__(self, parent_message):
1497 """Args:
1498 parent_message: The message whose _Modified() method we should call when
1499 we receive Modified() messages.
1500 """
1501 # This listener establishes a back reference from a child (contained) object
1502 # to its parent (containing) object. We make this a weak reference to avoid
1503 # creating cyclic garbage when the client finishes with the 'parent' object
1504 # in the tree.
1505 if isinstance(parent_message, weakref.ProxyType):
1506 self._parent_message_weakref = parent_message
1507 else:
1508 self._parent_message_weakref = weakref.proxy(parent_message)
1510 # As an optimization, we also indicate directly on the listener whether
1511 # or not the parent message is dirty. This way we can avoid traversing
1512 # up the tree in the common case.
1513 self.dirty = False
1515 def Modified(self):
1516 if self.dirty:
1517 return
1518 try:
1519 # Propagate the signal to our parents iff this is the first field set.
1520 self._parent_message_weakref._Modified()
1521 except ReferenceError:
1522 # We can get here if a client has kept a reference to a child object,
1523 # and is now setting a field on it, but the child's parent has been
1524 # garbage-collected. This is not an error.
1525 pass
1528class _OneofListener(_Listener):
1529 """Special listener implementation for setting composite oneof fields."""
1531 def __init__(self, parent_message, field):
1532 """Args:
1533 parent_message: The message whose _Modified() method we should call when
1534 we receive Modified() messages.
1535 field: The descriptor of the field being set in the parent message.
1536 """
1537 super(_OneofListener, self).__init__(parent_message)
1538 self._field = field
1540 def Modified(self):
1541 """Also updates the state of the containing oneof in the parent message."""
1542 try:
1543 self._parent_message_weakref._UpdateOneofState(self._field)
1544 super(_OneofListener, self).Modified()
1545 except ReferenceError:
1546 pass