1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import math
33import struct
34import sys
35import warnings
36import weakref
37
38from google.protobuf import descriptor as descriptor_mod
39from google.protobuf import message as message_mod
40from google.protobuf import text_format
41# We use "as" to avoid name collisions with variables.
42from google.protobuf.internal import api_implementation
43from google.protobuf.internal import containers
44from google.protobuf.internal import decoder
45from google.protobuf.internal import encoder
46from google.protobuf.internal import enum_type_wrapper
47from google.protobuf.internal import extension_dict
48from google.protobuf.internal import message_listener as message_listener_mod
49from google.protobuf.internal import type_checkers
50from google.protobuf.internal import well_known_types
51from google.protobuf.internal import wire_format
52
53_FieldDescriptor = descriptor_mod.FieldDescriptor
54_AnyFullTypeName = 'google.protobuf.Any'
55_StructFullTypeName = 'google.protobuf.Struct'
56_ListValueFullTypeName = 'google.protobuf.ListValue'
57_ExtensionDict = extension_dict._ExtensionDict
58
59class GeneratedProtocolMessageType(type):
60
61 """Metaclass for protocol message classes created at runtime from Descriptors.
62
63 We add implementations for all methods described in the Message class. We
64 also create properties to allow getting/setting all fields in the protocol
65 message. Finally, we create slots to prevent users from accidentally
66 "setting" nonexistent fields in the protocol message, which then wouldn't get
67 serialized / deserialized properly.
68
69 The protocol compiler currently uses this metaclass to create protocol
70 message classes at runtime. Clients can also manually create their own
71 classes at runtime, as in this example:
72
73 mydescriptor = Descriptor(.....)
74 factory = symbol_database.Default()
75 factory.pool.AddDescriptor(mydescriptor)
76 MyProtoClass = message_factory.GetMessageClass(mydescriptor)
77 myproto_instance = MyProtoClass()
78 myproto.foo_field = 23
79 ...
80 """
81
82 # Must be consistent with the protocol-compiler code in
83 # proto2/compiler/internal/generator.*.
84 _DESCRIPTOR_KEY = 'DESCRIPTOR'
85
86 def __new__(cls, name, bases, dictionary):
87 """Custom allocation for runtime-generated class types.
88
89 We override __new__ because this is apparently the only place
90 where we can meaningfully set __slots__ on the class we're creating(?).
91 (The interplay between metaclasses and slots is not very well-documented).
92
93 Args:
94 name: Name of the class (ignored, but required by the
95 metaclass protocol).
96 bases: Base classes of the class we're constructing.
97 (Should be message.Message). We ignore this field, but
98 it's required by the metaclass protocol
99 dictionary: The class dictionary of the class we're
100 constructing. dictionary[_DESCRIPTOR_KEY] must contain
101 a Descriptor object describing this protocol message
102 type.
103
104 Returns:
105 Newly-allocated class.
106
107 Raises:
108 RuntimeError: Generated code only work with python cpp extension.
109 """
110 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
111
112 if isinstance(descriptor, str):
113 raise RuntimeError('The generated code only work with python cpp '
114 'extension, but it is using pure python runtime.')
115
116 # If a concrete class already exists for this descriptor, don't try to
117 # create another. Doing so will break any messages that already exist with
118 # the existing class.
119 #
120 # The C++ implementation appears to have its own internal `PyMessageFactory`
121 # to achieve similar results.
122 #
123 # This most commonly happens in `text_format.py` when using descriptors from
124 # a custom pool; it calls message_factory.GetMessageClass() on a
125 # descriptor which already has an existing concrete class.
126 new_class = getattr(descriptor, '_concrete_class', None)
127 if new_class:
128 return new_class
129
130 if descriptor.full_name in well_known_types.WKTBASES:
131 bases += (well_known_types.WKTBASES[descriptor.full_name],)
132 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
133 _AddSlots(descriptor, dictionary)
134
135 superclass = super(GeneratedProtocolMessageType, cls)
136 new_class = superclass.__new__(cls, name, bases, dictionary)
137 return new_class
138
139 def __init__(cls, name, bases, dictionary):
140 """Here we perform the majority of our work on the class.
141 We add enum getters, an __init__ method, implementations
142 of all Message methods, and properties for all fields
143 in the protocol type.
144
145 Args:
146 name: Name of the class (ignored, but required by the
147 metaclass protocol).
148 bases: Base classes of the class we're constructing.
149 (Should be message.Message). We ignore this field, but
150 it's required by the metaclass protocol
151 dictionary: The class dictionary of the class we're
152 constructing. dictionary[_DESCRIPTOR_KEY] must contain
153 a Descriptor object describing this protocol message
154 type.
155 """
156 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
157
158 # If this is an _existing_ class looked up via `_concrete_class` in the
159 # __new__ method above, then we don't need to re-initialize anything.
160 existing_class = getattr(descriptor, '_concrete_class', None)
161 if existing_class:
162 assert existing_class is cls, (
163 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
164 % (descriptor.full_name))
165 return
166
167 cls._message_set_decoders_by_tag = {}
168 cls._fields_by_tag = {}
169 if (descriptor.has_options and
170 descriptor.GetOptions().message_set_wire_format):
171 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
172 decoder.MessageSetItemDecoder(descriptor),
173 None,
174 )
175
176 # Attach stuff to each FieldDescriptor for quick lookup later on.
177 for field in descriptor.fields:
178 _AttachFieldHelpers(cls, field)
179
180 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
181 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
182 for ext in extensions:
183 _AttachFieldHelpers(cls, ext)
184
185 descriptor._concrete_class = cls # pylint: disable=protected-access
186 _AddEnumValues(descriptor, cls)
187 _AddInitMethod(descriptor, cls)
188 _AddPropertiesForFields(descriptor, cls)
189 _AddPropertiesForExtensions(descriptor, cls)
190 _AddStaticMethods(cls)
191 _AddMessageMethods(descriptor, cls)
192 _AddPrivateHelperMethods(descriptor, cls)
193
194 superclass = super(GeneratedProtocolMessageType, cls)
195 superclass.__init__(name, bases, dictionary)
196
197
198# Stateless helpers for GeneratedProtocolMessageType below.
199# Outside clients should not access these directly.
200#
201# I opted not to make any of these methods on the metaclass, to make it more
202# clear that I'm not really using any state there and to keep clients from
203# thinking that they have direct access to these construction helpers.
204
205
206def _PropertyName(proto_field_name):
207 """Returns the name of the public property attribute which
208 clients can use to get and (in some cases) set the value
209 of a protocol message field.
210
211 Args:
212 proto_field_name: The protocol message field name, exactly
213 as it appears (or would appear) in a .proto file.
214 """
215 # TODO: Escape Python keywords (e.g., yield), and test this support.
216 # nnorwitz makes my day by writing:
217 # """
218 # FYI. See the keyword module in the stdlib. This could be as simple as:
219 #
220 # if keyword.iskeyword(proto_field_name):
221 # return proto_field_name + "_"
222 # return proto_field_name
223 # """
224 # Kenton says: The above is a BAD IDEA. People rely on being able to use
225 # getattr() and setattr() to reflectively manipulate field values. If we
226 # rename the properties, then every such user has to also make sure to apply
227 # the same transformation. Note that currently if you name a field "yield",
228 # you can still access it just fine using getattr/setattr -- it's not even
229 # that cumbersome to do so.
230 # TODO: Remove this method entirely if/when everyone agrees with my
231 # position.
232 return proto_field_name
233
234
235def _AddSlots(message_descriptor, dictionary):
236 """Adds a __slots__ entry to dictionary, containing the names of all valid
237 attributes for this message type.
238
239 Args:
240 message_descriptor: A Descriptor instance describing this message type.
241 dictionary: Class dictionary to which we'll add a '__slots__' entry.
242 """
243 dictionary['__slots__'] = ['_cached_byte_size',
244 '_cached_byte_size_dirty',
245 '_fields',
246 '_unknown_fields',
247 '_is_present_in_parent',
248 '_listener',
249 '_listener_for_children',
250 '__weakref__',
251 '_oneofs']
252
253
254def _IsMessageSetExtension(field):
255 return (field.is_extension and
256 field.containing_type.has_options and
257 field.containing_type.GetOptions().message_set_wire_format and
258 field.type == _FieldDescriptor.TYPE_MESSAGE and
259 field.label == _FieldDescriptor.LABEL_OPTIONAL)
260
261
262def _IsMapField(field):
263 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
264 field.message_type._is_map_entry)
265
266
267def _IsMessageMapField(field):
268 value_type = field.message_type.fields_by_name['value']
269 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
270
271def _AttachFieldHelpers(cls, field_descriptor):
272 field_descriptor._default_constructor = _DefaultValueConstructorForField(
273 field_descriptor
274 )
275
276 def AddFieldByTag(wiretype, is_packed):
277 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
278 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
279
280 AddFieldByTag(
281 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
282 )
283
284 if field_descriptor.is_repeated and wire_format.IsTypePackable(
285 field_descriptor.type
286 ):
287 # To support wire compatibility of adding packed = true, add a decoder for
288 # packed values regardless of the field's options.
289 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
290
291
292def _MaybeAddEncoder(cls, field_descriptor):
293 if hasattr(field_descriptor, '_encoder'):
294 return
295 is_repeated = field_descriptor.is_repeated
296 is_map_entry = _IsMapField(field_descriptor)
297 is_packed = field_descriptor.is_packed
298
299 if is_map_entry:
300 field_encoder = encoder.MapEncoder(field_descriptor)
301 sizer = encoder.MapSizer(field_descriptor,
302 _IsMessageMapField(field_descriptor))
303 elif _IsMessageSetExtension(field_descriptor):
304 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
305 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
306 else:
307 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
308 field_descriptor.number, is_repeated, is_packed)
309 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
310 field_descriptor.number, is_repeated, is_packed)
311
312 field_descriptor._sizer = sizer
313 field_descriptor._encoder = field_encoder
314
315
316def _MaybeAddDecoder(cls, field_descriptor):
317 if hasattr(field_descriptor, '_decoders'):
318 return
319
320 is_repeated = field_descriptor.is_repeated
321 is_map_entry = _IsMapField(field_descriptor)
322 helper_decoders = {}
323
324 def AddDecoder(is_packed):
325 decode_type = field_descriptor.type
326 if (decode_type == _FieldDescriptor.TYPE_ENUM and
327 not field_descriptor.enum_type.is_closed):
328 decode_type = _FieldDescriptor.TYPE_INT32
329
330 oneof_descriptor = None
331 if field_descriptor.containing_oneof is not None:
332 oneof_descriptor = field_descriptor
333
334 if is_map_entry:
335 is_message_map = _IsMessageMapField(field_descriptor)
336
337 field_decoder = decoder.MapDecoder(
338 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
339 is_message_map)
340 elif decode_type == _FieldDescriptor.TYPE_STRING:
341 field_decoder = decoder.StringDecoder(
342 field_descriptor.number, is_repeated, is_packed,
343 field_descriptor, field_descriptor._default_constructor,
344 not field_descriptor.has_presence)
345 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
346 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
347 field_descriptor.number, is_repeated, is_packed,
348 field_descriptor, field_descriptor._default_constructor)
349 else:
350 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
351 field_descriptor.number, is_repeated, is_packed,
352 # pylint: disable=protected-access
353 field_descriptor, field_descriptor._default_constructor,
354 not field_descriptor.has_presence)
355
356 helper_decoders[is_packed] = field_decoder
357
358 AddDecoder(False)
359
360 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
361 # To support wire compatibility of adding packed = true, add a decoder for
362 # packed values regardless of the field's options.
363 AddDecoder(True)
364
365 field_descriptor._decoders = helper_decoders
366
367
368def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
369 extensions = descriptor.extensions_by_name
370 for extension_name, extension_field in extensions.items():
371 assert extension_name not in dictionary
372 dictionary[extension_name] = extension_field
373
374
375def _AddEnumValues(descriptor, cls):
376 """Sets class-level attributes for all enum fields defined in this message.
377
378 Also exporting a class-level object that can name enum values.
379
380 Args:
381 descriptor: Descriptor object for this message type.
382 cls: Class we're constructing for this message type.
383 """
384 for enum_type in descriptor.enum_types:
385 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
386 for enum_value in enum_type.values:
387 setattr(cls, enum_value.name, enum_value.number)
388
389
390def _GetInitializeDefaultForMap(field):
391 if not field.is_repeated:
392 raise ValueError('map_entry set on non-repeated field %s' % (
393 field.name))
394 fields_by_name = field.message_type.fields_by_name
395 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
396
397 value_field = fields_by_name['value']
398 if _IsMessageMapField(field):
399 def MakeMessageMapDefault(message):
400 return containers.MessageMap(
401 message._listener_for_children, value_field.message_type, key_checker,
402 field.message_type)
403 return MakeMessageMapDefault
404 else:
405 value_checker = type_checkers.GetTypeChecker(value_field)
406 def MakePrimitiveMapDefault(message):
407 return containers.ScalarMap(
408 message._listener_for_children, key_checker, value_checker,
409 field.message_type)
410 return MakePrimitiveMapDefault
411
412def _DefaultValueConstructorForField(field):
413 """Returns a function which returns a default value for a field.
414
415 Args:
416 field: FieldDescriptor object for this field.
417
418 The returned function has one argument:
419 message: Message instance containing this field, or a weakref proxy
420 of same.
421
422 That function in turn returns a default value for this field. The default
423 value may refer back to |message| via a weak reference.
424 """
425
426 if _IsMapField(field):
427 return _GetInitializeDefaultForMap(field)
428
429 if field.is_repeated:
430 if field.has_default_value and field.default_value != []:
431 raise ValueError('Repeated field default value not empty list: %s' % (
432 field.default_value))
433 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
434 # We can't look at _concrete_class yet since it might not have
435 # been set. (Depends on order in which we initialize the classes).
436 message_type = field.message_type
437 def MakeRepeatedMessageDefault(message):
438 return containers.RepeatedCompositeFieldContainer(
439 message._listener_for_children, field.message_type)
440 return MakeRepeatedMessageDefault
441 else:
442 type_checker = type_checkers.GetTypeChecker(field)
443 def MakeRepeatedScalarDefault(message):
444 return containers.RepeatedScalarFieldContainer(
445 message._listener_for_children, type_checker)
446 return MakeRepeatedScalarDefault
447
448 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
449 message_type = field.message_type
450 def MakeSubMessageDefault(message):
451 # _concrete_class may not yet be initialized.
452 if not hasattr(message_type, '_concrete_class'):
453 from google.protobuf import message_factory
454 message_factory.GetMessageClass(message_type)
455 result = message_type._concrete_class()
456 result._SetListener(
457 _OneofListener(message, field)
458 if field.containing_oneof is not None
459 else message._listener_for_children)
460 return result
461 return MakeSubMessageDefault
462
463 def MakeScalarDefault(message):
464 # TODO: This may be broken since there may not be
465 # default_value. Combine with has_default_value somehow.
466 return field.default_value
467 return MakeScalarDefault
468
469
470def _ReraiseTypeErrorWithFieldName(message_name, field_name):
471 """Re-raise the currently-handled TypeError with the field name added."""
472 exc = sys.exc_info()[1]
473 if len(exc.args) == 1 and type(exc) is TypeError:
474 # simple TypeError; add field name to exception message
475 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
476
477 # re-raise possibly-amended exception with original traceback:
478 raise exc.with_traceback(sys.exc_info()[2])
479
480
481def _AddInitMethod(message_descriptor, cls):
482 """Adds an __init__ method to cls."""
483
484 def _GetIntegerEnumValue(enum_type, value):
485 """Convert a string or integer enum value to an integer.
486
487 If the value is a string, it is converted to the enum value in
488 enum_type with the same name. If the value is not a string, it's
489 returned as-is. (No conversion or bounds-checking is done.)
490 """
491 if isinstance(value, str):
492 try:
493 return enum_type.values_by_name[value].number
494 except KeyError:
495 raise ValueError('Enum type %s: unknown label "%s"' % (
496 enum_type.full_name, value))
497 return value
498
499 def init(self, **kwargs):
500 self._cached_byte_size = 0
501 self._cached_byte_size_dirty = len(kwargs) > 0
502 self._fields = {}
503 # Contains a mapping from oneof field descriptors to the descriptor
504 # of the currently set field in that oneof field.
505 self._oneofs = {}
506
507 # _unknown_fields is () when empty for efficiency, and will be turned into
508 # a list if fields are added.
509 self._unknown_fields = ()
510 self._is_present_in_parent = False
511 self._listener = message_listener_mod.NullMessageListener()
512 self._listener_for_children = _Listener(self)
513 for field_name, field_value in kwargs.items():
514 field = _GetFieldByName(message_descriptor, field_name)
515 if field is None:
516 raise TypeError('%s() got an unexpected keyword argument "%s"' %
517 (message_descriptor.name, field_name))
518 if field_value is None:
519 # field=None is the same as no field at all.
520 continue
521 if field.is_repeated:
522 field_copy = field._default_constructor(self)
523 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
524 if _IsMapField(field):
525 if _IsMessageMapField(field):
526 for key in field_value:
527 item_value = field_value[key]
528 if isinstance(item_value, dict):
529 field_copy[key].__init__(**item_value)
530 else:
531 field_copy[key].MergeFrom(item_value)
532 else:
533 field_copy.update(field_value)
534 else:
535 for val in field_value:
536 if isinstance(val, dict):
537 field_copy.add(**val)
538 else:
539 field_copy.add().MergeFrom(val)
540 else: # Scalar
541 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
542 field_value = [_GetIntegerEnumValue(field.enum_type, val)
543 for val in field_value]
544 field_copy.extend(field_value)
545 self._fields[field] = field_copy
546 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
547 field_copy = field._default_constructor(self)
548 new_val = None
549 if isinstance(field_value, message_mod.Message):
550 new_val = field_value
551 elif isinstance(field_value, dict):
552 if field.message_type.full_name == _StructFullTypeName:
553 field_copy.Clear()
554 if len(field_value) == 1 and 'fields' in field_value:
555 try:
556 field_copy.update(field_value)
557 except:
558 # Fall back to init normal message field
559 field_copy.Clear()
560 new_val = field.message_type._concrete_class(**field_value)
561 else:
562 field_copy.update(field_value)
563 else:
564 new_val = field.message_type._concrete_class(**field_value)
565 elif hasattr(field_copy, '_internal_assign'):
566 field_copy._internal_assign(field_value)
567 else:
568 raise TypeError(
569 'Message field {0}.{1} must be initialized with a '
570 'dict or instance of same class, got {2}.'.format(
571 message_descriptor.name,
572 field_name,
573 type(field_value).__name__,
574 )
575 )
576
577 if new_val != None:
578 try:
579 field_copy.MergeFrom(new_val)
580 except TypeError:
581 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
582 self._fields[field] = field_copy
583 else:
584 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
585 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
586 try:
587 setattr(self, field_name, field_value)
588 except TypeError:
589 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
590
591 init.__module__ = None
592 init.__doc__ = None
593 cls.__init__ = init
594
595
596def _GetFieldByName(message_descriptor, field_name):
597 """Returns a field descriptor by field name.
598
599 Args:
600 message_descriptor: A Descriptor describing all fields in message.
601 field_name: The name of the field to retrieve.
602 Returns:
603 The field descriptor associated with the field name.
604 """
605 try:
606 return message_descriptor.fields_by_name[field_name]
607 except KeyError:
608 raise ValueError('Protocol message %s has no "%s" field.' %
609 (message_descriptor.name, field_name))
610
611
612def _AddPropertiesForFields(descriptor, cls):
613 """Adds properties for all fields in this protocol message type."""
614 for field in descriptor.fields:
615 _AddPropertiesForField(field, cls)
616
617 if descriptor.is_extendable:
618 # _ExtensionDict is just an adaptor with no state so we allocate a new one
619 # every time it is accessed.
620 cls.Extensions = property(lambda self: _ExtensionDict(self))
621
622
623def _AddPropertiesForField(field, cls):
624 """Adds a public property for a protocol message field.
625 Clients can use this property to get and (in the case
626 of non-repeated scalar fields) directly set the value
627 of a protocol message field.
628
629 Args:
630 field: A FieldDescriptor for this field.
631 cls: The class we're constructing.
632 """
633 # Catch it if we add other types that we should
634 # handle specially here.
635 assert _FieldDescriptor.MAX_CPPTYPE == 10
636
637 constant_name = field.name.upper() + '_FIELD_NUMBER'
638 setattr(cls, constant_name, field.number)
639
640 if field.is_repeated:
641 _AddPropertiesForRepeatedField(field, cls)
642 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
643 _AddPropertiesForNonRepeatedCompositeField(field, cls)
644 else:
645 _AddPropertiesForNonRepeatedScalarField(field, cls)
646
647
648class _FieldProperty(property):
649 __slots__ = ('DESCRIPTOR',)
650
651 def __init__(self, descriptor, getter, setter, doc):
652 property.__init__(self, getter, setter, doc=doc)
653 self.DESCRIPTOR = descriptor
654
655
656def _AddPropertiesForRepeatedField(field, cls):
657 """Adds a public property for a "repeated" protocol message field. Clients
658 can use this property to get the value of the field, which will be either a
659 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
660 below).
661
662 Note that when clients add values to these containers, we perform
663 type-checking in the case of repeated scalar fields, and we also set any
664 necessary "has" bits as a side-effect.
665
666 Args:
667 field: A FieldDescriptor for this field.
668 cls: The class we're constructing.
669 """
670 proto_field_name = field.name
671 property_name = _PropertyName(proto_field_name)
672
673 def getter(self):
674 field_value = self._fields.get(field)
675 if field_value is None:
676 # Construct a new object to represent this field.
677 field_value = field._default_constructor(self)
678
679 # Atomically check if another thread has preempted us and, if not, swap
680 # in the new object we just created. If someone has preempted us, we
681 # take that object and discard ours.
682 # WARNING: We are relying on setdefault() being atomic. This is true
683 # in CPython but we haven't investigated others. This warning appears
684 # in several other locations in this file.
685 field_value = self._fields.setdefault(field, field_value)
686 return field_value
687 getter.__module__ = None
688 getter.__doc__ = 'Getter for %s.' % proto_field_name
689
690 # We define a setter just so we can throw an exception with a more
691 # helpful error message.
692 def setter(self, new_value):
693 raise AttributeError('Assignment not allowed to repeated field '
694 '"%s" in protocol message object.' % proto_field_name)
695
696 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
697 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
698
699
700def _AddPropertiesForNonRepeatedScalarField(field, cls):
701 """Adds a public property for a nonrepeated, scalar protocol message field.
702 Clients can use this property to get and directly set the value of the field.
703 Note that when the client sets the value of a field by using this property,
704 all necessary "has" bits are set as a side-effect, and we also perform
705 type-checking.
706
707 Args:
708 field: A FieldDescriptor for this field.
709 cls: The class we're constructing.
710 """
711 proto_field_name = field.name
712 property_name = _PropertyName(proto_field_name)
713 type_checker = type_checkers.GetTypeChecker(field)
714 default_value = field.default_value
715
716 def getter(self):
717 # TODO: This may be broken since there may not be
718 # default_value. Combine with has_default_value somehow.
719 return self._fields.get(field, default_value)
720 getter.__module__ = None
721 getter.__doc__ = 'Getter for %s.' % proto_field_name
722
723 def field_setter(self, new_value):
724 # pylint: disable=protected-access
725 # Testing the value for truthiness captures all of the implicit presence
726 # defaults (0, 0.0, enum 0, and False), except for -0.0.
727 try:
728 new_value = type_checker.CheckValue(new_value)
729 except TypeError as e:
730 raise TypeError(
731 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
732 if not field.has_presence and decoder.IsDefaultScalarValue(new_value):
733 self._fields.pop(field, None)
734 else:
735 self._fields[field] = new_value
736 # Check _cached_byte_size_dirty inline to improve performance, since scalar
737 # setters are called frequently.
738 if not self._cached_byte_size_dirty:
739 self._Modified()
740
741 if field.containing_oneof:
742 def setter(self, new_value):
743 field_setter(self, new_value)
744 self._UpdateOneofState(field)
745 else:
746 setter = field_setter
747
748 setter.__module__ = None
749 setter.__doc__ = 'Setter for %s.' % proto_field_name
750
751 # Add a property to encapsulate the getter/setter.
752 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
753 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
754
755
756def _AddPropertiesForNonRepeatedCompositeField(field, cls):
757 """Adds a public property for a nonrepeated, composite protocol message field.
758 A composite field is a "group" or "message" field.
759
760 Clients can use this property to get the value of the field, but cannot
761 assign to the property directly.
762
763 Args:
764 field: A FieldDescriptor for this field.
765 cls: The class we're constructing.
766 """
767 # TODO: Remove duplication with similar method
768 # for non-repeated scalars.
769 proto_field_name = field.name
770 property_name = _PropertyName(proto_field_name)
771
772 def getter(self):
773 field_value = self._fields.get(field)
774 if field_value is None:
775 # Construct a new object to represent this field.
776 field_value = field._default_constructor(self)
777
778 # Atomically check if another thread has preempted us and, if not, swap
779 # in the new object we just created. If someone has preempted us, we
780 # take that object and discard ours.
781 # WARNING: We are relying on setdefault() being atomic. This is true
782 # in CPython but we haven't investigated others. This warning appears
783 # in several other locations in this file.
784 field_value = self._fields.setdefault(field, field_value)
785 return field_value
786 getter.__module__ = None
787 getter.__doc__ = 'Getter for %s.' % proto_field_name
788
789 # We define a setter just so we can throw an exception with a more
790 # helpful error message.
791 def setter(self, new_value):
792 if field.message_type.full_name == 'google.protobuf.Timestamp':
793 getter(self)
794 self._fields[field].FromDatetime(new_value)
795 elif field.message_type.full_name == 'google.protobuf.Duration':
796 getter(self)
797 self._fields[field].FromTimedelta(new_value)
798 elif field.message_type.full_name == _StructFullTypeName:
799 getter(self)
800 self._fields[field].Clear()
801 self._fields[field].update(new_value)
802 elif field.message_type.full_name == _ListValueFullTypeName:
803 getter(self)
804 self._fields[field].Clear()
805 self._fields[field].extend(new_value)
806 else:
807 raise AttributeError(
808 'Assignment not allowed to composite field '
809 '"%s" in protocol message object.' % proto_field_name
810 )
811
812 # Add a property to encapsulate the getter.
813 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
814 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
815
816
817def _AddPropertiesForExtensions(descriptor, cls):
818 """Adds properties for all fields in this protocol message type."""
819 extensions = descriptor.extensions_by_name
820 for extension_name, extension_field in extensions.items():
821 constant_name = extension_name.upper() + '_FIELD_NUMBER'
822 setattr(cls, constant_name, extension_field.number)
823
824 # TODO: Migrate all users of these attributes to functions like
825 # pool.FindExtensionByNumber(descriptor).
826 if descriptor.file is not None:
827 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
828 pool = descriptor.file.pool
829
830def _AddStaticMethods(cls):
831 def FromString(s):
832 message = cls()
833 message.MergeFromString(s)
834 return message
835 cls.FromString = staticmethod(FromString)
836
837
838def _IsPresent(item):
839 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
840 value should be included in the list returned by ListFields()."""
841
842 if item[0].is_repeated:
843 return bool(item[1])
844 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
845 return item[1]._is_present_in_parent
846 else:
847 return True
848
849
850def _AddListFieldsMethod(message_descriptor, cls):
851 """Helper for _AddMessageMethods()."""
852
853 def ListFields(self):
854 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
855 all_fields.sort(key = lambda item: item[0].number)
856 return all_fields
857
858 cls.ListFields = ListFields
859
860
861def _AddHasFieldMethod(message_descriptor, cls):
862 """Helper for _AddMessageMethods()."""
863
864 hassable_fields = {}
865 for field in message_descriptor.fields:
866 if field.is_repeated:
867 continue
868 # For proto3, only submessages and fields inside a oneof have presence.
869 if not field.has_presence:
870 continue
871 hassable_fields[field.name] = field
872
873 # Has methods are supported for oneof descriptors.
874 for oneof in message_descriptor.oneofs:
875 hassable_fields[oneof.name] = oneof
876
877 def HasField(self, field_name):
878 try:
879 field = hassable_fields[field_name]
880 except KeyError as exc:
881 raise ValueError('Protocol message %s has no non-repeated field "%s" '
882 'nor has presence is not available for this field.' % (
883 message_descriptor.full_name, field_name)) from exc
884
885 if isinstance(field, descriptor_mod.OneofDescriptor):
886 try:
887 return HasField(self, self._oneofs[field].name)
888 except KeyError:
889 return False
890 else:
891 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
892 value = self._fields.get(field)
893 return value is not None and value._is_present_in_parent
894 else:
895 return field in self._fields
896
897 cls.HasField = HasField
898
899
900def _AddClearFieldMethod(message_descriptor, cls):
901 """Helper for _AddMessageMethods()."""
902 def ClearField(self, field_name):
903 try:
904 field = message_descriptor.fields_by_name[field_name]
905 except KeyError:
906 try:
907 field = message_descriptor.oneofs_by_name[field_name]
908 if field in self._oneofs:
909 field = self._oneofs[field]
910 else:
911 return
912 except KeyError:
913 raise ValueError('Protocol message %s has no "%s" field.' %
914 (message_descriptor.name, field_name))
915
916 if field in self._fields:
917 # To match the C++ implementation, we need to invalidate iterators
918 # for map fields when ClearField() happens.
919 if hasattr(self._fields[field], 'InvalidateIterators'):
920 self._fields[field].InvalidateIterators()
921
922 # Note: If the field is a sub-message, its listener will still point
923 # at us. That's fine, because the worst than can happen is that it
924 # will call _Modified() and invalidate our byte size. Big deal.
925 del self._fields[field]
926
927 if self._oneofs.get(field.containing_oneof, None) is field:
928 del self._oneofs[field.containing_oneof]
929
930 # Always call _Modified() -- even if nothing was changed, this is
931 # a mutating method, and thus calling it should cause the field to become
932 # present in the parent message.
933 self._Modified()
934
935 cls.ClearField = ClearField
936
937
938def _AddClearExtensionMethod(cls):
939 """Helper for _AddMessageMethods()."""
940 def ClearExtension(self, field_descriptor):
941 extension_dict._VerifyExtensionHandle(self, field_descriptor)
942
943 # Similar to ClearField(), above.
944 if field_descriptor in self._fields:
945 del self._fields[field_descriptor]
946 self._Modified()
947 cls.ClearExtension = ClearExtension
948
949
950def _AddHasExtensionMethod(cls):
951 """Helper for _AddMessageMethods()."""
952 def HasExtension(self, field_descriptor):
953 extension_dict._VerifyExtensionHandle(self, field_descriptor)
954 if field_descriptor.is_repeated:
955 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
956
957 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
958 value = self._fields.get(field_descriptor)
959 return value is not None and value._is_present_in_parent
960 else:
961 return field_descriptor in self._fields
962 cls.HasExtension = HasExtension
963
964def _InternalUnpackAny(msg):
965 """Unpacks Any message and returns the unpacked message.
966
967 This internal method is different from public Any Unpack method which takes
968 the target message as argument. _InternalUnpackAny method does not have
969 target message type and need to find the message type in descriptor pool.
970
971 Args:
972 msg: An Any message to be unpacked.
973
974 Returns:
975 The unpacked message.
976 """
977 # TODO: Don't use the factory of generated messages.
978 # To make Any work with custom factories, use the message factory of the
979 # parent message.
980 # pylint: disable=g-import-not-at-top
981 from google.protobuf import symbol_database
982 factory = symbol_database.Default()
983
984 type_url = msg.type_url
985
986 if not type_url:
987 return None
988
989 # TODO: For now we just strip the hostname. Better logic will be
990 # required.
991 type_name = type_url.split('/')[-1]
992 descriptor = factory.pool.FindMessageTypeByName(type_name)
993
994 if descriptor is None:
995 return None
996
997 # Unable to import message_factory at top because of circular import.
998 # pylint: disable=g-import-not-at-top
999 from google.protobuf import message_factory
1000 message_class = message_factory.GetMessageClass(descriptor)
1001 message = message_class()
1002
1003 message.ParseFromString(msg.value)
1004 return message
1005
1006
1007def _AddEqualsMethod(message_descriptor, cls):
1008 """Helper for _AddMessageMethods()."""
1009 def __eq__(self, other):
1010 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1011 other, list
1012 ):
1013 return self._internal_compare(other)
1014 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1015 other, dict
1016 ):
1017 return self._internal_compare(other)
1018
1019 if (not isinstance(other, message_mod.Message) or
1020 other.DESCRIPTOR != self.DESCRIPTOR):
1021 return NotImplemented
1022
1023 if self is other:
1024 return True
1025
1026 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1027 any_a = _InternalUnpackAny(self)
1028 any_b = _InternalUnpackAny(other)
1029 if any_a and any_b:
1030 return any_a == any_b
1031
1032 if not self.ListFields() == other.ListFields():
1033 return False
1034
1035 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1036 # then use it for the comparison.
1037 unknown_fields = list(self._unknown_fields)
1038 unknown_fields.sort()
1039 other_unknown_fields = list(other._unknown_fields)
1040 other_unknown_fields.sort()
1041 return unknown_fields == other_unknown_fields
1042
1043 cls.__eq__ = __eq__
1044
1045
1046def _AddStrMethod(message_descriptor, cls):
1047 """Helper for _AddMessageMethods()."""
1048 def __str__(self):
1049 return text_format.MessageToString(self)
1050 cls.__str__ = __str__
1051
1052
1053def _AddReprMethod(message_descriptor, cls):
1054 """Helper for _AddMessageMethods()."""
1055 def __repr__(self):
1056 return text_format.MessageToString(self)
1057 cls.__repr__ = __repr__
1058
1059
1060def _AddUnicodeMethod(unused_message_descriptor, cls):
1061 """Helper for _AddMessageMethods()."""
1062
1063 def __unicode__(self):
1064 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1065 cls.__unicode__ = __unicode__
1066
1067
1068def _AddContainsMethod(message_descriptor, cls):
1069
1070 if message_descriptor.full_name == 'google.protobuf.Struct':
1071 def __contains__(self, key):
1072 return key in self.fields
1073 elif message_descriptor.full_name == 'google.protobuf.ListValue':
1074 def __contains__(self, value):
1075 return value in self.items()
1076 else:
1077 def __contains__(self, field):
1078 return self.HasField(field)
1079
1080 cls.__contains__ = __contains__
1081
1082
1083def _BytesForNonRepeatedElement(value, field_number, field_type):
1084 """Returns the number of bytes needed to serialize a non-repeated element.
1085 The returned byte count includes space for tag information and any
1086 other additional space associated with serializing value.
1087
1088 Args:
1089 value: Value we're serializing.
1090 field_number: Field number of this value. (Since the field number
1091 is stored as part of a varint-encoded tag, this has an impact
1092 on the total bytes required to serialize the value).
1093 field_type: The type of the field. One of the TYPE_* constants
1094 within FieldDescriptor.
1095 """
1096 try:
1097 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1098 return fn(field_number, value)
1099 except KeyError:
1100 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1101
1102
1103def _AddByteSizeMethod(message_descriptor, cls):
1104 """Helper for _AddMessageMethods()."""
1105
1106 def ByteSize(self):
1107 if not self._cached_byte_size_dirty:
1108 return self._cached_byte_size
1109
1110 size = 0
1111 descriptor = self.DESCRIPTOR
1112 if descriptor._is_map_entry:
1113 # Fields of map entry should always be serialized.
1114 key_field = descriptor.fields_by_name['key']
1115 _MaybeAddEncoder(cls, key_field)
1116 size = key_field._sizer(self.key)
1117 value_field = descriptor.fields_by_name['value']
1118 _MaybeAddEncoder(cls, value_field)
1119 size += value_field._sizer(self.value)
1120 else:
1121 for field_descriptor, field_value in self.ListFields():
1122 _MaybeAddEncoder(cls, field_descriptor)
1123 size += field_descriptor._sizer(field_value)
1124 for tag_bytes, value_bytes in self._unknown_fields:
1125 size += len(tag_bytes) + len(value_bytes)
1126
1127 self._cached_byte_size = size
1128 self._cached_byte_size_dirty = False
1129 self._listener_for_children.dirty = False
1130 return size
1131
1132 cls.ByteSize = ByteSize
1133
1134
1135def _AddSerializeToStringMethod(message_descriptor, cls):
1136 """Helper for _AddMessageMethods()."""
1137
1138 def SerializeToString(self, **kwargs):
1139 # Check if the message has all of its required fields set.
1140 if not self.IsInitialized():
1141 raise message_mod.EncodeError(
1142 'Message %s is missing required fields: %s' % (
1143 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1144 return self.SerializePartialToString(**kwargs)
1145 cls.SerializeToString = SerializeToString
1146
1147
1148def _AddSerializePartialToStringMethod(message_descriptor, cls):
1149 """Helper for _AddMessageMethods()."""
1150
1151 def SerializePartialToString(self, **kwargs):
1152 out = BytesIO()
1153 self._InternalSerialize(out.write, **kwargs)
1154 return out.getvalue()
1155 cls.SerializePartialToString = SerializePartialToString
1156
1157 def InternalSerialize(self, write_bytes, deterministic=None):
1158 if deterministic is None:
1159 deterministic = (
1160 api_implementation.IsPythonDefaultSerializationDeterministic())
1161 else:
1162 deterministic = bool(deterministic)
1163
1164 descriptor = self.DESCRIPTOR
1165 if descriptor._is_map_entry:
1166 # Fields of map entry should always be serialized.
1167 key_field = descriptor.fields_by_name['key']
1168 _MaybeAddEncoder(cls, key_field)
1169 key_field._encoder(write_bytes, self.key, deterministic)
1170 value_field = descriptor.fields_by_name['value']
1171 _MaybeAddEncoder(cls, value_field)
1172 value_field._encoder(write_bytes, self.value, deterministic)
1173 else:
1174 for field_descriptor, field_value in self.ListFields():
1175 _MaybeAddEncoder(cls, field_descriptor)
1176 field_descriptor._encoder(write_bytes, field_value, deterministic)
1177 for tag_bytes, value_bytes in self._unknown_fields:
1178 write_bytes(tag_bytes)
1179 write_bytes(value_bytes)
1180 cls._InternalSerialize = InternalSerialize
1181
1182
1183def _AddMergeFromStringMethod(message_descriptor, cls):
1184 """Helper for _AddMessageMethods()."""
1185 def MergeFromString(self, serialized):
1186 serialized = memoryview(serialized)
1187 length = len(serialized)
1188 try:
1189 if self._InternalParse(serialized, 0, length) != length:
1190 # The only reason _InternalParse would return early is if it
1191 # encountered an end-group tag.
1192 raise message_mod.DecodeError('Unexpected end-group tag.')
1193 except (IndexError, TypeError):
1194 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1195 raise message_mod.DecodeError('Truncated message.')
1196 except struct.error as e:
1197 raise message_mod.DecodeError(e)
1198 return length # Return this for legacy reasons.
1199 cls.MergeFromString = MergeFromString
1200
1201 fields_by_tag = cls._fields_by_tag
1202 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1203
1204 def InternalParse(self, buffer, pos, end, current_depth=0):
1205 """Create a message from serialized bytes.
1206
1207 Args:
1208 self: Message, instance of the proto message object.
1209 buffer: memoryview of the serialized data.
1210 pos: int, position to start in the serialized data.
1211 end: int, end position of the serialized data.
1212
1213 Returns:
1214 Message object.
1215 """
1216 # Guard against internal misuse, since this function is called internally
1217 # quite extensively, and its easy to accidentally pass bytes.
1218 assert isinstance(buffer, memoryview)
1219 self._Modified()
1220 field_dict = self._fields
1221 while pos != end:
1222 (tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
1223 field_decoder, field_des = message_set_decoders_by_tag.get(
1224 tag_bytes, (None, None)
1225 )
1226 if field_decoder:
1227 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1228 continue
1229 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1230 if field_des is None:
1231 if not self._unknown_fields: # pylint: disable=protected-access
1232 self._unknown_fields = [] # pylint: disable=protected-access
1233 field_number, wire_type = decoder.DecodeTag(tag_bytes)
1234 if field_number == 0:
1235 raise message_mod.DecodeError('Field number 0 is illegal.')
1236 (data, new_pos) = decoder._DecodeUnknownField(
1237 buffer, new_pos, end, field_number, wire_type
1238 ) # pylint: disable=protected-access
1239 if new_pos == -1:
1240 return pos
1241 self._unknown_fields.append(
1242 (tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
1243 )
1244 pos = new_pos
1245 else:
1246 _MaybeAddDecoder(cls, field_des)
1247 field_decoder = field_des._decoders[is_packed]
1248 pos = field_decoder(
1249 buffer, new_pos, end, self, field_dict, current_depth
1250 )
1251 if field_des.containing_oneof:
1252 self._UpdateOneofState(field_des)
1253 return pos
1254
1255 cls._InternalParse = InternalParse
1256
1257
1258def _AddIsInitializedMethod(message_descriptor, cls):
1259 """Adds the IsInitialized and FindInitializationError methods to the
1260 protocol message class."""
1261
1262 required_fields = [field for field in message_descriptor.fields
1263 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1264
1265 def IsInitialized(self, errors=None):
1266 """Checks if all required fields of a message are set.
1267
1268 Args:
1269 errors: A list which, if provided, will be populated with the field
1270 paths of all missing required fields.
1271
1272 Returns:
1273 True iff the specified message has all required fields set.
1274 """
1275
1276 # Performance is critical so we avoid HasField() and ListFields().
1277
1278 for field in required_fields:
1279 if (field not in self._fields or
1280 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1281 not self._fields[field]._is_present_in_parent)):
1282 if errors is not None:
1283 errors.extend(self.FindInitializationErrors())
1284 return False
1285
1286 for field, value in list(self._fields.items()): # dict can change size!
1287 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1288 if field.is_repeated:
1289 if (field.message_type._is_map_entry):
1290 continue
1291 for element in value:
1292 if not element.IsInitialized():
1293 if errors is not None:
1294 errors.extend(self.FindInitializationErrors())
1295 return False
1296 elif value._is_present_in_parent and not value.IsInitialized():
1297 if errors is not None:
1298 errors.extend(self.FindInitializationErrors())
1299 return False
1300
1301 return True
1302
1303 cls.IsInitialized = IsInitialized
1304
1305 def FindInitializationErrors(self):
1306 """Finds required fields which are not initialized.
1307
1308 Returns:
1309 A list of strings. Each string is a path to an uninitialized field from
1310 the top-level message, e.g. "foo.bar[5].baz".
1311 """
1312
1313 errors = [] # simplify things
1314
1315 for field in required_fields:
1316 if not self.HasField(field.name):
1317 errors.append(field.name)
1318
1319 for field, value in self.ListFields():
1320 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1321 if field.is_extension:
1322 name = '(%s)' % field.full_name
1323 else:
1324 name = field.name
1325
1326 if _IsMapField(field):
1327 if _IsMessageMapField(field):
1328 for key in value:
1329 element = value[key]
1330 prefix = '%s[%s].' % (name, key)
1331 sub_errors = element.FindInitializationErrors()
1332 errors += [prefix + error for error in sub_errors]
1333 else:
1334 # ScalarMaps can't have any initialization errors.
1335 pass
1336 elif field.is_repeated:
1337 for i in range(len(value)):
1338 element = value[i]
1339 prefix = '%s[%d].' % (name, i)
1340 sub_errors = element.FindInitializationErrors()
1341 errors += [prefix + error for error in sub_errors]
1342 else:
1343 prefix = name + '.'
1344 sub_errors = value.FindInitializationErrors()
1345 errors += [prefix + error for error in sub_errors]
1346
1347 return errors
1348
1349 cls.FindInitializationErrors = FindInitializationErrors
1350
1351
1352def _FullyQualifiedClassName(klass):
1353 module = klass.__module__
1354 name = getattr(klass, '__qualname__', klass.__name__)
1355 if module in (None, 'builtins', '__builtin__'):
1356 return name
1357 return module + '.' + name
1358
1359
1360def _AddMergeFromMethod(cls):
1361 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1362
1363 def MergeFrom(self, msg):
1364 if not isinstance(msg, cls):
1365 raise TypeError(
1366 'Parameter to MergeFrom() must be instance of same class: '
1367 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1368 _FullyQualifiedClassName(msg.__class__)))
1369
1370 assert msg is not self
1371 self._Modified()
1372
1373 fields = self._fields
1374
1375 for field, value in msg._fields.items():
1376 if field.is_repeated:
1377 field_value = fields.get(field)
1378 if field_value is None:
1379 # Construct a new object to represent this field.
1380 field_value = field._default_constructor(self)
1381 fields[field] = field_value
1382 field_value.MergeFrom(value)
1383 elif field.cpp_type == CPPTYPE_MESSAGE:
1384 if value._is_present_in_parent:
1385 field_value = fields.get(field)
1386 if field_value is None:
1387 # Construct a new object to represent this field.
1388 field_value = field._default_constructor(self)
1389 fields[field] = field_value
1390 field_value.MergeFrom(value)
1391 else:
1392 self._fields[field] = value
1393 if field.containing_oneof:
1394 self._UpdateOneofState(field)
1395
1396 if msg._unknown_fields:
1397 if not self._unknown_fields:
1398 self._unknown_fields = []
1399 self._unknown_fields.extend(msg._unknown_fields)
1400
1401 cls.MergeFrom = MergeFrom
1402
1403
1404def _AddWhichOneofMethod(message_descriptor, cls):
1405 def WhichOneof(self, oneof_name):
1406 """Returns the name of the currently set field inside a oneof, or None."""
1407 try:
1408 field = message_descriptor.oneofs_by_name[oneof_name]
1409 except KeyError:
1410 raise ValueError(
1411 'Protocol message has no oneof "%s" field.' % oneof_name)
1412
1413 nested_field = self._oneofs.get(field, None)
1414 if nested_field is not None and self.HasField(nested_field.name):
1415 return nested_field.name
1416 else:
1417 return None
1418
1419 cls.WhichOneof = WhichOneof
1420
1421
1422def _Clear(self):
1423 # Clear fields.
1424 self._fields = {}
1425 self._unknown_fields = ()
1426
1427 self._oneofs = {}
1428 self._Modified()
1429
1430
1431def _UnknownFields(self):
1432 raise NotImplementedError('Please use the add-on feaure '
1433 'unknown_fields.UnknownFieldSet(message) in '
1434 'unknown_fields.py instead.')
1435
1436
1437def _DiscardUnknownFields(self):
1438 self._unknown_fields = []
1439 for field, value in self.ListFields():
1440 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1441 if _IsMapField(field):
1442 if _IsMessageMapField(field):
1443 for key in value:
1444 value[key].DiscardUnknownFields()
1445 elif field.is_repeated:
1446 for sub_message in value:
1447 sub_message.DiscardUnknownFields()
1448 else:
1449 value.DiscardUnknownFields()
1450
1451
1452def _SetListener(self, listener):
1453 if listener is None:
1454 self._listener = message_listener_mod.NullMessageListener()
1455 else:
1456 self._listener = listener
1457
1458
1459def _AddMessageMethods(message_descriptor, cls):
1460 """Adds implementations of all Message methods to cls."""
1461 _AddListFieldsMethod(message_descriptor, cls)
1462 _AddHasFieldMethod(message_descriptor, cls)
1463 _AddClearFieldMethod(message_descriptor, cls)
1464 if message_descriptor.is_extendable:
1465 _AddClearExtensionMethod(cls)
1466 _AddHasExtensionMethod(cls)
1467 _AddEqualsMethod(message_descriptor, cls)
1468 _AddStrMethod(message_descriptor, cls)
1469 _AddReprMethod(message_descriptor, cls)
1470 _AddUnicodeMethod(message_descriptor, cls)
1471 _AddContainsMethod(message_descriptor, cls)
1472 _AddByteSizeMethod(message_descriptor, cls)
1473 _AddSerializeToStringMethod(message_descriptor, cls)
1474 _AddSerializePartialToStringMethod(message_descriptor, cls)
1475 _AddMergeFromStringMethod(message_descriptor, cls)
1476 _AddIsInitializedMethod(message_descriptor, cls)
1477 _AddMergeFromMethod(cls)
1478 _AddWhichOneofMethod(message_descriptor, cls)
1479 # Adds methods which do not depend on cls.
1480 cls.Clear = _Clear
1481 cls.DiscardUnknownFields = _DiscardUnknownFields
1482 cls._SetListener = _SetListener
1483
1484
1485def _AddPrivateHelperMethods(message_descriptor, cls):
1486 """Adds implementation of private helper methods to cls."""
1487
1488 def Modified(self):
1489 """Sets the _cached_byte_size_dirty bit to true,
1490 and propagates this to our listener iff this was a state change.
1491 """
1492
1493 # Note: Some callers check _cached_byte_size_dirty before calling
1494 # _Modified() as an extra optimization. So, if this method is ever
1495 # changed such that it does stuff even when _cached_byte_size_dirty is
1496 # already true, the callers need to be updated.
1497 if not self._cached_byte_size_dirty:
1498 self._cached_byte_size_dirty = True
1499 self._listener_for_children.dirty = True
1500 self._is_present_in_parent = True
1501 self._listener.Modified()
1502
1503 def _UpdateOneofState(self, field):
1504 """Sets field as the active field in its containing oneof.
1505
1506 Will also delete currently active field in the oneof, if it is different
1507 from the argument. Does not mark the message as modified.
1508 """
1509 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1510 if other_field is not field:
1511 del self._fields[other_field]
1512 self._oneofs[field.containing_oneof] = field
1513
1514 cls._Modified = Modified
1515 cls.SetInParent = Modified
1516 cls._UpdateOneofState = _UpdateOneofState
1517
1518
1519class _Listener(object):
1520
1521 """MessageListener implementation that a parent message registers with its
1522 child message.
1523
1524 In order to support semantics like:
1525
1526 foo.bar.baz.moo = 23
1527 assert foo.HasField('bar')
1528
1529 ...child objects must have back references to their parents.
1530 This helper class is at the heart of this support.
1531 """
1532
1533 def __init__(self, parent_message):
1534 """Args:
1535 parent_message: The message whose _Modified() method we should call when
1536 we receive Modified() messages.
1537 """
1538 # This listener establishes a back reference from a child (contained) object
1539 # to its parent (containing) object. We make this a weak reference to avoid
1540 # creating cyclic garbage when the client finishes with the 'parent' object
1541 # in the tree.
1542 if isinstance(parent_message, weakref.ProxyType):
1543 self._parent_message_weakref = parent_message
1544 else:
1545 self._parent_message_weakref = weakref.proxy(parent_message)
1546
1547 # As an optimization, we also indicate directly on the listener whether
1548 # or not the parent message is dirty. This way we can avoid traversing
1549 # up the tree in the common case.
1550 self.dirty = False
1551
1552 def Modified(self):
1553 if self.dirty:
1554 return
1555 try:
1556 # Propagate the signal to our parents iff this is the first field set.
1557 self._parent_message_weakref._Modified()
1558 except ReferenceError:
1559 # We can get here if a client has kept a reference to a child object,
1560 # and is now setting a field on it, but the child's parent has been
1561 # garbage-collected. This is not an error.
1562 pass
1563
1564
1565class _OneofListener(_Listener):
1566 """Special listener implementation for setting composite oneof fields."""
1567
1568 def __init__(self, parent_message, field):
1569 """Args:
1570 parent_message: The message whose _Modified() method we should call when
1571 we receive Modified() messages.
1572 field: The descriptor of the field being set in the parent message.
1573 """
1574 super(_OneofListener, self).__init__(parent_message)
1575 self._field = field
1576
1577 def Modified(self):
1578 """Also updates the state of the containing oneof in the parent message."""
1579 try:
1580 self._parent_message_weakref._UpdateOneofState(self._field)
1581 super(_OneofListener, self).Modified()
1582 except ReferenceError:
1583 pass