Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 67%
732 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:57 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:57 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
51__author__ = 'robinson@google.com (Will Robinson)'
53from io import BytesIO
54import struct
55import sys
56import weakref
58# We use "as" to avoid name collisions with variables.
59from google.protobuf.internal import api_implementation
60from google.protobuf.internal import containers
61from google.protobuf.internal import decoder
62from google.protobuf.internal import encoder
63from google.protobuf.internal import enum_type_wrapper
64from google.protobuf.internal import extension_dict
65from google.protobuf.internal import message_listener as message_listener_mod
66from google.protobuf.internal import type_checkers
67from google.protobuf.internal import well_known_types
68from google.protobuf.internal import wire_format
69from google.protobuf import descriptor as descriptor_mod
70from google.protobuf import message as message_mod
71from google.protobuf import text_format
73_FieldDescriptor = descriptor_mod.FieldDescriptor
74_AnyFullTypeName = 'google.protobuf.Any'
75_ExtensionDict = extension_dict._ExtensionDict
77class GeneratedProtocolMessageType(type):
79 """Metaclass for protocol message classes created at runtime from Descriptors.
81 We add implementations for all methods described in the Message class. We
82 also create properties to allow getting/setting all fields in the protocol
83 message. Finally, we create slots to prevent users from accidentally
84 "setting" nonexistent fields in the protocol message, which then wouldn't get
85 serialized / deserialized properly.
87 The protocol compiler currently uses this metaclass to create protocol
88 message classes at runtime. Clients can also manually create their own
89 classes at runtime, as in this example:
91 mydescriptor = Descriptor(.....)
92 factory = symbol_database.Default()
93 factory.pool.AddDescriptor(mydescriptor)
94 MyProtoClass = factory.GetPrototype(mydescriptor)
95 myproto_instance = MyProtoClass()
96 myproto.foo_field = 23
97 ...
98 """
100 # Must be consistent with the protocol-compiler code in
101 # proto2/compiler/internal/generator.*.
102 _DESCRIPTOR_KEY = 'DESCRIPTOR'
104 def __new__(cls, name, bases, dictionary):
105 """Custom allocation for runtime-generated class types.
107 We override __new__ because this is apparently the only place
108 where we can meaningfully set __slots__ on the class we're creating(?).
109 (The interplay between metaclasses and slots is not very well-documented).
111 Args:
112 name: Name of the class (ignored, but required by the
113 metaclass protocol).
114 bases: Base classes of the class we're constructing.
115 (Should be message.Message). We ignore this field, but
116 it's required by the metaclass protocol
117 dictionary: The class dictionary of the class we're
118 constructing. dictionary[_DESCRIPTOR_KEY] must contain
119 a Descriptor object describing this protocol message
120 type.
122 Returns:
123 Newly-allocated class.
125 Raises:
126 RuntimeError: Generated code only work with python cpp extension.
127 """
128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
130 if isinstance(descriptor, str):
131 raise RuntimeError('The generated code only work with python cpp '
132 'extension, but it is using pure python runtime.')
134 # If a concrete class already exists for this descriptor, don't try to
135 # create another. Doing so will break any messages that already exist with
136 # the existing class.
137 #
138 # The C++ implementation appears to have its own internal `PyMessageFactory`
139 # to achieve similar results.
140 #
141 # This most commonly happens in `text_format.py` when using descriptors from
142 # a custom pool; it calls symbol_database.Global().getPrototype() on a
143 # descriptor which already has an existing concrete class.
144 new_class = getattr(descriptor, '_concrete_class', None)
145 if new_class:
146 return new_class
148 if descriptor.full_name in well_known_types.WKTBASES:
149 bases += (well_known_types.WKTBASES[descriptor.full_name],)
150 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
151 _AddSlots(descriptor, dictionary)
153 superclass = super(GeneratedProtocolMessageType, cls)
154 new_class = superclass.__new__(cls, name, bases, dictionary)
155 return new_class
157 def __init__(cls, name, bases, dictionary):
158 """Here we perform the majority of our work on the class.
159 We add enum getters, an __init__ method, implementations
160 of all Message methods, and properties for all fields
161 in the protocol type.
163 Args:
164 name: Name of the class (ignored, but required by the
165 metaclass protocol).
166 bases: Base classes of the class we're constructing.
167 (Should be message.Message). We ignore this field, but
168 it's required by the metaclass protocol
169 dictionary: The class dictionary of the class we're
170 constructing. dictionary[_DESCRIPTOR_KEY] must contain
171 a Descriptor object describing this protocol message
172 type.
173 """
174 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
176 # If this is an _existing_ class looked up via `_concrete_class` in the
177 # __new__ method above, then we don't need to re-initialize anything.
178 existing_class = getattr(descriptor, '_concrete_class', None)
179 if existing_class:
180 assert existing_class is cls, (
181 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
182 % (descriptor.full_name))
183 return
185 cls._decoders_by_tag = {}
186 if (descriptor.has_options and
187 descriptor.GetOptions().message_set_wire_format):
188 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
189 decoder.MessageSetItemDecoder(descriptor), None)
191 # Attach stuff to each FieldDescriptor for quick lookup later on.
192 for field in descriptor.fields:
193 _AttachFieldHelpers(cls, field)
195 descriptor._concrete_class = cls # pylint: disable=protected-access
196 _AddEnumValues(descriptor, cls)
197 _AddInitMethod(descriptor, cls)
198 _AddPropertiesForFields(descriptor, cls)
199 _AddPropertiesForExtensions(descriptor, cls)
200 _AddStaticMethods(cls)
201 _AddMessageMethods(descriptor, cls)
202 _AddPrivateHelperMethods(descriptor, cls)
204 superclass = super(GeneratedProtocolMessageType, cls)
205 superclass.__init__(name, bases, dictionary)
208# Stateless helpers for GeneratedProtocolMessageType below.
209# Outside clients should not access these directly.
210#
211# I opted not to make any of these methods on the metaclass, to make it more
212# clear that I'm not really using any state there and to keep clients from
213# thinking that they have direct access to these construction helpers.
216def _PropertyName(proto_field_name):
217 """Returns the name of the public property attribute which
218 clients can use to get and (in some cases) set the value
219 of a protocol message field.
221 Args:
222 proto_field_name: The protocol message field name, exactly
223 as it appears (or would appear) in a .proto file.
224 """
225 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
226 # nnorwitz makes my day by writing:
227 # """
228 # FYI. See the keyword module in the stdlib. This could be as simple as:
229 #
230 # if keyword.iskeyword(proto_field_name):
231 # return proto_field_name + "_"
232 # return proto_field_name
233 # """
234 # Kenton says: The above is a BAD IDEA. People rely on being able to use
235 # getattr() and setattr() to reflectively manipulate field values. If we
236 # rename the properties, then every such user has to also make sure to apply
237 # the same transformation. Note that currently if you name a field "yield",
238 # you can still access it just fine using getattr/setattr -- it's not even
239 # that cumbersome to do so.
240 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
241 # position.
242 return proto_field_name
245def _AddSlots(message_descriptor, dictionary):
246 """Adds a __slots__ entry to dictionary, containing the names of all valid
247 attributes for this message type.
249 Args:
250 message_descriptor: A Descriptor instance describing this message type.
251 dictionary: Class dictionary to which we'll add a '__slots__' entry.
252 """
253 dictionary['__slots__'] = ['_cached_byte_size',
254 '_cached_byte_size_dirty',
255 '_fields',
256 '_unknown_fields',
257 '_unknown_field_set',
258 '_is_present_in_parent',
259 '_listener',
260 '_listener_for_children',
261 '__weakref__',
262 '_oneofs']
265def _IsMessageSetExtension(field):
266 return (field.is_extension and
267 field.containing_type.has_options and
268 field.containing_type.GetOptions().message_set_wire_format and
269 field.type == _FieldDescriptor.TYPE_MESSAGE and
270 field.label == _FieldDescriptor.LABEL_OPTIONAL)
273def _IsMapField(field):
274 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
275 field.message_type.has_options and
276 field.message_type.GetOptions().map_entry)
279def _IsMessageMapField(field):
280 value_type = field.message_type.fields_by_name['value']
281 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
284def _AttachFieldHelpers(cls, field_descriptor):
285 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
286 is_map_entry = _IsMapField(field_descriptor)
287 is_packed = field_descriptor.is_packed
289 if is_map_entry:
290 field_encoder = encoder.MapEncoder(field_descriptor)
291 sizer = encoder.MapSizer(field_descriptor,
292 _IsMessageMapField(field_descriptor))
293 elif _IsMessageSetExtension(field_descriptor):
294 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
295 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
296 else:
297 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
298 field_descriptor.number, is_repeated, is_packed)
299 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
300 field_descriptor.number, is_repeated, is_packed)
302 field_descriptor._encoder = field_encoder
303 field_descriptor._sizer = sizer
304 field_descriptor._default_constructor = _DefaultValueConstructorForField(
305 field_descriptor)
307 def AddDecoder(wiretype, is_packed):
308 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
309 decode_type = field_descriptor.type
310 if (decode_type == _FieldDescriptor.TYPE_ENUM and
311 not field_descriptor.enum_type.is_closed):
312 decode_type = _FieldDescriptor.TYPE_INT32
314 oneof_descriptor = None
315 if field_descriptor.containing_oneof is not None:
316 oneof_descriptor = field_descriptor
318 if is_map_entry:
319 is_message_map = _IsMessageMapField(field_descriptor)
321 field_decoder = decoder.MapDecoder(
322 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
323 is_message_map)
324 elif decode_type == _FieldDescriptor.TYPE_STRING:
325 field_decoder = decoder.StringDecoder(
326 field_descriptor.number, is_repeated, is_packed,
327 field_descriptor, field_descriptor._default_constructor,
328 not field_descriptor.has_presence)
329 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
330 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
331 field_descriptor.number, is_repeated, is_packed,
332 field_descriptor, field_descriptor._default_constructor)
333 else:
334 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
335 field_descriptor.number, is_repeated, is_packed,
336 # pylint: disable=protected-access
337 field_descriptor, field_descriptor._default_constructor,
338 not field_descriptor.has_presence)
340 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
342 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
343 False)
345 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
346 # To support wire compatibility of adding packed = true, add a decoder for
347 # packed values regardless of the field's options.
348 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
351def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
352 extensions = descriptor.extensions_by_name
353 for extension_name, extension_field in extensions.items():
354 assert extension_name not in dictionary
355 dictionary[extension_name] = extension_field
358def _AddEnumValues(descriptor, cls):
359 """Sets class-level attributes for all enum fields defined in this message.
361 Also exporting a class-level object that can name enum values.
363 Args:
364 descriptor: Descriptor object for this message type.
365 cls: Class we're constructing for this message type.
366 """
367 for enum_type in descriptor.enum_types:
368 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
369 for enum_value in enum_type.values:
370 setattr(cls, enum_value.name, enum_value.number)
373def _GetInitializeDefaultForMap(field):
374 if field.label != _FieldDescriptor.LABEL_REPEATED:
375 raise ValueError('map_entry set on non-repeated field %s' % (
376 field.name))
377 fields_by_name = field.message_type.fields_by_name
378 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
380 value_field = fields_by_name['value']
381 if _IsMessageMapField(field):
382 def MakeMessageMapDefault(message):
383 return containers.MessageMap(
384 message._listener_for_children, value_field.message_type, key_checker,
385 field.message_type)
386 return MakeMessageMapDefault
387 else:
388 value_checker = type_checkers.GetTypeChecker(value_field)
389 def MakePrimitiveMapDefault(message):
390 return containers.ScalarMap(
391 message._listener_for_children, key_checker, value_checker,
392 field.message_type)
393 return MakePrimitiveMapDefault
395def _DefaultValueConstructorForField(field):
396 """Returns a function which returns a default value for a field.
398 Args:
399 field: FieldDescriptor object for this field.
401 The returned function has one argument:
402 message: Message instance containing this field, or a weakref proxy
403 of same.
405 That function in turn returns a default value for this field. The default
406 value may refer back to |message| via a weak reference.
407 """
409 if _IsMapField(field):
410 return _GetInitializeDefaultForMap(field)
412 if field.label == _FieldDescriptor.LABEL_REPEATED:
413 if field.has_default_value and field.default_value != []:
414 raise ValueError('Repeated field default value not empty list: %s' % (
415 field.default_value))
416 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
417 # We can't look at _concrete_class yet since it might not have
418 # been set. (Depends on order in which we initialize the classes).
419 message_type = field.message_type
420 def MakeRepeatedMessageDefault(message):
421 return containers.RepeatedCompositeFieldContainer(
422 message._listener_for_children, field.message_type)
423 return MakeRepeatedMessageDefault
424 else:
425 type_checker = type_checkers.GetTypeChecker(field)
426 def MakeRepeatedScalarDefault(message):
427 return containers.RepeatedScalarFieldContainer(
428 message._listener_for_children, type_checker)
429 return MakeRepeatedScalarDefault
431 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
432 # _concrete_class may not yet be initialized.
433 message_type = field.message_type
434 def MakeSubMessageDefault(message):
435 assert getattr(message_type, '_concrete_class', None), (
436 'Uninitialized concrete class found for field %r (message type %r)'
437 % (field.full_name, message_type.full_name))
438 result = message_type._concrete_class()
439 result._SetListener(
440 _OneofListener(message, field)
441 if field.containing_oneof is not None
442 else message._listener_for_children)
443 return result
444 return MakeSubMessageDefault
446 def MakeScalarDefault(message):
447 # TODO(protobuf-team): This may be broken since there may not be
448 # default_value. Combine with has_default_value somehow.
449 return field.default_value
450 return MakeScalarDefault
453def _ReraiseTypeErrorWithFieldName(message_name, field_name):
454 """Re-raise the currently-handled TypeError with the field name added."""
455 exc = sys.exc_info()[1]
456 if len(exc.args) == 1 and type(exc) is TypeError:
457 # simple TypeError; add field name to exception message
458 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
460 # re-raise possibly-amended exception with original traceback:
461 raise exc.with_traceback(sys.exc_info()[2])
464def _AddInitMethod(message_descriptor, cls):
465 """Adds an __init__ method to cls."""
467 def _GetIntegerEnumValue(enum_type, value):
468 """Convert a string or integer enum value to an integer.
470 If the value is a string, it is converted to the enum value in
471 enum_type with the same name. If the value is not a string, it's
472 returned as-is. (No conversion or bounds-checking is done.)
473 """
474 if isinstance(value, str):
475 try:
476 return enum_type.values_by_name[value].number
477 except KeyError:
478 raise ValueError('Enum type %s: unknown label "%s"' % (
479 enum_type.full_name, value))
480 return value
482 def init(self, **kwargs):
483 self._cached_byte_size = 0
484 self._cached_byte_size_dirty = len(kwargs) > 0
485 self._fields = {}
486 # Contains a mapping from oneof field descriptors to the descriptor
487 # of the currently set field in that oneof field.
488 self._oneofs = {}
490 # _unknown_fields is () when empty for efficiency, and will be turned into
491 # a list if fields are added.
492 self._unknown_fields = ()
493 # _unknown_field_set is None when empty for efficiency, and will be
494 # turned into UnknownFieldSet struct if fields are added.
495 self._unknown_field_set = None # pylint: disable=protected-access
496 self._is_present_in_parent = False
497 self._listener = message_listener_mod.NullMessageListener()
498 self._listener_for_children = _Listener(self)
499 for field_name, field_value in kwargs.items():
500 field = _GetFieldByName(message_descriptor, field_name)
501 if field is None:
502 raise TypeError('%s() got an unexpected keyword argument "%s"' %
503 (message_descriptor.name, field_name))
504 if field_value is None:
505 # field=None is the same as no field at all.
506 continue
507 if field.label == _FieldDescriptor.LABEL_REPEATED:
508 copy = field._default_constructor(self)
509 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
510 if _IsMapField(field):
511 if _IsMessageMapField(field):
512 for key in field_value:
513 copy[key].MergeFrom(field_value[key])
514 else:
515 copy.update(field_value)
516 else:
517 for val in field_value:
518 if isinstance(val, dict):
519 copy.add(**val)
520 else:
521 copy.add().MergeFrom(val)
522 else: # Scalar
523 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
524 field_value = [_GetIntegerEnumValue(field.enum_type, val)
525 for val in field_value]
526 copy.extend(field_value)
527 self._fields[field] = copy
528 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
529 copy = field._default_constructor(self)
530 new_val = field_value
531 if isinstance(field_value, dict):
532 new_val = field.message_type._concrete_class(**field_value)
533 try:
534 copy.MergeFrom(new_val)
535 except TypeError:
536 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
537 self._fields[field] = copy
538 else:
539 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
540 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
541 try:
542 setattr(self, field_name, field_value)
543 except TypeError:
544 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
546 init.__module__ = None
547 init.__doc__ = None
548 cls.__init__ = init
551def _GetFieldByName(message_descriptor, field_name):
552 """Returns a field descriptor by field name.
554 Args:
555 message_descriptor: A Descriptor describing all fields in message.
556 field_name: The name of the field to retrieve.
557 Returns:
558 The field descriptor associated with the field name.
559 """
560 try:
561 return message_descriptor.fields_by_name[field_name]
562 except KeyError:
563 raise ValueError('Protocol message %s has no "%s" field.' %
564 (message_descriptor.name, field_name))
567def _AddPropertiesForFields(descriptor, cls):
568 """Adds properties for all fields in this protocol message type."""
569 for field in descriptor.fields:
570 _AddPropertiesForField(field, cls)
572 if descriptor.is_extendable:
573 # _ExtensionDict is just an adaptor with no state so we allocate a new one
574 # every time it is accessed.
575 cls.Extensions = property(lambda self: _ExtensionDict(self))
578def _AddPropertiesForField(field, cls):
579 """Adds a public property for a protocol message field.
580 Clients can use this property to get and (in the case
581 of non-repeated scalar fields) directly set the value
582 of a protocol message field.
584 Args:
585 field: A FieldDescriptor for this field.
586 cls: The class we're constructing.
587 """
588 # Catch it if we add other types that we should
589 # handle specially here.
590 assert _FieldDescriptor.MAX_CPPTYPE == 10
592 constant_name = field.name.upper() + '_FIELD_NUMBER'
593 setattr(cls, constant_name, field.number)
595 if field.label == _FieldDescriptor.LABEL_REPEATED:
596 _AddPropertiesForRepeatedField(field, cls)
597 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
598 _AddPropertiesForNonRepeatedCompositeField(field, cls)
599 else:
600 _AddPropertiesForNonRepeatedScalarField(field, cls)
603class _FieldProperty(property):
604 __slots__ = ('DESCRIPTOR',)
606 def __init__(self, descriptor, getter, setter, doc):
607 property.__init__(self, getter, setter, doc=doc)
608 self.DESCRIPTOR = descriptor
611def _AddPropertiesForRepeatedField(field, cls):
612 """Adds a public property for a "repeated" protocol message field. Clients
613 can use this property to get the value of the field, which will be either a
614 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
615 below).
617 Note that when clients add values to these containers, we perform
618 type-checking in the case of repeated scalar fields, and we also set any
619 necessary "has" bits as a side-effect.
621 Args:
622 field: A FieldDescriptor for this field.
623 cls: The class we're constructing.
624 """
625 proto_field_name = field.name
626 property_name = _PropertyName(proto_field_name)
628 def getter(self):
629 field_value = self._fields.get(field)
630 if field_value is None:
631 # Construct a new object to represent this field.
632 field_value = field._default_constructor(self)
634 # Atomically check if another thread has preempted us and, if not, swap
635 # in the new object we just created. If someone has preempted us, we
636 # take that object and discard ours.
637 # WARNING: We are relying on setdefault() being atomic. This is true
638 # in CPython but we haven't investigated others. This warning appears
639 # in several other locations in this file.
640 field_value = self._fields.setdefault(field, field_value)
641 return field_value
642 getter.__module__ = None
643 getter.__doc__ = 'Getter for %s.' % proto_field_name
645 # We define a setter just so we can throw an exception with a more
646 # helpful error message.
647 def setter(self, new_value):
648 raise AttributeError('Assignment not allowed to repeated field '
649 '"%s" in protocol message object.' % proto_field_name)
651 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
652 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
655def _AddPropertiesForNonRepeatedScalarField(field, cls):
656 """Adds a public property for a nonrepeated, scalar protocol message field.
657 Clients can use this property to get and directly set the value of the field.
658 Note that when the client sets the value of a field by using this property,
659 all necessary "has" bits are set as a side-effect, and we also perform
660 type-checking.
662 Args:
663 field: A FieldDescriptor for this field.
664 cls: The class we're constructing.
665 """
666 proto_field_name = field.name
667 property_name = _PropertyName(proto_field_name)
668 type_checker = type_checkers.GetTypeChecker(field)
669 default_value = field.default_value
671 def getter(self):
672 # TODO(protobuf-team): This may be broken since there may not be
673 # default_value. Combine with has_default_value somehow.
674 return self._fields.get(field, default_value)
675 getter.__module__ = None
676 getter.__doc__ = 'Getter for %s.' % proto_field_name
678 def field_setter(self, new_value):
679 # pylint: disable=protected-access
680 # Testing the value for truthiness captures all of the proto3 defaults
681 # (0, 0.0, enum 0, and False).
682 try:
683 new_value = type_checker.CheckValue(new_value)
684 except TypeError as e:
685 raise TypeError(
686 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
687 if not field.has_presence and not new_value:
688 self._fields.pop(field, None)
689 else:
690 self._fields[field] = new_value
691 # Check _cached_byte_size_dirty inline to improve performance, since scalar
692 # setters are called frequently.
693 if not self._cached_byte_size_dirty:
694 self._Modified()
696 if field.containing_oneof:
697 def setter(self, new_value):
698 field_setter(self, new_value)
699 self._UpdateOneofState(field)
700 else:
701 setter = field_setter
703 setter.__module__ = None
704 setter.__doc__ = 'Setter for %s.' % proto_field_name
706 # Add a property to encapsulate the getter/setter.
707 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
708 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
711def _AddPropertiesForNonRepeatedCompositeField(field, cls):
712 """Adds a public property for a nonrepeated, composite protocol message field.
713 A composite field is a "group" or "message" field.
715 Clients can use this property to get the value of the field, but cannot
716 assign to the property directly.
718 Args:
719 field: A FieldDescriptor for this field.
720 cls: The class we're constructing.
721 """
722 # TODO(robinson): Remove duplication with similar method
723 # for non-repeated scalars.
724 proto_field_name = field.name
725 property_name = _PropertyName(proto_field_name)
727 def getter(self):
728 field_value = self._fields.get(field)
729 if field_value is None:
730 # Construct a new object to represent this field.
731 field_value = field._default_constructor(self)
733 # Atomically check if another thread has preempted us and, if not, swap
734 # in the new object we just created. If someone has preempted us, we
735 # take that object and discard ours.
736 # WARNING: We are relying on setdefault() being atomic. This is true
737 # in CPython but we haven't investigated others. This warning appears
738 # in several other locations in this file.
739 field_value = self._fields.setdefault(field, field_value)
740 return field_value
741 getter.__module__ = None
742 getter.__doc__ = 'Getter for %s.' % proto_field_name
744 # We define a setter just so we can throw an exception with a more
745 # helpful error message.
746 def setter(self, new_value):
747 raise AttributeError('Assignment not allowed to composite field '
748 '"%s" in protocol message object.' % proto_field_name)
750 # Add a property to encapsulate the getter.
751 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
752 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
755def _AddPropertiesForExtensions(descriptor, cls):
756 """Adds properties for all fields in this protocol message type."""
757 extensions = descriptor.extensions_by_name
758 for extension_name, extension_field in extensions.items():
759 constant_name = extension_name.upper() + '_FIELD_NUMBER'
760 setattr(cls, constant_name, extension_field.number)
762 # TODO(amauryfa): Migrate all users of these attributes to functions like
763 # pool.FindExtensionByNumber(descriptor).
764 if descriptor.file is not None:
765 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
766 pool = descriptor.file.pool
767 cls._extensions_by_number = pool._extensions_by_number[descriptor]
768 cls._extensions_by_name = pool._extensions_by_name[descriptor]
770def _AddStaticMethods(cls):
771 # TODO(robinson): This probably needs to be thread-safe(?)
772 def RegisterExtension(field_descriptor):
773 field_descriptor.containing_type = cls.DESCRIPTOR
774 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
775 # pylint: disable=protected-access
776 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor)
777 _AttachFieldHelpers(cls, field_descriptor)
778 cls.RegisterExtension = staticmethod(RegisterExtension)
780 def FromString(s):
781 message = cls()
782 message.MergeFromString(s)
783 return message
784 cls.FromString = staticmethod(FromString)
787def _IsPresent(item):
788 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
789 value should be included in the list returned by ListFields()."""
791 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
792 return bool(item[1])
793 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
794 return item[1]._is_present_in_parent
795 else:
796 return True
799def _AddListFieldsMethod(message_descriptor, cls):
800 """Helper for _AddMessageMethods()."""
802 def ListFields(self):
803 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
804 all_fields.sort(key = lambda item: item[0].number)
805 return all_fields
807 cls.ListFields = ListFields
810def _AddHasFieldMethod(message_descriptor, cls):
811 """Helper for _AddMessageMethods()."""
813 hassable_fields = {}
814 for field in message_descriptor.fields:
815 if field.label == _FieldDescriptor.LABEL_REPEATED:
816 continue
817 # For proto3, only submessages and fields inside a oneof have presence.
818 if not field.has_presence:
819 continue
820 hassable_fields[field.name] = field
822 # Has methods are supported for oneof descriptors.
823 for oneof in message_descriptor.oneofs:
824 hassable_fields[oneof.name] = oneof
826 def HasField(self, field_name):
827 try:
828 field = hassable_fields[field_name]
829 except KeyError as exc:
830 raise ValueError('Protocol message %s has no non-repeated field "%s" '
831 'nor has presence is not available for this field.' % (
832 message_descriptor.full_name, field_name)) from exc
834 if isinstance(field, descriptor_mod.OneofDescriptor):
835 try:
836 return HasField(self, self._oneofs[field].name)
837 except KeyError:
838 return False
839 else:
840 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
841 value = self._fields.get(field)
842 return value is not None and value._is_present_in_parent
843 else:
844 return field in self._fields
846 cls.HasField = HasField
849def _AddClearFieldMethod(message_descriptor, cls):
850 """Helper for _AddMessageMethods()."""
851 def ClearField(self, field_name):
852 try:
853 field = message_descriptor.fields_by_name[field_name]
854 except KeyError:
855 try:
856 field = message_descriptor.oneofs_by_name[field_name]
857 if field in self._oneofs:
858 field = self._oneofs[field]
859 else:
860 return
861 except KeyError:
862 raise ValueError('Protocol message %s has no "%s" field.' %
863 (message_descriptor.name, field_name))
865 if field in self._fields:
866 # To match the C++ implementation, we need to invalidate iterators
867 # for map fields when ClearField() happens.
868 if hasattr(self._fields[field], 'InvalidateIterators'):
869 self._fields[field].InvalidateIterators()
871 # Note: If the field is a sub-message, its listener will still point
872 # at us. That's fine, because the worst than can happen is that it
873 # will call _Modified() and invalidate our byte size. Big deal.
874 del self._fields[field]
876 if self._oneofs.get(field.containing_oneof, None) is field:
877 del self._oneofs[field.containing_oneof]
879 # Always call _Modified() -- even if nothing was changed, this is
880 # a mutating method, and thus calling it should cause the field to become
881 # present in the parent message.
882 self._Modified()
884 cls.ClearField = ClearField
887def _AddClearExtensionMethod(cls):
888 """Helper for _AddMessageMethods()."""
889 def ClearExtension(self, field_descriptor):
890 extension_dict._VerifyExtensionHandle(self, field_descriptor)
892 # Similar to ClearField(), above.
893 if field_descriptor in self._fields:
894 del self._fields[field_descriptor]
895 self._Modified()
896 cls.ClearExtension = ClearExtension
899def _AddHasExtensionMethod(cls):
900 """Helper for _AddMessageMethods()."""
901 def HasExtension(self, field_descriptor):
902 extension_dict._VerifyExtensionHandle(self, field_descriptor)
903 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
904 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
906 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
907 value = self._fields.get(field_descriptor)
908 return value is not None and value._is_present_in_parent
909 else:
910 return field_descriptor in self._fields
911 cls.HasExtension = HasExtension
913def _InternalUnpackAny(msg):
914 """Unpacks Any message and returns the unpacked message.
916 This internal method is different from public Any Unpack method which takes
917 the target message as argument. _InternalUnpackAny method does not have
918 target message type and need to find the message type in descriptor pool.
920 Args:
921 msg: An Any message to be unpacked.
923 Returns:
924 The unpacked message.
925 """
926 # TODO(amauryfa): Don't use the factory of generated messages.
927 # To make Any work with custom factories, use the message factory of the
928 # parent message.
929 # pylint: disable=g-import-not-at-top
930 from google.protobuf import symbol_database
931 factory = symbol_database.Default()
933 type_url = msg.type_url
935 if not type_url:
936 return None
938 # TODO(haberman): For now we just strip the hostname. Better logic will be
939 # required.
940 type_name = type_url.split('/')[-1]
941 descriptor = factory.pool.FindMessageTypeByName(type_name)
943 if descriptor is None:
944 return None
946 message_class = factory.GetPrototype(descriptor)
947 message = message_class()
949 message.ParseFromString(msg.value)
950 return message
953def _AddEqualsMethod(message_descriptor, cls):
954 """Helper for _AddMessageMethods()."""
955 def __eq__(self, other):
956 if (not isinstance(other, message_mod.Message) or
957 other.DESCRIPTOR != self.DESCRIPTOR):
958 return False
960 if self is other:
961 return True
963 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
964 any_a = _InternalUnpackAny(self)
965 any_b = _InternalUnpackAny(other)
966 if any_a and any_b:
967 return any_a == any_b
969 if not self.ListFields() == other.ListFields():
970 return False
972 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
973 # then use it for the comparison.
974 unknown_fields = list(self._unknown_fields)
975 unknown_fields.sort()
976 other_unknown_fields = list(other._unknown_fields)
977 other_unknown_fields.sort()
978 return unknown_fields == other_unknown_fields
980 cls.__eq__ = __eq__
983def _AddStrMethod(message_descriptor, cls):
984 """Helper for _AddMessageMethods()."""
985 def __str__(self):
986 return text_format.MessageToString(self)
987 cls.__str__ = __str__
990def _AddReprMethod(message_descriptor, cls):
991 """Helper for _AddMessageMethods()."""
992 def __repr__(self):
993 return text_format.MessageToString(self)
994 cls.__repr__ = __repr__
997def _AddUnicodeMethod(unused_message_descriptor, cls):
998 """Helper for _AddMessageMethods()."""
1000 def __unicode__(self):
1001 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1002 cls.__unicode__ = __unicode__
1005def _BytesForNonRepeatedElement(value, field_number, field_type):
1006 """Returns the number of bytes needed to serialize a non-repeated element.
1007 The returned byte count includes space for tag information and any
1008 other additional space associated with serializing value.
1010 Args:
1011 value: Value we're serializing.
1012 field_number: Field number of this value. (Since the field number
1013 is stored as part of a varint-encoded tag, this has an impact
1014 on the total bytes required to serialize the value).
1015 field_type: The type of the field. One of the TYPE_* constants
1016 within FieldDescriptor.
1017 """
1018 try:
1019 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1020 return fn(field_number, value)
1021 except KeyError:
1022 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1025def _AddByteSizeMethod(message_descriptor, cls):
1026 """Helper for _AddMessageMethods()."""
1028 def ByteSize(self):
1029 if not self._cached_byte_size_dirty:
1030 return self._cached_byte_size
1032 size = 0
1033 descriptor = self.DESCRIPTOR
1034 if descriptor.GetOptions().map_entry:
1035 # Fields of map entry should always be serialized.
1036 size = descriptor.fields_by_name['key']._sizer(self.key)
1037 size += descriptor.fields_by_name['value']._sizer(self.value)
1038 else:
1039 for field_descriptor, field_value in self.ListFields():
1040 size += field_descriptor._sizer(field_value)
1041 for tag_bytes, value_bytes in self._unknown_fields:
1042 size += len(tag_bytes) + len(value_bytes)
1044 self._cached_byte_size = size
1045 self._cached_byte_size_dirty = False
1046 self._listener_for_children.dirty = False
1047 return size
1049 cls.ByteSize = ByteSize
1052def _AddSerializeToStringMethod(message_descriptor, cls):
1053 """Helper for _AddMessageMethods()."""
1055 def SerializeToString(self, **kwargs):
1056 # Check if the message has all of its required fields set.
1057 if not self.IsInitialized():
1058 raise message_mod.EncodeError(
1059 'Message %s is missing required fields: %s' % (
1060 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1061 return self.SerializePartialToString(**kwargs)
1062 cls.SerializeToString = SerializeToString
1065def _AddSerializePartialToStringMethod(message_descriptor, cls):
1066 """Helper for _AddMessageMethods()."""
1068 def SerializePartialToString(self, **kwargs):
1069 out = BytesIO()
1070 self._InternalSerialize(out.write, **kwargs)
1071 return out.getvalue()
1072 cls.SerializePartialToString = SerializePartialToString
1074 def InternalSerialize(self, write_bytes, deterministic=None):
1075 if deterministic is None:
1076 deterministic = (
1077 api_implementation.IsPythonDefaultSerializationDeterministic())
1078 else:
1079 deterministic = bool(deterministic)
1081 descriptor = self.DESCRIPTOR
1082 if descriptor.GetOptions().map_entry:
1083 # Fields of map entry should always be serialized.
1084 descriptor.fields_by_name['key']._encoder(
1085 write_bytes, self.key, deterministic)
1086 descriptor.fields_by_name['value']._encoder(
1087 write_bytes, self.value, deterministic)
1088 else:
1089 for field_descriptor, field_value in self.ListFields():
1090 field_descriptor._encoder(write_bytes, field_value, deterministic)
1091 for tag_bytes, value_bytes in self._unknown_fields:
1092 write_bytes(tag_bytes)
1093 write_bytes(value_bytes)
1094 cls._InternalSerialize = InternalSerialize
1097def _AddMergeFromStringMethod(message_descriptor, cls):
1098 """Helper for _AddMessageMethods()."""
1099 def MergeFromString(self, serialized):
1100 serialized = memoryview(serialized)
1101 length = len(serialized)
1102 try:
1103 if self._InternalParse(serialized, 0, length) != length:
1104 # The only reason _InternalParse would return early is if it
1105 # encountered an end-group tag.
1106 raise message_mod.DecodeError('Unexpected end-group tag.')
1107 except (IndexError, TypeError):
1108 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1109 raise message_mod.DecodeError('Truncated message.')
1110 except struct.error as e:
1111 raise message_mod.DecodeError(e)
1112 return length # Return this for legacy reasons.
1113 cls.MergeFromString = MergeFromString
1115 local_ReadTag = decoder.ReadTag
1116 local_SkipField = decoder.SkipField
1117 decoders_by_tag = cls._decoders_by_tag
1119 def InternalParse(self, buffer, pos, end):
1120 """Create a message from serialized bytes.
1122 Args:
1123 self: Message, instance of the proto message object.
1124 buffer: memoryview of the serialized data.
1125 pos: int, position to start in the serialized data.
1126 end: int, end position of the serialized data.
1128 Returns:
1129 Message object.
1130 """
1131 # Guard against internal misuse, since this function is called internally
1132 # quite extensively, and its easy to accidentally pass bytes.
1133 assert isinstance(buffer, memoryview)
1134 self._Modified()
1135 field_dict = self._fields
1136 # pylint: disable=protected-access
1137 unknown_field_set = self._unknown_field_set
1138 while pos != end:
1139 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1140 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1141 if field_decoder is None:
1142 if not self._unknown_fields: # pylint: disable=protected-access
1143 self._unknown_fields = [] # pylint: disable=protected-access
1144 if unknown_field_set is None:
1145 # pylint: disable=protected-access
1146 self._unknown_field_set = containers.UnknownFieldSet()
1147 # pylint: disable=protected-access
1148 unknown_field_set = self._unknown_field_set
1149 # pylint: disable=protected-access
1150 (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1151 field_number, wire_type = wire_format.UnpackTag(tag)
1152 if field_number == 0:
1153 raise message_mod.DecodeError('Field number 0 is illegal.')
1154 # TODO(jieluo): remove old_pos.
1155 old_pos = new_pos
1156 (data, new_pos) = decoder._DecodeUnknownField(
1157 buffer, new_pos, wire_type) # pylint: disable=protected-access
1158 if new_pos == -1:
1159 return pos
1160 # pylint: disable=protected-access
1161 unknown_field_set._add(field_number, wire_type, data)
1162 # TODO(jieluo): remove _unknown_fields.
1163 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1164 if new_pos == -1:
1165 return pos
1166 self._unknown_fields.append(
1167 (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1168 pos = new_pos
1169 else:
1170 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1171 if field_desc:
1172 self._UpdateOneofState(field_desc)
1173 return pos
1174 cls._InternalParse = InternalParse
1177def _AddIsInitializedMethod(message_descriptor, cls):
1178 """Adds the IsInitialized and FindInitializationError methods to the
1179 protocol message class."""
1181 required_fields = [field for field in message_descriptor.fields
1182 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1184 def IsInitialized(self, errors=None):
1185 """Checks if all required fields of a message are set.
1187 Args:
1188 errors: A list which, if provided, will be populated with the field
1189 paths of all missing required fields.
1191 Returns:
1192 True iff the specified message has all required fields set.
1193 """
1195 # Performance is critical so we avoid HasField() and ListFields().
1197 for field in required_fields:
1198 if (field not in self._fields or
1199 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1200 not self._fields[field]._is_present_in_parent)):
1201 if errors is not None:
1202 errors.extend(self.FindInitializationErrors())
1203 return False
1205 for field, value in list(self._fields.items()): # dict can change size!
1206 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1207 if field.label == _FieldDescriptor.LABEL_REPEATED:
1208 if (field.message_type.has_options and
1209 field.message_type.GetOptions().map_entry):
1210 continue
1211 for element in value:
1212 if not element.IsInitialized():
1213 if errors is not None:
1214 errors.extend(self.FindInitializationErrors())
1215 return False
1216 elif value._is_present_in_parent and not value.IsInitialized():
1217 if errors is not None:
1218 errors.extend(self.FindInitializationErrors())
1219 return False
1221 return True
1223 cls.IsInitialized = IsInitialized
1225 def FindInitializationErrors(self):
1226 """Finds required fields which are not initialized.
1228 Returns:
1229 A list of strings. Each string is a path to an uninitialized field from
1230 the top-level message, e.g. "foo.bar[5].baz".
1231 """
1233 errors = [] # simplify things
1235 for field in required_fields:
1236 if not self.HasField(field.name):
1237 errors.append(field.name)
1239 for field, value in self.ListFields():
1240 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1241 if field.is_extension:
1242 name = '(%s)' % field.full_name
1243 else:
1244 name = field.name
1246 if _IsMapField(field):
1247 if _IsMessageMapField(field):
1248 for key in value:
1249 element = value[key]
1250 prefix = '%s[%s].' % (name, key)
1251 sub_errors = element.FindInitializationErrors()
1252 errors += [prefix + error for error in sub_errors]
1253 else:
1254 # ScalarMaps can't have any initialization errors.
1255 pass
1256 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1257 for i in range(len(value)):
1258 element = value[i]
1259 prefix = '%s[%d].' % (name, i)
1260 sub_errors = element.FindInitializationErrors()
1261 errors += [prefix + error for error in sub_errors]
1262 else:
1263 prefix = name + '.'
1264 sub_errors = value.FindInitializationErrors()
1265 errors += [prefix + error for error in sub_errors]
1267 return errors
1269 cls.FindInitializationErrors = FindInitializationErrors
1272def _FullyQualifiedClassName(klass):
1273 module = klass.__module__
1274 name = getattr(klass, '__qualname__', klass.__name__)
1275 if module in (None, 'builtins', '__builtin__'):
1276 return name
1277 return module + '.' + name
1280def _AddMergeFromMethod(cls):
1281 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1282 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1284 def MergeFrom(self, msg):
1285 if not isinstance(msg, cls):
1286 raise TypeError(
1287 'Parameter to MergeFrom() must be instance of same class: '
1288 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1289 _FullyQualifiedClassName(msg.__class__)))
1291 assert msg is not self
1292 self._Modified()
1294 fields = self._fields
1296 for field, value in msg._fields.items():
1297 if field.label == LABEL_REPEATED:
1298 field_value = fields.get(field)
1299 if field_value is None:
1300 # Construct a new object to represent this field.
1301 field_value = field._default_constructor(self)
1302 fields[field] = field_value
1303 field_value.MergeFrom(value)
1304 elif field.cpp_type == CPPTYPE_MESSAGE:
1305 if value._is_present_in_parent:
1306 field_value = fields.get(field)
1307 if field_value is None:
1308 # Construct a new object to represent this field.
1309 field_value = field._default_constructor(self)
1310 fields[field] = field_value
1311 field_value.MergeFrom(value)
1312 else:
1313 self._fields[field] = value
1314 if field.containing_oneof:
1315 self._UpdateOneofState(field)
1317 if msg._unknown_fields:
1318 if not self._unknown_fields:
1319 self._unknown_fields = []
1320 self._unknown_fields.extend(msg._unknown_fields)
1321 # pylint: disable=protected-access
1322 if self._unknown_field_set is None:
1323 self._unknown_field_set = containers.UnknownFieldSet()
1324 self._unknown_field_set._extend(msg._unknown_field_set)
1326 cls.MergeFrom = MergeFrom
1329def _AddWhichOneofMethod(message_descriptor, cls):
1330 def WhichOneof(self, oneof_name):
1331 """Returns the name of the currently set field inside a oneof, or None."""
1332 try:
1333 field = message_descriptor.oneofs_by_name[oneof_name]
1334 except KeyError:
1335 raise ValueError(
1336 'Protocol message has no oneof "%s" field.' % oneof_name)
1338 nested_field = self._oneofs.get(field, None)
1339 if nested_field is not None and self.HasField(nested_field.name):
1340 return nested_field.name
1341 else:
1342 return None
1344 cls.WhichOneof = WhichOneof
1347def _Clear(self):
1348 # Clear fields.
1349 self._fields = {}
1350 self._unknown_fields = ()
1351 # pylint: disable=protected-access
1352 if self._unknown_field_set is not None:
1353 self._unknown_field_set._clear()
1354 self._unknown_field_set = None
1356 self._oneofs = {}
1357 self._Modified()
1360def _UnknownFields(self):
1361 if self._unknown_field_set is None: # pylint: disable=protected-access
1362 # pylint: disable=protected-access
1363 self._unknown_field_set = containers.UnknownFieldSet()
1364 return self._unknown_field_set # pylint: disable=protected-access
1367def _DiscardUnknownFields(self):
1368 self._unknown_fields = []
1369 self._unknown_field_set = None # pylint: disable=protected-access
1370 for field, value in self.ListFields():
1371 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1372 if _IsMapField(field):
1373 if _IsMessageMapField(field):
1374 for key in value:
1375 value[key].DiscardUnknownFields()
1376 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1377 for sub_message in value:
1378 sub_message.DiscardUnknownFields()
1379 else:
1380 value.DiscardUnknownFields()
1383def _SetListener(self, listener):
1384 if listener is None:
1385 self._listener = message_listener_mod.NullMessageListener()
1386 else:
1387 self._listener = listener
1390def _AddMessageMethods(message_descriptor, cls):
1391 """Adds implementations of all Message methods to cls."""
1392 _AddListFieldsMethod(message_descriptor, cls)
1393 _AddHasFieldMethod(message_descriptor, cls)
1394 _AddClearFieldMethod(message_descriptor, cls)
1395 if message_descriptor.is_extendable:
1396 _AddClearExtensionMethod(cls)
1397 _AddHasExtensionMethod(cls)
1398 _AddEqualsMethod(message_descriptor, cls)
1399 _AddStrMethod(message_descriptor, cls)
1400 _AddReprMethod(message_descriptor, cls)
1401 _AddUnicodeMethod(message_descriptor, cls)
1402 _AddByteSizeMethod(message_descriptor, cls)
1403 _AddSerializeToStringMethod(message_descriptor, cls)
1404 _AddSerializePartialToStringMethod(message_descriptor, cls)
1405 _AddMergeFromStringMethod(message_descriptor, cls)
1406 _AddIsInitializedMethod(message_descriptor, cls)
1407 _AddMergeFromMethod(cls)
1408 _AddWhichOneofMethod(message_descriptor, cls)
1409 # Adds methods which do not depend on cls.
1410 cls.Clear = _Clear
1411 cls.UnknownFields = _UnknownFields
1412 cls.DiscardUnknownFields = _DiscardUnknownFields
1413 cls._SetListener = _SetListener
1416def _AddPrivateHelperMethods(message_descriptor, cls):
1417 """Adds implementation of private helper methods to cls."""
1419 def Modified(self):
1420 """Sets the _cached_byte_size_dirty bit to true,
1421 and propagates this to our listener iff this was a state change.
1422 """
1424 # Note: Some callers check _cached_byte_size_dirty before calling
1425 # _Modified() as an extra optimization. So, if this method is ever
1426 # changed such that it does stuff even when _cached_byte_size_dirty is
1427 # already true, the callers need to be updated.
1428 if not self._cached_byte_size_dirty:
1429 self._cached_byte_size_dirty = True
1430 self._listener_for_children.dirty = True
1431 self._is_present_in_parent = True
1432 self._listener.Modified()
1434 def _UpdateOneofState(self, field):
1435 """Sets field as the active field in its containing oneof.
1437 Will also delete currently active field in the oneof, if it is different
1438 from the argument. Does not mark the message as modified.
1439 """
1440 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1441 if other_field is not field:
1442 del self._fields[other_field]
1443 self._oneofs[field.containing_oneof] = field
1445 cls._Modified = Modified
1446 cls.SetInParent = Modified
1447 cls._UpdateOneofState = _UpdateOneofState
1450class _Listener(object):
1452 """MessageListener implementation that a parent message registers with its
1453 child message.
1455 In order to support semantics like:
1457 foo.bar.baz.moo = 23
1458 assert foo.HasField('bar')
1460 ...child objects must have back references to their parents.
1461 This helper class is at the heart of this support.
1462 """
1464 def __init__(self, parent_message):
1465 """Args:
1466 parent_message: The message whose _Modified() method we should call when
1467 we receive Modified() messages.
1468 """
1469 # This listener establishes a back reference from a child (contained) object
1470 # to its parent (containing) object. We make this a weak reference to avoid
1471 # creating cyclic garbage when the client finishes with the 'parent' object
1472 # in the tree.
1473 if isinstance(parent_message, weakref.ProxyType):
1474 self._parent_message_weakref = parent_message
1475 else:
1476 self._parent_message_weakref = weakref.proxy(parent_message)
1478 # As an optimization, we also indicate directly on the listener whether
1479 # or not the parent message is dirty. This way we can avoid traversing
1480 # up the tree in the common case.
1481 self.dirty = False
1483 def Modified(self):
1484 if self.dirty:
1485 return
1486 try:
1487 # Propagate the signal to our parents iff this is the first field set.
1488 self._parent_message_weakref._Modified()
1489 except ReferenceError:
1490 # We can get here if a client has kept a reference to a child object,
1491 # and is now setting a field on it, but the child's parent has been
1492 # garbage-collected. This is not an error.
1493 pass
1496class _OneofListener(_Listener):
1497 """Special listener implementation for setting composite oneof fields."""
1499 def __init__(self, parent_message, field):
1500 """Args:
1501 parent_message: The message whose _Modified() method we should call when
1502 we receive Modified() messages.
1503 field: The descriptor of the field being set in the parent message.
1504 """
1505 super(_OneofListener, self).__init__(parent_message)
1506 self._field = field
1508 def Modified(self):
1509 """Also updates the state of the containing oneof in the parent message."""
1510 try:
1511 self._parent_message_weakref._UpdateOneofState(self._field)
1512 super(_OneofListener, self).Modified()
1513 except ReferenceError:
1514 pass