1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import math
33import struct
34import sys
35import warnings
36import weakref
37
38from google.protobuf import descriptor as descriptor_mod
39from google.protobuf import message as message_mod
40from google.protobuf import text_format
41# We use "as" to avoid name collisions with variables.
42from google.protobuf.internal import api_implementation
43from google.protobuf.internal import containers
44from google.protobuf.internal import decoder
45from google.protobuf.internal import encoder
46from google.protobuf.internal import enum_type_wrapper
47from google.protobuf.internal import extension_dict
48from google.protobuf.internal import message_listener as message_listener_mod
49from google.protobuf.internal import type_checkers
50from google.protobuf.internal import well_known_types
51from google.protobuf.internal import wire_format
52
53_FieldDescriptor = descriptor_mod.FieldDescriptor
54_AnyFullTypeName = 'google.protobuf.Any'
55_StructFullTypeName = 'google.protobuf.Struct'
56_ListValueFullTypeName = 'google.protobuf.ListValue'
57_ExtensionDict = extension_dict._ExtensionDict
58
59class GeneratedProtocolMessageType(type):
60
61 """Metaclass for protocol message classes created at runtime from Descriptors.
62
63 We add implementations for all methods described in the Message class. We
64 also create properties to allow getting/setting all fields in the protocol
65 message. Finally, we create slots to prevent users from accidentally
66 "setting" nonexistent fields in the protocol message, which then wouldn't get
67 serialized / deserialized properly.
68
69 The protocol compiler currently uses this metaclass to create protocol
70 message classes at runtime. Clients can also manually create their own
71 classes at runtime, as in this example:
72
73 mydescriptor = Descriptor(.....)
74 factory = symbol_database.Default()
75 factory.pool.AddDescriptor(mydescriptor)
76 MyProtoClass = message_factory.GetMessageClass(mydescriptor)
77 myproto_instance = MyProtoClass()
78 myproto.foo_field = 23
79 ...
80 """
81
82 # Must be consistent with the protocol-compiler code in
83 # proto2/compiler/internal/generator.*.
84 _DESCRIPTOR_KEY = 'DESCRIPTOR'
85
86 def __new__(cls, name, bases, dictionary):
87 """Custom allocation for runtime-generated class types.
88
89 We override __new__ because this is apparently the only place
90 where we can meaningfully set __slots__ on the class we're creating(?).
91 (The interplay between metaclasses and slots is not very well-documented).
92
93 Args:
94 name: Name of the class (ignored, but required by the
95 metaclass protocol).
96 bases: Base classes of the class we're constructing.
97 (Should be message.Message). We ignore this field, but
98 it's required by the metaclass protocol
99 dictionary: The class dictionary of the class we're
100 constructing. dictionary[_DESCRIPTOR_KEY] must contain
101 a Descriptor object describing this protocol message
102 type.
103
104 Returns:
105 Newly-allocated class.
106
107 Raises:
108 RuntimeError: Generated code only work with python cpp extension.
109 """
110 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
111
112 if isinstance(descriptor, str):
113 raise RuntimeError('The generated code only work with python cpp '
114 'extension, but it is using pure python runtime.')
115
116 # If a concrete class already exists for this descriptor, don't try to
117 # create another. Doing so will break any messages that already exist with
118 # the existing class.
119 #
120 # The C++ implementation appears to have its own internal `PyMessageFactory`
121 # to achieve similar results.
122 #
123 # This most commonly happens in `text_format.py` when using descriptors from
124 # a custom pool; it calls message_factory.GetMessageClass() on a
125 # descriptor which already has an existing concrete class.
126 new_class = getattr(descriptor, '_concrete_class', None)
127 if new_class:
128 return new_class
129
130 if descriptor.full_name in well_known_types.WKTBASES:
131 bases += (well_known_types.WKTBASES[descriptor.full_name],)
132 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
133 _AddSlots(descriptor, dictionary)
134
135 superclass = super(GeneratedProtocolMessageType, cls)
136 new_class = superclass.__new__(cls, name, bases, dictionary)
137 return new_class
138
139 def __init__(cls, name, bases, dictionary):
140 """Here we perform the majority of our work on the class.
141 We add enum getters, an __init__ method, implementations
142 of all Message methods, and properties for all fields
143 in the protocol type.
144
145 Args:
146 name: Name of the class (ignored, but required by the
147 metaclass protocol).
148 bases: Base classes of the class we're constructing.
149 (Should be message.Message). We ignore this field, but
150 it's required by the metaclass protocol
151 dictionary: The class dictionary of the class we're
152 constructing. dictionary[_DESCRIPTOR_KEY] must contain
153 a Descriptor object describing this protocol message
154 type.
155 """
156 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
157
158 # If this is an _existing_ class looked up via `_concrete_class` in the
159 # __new__ method above, then we don't need to re-initialize anything.
160 existing_class = getattr(descriptor, '_concrete_class', None)
161 if existing_class:
162 assert existing_class is cls, (
163 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
164 % (descriptor.full_name))
165 return
166
167 cls._message_set_decoders_by_tag = {}
168 cls._fields_by_tag = {}
169 if (descriptor.has_options and
170 descriptor.GetOptions().message_set_wire_format):
171 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
172 decoder.MessageSetItemDecoder(descriptor),
173 None,
174 )
175
176 # Attach stuff to each FieldDescriptor for quick lookup later on.
177 for field in descriptor.fields:
178 _AttachFieldHelpers(cls, field)
179
180 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
181 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
182 for ext in extensions:
183 _AttachFieldHelpers(cls, ext)
184
185 descriptor._concrete_class = cls # pylint: disable=protected-access
186 _AddEnumValues(descriptor, cls)
187 _AddInitMethod(descriptor, cls)
188 _AddPropertiesForFields(descriptor, cls)
189 _AddPropertiesForExtensions(descriptor, cls)
190 _AddStaticMethods(cls)
191 _AddMessageMethods(descriptor, cls)
192 _AddPrivateHelperMethods(descriptor, cls)
193
194 superclass = super(GeneratedProtocolMessageType, cls)
195 superclass.__init__(name, bases, dictionary)
196
197
198# Stateless helpers for GeneratedProtocolMessageType below.
199# Outside clients should not access these directly.
200#
201# I opted not to make any of these methods on the metaclass, to make it more
202# clear that I'm not really using any state there and to keep clients from
203# thinking that they have direct access to these construction helpers.
204
205
206def _PropertyName(proto_field_name):
207 """Returns the name of the public property attribute which
208 clients can use to get and (in some cases) set the value
209 of a protocol message field.
210
211 Args:
212 proto_field_name: The protocol message field name, exactly
213 as it appears (or would appear) in a .proto file.
214 """
215 # TODO: Escape Python keywords (e.g., yield), and test this support.
216 # nnorwitz makes my day by writing:
217 # """
218 # FYI. See the keyword module in the stdlib. This could be as simple as:
219 #
220 # if keyword.iskeyword(proto_field_name):
221 # return proto_field_name + "_"
222 # return proto_field_name
223 # """
224 # Kenton says: The above is a BAD IDEA. People rely on being able to use
225 # getattr() and setattr() to reflectively manipulate field values. If we
226 # rename the properties, then every such user has to also make sure to apply
227 # the same transformation. Note that currently if you name a field "yield",
228 # you can still access it just fine using getattr/setattr -- it's not even
229 # that cumbersome to do so.
230 # TODO: Remove this method entirely if/when everyone agrees with my
231 # position.
232 return proto_field_name
233
234
235def _AddSlots(message_descriptor, dictionary):
236 """Adds a __slots__ entry to dictionary, containing the names of all valid
237 attributes for this message type.
238
239 Args:
240 message_descriptor: A Descriptor instance describing this message type.
241 dictionary: Class dictionary to which we'll add a '__slots__' entry.
242 """
243 dictionary['__slots__'] = ['_cached_byte_size',
244 '_cached_byte_size_dirty',
245 '_fields',
246 '_unknown_fields',
247 '_is_present_in_parent',
248 '_listener',
249 '_listener_for_children',
250 '__weakref__',
251 '_oneofs']
252
253
254def _IsMessageSetExtension(field):
255 return (field.is_extension and
256 field.containing_type.has_options and
257 field.containing_type.GetOptions().message_set_wire_format and
258 field.type == _FieldDescriptor.TYPE_MESSAGE and
259 not field.is_required and
260 not field.is_repeated)
261
262
263def _IsMapField(field):
264 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
265 field.message_type._is_map_entry)
266
267
268def _IsMessageMapField(field):
269 value_type = field.message_type.fields_by_name['value']
270 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
271
272def _AttachFieldHelpers(cls, field_descriptor):
273 field_descriptor._default_constructor = _DefaultValueConstructorForField(
274 field_descriptor
275 )
276
277 def AddFieldByTag(wiretype, is_packed):
278 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
280
281 AddFieldByTag(
282 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
283 )
284
285 if field_descriptor.is_repeated and wire_format.IsTypePackable(
286 field_descriptor.type
287 ):
288 # To support wire compatibility of adding packed = true, add a decoder for
289 # packed values regardless of the field's options.
290 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
291
292
293def _MaybeAddEncoder(cls, field_descriptor):
294 if hasattr(field_descriptor, '_encoder'):
295 return
296 is_repeated = field_descriptor.is_repeated
297 is_map_entry = _IsMapField(field_descriptor)
298 is_packed = field_descriptor.is_packed
299
300 if is_map_entry:
301 field_encoder = encoder.MapEncoder(field_descriptor)
302 sizer = encoder.MapSizer(field_descriptor,
303 _IsMessageMapField(field_descriptor))
304 elif _IsMessageSetExtension(field_descriptor):
305 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
306 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
307 else:
308 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
309 field_descriptor.number, is_repeated, is_packed)
310 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
311 field_descriptor.number, is_repeated, is_packed)
312
313 field_descriptor._sizer = sizer
314 field_descriptor._encoder = field_encoder
315
316
317def _MaybeAddDecoder(cls, field_descriptor):
318 if hasattr(field_descriptor, '_decoders'):
319 return
320
321 is_repeated = field_descriptor.is_repeated
322 is_map_entry = _IsMapField(field_descriptor)
323 helper_decoders = {}
324
325 def AddDecoder(is_packed):
326 decode_type = field_descriptor.type
327 if (decode_type == _FieldDescriptor.TYPE_ENUM and
328 not field_descriptor.enum_type.is_closed):
329 decode_type = _FieldDescriptor.TYPE_INT32
330
331 oneof_descriptor = None
332 if field_descriptor.containing_oneof is not None:
333 oneof_descriptor = field_descriptor
334
335 if is_map_entry:
336 is_message_map = _IsMessageMapField(field_descriptor)
337
338 field_decoder = decoder.MapDecoder(
339 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
340 is_message_map)
341 elif decode_type == _FieldDescriptor.TYPE_STRING:
342 field_decoder = decoder.StringDecoder(
343 field_descriptor.number, is_repeated, is_packed,
344 field_descriptor, field_descriptor._default_constructor,
345 not field_descriptor.has_presence)
346 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
347 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
348 field_descriptor.number, is_repeated, is_packed,
349 field_descriptor, field_descriptor._default_constructor)
350 else:
351 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
352 field_descriptor.number, is_repeated, is_packed,
353 # pylint: disable=protected-access
354 field_descriptor, field_descriptor._default_constructor,
355 not field_descriptor.has_presence)
356
357 helper_decoders[is_packed] = field_decoder
358
359 AddDecoder(False)
360
361 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
362 # To support wire compatibility of adding packed = true, add a decoder for
363 # packed values regardless of the field's options.
364 AddDecoder(True)
365
366 field_descriptor._decoders = helper_decoders
367
368
369def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
370 extensions = descriptor.extensions_by_name
371 for extension_name, extension_field in extensions.items():
372 assert extension_name not in dictionary
373 dictionary[extension_name] = extension_field
374
375
376def _AddEnumValues(descriptor, cls):
377 """Sets class-level attributes for all enum fields defined in this message.
378
379 Also exporting a class-level object that can name enum values.
380
381 Args:
382 descriptor: Descriptor object for this message type.
383 cls: Class we're constructing for this message type.
384 """
385 for enum_type in descriptor.enum_types:
386 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
387 for enum_value in enum_type.values:
388 setattr(cls, enum_value.name, enum_value.number)
389
390
391def _GetInitializeDefaultForMap(field):
392 if not field.is_repeated:
393 raise ValueError('map_entry set on non-repeated field %s' % (
394 field.name))
395 fields_by_name = field.message_type.fields_by_name
396 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
397
398 value_field = fields_by_name['value']
399 if _IsMessageMapField(field):
400 def MakeMessageMapDefault(message):
401 return containers.MessageMap(
402 message._listener_for_children, value_field.message_type, key_checker,
403 field.message_type)
404 return MakeMessageMapDefault
405 else:
406 value_checker = type_checkers.GetTypeChecker(value_field)
407 def MakePrimitiveMapDefault(message):
408 return containers.ScalarMap(
409 message._listener_for_children, key_checker, value_checker,
410 field.message_type)
411 return MakePrimitiveMapDefault
412
413def _DefaultValueConstructorForField(field):
414 """Returns a function which returns a default value for a field.
415
416 Args:
417 field: FieldDescriptor object for this field.
418
419 The returned function has one argument:
420 message: Message instance containing this field, or a weakref proxy
421 of same.
422
423 That function in turn returns a default value for this field. The default
424 value may refer back to |message| via a weak reference.
425 """
426
427 if _IsMapField(field):
428 return _GetInitializeDefaultForMap(field)
429
430 if field.is_repeated:
431 if field.has_default_value and field.default_value != []:
432 raise ValueError('Repeated field default value not empty list: %s' % (
433 field.default_value))
434 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
435 # We can't look at _concrete_class yet since it might not have
436 # been set. (Depends on order in which we initialize the classes).
437 message_type = field.message_type
438 def MakeRepeatedMessageDefault(message):
439 return containers.RepeatedCompositeFieldContainer(
440 message._listener_for_children, field.message_type)
441 return MakeRepeatedMessageDefault
442 else:
443 type_checker = type_checkers.GetTypeChecker(field)
444 def MakeRepeatedScalarDefault(message):
445 return containers.RepeatedScalarFieldContainer(
446 message._listener_for_children, type_checker)
447 return MakeRepeatedScalarDefault
448
449 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
450 message_type = field.message_type
451 def MakeSubMessageDefault(message):
452 # _concrete_class may not yet be initialized.
453 if not hasattr(message_type, '_concrete_class'):
454 from google.protobuf import message_factory
455 message_factory.GetMessageClass(message_type)
456 result = message_type._concrete_class()
457 result._SetListener(
458 _OneofListener(message, field)
459 if field.containing_oneof is not None
460 else message._listener_for_children)
461 return result
462 return MakeSubMessageDefault
463
464 def MakeScalarDefault(message):
465 # TODO: This may be broken since there may not be
466 # default_value. Combine with has_default_value somehow.
467 return field.default_value
468 return MakeScalarDefault
469
470
471def _ReraiseTypeErrorWithFieldName(message_name, field_name):
472 """Re-raise the currently-handled TypeError with the field name added."""
473 exc = sys.exc_info()[1]
474 if len(exc.args) == 1 and type(exc) is TypeError:
475 # simple TypeError; add field name to exception message
476 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
477
478 # re-raise possibly-amended exception with original traceback:
479 raise exc.with_traceback(sys.exc_info()[2])
480
481
482def _AddInitMethod(message_descriptor, cls):
483 """Adds an __init__ method to cls."""
484
485 def _GetIntegerEnumValue(enum_type, value):
486 """Convert a string or integer enum value to an integer.
487
488 If the value is a string, it is converted to the enum value in
489 enum_type with the same name. If the value is not a string, it's
490 returned as-is. (No conversion or bounds-checking is done.)
491 """
492 if isinstance(value, str):
493 try:
494 return enum_type.values_by_name[value].number
495 except KeyError:
496 raise ValueError('Enum type %s: unknown label "%s"' % (
497 enum_type.full_name, value))
498 return value
499
500 def init(self, **kwargs):
501 self._cached_byte_size = 0
502 self._cached_byte_size_dirty = len(kwargs) > 0
503 self._fields = {}
504 # Contains a mapping from oneof field descriptors to the descriptor
505 # of the currently set field in that oneof field.
506 self._oneofs = {}
507
508 # _unknown_fields is () when empty for efficiency, and will be turned into
509 # a list if fields are added.
510 self._unknown_fields = ()
511 self._is_present_in_parent = False
512 self._listener = message_listener_mod.NullMessageListener()
513 self._listener_for_children = _Listener(self)
514 for field_name, field_value in kwargs.items():
515 field = _GetFieldByName(message_descriptor, field_name)
516 if field is None:
517 raise TypeError('%s() got an unexpected keyword argument "%s"' %
518 (message_descriptor.name, field_name))
519 if field_value is None:
520 # field=None is the same as no field at all.
521 continue
522 if field.is_repeated:
523 field_copy = field._default_constructor(self)
524 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
525 if _IsMapField(field):
526 if _IsMessageMapField(field):
527 for key in field_value:
528 item_value = field_value[key]
529 if isinstance(item_value, dict):
530 field_copy[key].__init__(**item_value)
531 else:
532 field_copy[key].MergeFrom(item_value)
533 else:
534 field_copy.update(field_value)
535 else:
536 for val in field_value:
537 if isinstance(val, dict):
538 field_copy.add(**val)
539 else:
540 field_copy.add().MergeFrom(val)
541 else: # Scalar
542 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
543 field_value = [_GetIntegerEnumValue(field.enum_type, val)
544 for val in field_value]
545 field_copy.extend(field_value)
546 self._fields[field] = field_copy
547 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
548 field_copy = field._default_constructor(self)
549 new_val = None
550 if isinstance(field_value, message_mod.Message):
551 new_val = field_value
552 elif isinstance(field_value, dict):
553 if field.message_type.full_name == _StructFullTypeName:
554 field_copy.Clear()
555 if len(field_value) == 1 and 'fields' in field_value:
556 try:
557 field_copy.update(field_value)
558 except:
559 # Fall back to init normal message field
560 field_copy.Clear()
561 new_val = field.message_type._concrete_class(**field_value)
562 else:
563 field_copy.update(field_value)
564 else:
565 new_val = field.message_type._concrete_class(**field_value)
566 elif hasattr(field_copy, '_internal_assign'):
567 field_copy._internal_assign(field_value)
568 else:
569 raise TypeError(
570 'Message field {0}.{1} must be initialized with a '
571 'dict or instance of same class, got {2}.'.format(
572 message_descriptor.name,
573 field_name,
574 type(field_value).__name__,
575 )
576 )
577
578 if new_val != None:
579 try:
580 field_copy.MergeFrom(new_val)
581 except TypeError:
582 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
583 self._fields[field] = field_copy
584 else:
585 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
586 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
587 try:
588 setattr(self, field_name, field_value)
589 except TypeError:
590 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
591
592 init.__module__ = None
593 init.__doc__ = None
594 cls.__init__ = init
595
596
597def _GetFieldByName(message_descriptor, field_name):
598 """Returns a field descriptor by field name.
599
600 Args:
601 message_descriptor: A Descriptor describing all fields in message.
602 field_name: The name of the field to retrieve.
603 Returns:
604 The field descriptor associated with the field name.
605 """
606 try:
607 return message_descriptor.fields_by_name[field_name]
608 except KeyError:
609 raise ValueError('Protocol message %s has no "%s" field.' %
610 (message_descriptor.name, field_name))
611
612
613def _AddPropertiesForFields(descriptor, cls):
614 """Adds properties for all fields in this protocol message type."""
615 for field in descriptor.fields:
616 _AddPropertiesForField(field, cls)
617
618 if descriptor.is_extendable:
619 # _ExtensionDict is just an adaptor with no state so we allocate a new one
620 # every time it is accessed.
621 cls.Extensions = property(lambda self: _ExtensionDict(self))
622
623
624def _AddPropertiesForField(field, cls):
625 """Adds a public property for a protocol message field.
626 Clients can use this property to get and (in the case
627 of non-repeated scalar fields) directly set the value
628 of a protocol message field.
629
630 Args:
631 field: A FieldDescriptor for this field.
632 cls: The class we're constructing.
633 """
634 # Catch it if we add other types that we should
635 # handle specially here.
636 assert _FieldDescriptor.MAX_CPPTYPE == 10
637
638 constant_name = field.name.upper() + '_FIELD_NUMBER'
639 setattr(cls, constant_name, field.number)
640
641 if field.is_repeated:
642 _AddPropertiesForRepeatedField(field, cls)
643 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
644 _AddPropertiesForNonRepeatedCompositeField(field, cls)
645 else:
646 _AddPropertiesForNonRepeatedScalarField(field, cls)
647
648
649class _FieldProperty(property):
650 __slots__ = ('DESCRIPTOR',)
651
652 def __init__(self, descriptor, getter, setter, doc):
653 property.__init__(self, getter, setter, doc=doc)
654 self.DESCRIPTOR = descriptor
655
656
657def _AddPropertiesForRepeatedField(field, cls):
658 """Adds a public property for a "repeated" protocol message field. Clients
659 can use this property to get the value of the field, which will be either a
660 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
661 below).
662
663 Note that when clients add values to these containers, we perform
664 type-checking in the case of repeated scalar fields, and we also set any
665 necessary "has" bits as a side-effect.
666
667 Args:
668 field: A FieldDescriptor for this field.
669 cls: The class we're constructing.
670 """
671 proto_field_name = field.name
672 property_name = _PropertyName(proto_field_name)
673
674 def getter(self):
675 field_value = self._fields.get(field)
676 if field_value is None:
677 # Construct a new object to represent this field.
678 field_value = field._default_constructor(self)
679
680 # Atomically check if another thread has preempted us and, if not, swap
681 # in the new object we just created. If someone has preempted us, we
682 # take that object and discard ours.
683 # WARNING: We are relying on setdefault() being atomic. This is true
684 # in CPython but we haven't investigated others. This warning appears
685 # in several other locations in this file.
686 field_value = self._fields.setdefault(field, field_value)
687 return field_value
688 getter.__module__ = None
689 getter.__doc__ = 'Getter for %s.' % proto_field_name
690
691 # We define a setter just so we can throw an exception with a more
692 # helpful error message.
693 def setter(self, new_value):
694 raise AttributeError('Assignment not allowed to repeated field '
695 '"%s" in protocol message object.' % proto_field_name)
696
697 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
698 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
699
700
701def _AddPropertiesForNonRepeatedScalarField(field, cls):
702 """Adds a public property for a nonrepeated, scalar protocol message field.
703 Clients can use this property to get and directly set the value of the field.
704 Note that when the client sets the value of a field by using this property,
705 all necessary "has" bits are set as a side-effect, and we also perform
706 type-checking.
707
708 Args:
709 field: A FieldDescriptor for this field.
710 cls: The class we're constructing.
711 """
712 proto_field_name = field.name
713 property_name = _PropertyName(proto_field_name)
714 type_checker = type_checkers.GetTypeChecker(field)
715 default_value = field.default_value
716
717 def getter(self):
718 # TODO: This may be broken since there may not be
719 # default_value. Combine with has_default_value somehow.
720 return self._fields.get(field, default_value)
721 getter.__module__ = None
722 getter.__doc__ = 'Getter for %s.' % proto_field_name
723
724 def field_setter(self, new_value):
725 # pylint: disable=protected-access
726 # Testing the value for truthiness captures all of the implicit presence
727 # defaults (0, 0.0, enum 0, and False), except for -0.0.
728 try:
729 new_value = type_checker.CheckValue(new_value)
730 except TypeError as e:
731 raise TypeError(
732 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
733 if not field.has_presence and decoder.IsDefaultScalarValue(new_value):
734 self._fields.pop(field, None)
735 else:
736 self._fields[field] = new_value
737 # Check _cached_byte_size_dirty inline to improve performance, since scalar
738 # setters are called frequently.
739 if not self._cached_byte_size_dirty:
740 self._Modified()
741
742 if field.containing_oneof:
743 def setter(self, new_value):
744 field_setter(self, new_value)
745 self._UpdateOneofState(field)
746 else:
747 setter = field_setter
748
749 setter.__module__ = None
750 setter.__doc__ = 'Setter for %s.' % proto_field_name
751
752 # Add a property to encapsulate the getter/setter.
753 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
754 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
755
756
757def _AddPropertiesForNonRepeatedCompositeField(field, cls):
758 """Adds a public property for a nonrepeated, composite protocol message field.
759 A composite field is a "group" or "message" field.
760
761 Clients can use this property to get the value of the field, but cannot
762 assign to the property directly.
763
764 Args:
765 field: A FieldDescriptor for this field.
766 cls: The class we're constructing.
767 """
768 # TODO: Remove duplication with similar method
769 # for non-repeated scalars.
770 proto_field_name = field.name
771 property_name = _PropertyName(proto_field_name)
772
773 def getter(self):
774 field_value = self._fields.get(field)
775 if field_value is None:
776 # Construct a new object to represent this field.
777 field_value = field._default_constructor(self)
778
779 # Atomically check if another thread has preempted us and, if not, swap
780 # in the new object we just created. If someone has preempted us, we
781 # take that object and discard ours.
782 # WARNING: We are relying on setdefault() being atomic. This is true
783 # in CPython but we haven't investigated others. This warning appears
784 # in several other locations in this file.
785 field_value = self._fields.setdefault(field, field_value)
786 return field_value
787 getter.__module__ = None
788 getter.__doc__ = 'Getter for %s.' % proto_field_name
789
790 # We define a setter just so we can throw an exception with a more
791 # helpful error message.
792 def setter(self, new_value):
793 if field.message_type.full_name == 'google.protobuf.Timestamp':
794 getter(self)
795 self._fields[field].FromDatetime(new_value)
796 elif field.message_type.full_name == 'google.protobuf.Duration':
797 getter(self)
798 self._fields[field].FromTimedelta(new_value)
799 elif field.message_type.full_name == _StructFullTypeName:
800 getter(self)
801 self._fields[field].Clear()
802 self._fields[field].update(new_value)
803 elif field.message_type.full_name == _ListValueFullTypeName:
804 getter(self)
805 self._fields[field].Clear()
806 self._fields[field].extend(new_value)
807 else:
808 raise AttributeError(
809 'Assignment not allowed to composite field '
810 '"%s" in protocol message object.' % proto_field_name
811 )
812
813 # Add a property to encapsulate the getter.
814 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
815 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
816
817
818def _AddPropertiesForExtensions(descriptor, cls):
819 """Adds properties for all fields in this protocol message type."""
820 extensions = descriptor.extensions_by_name
821 for extension_name, extension_field in extensions.items():
822 constant_name = extension_name.upper() + '_FIELD_NUMBER'
823 setattr(cls, constant_name, extension_field.number)
824
825 # TODO: Migrate all users of these attributes to functions like
826 # pool.FindExtensionByNumber(descriptor).
827 if descriptor.file is not None:
828 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
829 pool = descriptor.file.pool
830
831def _AddStaticMethods(cls):
832
833 def RegisterExtension(_):
834 """no-op to keep generated code <=4.23 working with new runtimes."""
835 # This was originally removed in 5.26 (cl/595989309).
836 pass
837
838 cls.RegisterExtension = staticmethod(RegisterExtension)
839 def FromString(s):
840 message = cls()
841 message.MergeFromString(s)
842 return message
843 cls.FromString = staticmethod(FromString)
844
845
846def _IsPresent(item):
847 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
848 value should be included in the list returned by ListFields()."""
849
850 if item[0].is_repeated:
851 return bool(item[1])
852 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
853 return item[1]._is_present_in_parent
854 else:
855 return True
856
857
858def _AddListFieldsMethod(message_descriptor, cls):
859 """Helper for _AddMessageMethods()."""
860
861 def ListFields(self):
862 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
863 all_fields.sort(key = lambda item: item[0].number)
864 return all_fields
865
866 cls.ListFields = ListFields
867
868
869def _AddHasFieldMethod(message_descriptor, cls):
870 """Helper for _AddMessageMethods()."""
871
872 hassable_fields = {}
873 for field in message_descriptor.fields:
874 if field.is_repeated:
875 continue
876 # For proto3, only submessages and fields inside a oneof have presence.
877 if not field.has_presence:
878 continue
879 hassable_fields[field.name] = field
880
881 # Has methods are supported for oneof descriptors.
882 for oneof in message_descriptor.oneofs:
883 hassable_fields[oneof.name] = oneof
884
885 def HasField(self, field_name):
886 try:
887 field = hassable_fields[field_name]
888 except KeyError as exc:
889 raise ValueError('Protocol message %s has no non-repeated field "%s" '
890 'nor has presence is not available for this field.' % (
891 message_descriptor.full_name, field_name)) from exc
892
893 if isinstance(field, descriptor_mod.OneofDescriptor):
894 try:
895 return HasField(self, self._oneofs[field].name)
896 except KeyError:
897 return False
898 else:
899 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
900 value = self._fields.get(field)
901 return value is not None and value._is_present_in_parent
902 else:
903 return field in self._fields
904
905 cls.HasField = HasField
906
907
908def _AddClearFieldMethod(message_descriptor, cls):
909 """Helper for _AddMessageMethods()."""
910 def ClearField(self, field_name):
911 try:
912 field = message_descriptor.fields_by_name[field_name]
913 except KeyError:
914 try:
915 field = message_descriptor.oneofs_by_name[field_name]
916 if field in self._oneofs:
917 field = self._oneofs[field]
918 else:
919 return
920 except KeyError:
921 raise ValueError('Protocol message %s has no "%s" field.' %
922 (message_descriptor.name, field_name))
923
924 if field in self._fields:
925 # To match the C++ implementation, we need to invalidate iterators
926 # for map fields when ClearField() happens.
927 if hasattr(self._fields[field], 'InvalidateIterators'):
928 self._fields[field].InvalidateIterators()
929
930 # Note: If the field is a sub-message, its listener will still point
931 # at us. That's fine, because the worst than can happen is that it
932 # will call _Modified() and invalidate our byte size. Big deal.
933 del self._fields[field]
934
935 if self._oneofs.get(field.containing_oneof, None) is field:
936 del self._oneofs[field.containing_oneof]
937
938 # Always call _Modified() -- even if nothing was changed, this is
939 # a mutating method, and thus calling it should cause the field to become
940 # present in the parent message.
941 self._Modified()
942
943 cls.ClearField = ClearField
944
945
946def _AddClearExtensionMethod(cls):
947 """Helper for _AddMessageMethods()."""
948 def ClearExtension(self, field_descriptor):
949 extension_dict._VerifyExtensionHandle(self, field_descriptor)
950
951 # Similar to ClearField(), above.
952 if field_descriptor in self._fields:
953 del self._fields[field_descriptor]
954 self._Modified()
955 cls.ClearExtension = ClearExtension
956
957
958def _AddHasExtensionMethod(cls):
959 """Helper for _AddMessageMethods()."""
960 def HasExtension(self, field_descriptor):
961 extension_dict._VerifyExtensionHandle(self, field_descriptor)
962 if field_descriptor.is_repeated:
963 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
964
965 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
966 value = self._fields.get(field_descriptor)
967 return value is not None and value._is_present_in_parent
968 else:
969 return field_descriptor in self._fields
970 cls.HasExtension = HasExtension
971
972def _InternalUnpackAny(msg):
973 """Unpacks Any message and returns the unpacked message.
974
975 This internal method is different from public Any Unpack method which takes
976 the target message as argument. _InternalUnpackAny method does not have
977 target message type and need to find the message type in descriptor pool.
978
979 Args:
980 msg: An Any message to be unpacked.
981
982 Returns:
983 The unpacked message.
984 """
985 # TODO: Don't use the factory of generated messages.
986 # To make Any work with custom factories, use the message factory of the
987 # parent message.
988 # pylint: disable=g-import-not-at-top
989 from google.protobuf import symbol_database
990 factory = symbol_database.Default()
991
992 type_url = msg.type_url
993
994 if not type_url:
995 return None
996
997 # TODO: For now we just strip the hostname. Better logic will be
998 # required.
999 type_name = type_url.split('/')[-1]
1000 descriptor = factory.pool.FindMessageTypeByName(type_name)
1001
1002 if descriptor is None:
1003 return None
1004
1005 # Unable to import message_factory at top because of circular import.
1006 # pylint: disable=g-import-not-at-top
1007 from google.protobuf import message_factory
1008 message_class = message_factory.GetMessageClass(descriptor)
1009 message = message_class()
1010
1011 message.ParseFromString(msg.value)
1012 return message
1013
1014
1015def _AddEqualsMethod(message_descriptor, cls):
1016 """Helper for _AddMessageMethods()."""
1017 def __eq__(self, other):
1018 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1019 other, list
1020 ):
1021 return self._internal_compare(other)
1022 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1023 other, dict
1024 ):
1025 return self._internal_compare(other)
1026
1027 if (not isinstance(other, message_mod.Message) or
1028 other.DESCRIPTOR != self.DESCRIPTOR):
1029 return NotImplemented
1030
1031 if self is other:
1032 return True
1033
1034 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1035 any_a = _InternalUnpackAny(self)
1036 any_b = _InternalUnpackAny(other)
1037 if any_a and any_b:
1038 return any_a == any_b
1039
1040 if not self.ListFields() == other.ListFields():
1041 return False
1042
1043 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1044 # then use it for the comparison.
1045 unknown_fields = list(self._unknown_fields)
1046 unknown_fields.sort()
1047 other_unknown_fields = list(other._unknown_fields)
1048 other_unknown_fields.sort()
1049 return unknown_fields == other_unknown_fields
1050
1051 cls.__eq__ = __eq__
1052
1053
1054def _AddStrMethod(message_descriptor, cls):
1055 """Helper for _AddMessageMethods()."""
1056 def __str__(self):
1057 return text_format.MessageToString(self)
1058 cls.__str__ = __str__
1059
1060
1061def _AddReprMethod(message_descriptor, cls):
1062 """Helper for _AddMessageMethods()."""
1063 def __repr__(self):
1064 return text_format.MessageToString(self)
1065 cls.__repr__ = __repr__
1066
1067
1068def _AddUnicodeMethod(unused_message_descriptor, cls):
1069 """Helper for _AddMessageMethods()."""
1070
1071 def __unicode__(self):
1072 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1073 cls.__unicode__ = __unicode__
1074
1075
1076def _AddContainsMethod(message_descriptor, cls):
1077
1078 if message_descriptor.full_name == 'google.protobuf.Struct':
1079 def __contains__(self, key):
1080 return key in self.fields
1081 elif message_descriptor.full_name == 'google.protobuf.ListValue':
1082 def __contains__(self, value):
1083 return value in self.items()
1084 else:
1085 def __contains__(self, field):
1086 return self.HasField(field)
1087
1088 cls.__contains__ = __contains__
1089
1090
1091def _BytesForNonRepeatedElement(value, field_number, field_type):
1092 """Returns the number of bytes needed to serialize a non-repeated element.
1093 The returned byte count includes space for tag information and any
1094 other additional space associated with serializing value.
1095
1096 Args:
1097 value: Value we're serializing.
1098 field_number: Field number of this value. (Since the field number
1099 is stored as part of a varint-encoded tag, this has an impact
1100 on the total bytes required to serialize the value).
1101 field_type: The type of the field. One of the TYPE_* constants
1102 within FieldDescriptor.
1103 """
1104 try:
1105 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1106 return fn(field_number, value)
1107 except KeyError:
1108 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1109
1110
1111def _AddByteSizeMethod(message_descriptor, cls):
1112 """Helper for _AddMessageMethods()."""
1113
1114 def ByteSize(self):
1115 if not self._cached_byte_size_dirty:
1116 return self._cached_byte_size
1117
1118 size = 0
1119 descriptor = self.DESCRIPTOR
1120 if descriptor._is_map_entry:
1121 # Fields of map entry should always be serialized.
1122 key_field = descriptor.fields_by_name['key']
1123 _MaybeAddEncoder(cls, key_field)
1124 size = key_field._sizer(self.key)
1125 value_field = descriptor.fields_by_name['value']
1126 _MaybeAddEncoder(cls, value_field)
1127 size += value_field._sizer(self.value)
1128 else:
1129 for field_descriptor, field_value in self.ListFields():
1130 _MaybeAddEncoder(cls, field_descriptor)
1131 size += field_descriptor._sizer(field_value)
1132 for tag_bytes, value_bytes in self._unknown_fields:
1133 size += len(tag_bytes) + len(value_bytes)
1134
1135 self._cached_byte_size = size
1136 self._cached_byte_size_dirty = False
1137 self._listener_for_children.dirty = False
1138 return size
1139
1140 cls.ByteSize = ByteSize
1141
1142
1143def _AddSerializeToStringMethod(message_descriptor, cls):
1144 """Helper for _AddMessageMethods()."""
1145
1146 def SerializeToString(self, **kwargs):
1147 # Check if the message has all of its required fields set.
1148 if not self.IsInitialized():
1149 raise message_mod.EncodeError(
1150 'Message %s is missing required fields: %s' % (
1151 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1152 return self.SerializePartialToString(**kwargs)
1153 cls.SerializeToString = SerializeToString
1154
1155
1156def _AddSerializePartialToStringMethod(message_descriptor, cls):
1157 """Helper for _AddMessageMethods()."""
1158
1159 def SerializePartialToString(self, **kwargs):
1160 out = BytesIO()
1161 self._InternalSerialize(out.write, **kwargs)
1162 return out.getvalue()
1163 cls.SerializePartialToString = SerializePartialToString
1164
1165 def InternalSerialize(self, write_bytes, deterministic=None):
1166 if deterministic is None:
1167 deterministic = (
1168 api_implementation.IsPythonDefaultSerializationDeterministic())
1169 else:
1170 deterministic = bool(deterministic)
1171
1172 descriptor = self.DESCRIPTOR
1173 if descriptor._is_map_entry:
1174 # Fields of map entry should always be serialized.
1175 key_field = descriptor.fields_by_name['key']
1176 _MaybeAddEncoder(cls, key_field)
1177 key_field._encoder(write_bytes, self.key, deterministic)
1178 value_field = descriptor.fields_by_name['value']
1179 _MaybeAddEncoder(cls, value_field)
1180 value_field._encoder(write_bytes, self.value, deterministic)
1181 else:
1182 for field_descriptor, field_value in self.ListFields():
1183 _MaybeAddEncoder(cls, field_descriptor)
1184 field_descriptor._encoder(write_bytes, field_value, deterministic)
1185 for tag_bytes, value_bytes in self._unknown_fields:
1186 write_bytes(tag_bytes)
1187 write_bytes(value_bytes)
1188 cls._InternalSerialize = InternalSerialize
1189
1190
1191def _AddMergeFromStringMethod(message_descriptor, cls):
1192 """Helper for _AddMessageMethods()."""
1193 def MergeFromString(self, serialized):
1194 serialized = memoryview(serialized)
1195 length = len(serialized)
1196 try:
1197 if self._InternalParse(serialized, 0, length) != length:
1198 # The only reason _InternalParse would return early is if it
1199 # encountered an end-group tag.
1200 raise message_mod.DecodeError('Unexpected end-group tag.')
1201 except (IndexError, TypeError):
1202 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1203 raise message_mod.DecodeError('Truncated message.')
1204 except struct.error as e:
1205 raise message_mod.DecodeError(e)
1206 return length # Return this for legacy reasons.
1207 cls.MergeFromString = MergeFromString
1208
1209 fields_by_tag = cls._fields_by_tag
1210 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1211
1212 def InternalParse(self, buffer, pos, end, current_depth=0):
1213 """Create a message from serialized bytes.
1214
1215 Args:
1216 self: Message, instance of the proto message object.
1217 buffer: memoryview of the serialized data.
1218 pos: int, position to start in the serialized data.
1219 end: int, end position of the serialized data.
1220
1221 Returns:
1222 Message object.
1223 """
1224 # Guard against internal misuse, since this function is called internally
1225 # quite extensively, and its easy to accidentally pass bytes.
1226 assert isinstance(buffer, memoryview)
1227 self._Modified()
1228 field_dict = self._fields
1229 while pos != end:
1230 (tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
1231 field_decoder, field_des = message_set_decoders_by_tag.get(
1232 tag_bytes, (None, None)
1233 )
1234 if field_decoder:
1235 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1236 continue
1237 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1238 if field_des is None:
1239 if not self._unknown_fields: # pylint: disable=protected-access
1240 self._unknown_fields = [] # pylint: disable=protected-access
1241 field_number, wire_type = decoder.DecodeTag(tag_bytes)
1242 if field_number == 0:
1243 raise message_mod.DecodeError('Field number 0 is illegal.')
1244 (data, new_pos) = decoder._DecodeUnknownField(
1245 buffer, new_pos, end, field_number, wire_type
1246 ) # pylint: disable=protected-access
1247 if new_pos == -1:
1248 return pos
1249 self._unknown_fields.append(
1250 (tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
1251 )
1252 pos = new_pos
1253 else:
1254 _MaybeAddDecoder(cls, field_des)
1255 field_decoder = field_des._decoders[is_packed]
1256 pos = field_decoder(
1257 buffer, new_pos, end, self, field_dict, current_depth
1258 )
1259 if field_des.containing_oneof:
1260 self._UpdateOneofState(field_des)
1261 return pos
1262
1263 cls._InternalParse = InternalParse
1264
1265
1266def _AddIsInitializedMethod(message_descriptor, cls):
1267 """Adds the IsInitialized and FindInitializationError methods to the
1268 protocol message class."""
1269
1270 required_fields = [field for field in message_descriptor.fields
1271 if field.is_required]
1272
1273 def IsInitialized(self, errors=None):
1274 """Checks if all required fields of a message are set.
1275
1276 Args:
1277 errors: A list which, if provided, will be populated with the field
1278 paths of all missing required fields.
1279
1280 Returns:
1281 True iff the specified message has all required fields set.
1282 """
1283
1284 # Performance is critical so we avoid HasField() and ListFields().
1285
1286 for field in required_fields:
1287 if (field not in self._fields or
1288 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1289 not self._fields[field]._is_present_in_parent)):
1290 if errors is not None:
1291 errors.extend(self.FindInitializationErrors())
1292 return False
1293
1294 for field, value in list(self._fields.items()): # dict can change size!
1295 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1296 if field.is_repeated:
1297 if (field.message_type._is_map_entry):
1298 continue
1299 for element in value:
1300 if not element.IsInitialized():
1301 if errors is not None:
1302 errors.extend(self.FindInitializationErrors())
1303 return False
1304 elif value._is_present_in_parent and not value.IsInitialized():
1305 if errors is not None:
1306 errors.extend(self.FindInitializationErrors())
1307 return False
1308
1309 return True
1310
1311 cls.IsInitialized = IsInitialized
1312
1313 def FindInitializationErrors(self):
1314 """Finds required fields which are not initialized.
1315
1316 Returns:
1317 A list of strings. Each string is a path to an uninitialized field from
1318 the top-level message, e.g. "foo.bar[5].baz".
1319 """
1320
1321 errors = [] # simplify things
1322
1323 for field in required_fields:
1324 if not self.HasField(field.name):
1325 errors.append(field.name)
1326
1327 for field, value in self.ListFields():
1328 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1329 if field.is_extension:
1330 name = '(%s)' % field.full_name
1331 else:
1332 name = field.name
1333
1334 if _IsMapField(field):
1335 if _IsMessageMapField(field):
1336 for key in value:
1337 element = value[key]
1338 prefix = '%s[%s].' % (name, key)
1339 sub_errors = element.FindInitializationErrors()
1340 errors += [prefix + error for error in sub_errors]
1341 else:
1342 # ScalarMaps can't have any initialization errors.
1343 pass
1344 elif field.is_repeated:
1345 for i in range(len(value)):
1346 element = value[i]
1347 prefix = '%s[%d].' % (name, i)
1348 sub_errors = element.FindInitializationErrors()
1349 errors += [prefix + error for error in sub_errors]
1350 else:
1351 prefix = name + '.'
1352 sub_errors = value.FindInitializationErrors()
1353 errors += [prefix + error for error in sub_errors]
1354
1355 return errors
1356
1357 cls.FindInitializationErrors = FindInitializationErrors
1358
1359
1360def _FullyQualifiedClassName(klass):
1361 module = klass.__module__
1362 name = getattr(klass, '__qualname__', klass.__name__)
1363 if module in (None, 'builtins', '__builtin__'):
1364 return name
1365 return module + '.' + name
1366
1367
1368def _AddMergeFromMethod(cls):
1369 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1370
1371 def MergeFrom(self, msg):
1372 if not isinstance(msg, cls):
1373 raise TypeError(
1374 'Parameter to MergeFrom() must be instance of same class: '
1375 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1376 _FullyQualifiedClassName(msg.__class__)))
1377
1378 assert msg is not self
1379 self._Modified()
1380
1381 fields = self._fields
1382
1383 for field, value in msg._fields.items():
1384 if field.is_repeated:
1385 field_value = fields.get(field)
1386 if field_value is None:
1387 # Construct a new object to represent this field.
1388 field_value = field._default_constructor(self)
1389 fields[field] = field_value
1390 field_value.MergeFrom(value)
1391 elif field.cpp_type == CPPTYPE_MESSAGE:
1392 if value._is_present_in_parent:
1393 field_value = fields.get(field)
1394 if field_value is None:
1395 # Construct a new object to represent this field.
1396 field_value = field._default_constructor(self)
1397 fields[field] = field_value
1398 field_value.MergeFrom(value)
1399 else:
1400 self._fields[field] = value
1401 if field.containing_oneof:
1402 self._UpdateOneofState(field)
1403
1404 if msg._unknown_fields:
1405 if not self._unknown_fields:
1406 self._unknown_fields = []
1407 self._unknown_fields.extend(msg._unknown_fields)
1408
1409 cls.MergeFrom = MergeFrom
1410
1411
1412def _AddWhichOneofMethod(message_descriptor, cls):
1413 def WhichOneof(self, oneof_name):
1414 """Returns the name of the currently set field inside a oneof, or None."""
1415 try:
1416 field = message_descriptor.oneofs_by_name[oneof_name]
1417 except KeyError:
1418 raise ValueError(
1419 'Protocol message has no oneof "%s" field.' % oneof_name)
1420
1421 nested_field = self._oneofs.get(field, None)
1422 if nested_field is not None and self.HasField(nested_field.name):
1423 return nested_field.name
1424 else:
1425 return None
1426
1427 cls.WhichOneof = WhichOneof
1428
1429
1430def _Clear(self):
1431 # Clear fields.
1432 self._fields = {}
1433 self._unknown_fields = ()
1434
1435 self._oneofs = {}
1436 self._Modified()
1437
1438
1439def _UnknownFields(self):
1440 raise NotImplementedError('Please use the add-on feaure '
1441 'unknown_fields.UnknownFieldSet(message) in '
1442 'unknown_fields.py instead.')
1443
1444
1445def _DiscardUnknownFields(self):
1446 self._unknown_fields = []
1447 for field, value in self.ListFields():
1448 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1449 if _IsMapField(field):
1450 if _IsMessageMapField(field):
1451 for key in value:
1452 value[key].DiscardUnknownFields()
1453 elif field.is_repeated:
1454 for sub_message in value:
1455 sub_message.DiscardUnknownFields()
1456 else:
1457 value.DiscardUnknownFields()
1458
1459
1460def _SetListener(self, listener):
1461 if listener is None:
1462 self._listener = message_listener_mod.NullMessageListener()
1463 else:
1464 self._listener = listener
1465
1466
1467def _AddMessageMethods(message_descriptor, cls):
1468 """Adds implementations of all Message methods to cls."""
1469 _AddListFieldsMethod(message_descriptor, cls)
1470 _AddHasFieldMethod(message_descriptor, cls)
1471 _AddClearFieldMethod(message_descriptor, cls)
1472 if message_descriptor.is_extendable:
1473 _AddClearExtensionMethod(cls)
1474 _AddHasExtensionMethod(cls)
1475 _AddEqualsMethod(message_descriptor, cls)
1476 _AddStrMethod(message_descriptor, cls)
1477 _AddReprMethod(message_descriptor, cls)
1478 _AddUnicodeMethod(message_descriptor, cls)
1479 _AddContainsMethod(message_descriptor, cls)
1480 _AddByteSizeMethod(message_descriptor, cls)
1481 _AddSerializeToStringMethod(message_descriptor, cls)
1482 _AddSerializePartialToStringMethod(message_descriptor, cls)
1483 _AddMergeFromStringMethod(message_descriptor, cls)
1484 _AddIsInitializedMethod(message_descriptor, cls)
1485 _AddMergeFromMethod(cls)
1486 _AddWhichOneofMethod(message_descriptor, cls)
1487 # Adds methods which do not depend on cls.
1488 cls.Clear = _Clear
1489 cls.DiscardUnknownFields = _DiscardUnknownFields
1490 cls._SetListener = _SetListener
1491
1492
1493def _AddPrivateHelperMethods(message_descriptor, cls):
1494 """Adds implementation of private helper methods to cls."""
1495
1496 def Modified(self):
1497 """Sets the _cached_byte_size_dirty bit to true,
1498 and propagates this to our listener iff this was a state change.
1499 """
1500
1501 # Note: Some callers check _cached_byte_size_dirty before calling
1502 # _Modified() as an extra optimization. So, if this method is ever
1503 # changed such that it does stuff even when _cached_byte_size_dirty is
1504 # already true, the callers need to be updated.
1505 if not self._cached_byte_size_dirty:
1506 self._cached_byte_size_dirty = True
1507 self._listener_for_children.dirty = True
1508 self._is_present_in_parent = True
1509 self._listener.Modified()
1510
1511 def _UpdateOneofState(self, field):
1512 """Sets field as the active field in its containing oneof.
1513
1514 Will also delete currently active field in the oneof, if it is different
1515 from the argument. Does not mark the message as modified.
1516 """
1517 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1518 if other_field is not field:
1519 del self._fields[other_field]
1520 self._oneofs[field.containing_oneof] = field
1521
1522 cls._Modified = Modified
1523 cls.SetInParent = Modified
1524 cls._UpdateOneofState = _UpdateOneofState
1525
1526
1527class _Listener(object):
1528
1529 """MessageListener implementation that a parent message registers with its
1530 child message.
1531
1532 In order to support semantics like:
1533
1534 foo.bar.baz.moo = 23
1535 assert foo.HasField('bar')
1536
1537 ...child objects must have back references to their parents.
1538 This helper class is at the heart of this support.
1539 """
1540
1541 def __init__(self, parent_message):
1542 """Args:
1543 parent_message: The message whose _Modified() method we should call when
1544 we receive Modified() messages.
1545 """
1546 # This listener establishes a back reference from a child (contained) object
1547 # to its parent (containing) object. We make this a weak reference to avoid
1548 # creating cyclic garbage when the client finishes with the 'parent' object
1549 # in the tree.
1550 if isinstance(parent_message, weakref.ProxyType):
1551 self._parent_message_weakref = parent_message
1552 else:
1553 self._parent_message_weakref = weakref.proxy(parent_message)
1554
1555 # As an optimization, we also indicate directly on the listener whether
1556 # or not the parent message is dirty. This way we can avoid traversing
1557 # up the tree in the common case.
1558 self.dirty = False
1559
1560 def Modified(self):
1561 if self.dirty:
1562 return
1563 try:
1564 # Propagate the signal to our parents iff this is the first field set.
1565 self._parent_message_weakref._Modified()
1566 except ReferenceError:
1567 # We can get here if a client has kept a reference to a child object,
1568 # and is now setting a field on it, but the child's parent has been
1569 # garbage-collected. This is not an error.
1570 pass
1571
1572
1573class _OneofListener(_Listener):
1574 """Special listener implementation for setting composite oneof fields."""
1575
1576 def __init__(self, parent_message, field):
1577 """Args:
1578 parent_message: The message whose _Modified() method we should call when
1579 we receive Modified() messages.
1580 field: The descriptor of the field being set in the parent message.
1581 """
1582 super(_OneofListener, self).__init__(parent_message)
1583 self._field = field
1584
1585 def Modified(self):
1586 """Also updates the state of the containing oneof in the parent message."""
1587 try:
1588 self._parent_message_weakref._UpdateOneofState(self._field)
1589 super(_OneofListener, self).Modified()
1590 except ReferenceError:
1591 pass