Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 11%
736 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:37 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:37 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
51__author__ = 'robinson@google.com (Will Robinson)'
53from io import BytesIO
54import struct
55import sys
56import weakref
58# We use "as" to avoid name collisions with variables.
59from google.protobuf.internal import api_implementation
60from google.protobuf.internal import containers
61from google.protobuf.internal import decoder
62from google.protobuf.internal import encoder
63from google.protobuf.internal import enum_type_wrapper
64from google.protobuf.internal import extension_dict
65from google.protobuf.internal import message_listener as message_listener_mod
66from google.protobuf.internal import type_checkers
67from google.protobuf.internal import well_known_types
68from google.protobuf.internal import wire_format
69from google.protobuf import descriptor as descriptor_mod
70from google.protobuf import message as message_mod
71from google.protobuf import text_format
73_FieldDescriptor = descriptor_mod.FieldDescriptor
74_AnyFullTypeName = 'google.protobuf.Any'
75_ExtensionDict = extension_dict._ExtensionDict
77class GeneratedProtocolMessageType(type):
79 """Metaclass for protocol message classes created at runtime from Descriptors.
81 We add implementations for all methods described in the Message class. We
82 also create properties to allow getting/setting all fields in the protocol
83 message. Finally, we create slots to prevent users from accidentally
84 "setting" nonexistent fields in the protocol message, which then wouldn't get
85 serialized / deserialized properly.
87 The protocol compiler currently uses this metaclass to create protocol
88 message classes at runtime. Clients can also manually create their own
89 classes at runtime, as in this example:
91 mydescriptor = Descriptor(.....)
92 factory = symbol_database.Default()
93 factory.pool.AddDescriptor(mydescriptor)
94 MyProtoClass = factory.GetPrototype(mydescriptor)
95 myproto_instance = MyProtoClass()
96 myproto.foo_field = 23
97 ...
98 """
100 # Must be consistent with the protocol-compiler code in
101 # proto2/compiler/internal/generator.*.
102 _DESCRIPTOR_KEY = 'DESCRIPTOR'
104 def __new__(cls, name, bases, dictionary):
105 """Custom allocation for runtime-generated class types.
107 We override __new__ because this is apparently the only place
108 where we can meaningfully set __slots__ on the class we're creating(?).
109 (The interplay between metaclasses and slots is not very well-documented).
111 Args:
112 name: Name of the class (ignored, but required by the
113 metaclass protocol).
114 bases: Base classes of the class we're constructing.
115 (Should be message.Message). We ignore this field, but
116 it's required by the metaclass protocol
117 dictionary: The class dictionary of the class we're
118 constructing. dictionary[_DESCRIPTOR_KEY] must contain
119 a Descriptor object describing this protocol message
120 type.
122 Returns:
123 Newly-allocated class.
125 Raises:
126 RuntimeError: Generated code only work with python cpp extension.
127 """
128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
130 if isinstance(descriptor, str):
131 raise RuntimeError('The generated code only work with python cpp '
132 'extension, but it is using pure python runtime.')
134 # If a concrete class already exists for this descriptor, don't try to
135 # create another. Doing so will break any messages that already exist with
136 # the existing class.
137 #
138 # The C++ implementation appears to have its own internal `PyMessageFactory`
139 # to achieve similar results.
140 #
141 # This most commonly happens in `text_format.py` when using descriptors from
142 # a custom pool; it calls symbol_database.Global().getPrototype() on a
143 # descriptor which already has an existing concrete class.
144 new_class = getattr(descriptor, '_concrete_class', None)
145 if new_class:
146 return new_class
148 if descriptor.full_name in well_known_types.WKTBASES:
149 bases += (well_known_types.WKTBASES[descriptor.full_name],)
150 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
151 _AddSlots(descriptor, dictionary)
153 superclass = super(GeneratedProtocolMessageType, cls)
154 new_class = superclass.__new__(cls, name, bases, dictionary)
155 return new_class
157 def __init__(cls, name, bases, dictionary):
158 """Here we perform the majority of our work on the class.
159 We add enum getters, an __init__ method, implementations
160 of all Message methods, and properties for all fields
161 in the protocol type.
163 Args:
164 name: Name of the class (ignored, but required by the
165 metaclass protocol).
166 bases: Base classes of the class we're constructing.
167 (Should be message.Message). We ignore this field, but
168 it's required by the metaclass protocol
169 dictionary: The class dictionary of the class we're
170 constructing. dictionary[_DESCRIPTOR_KEY] must contain
171 a Descriptor object describing this protocol message
172 type.
173 """
174 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
176 # If this is an _existing_ class looked up via `_concrete_class` in the
177 # __new__ method above, then we don't need to re-initialize anything.
178 existing_class = getattr(descriptor, '_concrete_class', None)
179 if existing_class:
180 assert existing_class is cls, (
181 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
182 % (descriptor.full_name))
183 return
185 cls._decoders_by_tag = {}
186 if (descriptor.has_options and
187 descriptor.GetOptions().message_set_wire_format):
188 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
189 decoder.MessageSetItemDecoder(descriptor), None)
191 # Attach stuff to each FieldDescriptor for quick lookup later on.
192 for field in descriptor.fields:
193 _AttachFieldHelpers(cls, field)
195 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
196 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
197 for ext in extensions:
198 _AttachFieldHelpers(cls, ext)
200 descriptor._concrete_class = cls # pylint: disable=protected-access
201 _AddEnumValues(descriptor, cls)
202 _AddInitMethod(descriptor, cls)
203 _AddPropertiesForFields(descriptor, cls)
204 _AddPropertiesForExtensions(descriptor, cls)
205 _AddStaticMethods(cls)
206 _AddMessageMethods(descriptor, cls)
207 _AddPrivateHelperMethods(descriptor, cls)
209 superclass = super(GeneratedProtocolMessageType, cls)
210 superclass.__init__(name, bases, dictionary)
213# Stateless helpers for GeneratedProtocolMessageType below.
214# Outside clients should not access these directly.
215#
216# I opted not to make any of these methods on the metaclass, to make it more
217# clear that I'm not really using any state there and to keep clients from
218# thinking that they have direct access to these construction helpers.
221def _PropertyName(proto_field_name):
222 """Returns the name of the public property attribute which
223 clients can use to get and (in some cases) set the value
224 of a protocol message field.
226 Args:
227 proto_field_name: The protocol message field name, exactly
228 as it appears (or would appear) in a .proto file.
229 """
230 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
231 # nnorwitz makes my day by writing:
232 # """
233 # FYI. See the keyword module in the stdlib. This could be as simple as:
234 #
235 # if keyword.iskeyword(proto_field_name):
236 # return proto_field_name + "_"
237 # return proto_field_name
238 # """
239 # Kenton says: The above is a BAD IDEA. People rely on being able to use
240 # getattr() and setattr() to reflectively manipulate field values. If we
241 # rename the properties, then every such user has to also make sure to apply
242 # the same transformation. Note that currently if you name a field "yield",
243 # you can still access it just fine using getattr/setattr -- it's not even
244 # that cumbersome to do so.
245 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
246 # position.
247 return proto_field_name
250def _AddSlots(message_descriptor, dictionary):
251 """Adds a __slots__ entry to dictionary, containing the names of all valid
252 attributes for this message type.
254 Args:
255 message_descriptor: A Descriptor instance describing this message type.
256 dictionary: Class dictionary to which we'll add a '__slots__' entry.
257 """
258 dictionary['__slots__'] = ['_cached_byte_size',
259 '_cached_byte_size_dirty',
260 '_fields',
261 '_unknown_fields',
262 '_unknown_field_set',
263 '_is_present_in_parent',
264 '_listener',
265 '_listener_for_children',
266 '__weakref__',
267 '_oneofs']
270def _IsMessageSetExtension(field):
271 return (field.is_extension and
272 field.containing_type.has_options and
273 field.containing_type.GetOptions().message_set_wire_format and
274 field.type == _FieldDescriptor.TYPE_MESSAGE and
275 field.label == _FieldDescriptor.LABEL_OPTIONAL)
278def _IsMapField(field):
279 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
280 field.message_type.has_options and
281 field.message_type.GetOptions().map_entry)
284def _IsMessageMapField(field):
285 value_type = field.message_type.fields_by_name['value']
286 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
289def _AttachFieldHelpers(cls, field_descriptor):
290 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
291 is_map_entry = _IsMapField(field_descriptor)
292 is_packed = field_descriptor.is_packed
294 if is_map_entry:
295 field_encoder = encoder.MapEncoder(field_descriptor)
296 sizer = encoder.MapSizer(field_descriptor,
297 _IsMessageMapField(field_descriptor))
298 elif _IsMessageSetExtension(field_descriptor):
299 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
300 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
301 else:
302 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
303 field_descriptor.number, is_repeated, is_packed)
304 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
305 field_descriptor.number, is_repeated, is_packed)
307 field_descriptor._encoder = field_encoder
308 field_descriptor._sizer = sizer
309 field_descriptor._default_constructor = _DefaultValueConstructorForField(
310 field_descriptor)
312 def AddDecoder(wiretype, is_packed):
313 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
314 decode_type = field_descriptor.type
315 if (decode_type == _FieldDescriptor.TYPE_ENUM and
316 not field_descriptor.enum_type.is_closed):
317 decode_type = _FieldDescriptor.TYPE_INT32
319 oneof_descriptor = None
320 if field_descriptor.containing_oneof is not None:
321 oneof_descriptor = field_descriptor
323 if is_map_entry:
324 is_message_map = _IsMessageMapField(field_descriptor)
326 field_decoder = decoder.MapDecoder(
327 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
328 is_message_map)
329 elif decode_type == _FieldDescriptor.TYPE_STRING:
330 field_decoder = decoder.StringDecoder(
331 field_descriptor.number, is_repeated, is_packed,
332 field_descriptor, field_descriptor._default_constructor,
333 not field_descriptor.has_presence)
334 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
335 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
336 field_descriptor.number, is_repeated, is_packed,
337 field_descriptor, field_descriptor._default_constructor)
338 else:
339 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
340 field_descriptor.number, is_repeated, is_packed,
341 # pylint: disable=protected-access
342 field_descriptor, field_descriptor._default_constructor,
343 not field_descriptor.has_presence)
345 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
347 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
348 False)
350 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
351 # To support wire compatibility of adding packed = true, add a decoder for
352 # packed values regardless of the field's options.
353 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
356def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
357 extensions = descriptor.extensions_by_name
358 for extension_name, extension_field in extensions.items():
359 assert extension_name not in dictionary
360 dictionary[extension_name] = extension_field
363def _AddEnumValues(descriptor, cls):
364 """Sets class-level attributes for all enum fields defined in this message.
366 Also exporting a class-level object that can name enum values.
368 Args:
369 descriptor: Descriptor object for this message type.
370 cls: Class we're constructing for this message type.
371 """
372 for enum_type in descriptor.enum_types:
373 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
374 for enum_value in enum_type.values:
375 setattr(cls, enum_value.name, enum_value.number)
378def _GetInitializeDefaultForMap(field):
379 if field.label != _FieldDescriptor.LABEL_REPEATED:
380 raise ValueError('map_entry set on non-repeated field %s' % (
381 field.name))
382 fields_by_name = field.message_type.fields_by_name
383 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
385 value_field = fields_by_name['value']
386 if _IsMessageMapField(field):
387 def MakeMessageMapDefault(message):
388 return containers.MessageMap(
389 message._listener_for_children, value_field.message_type, key_checker,
390 field.message_type)
391 return MakeMessageMapDefault
392 else:
393 value_checker = type_checkers.GetTypeChecker(value_field)
394 def MakePrimitiveMapDefault(message):
395 return containers.ScalarMap(
396 message._listener_for_children, key_checker, value_checker,
397 field.message_type)
398 return MakePrimitiveMapDefault
400def _DefaultValueConstructorForField(field):
401 """Returns a function which returns a default value for a field.
403 Args:
404 field: FieldDescriptor object for this field.
406 The returned function has one argument:
407 message: Message instance containing this field, or a weakref proxy
408 of same.
410 That function in turn returns a default value for this field. The default
411 value may refer back to |message| via a weak reference.
412 """
414 if _IsMapField(field):
415 return _GetInitializeDefaultForMap(field)
417 if field.label == _FieldDescriptor.LABEL_REPEATED:
418 if field.has_default_value and field.default_value != []:
419 raise ValueError('Repeated field default value not empty list: %s' % (
420 field.default_value))
421 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
422 # We can't look at _concrete_class yet since it might not have
423 # been set. (Depends on order in which we initialize the classes).
424 message_type = field.message_type
425 def MakeRepeatedMessageDefault(message):
426 return containers.RepeatedCompositeFieldContainer(
427 message._listener_for_children, field.message_type)
428 return MakeRepeatedMessageDefault
429 else:
430 type_checker = type_checkers.GetTypeChecker(field)
431 def MakeRepeatedScalarDefault(message):
432 return containers.RepeatedScalarFieldContainer(
433 message._listener_for_children, type_checker)
434 return MakeRepeatedScalarDefault
436 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
437 message_type = field.message_type
438 def MakeSubMessageDefault(message):
439 # _concrete_class may not yet be initialized.
440 if not hasattr(message_type, '_concrete_class'):
441 from google.protobuf import message_factory
442 message_factory.GetMessageClass(message_type)
443 result = message_type._concrete_class()
444 result._SetListener(
445 _OneofListener(message, field)
446 if field.containing_oneof is not None
447 else message._listener_for_children)
448 return result
449 return MakeSubMessageDefault
451 def MakeScalarDefault(message):
452 # TODO(protobuf-team): This may be broken since there may not be
453 # default_value. Combine with has_default_value somehow.
454 return field.default_value
455 return MakeScalarDefault
458def _ReraiseTypeErrorWithFieldName(message_name, field_name):
459 """Re-raise the currently-handled TypeError with the field name added."""
460 exc = sys.exc_info()[1]
461 if len(exc.args) == 1 and type(exc) is TypeError:
462 # simple TypeError; add field name to exception message
463 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
465 # re-raise possibly-amended exception with original traceback:
466 raise exc.with_traceback(sys.exc_info()[2])
469def _AddInitMethod(message_descriptor, cls):
470 """Adds an __init__ method to cls."""
472 def _GetIntegerEnumValue(enum_type, value):
473 """Convert a string or integer enum value to an integer.
475 If the value is a string, it is converted to the enum value in
476 enum_type with the same name. If the value is not a string, it's
477 returned as-is. (No conversion or bounds-checking is done.)
478 """
479 if isinstance(value, str):
480 try:
481 return enum_type.values_by_name[value].number
482 except KeyError:
483 raise ValueError('Enum type %s: unknown label "%s"' % (
484 enum_type.full_name, value))
485 return value
487 def init(self, **kwargs):
488 self._cached_byte_size = 0
489 self._cached_byte_size_dirty = len(kwargs) > 0
490 self._fields = {}
491 # Contains a mapping from oneof field descriptors to the descriptor
492 # of the currently set field in that oneof field.
493 self._oneofs = {}
495 # _unknown_fields is () when empty for efficiency, and will be turned into
496 # a list if fields are added.
497 self._unknown_fields = ()
498 # _unknown_field_set is None when empty for efficiency, and will be
499 # turned into UnknownFieldSet struct if fields are added.
500 self._unknown_field_set = None # pylint: disable=protected-access
501 self._is_present_in_parent = False
502 self._listener = message_listener_mod.NullMessageListener()
503 self._listener_for_children = _Listener(self)
504 for field_name, field_value in kwargs.items():
505 field = _GetFieldByName(message_descriptor, field_name)
506 if field is None:
507 raise TypeError('%s() got an unexpected keyword argument "%s"' %
508 (message_descriptor.name, field_name))
509 if field_value is None:
510 # field=None is the same as no field at all.
511 continue
512 if field.label == _FieldDescriptor.LABEL_REPEATED:
513 copy = field._default_constructor(self)
514 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
515 if _IsMapField(field):
516 if _IsMessageMapField(field):
517 for key in field_value:
518 copy[key].MergeFrom(field_value[key])
519 else:
520 copy.update(field_value)
521 else:
522 for val in field_value:
523 if isinstance(val, dict):
524 copy.add(**val)
525 else:
526 copy.add().MergeFrom(val)
527 else: # Scalar
528 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
529 field_value = [_GetIntegerEnumValue(field.enum_type, val)
530 for val in field_value]
531 copy.extend(field_value)
532 self._fields[field] = copy
533 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
534 copy = field._default_constructor(self)
535 new_val = field_value
536 if isinstance(field_value, dict):
537 new_val = field.message_type._concrete_class(**field_value)
538 try:
539 copy.MergeFrom(new_val)
540 except TypeError:
541 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
542 self._fields[field] = copy
543 else:
544 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
545 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
546 try:
547 setattr(self, field_name, field_value)
548 except TypeError:
549 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
551 init.__module__ = None
552 init.__doc__ = None
553 cls.__init__ = init
556def _GetFieldByName(message_descriptor, field_name):
557 """Returns a field descriptor by field name.
559 Args:
560 message_descriptor: A Descriptor describing all fields in message.
561 field_name: The name of the field to retrieve.
562 Returns:
563 The field descriptor associated with the field name.
564 """
565 try:
566 return message_descriptor.fields_by_name[field_name]
567 except KeyError:
568 raise ValueError('Protocol message %s has no "%s" field.' %
569 (message_descriptor.name, field_name))
572def _AddPropertiesForFields(descriptor, cls):
573 """Adds properties for all fields in this protocol message type."""
574 for field in descriptor.fields:
575 _AddPropertiesForField(field, cls)
577 if descriptor.is_extendable:
578 # _ExtensionDict is just an adaptor with no state so we allocate a new one
579 # every time it is accessed.
580 cls.Extensions = property(lambda self: _ExtensionDict(self))
583def _AddPropertiesForField(field, cls):
584 """Adds a public property for a protocol message field.
585 Clients can use this property to get and (in the case
586 of non-repeated scalar fields) directly set the value
587 of a protocol message field.
589 Args:
590 field: A FieldDescriptor for this field.
591 cls: The class we're constructing.
592 """
593 # Catch it if we add other types that we should
594 # handle specially here.
595 assert _FieldDescriptor.MAX_CPPTYPE == 10
597 constant_name = field.name.upper() + '_FIELD_NUMBER'
598 setattr(cls, constant_name, field.number)
600 if field.label == _FieldDescriptor.LABEL_REPEATED:
601 _AddPropertiesForRepeatedField(field, cls)
602 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
603 _AddPropertiesForNonRepeatedCompositeField(field, cls)
604 else:
605 _AddPropertiesForNonRepeatedScalarField(field, cls)
608class _FieldProperty(property):
609 __slots__ = ('DESCRIPTOR',)
611 def __init__(self, descriptor, getter, setter, doc):
612 property.__init__(self, getter, setter, doc=doc)
613 self.DESCRIPTOR = descriptor
616def _AddPropertiesForRepeatedField(field, cls):
617 """Adds a public property for a "repeated" protocol message field. Clients
618 can use this property to get the value of the field, which will be either a
619 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
620 below).
622 Note that when clients add values to these containers, we perform
623 type-checking in the case of repeated scalar fields, and we also set any
624 necessary "has" bits as a side-effect.
626 Args:
627 field: A FieldDescriptor for this field.
628 cls: The class we're constructing.
629 """
630 proto_field_name = field.name
631 property_name = _PropertyName(proto_field_name)
633 def getter(self):
634 field_value = self._fields.get(field)
635 if field_value is None:
636 # Construct a new object to represent this field.
637 field_value = field._default_constructor(self)
639 # Atomically check if another thread has preempted us and, if not, swap
640 # in the new object we just created. If someone has preempted us, we
641 # take that object and discard ours.
642 # WARNING: We are relying on setdefault() being atomic. This is true
643 # in CPython but we haven't investigated others. This warning appears
644 # in several other locations in this file.
645 field_value = self._fields.setdefault(field, field_value)
646 return field_value
647 getter.__module__ = None
648 getter.__doc__ = 'Getter for %s.' % proto_field_name
650 # We define a setter just so we can throw an exception with a more
651 # helpful error message.
652 def setter(self, new_value):
653 raise AttributeError('Assignment not allowed to repeated field '
654 '"%s" in protocol message object.' % proto_field_name)
656 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
657 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
660def _AddPropertiesForNonRepeatedScalarField(field, cls):
661 """Adds a public property for a nonrepeated, scalar protocol message field.
662 Clients can use this property to get and directly set the value of the field.
663 Note that when the client sets the value of a field by using this property,
664 all necessary "has" bits are set as a side-effect, and we also perform
665 type-checking.
667 Args:
668 field: A FieldDescriptor for this field.
669 cls: The class we're constructing.
670 """
671 proto_field_name = field.name
672 property_name = _PropertyName(proto_field_name)
673 type_checker = type_checkers.GetTypeChecker(field)
674 default_value = field.default_value
676 def getter(self):
677 # TODO(protobuf-team): This may be broken since there may not be
678 # default_value. Combine with has_default_value somehow.
679 return self._fields.get(field, default_value)
680 getter.__module__ = None
681 getter.__doc__ = 'Getter for %s.' % proto_field_name
683 def field_setter(self, new_value):
684 # pylint: disable=protected-access
685 # Testing the value for truthiness captures all of the proto3 defaults
686 # (0, 0.0, enum 0, and False).
687 try:
688 new_value = type_checker.CheckValue(new_value)
689 except TypeError as e:
690 raise TypeError(
691 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
692 if not field.has_presence and not new_value:
693 self._fields.pop(field, None)
694 else:
695 self._fields[field] = new_value
696 # Check _cached_byte_size_dirty inline to improve performance, since scalar
697 # setters are called frequently.
698 if not self._cached_byte_size_dirty:
699 self._Modified()
701 if field.containing_oneof:
702 def setter(self, new_value):
703 field_setter(self, new_value)
704 self._UpdateOneofState(field)
705 else:
706 setter = field_setter
708 setter.__module__ = None
709 setter.__doc__ = 'Setter for %s.' % proto_field_name
711 # Add a property to encapsulate the getter/setter.
712 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
713 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
716def _AddPropertiesForNonRepeatedCompositeField(field, cls):
717 """Adds a public property for a nonrepeated, composite protocol message field.
718 A composite field is a "group" or "message" field.
720 Clients can use this property to get the value of the field, but cannot
721 assign to the property directly.
723 Args:
724 field: A FieldDescriptor for this field.
725 cls: The class we're constructing.
726 """
727 # TODO(robinson): Remove duplication with similar method
728 # for non-repeated scalars.
729 proto_field_name = field.name
730 property_name = _PropertyName(proto_field_name)
732 def getter(self):
733 field_value = self._fields.get(field)
734 if field_value is None:
735 # Construct a new object to represent this field.
736 field_value = field._default_constructor(self)
738 # Atomically check if another thread has preempted us and, if not, swap
739 # in the new object we just created. If someone has preempted us, we
740 # take that object and discard ours.
741 # WARNING: We are relying on setdefault() being atomic. This is true
742 # in CPython but we haven't investigated others. This warning appears
743 # in several other locations in this file.
744 field_value = self._fields.setdefault(field, field_value)
745 return field_value
746 getter.__module__ = None
747 getter.__doc__ = 'Getter for %s.' % proto_field_name
749 # We define a setter just so we can throw an exception with a more
750 # helpful error message.
751 def setter(self, new_value):
752 raise AttributeError('Assignment not allowed to composite field '
753 '"%s" in protocol message object.' % proto_field_name)
755 # Add a property to encapsulate the getter.
756 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
757 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
760def _AddPropertiesForExtensions(descriptor, cls):
761 """Adds properties for all fields in this protocol message type."""
762 extensions = descriptor.extensions_by_name
763 for extension_name, extension_field in extensions.items():
764 constant_name = extension_name.upper() + '_FIELD_NUMBER'
765 setattr(cls, constant_name, extension_field.number)
767 # TODO(amauryfa): Migrate all users of these attributes to functions like
768 # pool.FindExtensionByNumber(descriptor).
769 if descriptor.file is not None:
770 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
771 pool = descriptor.file.pool
773def _AddStaticMethods(cls):
774 # TODO(robinson): This probably needs to be thread-safe(?)
775 def RegisterExtension(field_descriptor):
776 field_descriptor.containing_type = cls.DESCRIPTOR
777 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
778 # pylint: disable=protected-access
779 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor)
780 _AttachFieldHelpers(cls, field_descriptor)
781 cls.RegisterExtension = staticmethod(RegisterExtension)
783 def FromString(s):
784 message = cls()
785 message.MergeFromString(s)
786 return message
787 cls.FromString = staticmethod(FromString)
790def _IsPresent(item):
791 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
792 value should be included in the list returned by ListFields()."""
794 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
795 return bool(item[1])
796 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
797 return item[1]._is_present_in_parent
798 else:
799 return True
802def _AddListFieldsMethod(message_descriptor, cls):
803 """Helper for _AddMessageMethods()."""
805 def ListFields(self):
806 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
807 all_fields.sort(key = lambda item: item[0].number)
808 return all_fields
810 cls.ListFields = ListFields
813def _AddHasFieldMethod(message_descriptor, cls):
814 """Helper for _AddMessageMethods()."""
816 hassable_fields = {}
817 for field in message_descriptor.fields:
818 if field.label == _FieldDescriptor.LABEL_REPEATED:
819 continue
820 # For proto3, only submessages and fields inside a oneof have presence.
821 if not field.has_presence:
822 continue
823 hassable_fields[field.name] = field
825 # Has methods are supported for oneof descriptors.
826 for oneof in message_descriptor.oneofs:
827 hassable_fields[oneof.name] = oneof
829 def HasField(self, field_name):
830 try:
831 field = hassable_fields[field_name]
832 except KeyError as exc:
833 raise ValueError('Protocol message %s has no non-repeated field "%s" '
834 'nor has presence is not available for this field.' % (
835 message_descriptor.full_name, field_name)) from exc
837 if isinstance(field, descriptor_mod.OneofDescriptor):
838 try:
839 return HasField(self, self._oneofs[field].name)
840 except KeyError:
841 return False
842 else:
843 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
844 value = self._fields.get(field)
845 return value is not None and value._is_present_in_parent
846 else:
847 return field in self._fields
849 cls.HasField = HasField
852def _AddClearFieldMethod(message_descriptor, cls):
853 """Helper for _AddMessageMethods()."""
854 def ClearField(self, field_name):
855 try:
856 field = message_descriptor.fields_by_name[field_name]
857 except KeyError:
858 try:
859 field = message_descriptor.oneofs_by_name[field_name]
860 if field in self._oneofs:
861 field = self._oneofs[field]
862 else:
863 return
864 except KeyError:
865 raise ValueError('Protocol message %s has no "%s" field.' %
866 (message_descriptor.name, field_name))
868 if field in self._fields:
869 # To match the C++ implementation, we need to invalidate iterators
870 # for map fields when ClearField() happens.
871 if hasattr(self._fields[field], 'InvalidateIterators'):
872 self._fields[field].InvalidateIterators()
874 # Note: If the field is a sub-message, its listener will still point
875 # at us. That's fine, because the worst than can happen is that it
876 # will call _Modified() and invalidate our byte size. Big deal.
877 del self._fields[field]
879 if self._oneofs.get(field.containing_oneof, None) is field:
880 del self._oneofs[field.containing_oneof]
882 # Always call _Modified() -- even if nothing was changed, this is
883 # a mutating method, and thus calling it should cause the field to become
884 # present in the parent message.
885 self._Modified()
887 cls.ClearField = ClearField
890def _AddClearExtensionMethod(cls):
891 """Helper for _AddMessageMethods()."""
892 def ClearExtension(self, field_descriptor):
893 extension_dict._VerifyExtensionHandle(self, field_descriptor)
895 # Similar to ClearField(), above.
896 if field_descriptor in self._fields:
897 del self._fields[field_descriptor]
898 self._Modified()
899 cls.ClearExtension = ClearExtension
902def _AddHasExtensionMethod(cls):
903 """Helper for _AddMessageMethods()."""
904 def HasExtension(self, field_descriptor):
905 extension_dict._VerifyExtensionHandle(self, field_descriptor)
906 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
907 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
909 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
910 value = self._fields.get(field_descriptor)
911 return value is not None and value._is_present_in_parent
912 else:
913 return field_descriptor in self._fields
914 cls.HasExtension = HasExtension
916def _InternalUnpackAny(msg):
917 """Unpacks Any message and returns the unpacked message.
919 This internal method is different from public Any Unpack method which takes
920 the target message as argument. _InternalUnpackAny method does not have
921 target message type and need to find the message type in descriptor pool.
923 Args:
924 msg: An Any message to be unpacked.
926 Returns:
927 The unpacked message.
928 """
929 # TODO(amauryfa): Don't use the factory of generated messages.
930 # To make Any work with custom factories, use the message factory of the
931 # parent message.
932 # pylint: disable=g-import-not-at-top
933 from google.protobuf import symbol_database
934 factory = symbol_database.Default()
936 type_url = msg.type_url
938 if not type_url:
939 return None
941 # TODO(haberman): For now we just strip the hostname. Better logic will be
942 # required.
943 type_name = type_url.split('/')[-1]
944 descriptor = factory.pool.FindMessageTypeByName(type_name)
946 if descriptor is None:
947 return None
949 message_class = factory.GetPrototype(descriptor)
950 message = message_class()
952 message.ParseFromString(msg.value)
953 return message
956def _AddEqualsMethod(message_descriptor, cls):
957 """Helper for _AddMessageMethods()."""
958 def __eq__(self, other):
959 if (not isinstance(other, message_mod.Message) or
960 other.DESCRIPTOR != self.DESCRIPTOR):
961 return False
963 if self is other:
964 return True
966 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
967 any_a = _InternalUnpackAny(self)
968 any_b = _InternalUnpackAny(other)
969 if any_a and any_b:
970 return any_a == any_b
972 if not self.ListFields() == other.ListFields():
973 return False
975 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
976 # then use it for the comparison.
977 unknown_fields = list(self._unknown_fields)
978 unknown_fields.sort()
979 other_unknown_fields = list(other._unknown_fields)
980 other_unknown_fields.sort()
981 return unknown_fields == other_unknown_fields
983 cls.__eq__ = __eq__
986def _AddStrMethod(message_descriptor, cls):
987 """Helper for _AddMessageMethods()."""
988 def __str__(self):
989 return text_format.MessageToString(self)
990 cls.__str__ = __str__
993def _AddReprMethod(message_descriptor, cls):
994 """Helper for _AddMessageMethods()."""
995 def __repr__(self):
996 return text_format.MessageToString(self)
997 cls.__repr__ = __repr__
1000def _AddUnicodeMethod(unused_message_descriptor, cls):
1001 """Helper for _AddMessageMethods()."""
1003 def __unicode__(self):
1004 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1005 cls.__unicode__ = __unicode__
1008def _BytesForNonRepeatedElement(value, field_number, field_type):
1009 """Returns the number of bytes needed to serialize a non-repeated element.
1010 The returned byte count includes space for tag information and any
1011 other additional space associated with serializing value.
1013 Args:
1014 value: Value we're serializing.
1015 field_number: Field number of this value. (Since the field number
1016 is stored as part of a varint-encoded tag, this has an impact
1017 on the total bytes required to serialize the value).
1018 field_type: The type of the field. One of the TYPE_* constants
1019 within FieldDescriptor.
1020 """
1021 try:
1022 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1023 return fn(field_number, value)
1024 except KeyError:
1025 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1028def _AddByteSizeMethod(message_descriptor, cls):
1029 """Helper for _AddMessageMethods()."""
1031 def ByteSize(self):
1032 if not self._cached_byte_size_dirty:
1033 return self._cached_byte_size
1035 size = 0
1036 descriptor = self.DESCRIPTOR
1037 if descriptor.GetOptions().map_entry:
1038 # Fields of map entry should always be serialized.
1039 size = descriptor.fields_by_name['key']._sizer(self.key)
1040 size += descriptor.fields_by_name['value']._sizer(self.value)
1041 else:
1042 for field_descriptor, field_value in self.ListFields():
1043 size += field_descriptor._sizer(field_value)
1044 for tag_bytes, value_bytes in self._unknown_fields:
1045 size += len(tag_bytes) + len(value_bytes)
1047 self._cached_byte_size = size
1048 self._cached_byte_size_dirty = False
1049 self._listener_for_children.dirty = False
1050 return size
1052 cls.ByteSize = ByteSize
1055def _AddSerializeToStringMethod(message_descriptor, cls):
1056 """Helper for _AddMessageMethods()."""
1058 def SerializeToString(self, **kwargs):
1059 # Check if the message has all of its required fields set.
1060 if not self.IsInitialized():
1061 raise message_mod.EncodeError(
1062 'Message %s is missing required fields: %s' % (
1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1064 return self.SerializePartialToString(**kwargs)
1065 cls.SerializeToString = SerializeToString
1068def _AddSerializePartialToStringMethod(message_descriptor, cls):
1069 """Helper for _AddMessageMethods()."""
1071 def SerializePartialToString(self, **kwargs):
1072 out = BytesIO()
1073 self._InternalSerialize(out.write, **kwargs)
1074 return out.getvalue()
1075 cls.SerializePartialToString = SerializePartialToString
1077 def InternalSerialize(self, write_bytes, deterministic=None):
1078 if deterministic is None:
1079 deterministic = (
1080 api_implementation.IsPythonDefaultSerializationDeterministic())
1081 else:
1082 deterministic = bool(deterministic)
1084 descriptor = self.DESCRIPTOR
1085 if descriptor.GetOptions().map_entry:
1086 # Fields of map entry should always be serialized.
1087 descriptor.fields_by_name['key']._encoder(
1088 write_bytes, self.key, deterministic)
1089 descriptor.fields_by_name['value']._encoder(
1090 write_bytes, self.value, deterministic)
1091 else:
1092 for field_descriptor, field_value in self.ListFields():
1093 field_descriptor._encoder(write_bytes, field_value, deterministic)
1094 for tag_bytes, value_bytes in self._unknown_fields:
1095 write_bytes(tag_bytes)
1096 write_bytes(value_bytes)
1097 cls._InternalSerialize = InternalSerialize
1100def _AddMergeFromStringMethod(message_descriptor, cls):
1101 """Helper for _AddMessageMethods()."""
1102 def MergeFromString(self, serialized):
1103 serialized = memoryview(serialized)
1104 length = len(serialized)
1105 try:
1106 if self._InternalParse(serialized, 0, length) != length:
1107 # The only reason _InternalParse would return early is if it
1108 # encountered an end-group tag.
1109 raise message_mod.DecodeError('Unexpected end-group tag.')
1110 except (IndexError, TypeError):
1111 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1112 raise message_mod.DecodeError('Truncated message.')
1113 except struct.error as e:
1114 raise message_mod.DecodeError(e)
1115 return length # Return this for legacy reasons.
1116 cls.MergeFromString = MergeFromString
1118 local_ReadTag = decoder.ReadTag
1119 local_SkipField = decoder.SkipField
1120 decoders_by_tag = cls._decoders_by_tag
1122 def InternalParse(self, buffer, pos, end):
1123 """Create a message from serialized bytes.
1125 Args:
1126 self: Message, instance of the proto message object.
1127 buffer: memoryview of the serialized data.
1128 pos: int, position to start in the serialized data.
1129 end: int, end position of the serialized data.
1131 Returns:
1132 Message object.
1133 """
1134 # Guard against internal misuse, since this function is called internally
1135 # quite extensively, and its easy to accidentally pass bytes.
1136 assert isinstance(buffer, memoryview)
1137 self._Modified()
1138 field_dict = self._fields
1139 # pylint: disable=protected-access
1140 unknown_field_set = self._unknown_field_set
1141 while pos != end:
1142 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1143 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1144 if field_decoder is None:
1145 if not self._unknown_fields: # pylint: disable=protected-access
1146 self._unknown_fields = [] # pylint: disable=protected-access
1147 if unknown_field_set is None:
1148 # pylint: disable=protected-access
1149 self._unknown_field_set = containers.UnknownFieldSet()
1150 # pylint: disable=protected-access
1151 unknown_field_set = self._unknown_field_set
1152 # pylint: disable=protected-access
1153 (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1154 field_number, wire_type = wire_format.UnpackTag(tag)
1155 if field_number == 0:
1156 raise message_mod.DecodeError('Field number 0 is illegal.')
1157 # TODO(jieluo): remove old_pos.
1158 old_pos = new_pos
1159 (data, new_pos) = decoder._DecodeUnknownField(
1160 buffer, new_pos, wire_type) # pylint: disable=protected-access
1161 if new_pos == -1:
1162 return pos
1163 # pylint: disable=protected-access
1164 unknown_field_set._add(field_number, wire_type, data)
1165 # TODO(jieluo): remove _unknown_fields.
1166 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1167 if new_pos == -1:
1168 return pos
1169 self._unknown_fields.append(
1170 (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1171 pos = new_pos
1172 else:
1173 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1174 if field_desc:
1175 self._UpdateOneofState(field_desc)
1176 return pos
1177 cls._InternalParse = InternalParse
1180def _AddIsInitializedMethod(message_descriptor, cls):
1181 """Adds the IsInitialized and FindInitializationError methods to the
1182 protocol message class."""
1184 required_fields = [field for field in message_descriptor.fields
1185 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1187 def IsInitialized(self, errors=None):
1188 """Checks if all required fields of a message are set.
1190 Args:
1191 errors: A list which, if provided, will be populated with the field
1192 paths of all missing required fields.
1194 Returns:
1195 True iff the specified message has all required fields set.
1196 """
1198 # Performance is critical so we avoid HasField() and ListFields().
1200 for field in required_fields:
1201 if (field not in self._fields or
1202 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1203 not self._fields[field]._is_present_in_parent)):
1204 if errors is not None:
1205 errors.extend(self.FindInitializationErrors())
1206 return False
1208 for field, value in list(self._fields.items()): # dict can change size!
1209 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1210 if field.label == _FieldDescriptor.LABEL_REPEATED:
1211 if (field.message_type.has_options and
1212 field.message_type.GetOptions().map_entry):
1213 continue
1214 for element in value:
1215 if not element.IsInitialized():
1216 if errors is not None:
1217 errors.extend(self.FindInitializationErrors())
1218 return False
1219 elif value._is_present_in_parent and not value.IsInitialized():
1220 if errors is not None:
1221 errors.extend(self.FindInitializationErrors())
1222 return False
1224 return True
1226 cls.IsInitialized = IsInitialized
1228 def FindInitializationErrors(self):
1229 """Finds required fields which are not initialized.
1231 Returns:
1232 A list of strings. Each string is a path to an uninitialized field from
1233 the top-level message, e.g. "foo.bar[5].baz".
1234 """
1236 errors = [] # simplify things
1238 for field in required_fields:
1239 if not self.HasField(field.name):
1240 errors.append(field.name)
1242 for field, value in self.ListFields():
1243 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1244 if field.is_extension:
1245 name = '(%s)' % field.full_name
1246 else:
1247 name = field.name
1249 if _IsMapField(field):
1250 if _IsMessageMapField(field):
1251 for key in value:
1252 element = value[key]
1253 prefix = '%s[%s].' % (name, key)
1254 sub_errors = element.FindInitializationErrors()
1255 errors += [prefix + error for error in sub_errors]
1256 else:
1257 # ScalarMaps can't have any initialization errors.
1258 pass
1259 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1260 for i in range(len(value)):
1261 element = value[i]
1262 prefix = '%s[%d].' % (name, i)
1263 sub_errors = element.FindInitializationErrors()
1264 errors += [prefix + error for error in sub_errors]
1265 else:
1266 prefix = name + '.'
1267 sub_errors = value.FindInitializationErrors()
1268 errors += [prefix + error for error in sub_errors]
1270 return errors
1272 cls.FindInitializationErrors = FindInitializationErrors
1275def _FullyQualifiedClassName(klass):
1276 module = klass.__module__
1277 name = getattr(klass, '__qualname__', klass.__name__)
1278 if module in (None, 'builtins', '__builtin__'):
1279 return name
1280 return module + '.' + name
1283def _AddMergeFromMethod(cls):
1284 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1285 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1287 def MergeFrom(self, msg):
1288 if not isinstance(msg, cls):
1289 raise TypeError(
1290 'Parameter to MergeFrom() must be instance of same class: '
1291 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1292 _FullyQualifiedClassName(msg.__class__)))
1294 assert msg is not self
1295 self._Modified()
1297 fields = self._fields
1299 for field, value in msg._fields.items():
1300 if field.label == LABEL_REPEATED:
1301 field_value = fields.get(field)
1302 if field_value is None:
1303 # Construct a new object to represent this field.
1304 field_value = field._default_constructor(self)
1305 fields[field] = field_value
1306 field_value.MergeFrom(value)
1307 elif field.cpp_type == CPPTYPE_MESSAGE:
1308 if value._is_present_in_parent:
1309 field_value = fields.get(field)
1310 if field_value is None:
1311 # Construct a new object to represent this field.
1312 field_value = field._default_constructor(self)
1313 fields[field] = field_value
1314 field_value.MergeFrom(value)
1315 else:
1316 self._fields[field] = value
1317 if field.containing_oneof:
1318 self._UpdateOneofState(field)
1320 if msg._unknown_fields:
1321 if not self._unknown_fields:
1322 self._unknown_fields = []
1323 self._unknown_fields.extend(msg._unknown_fields)
1324 # pylint: disable=protected-access
1325 if self._unknown_field_set is None:
1326 self._unknown_field_set = containers.UnknownFieldSet()
1327 self._unknown_field_set._extend(msg._unknown_field_set)
1329 cls.MergeFrom = MergeFrom
1332def _AddWhichOneofMethod(message_descriptor, cls):
1333 def WhichOneof(self, oneof_name):
1334 """Returns the name of the currently set field inside a oneof, or None."""
1335 try:
1336 field = message_descriptor.oneofs_by_name[oneof_name]
1337 except KeyError:
1338 raise ValueError(
1339 'Protocol message has no oneof "%s" field.' % oneof_name)
1341 nested_field = self._oneofs.get(field, None)
1342 if nested_field is not None and self.HasField(nested_field.name):
1343 return nested_field.name
1344 else:
1345 return None
1347 cls.WhichOneof = WhichOneof
1350def _Clear(self):
1351 # Clear fields.
1352 self._fields = {}
1353 self._unknown_fields = ()
1354 # pylint: disable=protected-access
1355 if self._unknown_field_set is not None:
1356 self._unknown_field_set._clear()
1357 self._unknown_field_set = None
1359 self._oneofs = {}
1360 self._Modified()
1363def _UnknownFields(self):
1364 if self._unknown_field_set is None: # pylint: disable=protected-access
1365 # pylint: disable=protected-access
1366 self._unknown_field_set = containers.UnknownFieldSet()
1367 return self._unknown_field_set # pylint: disable=protected-access
1370def _DiscardUnknownFields(self):
1371 self._unknown_fields = []
1372 self._unknown_field_set = None # pylint: disable=protected-access
1373 for field, value in self.ListFields():
1374 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1375 if _IsMapField(field):
1376 if _IsMessageMapField(field):
1377 for key in value:
1378 value[key].DiscardUnknownFields()
1379 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1380 for sub_message in value:
1381 sub_message.DiscardUnknownFields()
1382 else:
1383 value.DiscardUnknownFields()
1386def _SetListener(self, listener):
1387 if listener is None:
1388 self._listener = message_listener_mod.NullMessageListener()
1389 else:
1390 self._listener = listener
1393def _AddMessageMethods(message_descriptor, cls):
1394 """Adds implementations of all Message methods to cls."""
1395 _AddListFieldsMethod(message_descriptor, cls)
1396 _AddHasFieldMethod(message_descriptor, cls)
1397 _AddClearFieldMethod(message_descriptor, cls)
1398 if message_descriptor.is_extendable:
1399 _AddClearExtensionMethod(cls)
1400 _AddHasExtensionMethod(cls)
1401 _AddEqualsMethod(message_descriptor, cls)
1402 _AddStrMethod(message_descriptor, cls)
1403 _AddReprMethod(message_descriptor, cls)
1404 _AddUnicodeMethod(message_descriptor, cls)
1405 _AddByteSizeMethod(message_descriptor, cls)
1406 _AddSerializeToStringMethod(message_descriptor, cls)
1407 _AddSerializePartialToStringMethod(message_descriptor, cls)
1408 _AddMergeFromStringMethod(message_descriptor, cls)
1409 _AddIsInitializedMethod(message_descriptor, cls)
1410 _AddMergeFromMethod(cls)
1411 _AddWhichOneofMethod(message_descriptor, cls)
1412 # Adds methods which do not depend on cls.
1413 cls.Clear = _Clear
1414 cls.UnknownFields = _UnknownFields
1415 cls.DiscardUnknownFields = _DiscardUnknownFields
1416 cls._SetListener = _SetListener
1419def _AddPrivateHelperMethods(message_descriptor, cls):
1420 """Adds implementation of private helper methods to cls."""
1422 def Modified(self):
1423 """Sets the _cached_byte_size_dirty bit to true,
1424 and propagates this to our listener iff this was a state change.
1425 """
1427 # Note: Some callers check _cached_byte_size_dirty before calling
1428 # _Modified() as an extra optimization. So, if this method is ever
1429 # changed such that it does stuff even when _cached_byte_size_dirty is
1430 # already true, the callers need to be updated.
1431 if not self._cached_byte_size_dirty:
1432 self._cached_byte_size_dirty = True
1433 self._listener_for_children.dirty = True
1434 self._is_present_in_parent = True
1435 self._listener.Modified()
1437 def _UpdateOneofState(self, field):
1438 """Sets field as the active field in its containing oneof.
1440 Will also delete currently active field in the oneof, if it is different
1441 from the argument. Does not mark the message as modified.
1442 """
1443 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1444 if other_field is not field:
1445 del self._fields[other_field]
1446 self._oneofs[field.containing_oneof] = field
1448 cls._Modified = Modified
1449 cls.SetInParent = Modified
1450 cls._UpdateOneofState = _UpdateOneofState
1453class _Listener(object):
1455 """MessageListener implementation that a parent message registers with its
1456 child message.
1458 In order to support semantics like:
1460 foo.bar.baz.moo = 23
1461 assert foo.HasField('bar')
1463 ...child objects must have back references to their parents.
1464 This helper class is at the heart of this support.
1465 """
1467 def __init__(self, parent_message):
1468 """Args:
1469 parent_message: The message whose _Modified() method we should call when
1470 we receive Modified() messages.
1471 """
1472 # This listener establishes a back reference from a child (contained) object
1473 # to its parent (containing) object. We make this a weak reference to avoid
1474 # creating cyclic garbage when the client finishes with the 'parent' object
1475 # in the tree.
1476 if isinstance(parent_message, weakref.ProxyType):
1477 self._parent_message_weakref = parent_message
1478 else:
1479 self._parent_message_weakref = weakref.proxy(parent_message)
1481 # As an optimization, we also indicate directly on the listener whether
1482 # or not the parent message is dirty. This way we can avoid traversing
1483 # up the tree in the common case.
1484 self.dirty = False
1486 def Modified(self):
1487 if self.dirty:
1488 return
1489 try:
1490 # Propagate the signal to our parents iff this is the first field set.
1491 self._parent_message_weakref._Modified()
1492 except ReferenceError:
1493 # We can get here if a client has kept a reference to a child object,
1494 # and is now setting a field on it, but the child's parent has been
1495 # garbage-collected. This is not an error.
1496 pass
1499class _OneofListener(_Listener):
1500 """Special listener implementation for setting composite oneof fields."""
1502 def __init__(self, parent_message, field):
1503 """Args:
1504 parent_message: The message whose _Modified() method we should call when
1505 we receive Modified() messages.
1506 field: The descriptor of the field being set in the parent message.
1507 """
1508 super(_OneofListener, self).__init__(parent_message)
1509 self._field = field
1511 def Modified(self):
1512 """Also updates the state of the containing oneof in the parent message."""
1513 try:
1514 self._parent_message_weakref._UpdateOneofState(self._field)
1515 super(_OneofListener, self).Modified()
1516 except ReferenceError:
1517 pass