1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import math
33import struct
34import sys
35import warnings
36import weakref
37
38from google.protobuf import descriptor as descriptor_mod
39from google.protobuf import message as message_mod
40from google.protobuf import text_format
41# We use "as" to avoid name collisions with variables.
42from google.protobuf.internal import api_implementation
43from google.protobuf.internal import containers
44from google.protobuf.internal import decoder
45from google.protobuf.internal import encoder
46from google.protobuf.internal import enum_type_wrapper
47from google.protobuf.internal import extension_dict
48from google.protobuf.internal import message_listener as message_listener_mod
49from google.protobuf.internal import type_checkers
50from google.protobuf.internal import well_known_types
51from google.protobuf.internal import wire_format
52
53_FieldDescriptor = descriptor_mod.FieldDescriptor
54_AnyFullTypeName = 'google.protobuf.Any'
55_StructFullTypeName = 'google.protobuf.Struct'
56_ListValueFullTypeName = 'google.protobuf.ListValue'
57_ExtensionDict = extension_dict._ExtensionDict
58
59class GeneratedProtocolMessageType(type):
60
61 """Metaclass for protocol message classes created at runtime from Descriptors.
62
63 We add implementations for all methods described in the Message class. We
64 also create properties to allow getting/setting all fields in the protocol
65 message. Finally, we create slots to prevent users from accidentally
66 "setting" nonexistent fields in the protocol message, which then wouldn't get
67 serialized / deserialized properly.
68
69 The protocol compiler currently uses this metaclass to create protocol
70 message classes at runtime. Clients can also manually create their own
71 classes at runtime, as in this example:
72
73 mydescriptor = Descriptor(.....)
74 factory = symbol_database.Default()
75 factory.pool.AddDescriptor(mydescriptor)
76 MyProtoClass = message_factory.GetMessageClass(mydescriptor)
77 myproto_instance = MyProtoClass()
78 myproto.foo_field = 23
79 ...
80 """
81
82 # Must be consistent with the protocol-compiler code in
83 # proto2/compiler/internal/generator.*.
84 _DESCRIPTOR_KEY = 'DESCRIPTOR'
85
86 def __new__(cls, name, bases, dictionary):
87 """Custom allocation for runtime-generated class types.
88
89 We override __new__ because this is apparently the only place
90 where we can meaningfully set __slots__ on the class we're creating(?).
91 (The interplay between metaclasses and slots is not very well-documented).
92
93 Args:
94 name: Name of the class (ignored, but required by the
95 metaclass protocol).
96 bases: Base classes of the class we're constructing.
97 (Should be message.Message). We ignore this field, but
98 it's required by the metaclass protocol
99 dictionary: The class dictionary of the class we're
100 constructing. dictionary[_DESCRIPTOR_KEY] must contain
101 a Descriptor object describing this protocol message
102 type.
103
104 Returns:
105 Newly-allocated class.
106
107 Raises:
108 RuntimeError: Generated code only work with python cpp extension.
109 """
110 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
111
112 if isinstance(descriptor, str):
113 raise RuntimeError('The generated code only work with python cpp '
114 'extension, but it is using pure python runtime.')
115
116 # If a concrete class already exists for this descriptor, don't try to
117 # create another. Doing so will break any messages that already exist with
118 # the existing class.
119 #
120 # The C++ implementation appears to have its own internal `PyMessageFactory`
121 # to achieve similar results.
122 #
123 # This most commonly happens in `text_format.py` when using descriptors from
124 # a custom pool; it calls message_factory.GetMessageClass() on a
125 # descriptor which already has an existing concrete class.
126 new_class = getattr(descriptor, '_concrete_class', None)
127 if new_class:
128 return new_class
129
130 if descriptor.full_name in well_known_types.WKTBASES:
131 bases += (well_known_types.WKTBASES[descriptor.full_name],)
132 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
133 _AddSlots(descriptor, dictionary)
134
135 superclass = super(GeneratedProtocolMessageType, cls)
136 new_class = superclass.__new__(cls, name, bases, dictionary)
137 return new_class
138
139 def __init__(cls, name, bases, dictionary):
140 """Here we perform the majority of our work on the class.
141 We add enum getters, an __init__ method, implementations
142 of all Message methods, and properties for all fields
143 in the protocol type.
144
145 Args:
146 name: Name of the class (ignored, but required by the
147 metaclass protocol).
148 bases: Base classes of the class we're constructing.
149 (Should be message.Message). We ignore this field, but
150 it's required by the metaclass protocol
151 dictionary: The class dictionary of the class we're
152 constructing. dictionary[_DESCRIPTOR_KEY] must contain
153 a Descriptor object describing this protocol message
154 type.
155 """
156 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
157
158 # If this is an _existing_ class looked up via `_concrete_class` in the
159 # __new__ method above, then we don't need to re-initialize anything.
160 existing_class = getattr(descriptor, '_concrete_class', None)
161 if existing_class:
162 assert existing_class is cls, (
163 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
164 % (descriptor.full_name))
165 return
166
167 cls._message_set_decoders_by_tag = {}
168 cls._fields_by_tag = {}
169 if (descriptor.has_options and
170 descriptor.GetOptions().message_set_wire_format):
171 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
172 decoder.MessageSetItemDecoder(descriptor),
173 None,
174 )
175
176 # Attach stuff to each FieldDescriptor for quick lookup later on.
177 for field in descriptor.fields:
178 _AttachFieldHelpers(cls, field)
179
180 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
181 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
182 for ext in extensions:
183 _AttachFieldHelpers(cls, ext)
184
185 descriptor._concrete_class = cls # pylint: disable=protected-access
186 _AddEnumValues(descriptor, cls)
187 _AddInitMethod(descriptor, cls)
188 _AddPropertiesForFields(descriptor, cls)
189 _AddPropertiesForExtensions(descriptor, cls)
190 _AddStaticMethods(cls)
191 _AddMessageMethods(descriptor, cls)
192 _AddPrivateHelperMethods(descriptor, cls)
193
194 superclass = super(GeneratedProtocolMessageType, cls)
195 superclass.__init__(name, bases, dictionary)
196
197
198# Stateless helpers for GeneratedProtocolMessageType below.
199# Outside clients should not access these directly.
200#
201# I opted not to make any of these methods on the metaclass, to make it more
202# clear that I'm not really using any state there and to keep clients from
203# thinking that they have direct access to these construction helpers.
204
205
206def _PropertyName(proto_field_name):
207 """Returns the name of the public property attribute which
208 clients can use to get and (in some cases) set the value
209 of a protocol message field.
210
211 Args:
212 proto_field_name: The protocol message field name, exactly
213 as it appears (or would appear) in a .proto file.
214 """
215 # TODO: Escape Python keywords (e.g., yield), and test this support.
216 # nnorwitz makes my day by writing:
217 # """
218 # FYI. See the keyword module in the stdlib. This could be as simple as:
219 #
220 # if keyword.iskeyword(proto_field_name):
221 # return proto_field_name + "_"
222 # return proto_field_name
223 # """
224 # Kenton says: The above is a BAD IDEA. People rely on being able to use
225 # getattr() and setattr() to reflectively manipulate field values. If we
226 # rename the properties, then every such user has to also make sure to apply
227 # the same transformation. Note that currently if you name a field "yield",
228 # you can still access it just fine using getattr/setattr -- it's not even
229 # that cumbersome to do so.
230 # TODO: Remove this method entirely if/when everyone agrees with my
231 # position.
232 return proto_field_name
233
234
235def _AddSlots(message_descriptor, dictionary):
236 """Adds a __slots__ entry to dictionary, containing the names of all valid
237 attributes for this message type.
238
239 Args:
240 message_descriptor: A Descriptor instance describing this message type.
241 dictionary: Class dictionary to which we'll add a '__slots__' entry.
242 """
243 dictionary['__slots__'] = ['_cached_byte_size',
244 '_cached_byte_size_dirty',
245 '_fields',
246 '_unknown_fields',
247 '_is_present_in_parent',
248 '_listener',
249 '_listener_for_children',
250 '__weakref__',
251 '_oneofs']
252
253
254def _IsMessageSetExtension(field):
255 return (field.is_extension and
256 field.containing_type.has_options and
257 field.containing_type.GetOptions().message_set_wire_format and
258 field.type == _FieldDescriptor.TYPE_MESSAGE and
259 not field.is_required and
260 not field.is_repeated)
261
262
263def _IsMapField(field):
264 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
265 field.message_type._is_map_entry)
266
267
268def _IsMessageMapField(field):
269 value_type = field.message_type.fields_by_name['value']
270 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
271
272def _AttachFieldHelpers(cls, field_descriptor):
273 field_descriptor._default_constructor = _DefaultValueConstructorForField(
274 field_descriptor
275 )
276
277 def AddFieldByTag(wiretype, is_packed):
278 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
280
281 AddFieldByTag(
282 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
283 )
284
285 if field_descriptor.is_repeated and wire_format.IsTypePackable(
286 field_descriptor.type
287 ):
288 # To support wire compatibility of adding packed = true, add a decoder for
289 # packed values regardless of the field's options.
290 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
291
292
293def _MaybeAddEncoder(cls, field_descriptor):
294 if hasattr(field_descriptor, '_encoder'):
295 return
296 is_repeated = field_descriptor.is_repeated
297 is_map_entry = _IsMapField(field_descriptor)
298 is_packed = field_descriptor.is_packed
299
300 if is_map_entry:
301 field_encoder = encoder.MapEncoder(field_descriptor)
302 sizer = encoder.MapSizer(field_descriptor,
303 _IsMessageMapField(field_descriptor))
304 elif _IsMessageSetExtension(field_descriptor):
305 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
306 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
307 else:
308 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
309 field_descriptor.number, is_repeated, is_packed)
310 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
311 field_descriptor.number, is_repeated, is_packed)
312
313 field_descriptor._sizer = sizer
314 field_descriptor._encoder = field_encoder
315
316
317def _MaybeAddDecoder(cls, field_descriptor):
318 if hasattr(field_descriptor, '_decoders'):
319 return
320
321 is_repeated = field_descriptor.is_repeated
322 is_map_entry = _IsMapField(field_descriptor)
323 helper_decoders = {}
324
325 def AddDecoder(is_packed):
326 decode_type = field_descriptor.type
327 if (decode_type == _FieldDescriptor.TYPE_ENUM and
328 not field_descriptor.enum_type.is_closed):
329 decode_type = _FieldDescriptor.TYPE_INT32
330
331 oneof_descriptor = None
332 if field_descriptor.containing_oneof is not None:
333 oneof_descriptor = field_descriptor
334
335 if is_map_entry:
336 is_message_map = _IsMessageMapField(field_descriptor)
337
338 field_decoder = decoder.MapDecoder(
339 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
340 is_message_map)
341 elif decode_type == _FieldDescriptor.TYPE_STRING:
342 field_decoder = decoder.StringDecoder(
343 field_descriptor.number, is_repeated, is_packed,
344 field_descriptor, field_descriptor._default_constructor,
345 not field_descriptor.has_presence)
346 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
347 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
348 field_descriptor.number, is_repeated, is_packed,
349 field_descriptor, field_descriptor._default_constructor)
350 else:
351 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
352 field_descriptor.number, is_repeated, is_packed,
353 # pylint: disable=protected-access
354 field_descriptor, field_descriptor._default_constructor,
355 not field_descriptor.has_presence)
356
357 helper_decoders[is_packed] = field_decoder
358
359 AddDecoder(False)
360
361 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
362 # To support wire compatibility of adding packed = true, add a decoder for
363 # packed values regardless of the field's options.
364 AddDecoder(True)
365
366 field_descriptor._decoders = helper_decoders
367
368
369def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
370 extensions = descriptor.extensions_by_name
371 for extension_name, extension_field in extensions.items():
372 assert extension_name not in dictionary
373 dictionary[extension_name] = extension_field
374
375
376def _AddEnumValues(descriptor, cls):
377 """Sets class-level attributes for all enum fields defined in this message.
378
379 Also exporting a class-level object that can name enum values.
380
381 Args:
382 descriptor: Descriptor object for this message type.
383 cls: Class we're constructing for this message type.
384 """
385 for enum_type in descriptor.enum_types:
386 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
387 for enum_value in enum_type.values:
388 setattr(cls, enum_value.name, enum_value.number)
389
390
391def _GetInitializeDefaultForMap(field):
392 if not field.is_repeated:
393 raise ValueError('map_entry set on non-repeated field %s' % (
394 field.name))
395 fields_by_name = field.message_type.fields_by_name
396 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
397
398 value_field = fields_by_name['value']
399 if _IsMessageMapField(field):
400 def MakeMessageMapDefault(message):
401 return containers.MessageMap(
402 message._listener_for_children, value_field.message_type, key_checker,
403 field.message_type)
404 return MakeMessageMapDefault
405 else:
406 value_checker = type_checkers.GetTypeChecker(value_field)
407 def MakePrimitiveMapDefault(message):
408 return containers.ScalarMap(
409 message._listener_for_children, key_checker, value_checker,
410 field.message_type)
411 return MakePrimitiveMapDefault
412
413def _DefaultValueConstructorForField(field):
414 """Returns a function which returns a default value for a field.
415
416 Args:
417 field: FieldDescriptor object for this field.
418
419 The returned function has one argument:
420 message: Message instance containing this field, or a weakref proxy
421 of same.
422
423 That function in turn returns a default value for this field. The default
424 value may refer back to |message| via a weak reference.
425 """
426
427 if _IsMapField(field):
428 return _GetInitializeDefaultForMap(field)
429
430 if field.is_repeated:
431 if field.has_default_value and field.default_value != []:
432 raise ValueError('Repeated field default value not empty list: %s' % (
433 field.default_value))
434 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
435 # We can't look at _concrete_class yet since it might not have
436 # been set. (Depends on order in which we initialize the classes).
437 message_type = field.message_type
438 def MakeRepeatedMessageDefault(message):
439 return containers.RepeatedCompositeFieldContainer(
440 message._listener_for_children, field.message_type)
441 return MakeRepeatedMessageDefault
442 else:
443 type_checker = type_checkers.GetTypeChecker(field)
444 def MakeRepeatedScalarDefault(message):
445 return containers.RepeatedScalarFieldContainer(
446 message._listener_for_children, type_checker, field
447 )
448 return MakeRepeatedScalarDefault
449
450 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
451 message_type = field.message_type
452 def MakeSubMessageDefault(message):
453 # _concrete_class may not yet be initialized.
454 if not hasattr(message_type, '_concrete_class'):
455 from google.protobuf import message_factory
456 message_factory.GetMessageClass(message_type)
457 result = message_type._concrete_class()
458 result._SetListener(
459 _OneofListener(message, field)
460 if field.containing_oneof is not None
461 else message._listener_for_children)
462 return result
463 return MakeSubMessageDefault
464
465 def MakeScalarDefault(message):
466 # TODO: This may be broken since there may not be
467 # default_value. Combine with has_default_value somehow.
468 return field.default_value
469 return MakeScalarDefault
470
471
472def _ReraiseTypeErrorWithFieldName(message_name, field_name):
473 """Re-raise the currently-handled TypeError with the field name added."""
474 exc = sys.exc_info()[1]
475 if len(exc.args) == 1 and type(exc) is TypeError:
476 # simple TypeError; add field name to exception message
477 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
478
479 # re-raise possibly-amended exception with original traceback:
480 raise exc.with_traceback(sys.exc_info()[2])
481
482
483def _AddInitMethod(message_descriptor, cls):
484 """Adds an __init__ method to cls."""
485
486 def _GetIntegerEnumValue(enum_type, value):
487 """Convert a string or integer enum value to an integer.
488
489 If the value is a string, it is converted to the enum value in
490 enum_type with the same name. If the value is not a string, it's
491 returned as-is. (No conversion or bounds-checking is done.)
492 """
493 if isinstance(value, str):
494 try:
495 return enum_type.values_by_name[value].number
496 except KeyError:
497 raise ValueError('Enum type %s: unknown label "%s"' % (
498 enum_type.full_name, value))
499 return value
500
501 def init(self, **kwargs):
502
503 def init_wkt_or_merge(field, msg, value):
504 if isinstance(value, message_mod.Message):
505 msg.MergeFrom(value)
506 elif (
507 isinstance(value, dict)
508 and field.message_type.full_name == _StructFullTypeName
509 ):
510 msg.Clear()
511 if len(value) == 1 and 'fields' in value:
512 try:
513 msg.update(value)
514 except:
515 msg.Clear()
516 msg.__init__(**value)
517 else:
518 msg.update(value)
519 elif hasattr(msg, '_internal_assign'):
520 msg._internal_assign(value)
521 else:
522 raise TypeError(
523 'Message field {0}.{1} must be initialized with a '
524 'dict or instance of same class, got {2}.'.format(
525 message_descriptor.name,
526 field.name,
527 type(value).__name__,
528 )
529 )
530
531 self._cached_byte_size = 0
532 self._cached_byte_size_dirty = len(kwargs) > 0
533 self._fields = {}
534 # Contains a mapping from oneof field descriptors to the descriptor
535 # of the currently set field in that oneof field.
536 self._oneofs = {}
537
538 # _unknown_fields is () when empty for efficiency, and will be turned into
539 # a list if fields are added.
540 self._unknown_fields = ()
541 self._is_present_in_parent = False
542 self._listener = message_listener_mod.NullMessageListener()
543 self._listener_for_children = _Listener(self)
544 for field_name, field_value in kwargs.items():
545 field = _GetFieldByName(message_descriptor, field_name)
546 if field is None:
547 raise TypeError('%s() got an unexpected keyword argument "%s"' %
548 (message_descriptor.name, field_name))
549 if field_value is None:
550 # field=None is the same as no field at all.
551 continue
552 if field.is_repeated:
553 field_copy = field._default_constructor(self)
554 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
555 if _IsMapField(field):
556 if _IsMessageMapField(field):
557 for key in field_value:
558 item_value = field_value[key]
559 if isinstance(item_value, dict):
560 field_copy[key].__init__(**item_value)
561 else:
562 field_copy[key].MergeFrom(item_value)
563 else:
564 field_copy.update(field_value)
565 else:
566 for val in field_value:
567 if isinstance(val, dict) and (
568 field.message_type.full_name != _StructFullTypeName
569 ):
570 field_copy.add(**val)
571 else:
572 new_msg = field_copy.add()
573 init_wkt_or_merge(field, new_msg, val)
574 else: # Scalar
575 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
576 field_value = [_GetIntegerEnumValue(field.enum_type, val)
577 for val in field_value]
578 field_copy.extend(field_value)
579 self._fields[field] = field_copy
580 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
581 field_copy = field._default_constructor(self)
582 if isinstance(field_value, dict) and (
583 field.message_type.full_name != _StructFullTypeName
584 ):
585 new_val = field.message_type._concrete_class(**field_value)
586 field_copy.MergeFrom(new_val)
587 else:
588 try:
589 init_wkt_or_merge(field, field_copy, field_value)
590 except TypeError:
591 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
592 self._fields[field] = field_copy
593 else:
594 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
595 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
596 try:
597 setattr(self, field_name, field_value)
598 except TypeError:
599 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
600
601 init.__module__ = None
602 init.__doc__ = None
603 cls.__init__ = init
604
605
606def _GetFieldByName(message_descriptor, field_name):
607 """Returns a field descriptor by field name.
608
609 Args:
610 message_descriptor: A Descriptor describing all fields in message.
611 field_name: The name of the field to retrieve.
612 Returns:
613 The field descriptor associated with the field name.
614 """
615 try:
616 return message_descriptor.fields_by_name[field_name]
617 except KeyError:
618 raise ValueError('Protocol message %s has no "%s" field.' %
619 (message_descriptor.name, field_name))
620
621
622def _AddPropertiesForFields(descriptor, cls):
623 """Adds properties for all fields in this protocol message type."""
624 for field in descriptor.fields:
625 _AddPropertiesForField(field, cls)
626
627 if descriptor.is_extendable:
628 # _ExtensionDict is just an adaptor with no state so we allocate a new one
629 # every time it is accessed.
630 cls.Extensions = property(lambda self: _ExtensionDict(self))
631
632
633def _AddPropertiesForField(field, cls):
634 """Adds a public property for a protocol message field.
635 Clients can use this property to get and (in the case
636 of non-repeated scalar fields) directly set the value
637 of a protocol message field.
638
639 Args:
640 field: A FieldDescriptor for this field.
641 cls: The class we're constructing.
642 """
643 # Catch it if we add other types that we should
644 # handle specially here.
645 assert _FieldDescriptor.MAX_CPPTYPE == 10
646
647 constant_name = field.name.upper() + '_FIELD_NUMBER'
648 setattr(cls, constant_name, field.number)
649
650 if field.is_repeated:
651 _AddPropertiesForRepeatedField(field, cls)
652 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
653 _AddPropertiesForNonRepeatedCompositeField(field, cls)
654 else:
655 _AddPropertiesForNonRepeatedScalarField(field, cls)
656
657
658class _FieldProperty(property):
659 __slots__ = ('DESCRIPTOR',)
660
661 def __init__(self, descriptor, getter, setter, doc):
662 property.__init__(self, getter, setter, doc=doc)
663 self.DESCRIPTOR = descriptor
664
665
666def _AddPropertiesForRepeatedField(field, cls):
667 """Adds a public property for a "repeated" protocol message field. Clients
668 can use this property to get the value of the field, which will be either a
669 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
670 below).
671
672 Note that when clients add values to these containers, we perform
673 type-checking in the case of repeated scalar fields, and we also set any
674 necessary "has" bits as a side-effect.
675
676 Args:
677 field: A FieldDescriptor for this field.
678 cls: The class we're constructing.
679 """
680 proto_field_name = field.name
681 property_name = _PropertyName(proto_field_name)
682
683 def getter(self):
684 field_value = self._fields.get(field)
685 if field_value is None:
686 # Construct a new object to represent this field.
687 field_value = field._default_constructor(self)
688
689 # Atomically check if another thread has preempted us and, if not, swap
690 # in the new object we just created. If someone has preempted us, we
691 # take that object and discard ours.
692 # WARNING: We are relying on setdefault() being atomic. This is true
693 # in CPython but we haven't investigated others. This warning appears
694 # in several other locations in this file.
695 field_value = self._fields.setdefault(field, field_value)
696 return field_value
697 getter.__module__ = None
698 getter.__doc__ = 'Getter for %s.' % proto_field_name
699
700 # We define a setter just so we can throw an exception with a more
701 # helpful error message.
702 def setter(self, new_value):
703 raise AttributeError('Assignment not allowed to repeated field '
704 '"%s" in protocol message object.' % proto_field_name)
705
706 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
707 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
708
709
710def _AddPropertiesForNonRepeatedScalarField(field, cls):
711 """Adds a public property for a nonrepeated, scalar protocol message field.
712 Clients can use this property to get and directly set the value of the field.
713 Note that when the client sets the value of a field by using this property,
714 all necessary "has" bits are set as a side-effect, and we also perform
715 type-checking.
716
717 Args:
718 field: A FieldDescriptor for this field.
719 cls: The class we're constructing.
720 """
721 proto_field_name = field.name
722 property_name = _PropertyName(proto_field_name)
723 type_checker = type_checkers.GetTypeChecker(field)
724 default_value = field.default_value
725
726 def getter(self):
727 # TODO: This may be broken since there may not be
728 # default_value. Combine with has_default_value somehow.
729 return self._fields.get(field, default_value)
730 getter.__module__ = None
731 getter.__doc__ = 'Getter for %s.' % proto_field_name
732
733 def field_setter(self, new_value):
734 # pylint: disable=protected-access
735 # Testing the value for truthiness captures all of the implicit presence
736 # defaults (0, 0.0, enum 0, and False), except for -0.0.
737 try:
738 new_value = type_checker.CheckValue(new_value)
739 except TypeError as e:
740 raise TypeError(
741 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
742 if not field.has_presence and decoder.IsDefaultScalarValue(new_value):
743 self._fields.pop(field, None)
744 else:
745 self._fields[field] = new_value
746 # Check _cached_byte_size_dirty inline to improve performance, since scalar
747 # setters are called frequently.
748 if not self._cached_byte_size_dirty:
749 self._Modified()
750
751 if field.containing_oneof:
752 def setter(self, new_value):
753 field_setter(self, new_value)
754 self._UpdateOneofState(field)
755 else:
756 setter = field_setter
757
758 setter.__module__ = None
759 setter.__doc__ = 'Setter for %s.' % proto_field_name
760
761 # Add a property to encapsulate the getter/setter.
762 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
763 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
764
765
766def _AddPropertiesForNonRepeatedCompositeField(field, cls):
767 """Adds a public property for a nonrepeated, composite protocol message field.
768 A composite field is a "group" or "message" field.
769
770 Clients can use this property to get the value of the field, but cannot
771 assign to the property directly.
772
773 Args:
774 field: A FieldDescriptor for this field.
775 cls: The class we're constructing.
776 """
777 # TODO: Remove duplication with similar method
778 # for non-repeated scalars.
779 proto_field_name = field.name
780 property_name = _PropertyName(proto_field_name)
781
782 def getter(self):
783 field_value = self._fields.get(field)
784 if field_value is None:
785 # Construct a new object to represent this field.
786 field_value = field._default_constructor(self)
787
788 # Atomically check if another thread has preempted us and, if not, swap
789 # in the new object we just created. If someone has preempted us, we
790 # take that object and discard ours.
791 # WARNING: We are relying on setdefault() being atomic. This is true
792 # in CPython but we haven't investigated others. This warning appears
793 # in several other locations in this file.
794 field_value = self._fields.setdefault(field, field_value)
795 return field_value
796 getter.__module__ = None
797 getter.__doc__ = 'Getter for %s.' % proto_field_name
798
799 # We define a setter just so we can throw an exception with a more
800 # helpful error message.
801 def setter(self, new_value):
802 if field.message_type.full_name == 'google.protobuf.Timestamp':
803 getter(self)
804 self._fields[field].FromDatetime(new_value)
805 elif field.message_type.full_name == 'google.protobuf.Duration':
806 getter(self)
807 self._fields[field].FromTimedelta(new_value)
808 elif field.message_type.full_name == _StructFullTypeName:
809 getter(self)
810 self._fields[field].Clear()
811 self._fields[field].update(new_value)
812 elif field.message_type.full_name == _ListValueFullTypeName:
813 getter(self)
814 self._fields[field].Clear()
815 self._fields[field].extend(new_value)
816 else:
817 raise AttributeError(
818 'Assignment not allowed to composite field '
819 '"%s" in protocol message object.' % proto_field_name
820 )
821
822 # Add a property to encapsulate the getter.
823 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
824 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
825
826
827def _AddPropertiesForExtensions(descriptor, cls):
828 """Adds properties for all fields in this protocol message type."""
829 extensions = descriptor.extensions_by_name
830 for extension_name, extension_field in extensions.items():
831 constant_name = extension_name.upper() + '_FIELD_NUMBER'
832 setattr(cls, constant_name, extension_field.number)
833
834 # TODO: Migrate all users of these attributes to functions like
835 # pool.FindExtensionByNumber(descriptor).
836 if descriptor.file is not None:
837 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
838 pool = descriptor.file.pool
839
840def _AddStaticMethods(cls):
841
842 def RegisterExtension(_):
843 """no-op to keep generated code <=4.23 working with new runtimes."""
844 # This was originally removed in 5.26 (cl/595989309).
845 pass
846
847 cls.RegisterExtension = staticmethod(RegisterExtension)
848 def FromString(s):
849 message = cls()
850 message.MergeFromString(s)
851 return message
852 cls.FromString = staticmethod(FromString)
853
854
855def _IsPresent(item):
856 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
857 value should be included in the list returned by ListFields()."""
858
859 if item[0].is_repeated:
860 return bool(item[1])
861 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862 return item[1]._is_present_in_parent
863 else:
864 return True
865
866
867def _AddListFieldsMethod(message_descriptor, cls):
868 """Helper for _AddMessageMethods()."""
869
870 def ListFields(self):
871 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
872 all_fields.sort(key = lambda item: item[0].number)
873 return all_fields
874
875 cls.ListFields = ListFields
876
877
878def _AddHasFieldMethod(message_descriptor, cls):
879 """Helper for _AddMessageMethods()."""
880
881 hassable_fields = {}
882 for field in message_descriptor.fields:
883 if field.is_repeated:
884 continue
885 # For proto3, only submessages and fields inside a oneof have presence.
886 if not field.has_presence:
887 continue
888 hassable_fields[field.name] = field
889
890 # Has methods are supported for oneof descriptors.
891 for oneof in message_descriptor.oneofs:
892 hassable_fields[oneof.name] = oneof
893
894 def HasField(self, field_name):
895 try:
896 field = hassable_fields[field_name]
897 except KeyError as exc:
898 raise ValueError('Protocol message %s has no non-repeated field "%s" '
899 'nor has presence is not available for this field.' % (
900 message_descriptor.full_name, field_name)) from exc
901
902 if isinstance(field, descriptor_mod.OneofDescriptor):
903 try:
904 return HasField(self, self._oneofs[field].name)
905 except KeyError:
906 return False
907 else:
908 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
909 value = self._fields.get(field)
910 return value is not None and value._is_present_in_parent
911 else:
912 return field in self._fields
913
914 cls.HasField = HasField
915
916
917def _AddClearFieldMethod(message_descriptor, cls):
918 """Helper for _AddMessageMethods()."""
919 def ClearField(self, field_name):
920 try:
921 field = message_descriptor.fields_by_name[field_name]
922 except KeyError:
923 try:
924 field = message_descriptor.oneofs_by_name[field_name]
925 if field in self._oneofs:
926 field = self._oneofs[field]
927 else:
928 return
929 except KeyError:
930 raise ValueError('Protocol message %s has no "%s" field.' %
931 (message_descriptor.name, field_name))
932
933 if field in self._fields:
934 # To match the C++ implementation, we need to invalidate iterators
935 # for map fields when ClearField() happens.
936 if hasattr(self._fields[field], 'InvalidateIterators'):
937 self._fields[field].InvalidateIterators()
938
939 # Note: If the field is a sub-message, its listener will still point
940 # at us. That's fine, because the worst than can happen is that it
941 # will call _Modified() and invalidate our byte size. Big deal.
942 del self._fields[field]
943
944 if self._oneofs.get(field.containing_oneof, None) is field:
945 del self._oneofs[field.containing_oneof]
946
947 # Always call _Modified() -- even if nothing was changed, this is
948 # a mutating method, and thus calling it should cause the field to become
949 # present in the parent message.
950 self._Modified()
951
952 cls.ClearField = ClearField
953
954
955def _AddClearExtensionMethod(cls):
956 """Helper for _AddMessageMethods()."""
957 def ClearExtension(self, field_descriptor):
958 extension_dict._VerifyExtensionHandle(self, field_descriptor)
959
960 # Similar to ClearField(), above.
961 if field_descriptor in self._fields:
962 del self._fields[field_descriptor]
963 self._Modified()
964 cls.ClearExtension = ClearExtension
965
966
967def _AddHasExtensionMethod(cls):
968 """Helper for _AddMessageMethods()."""
969 def HasExtension(self, field_descriptor):
970 extension_dict._VerifyExtensionHandle(self, field_descriptor)
971 if field_descriptor.is_repeated:
972 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
973
974 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
975 value = self._fields.get(field_descriptor)
976 return value is not None and value._is_present_in_parent
977 else:
978 return field_descriptor in self._fields
979 cls.HasExtension = HasExtension
980
981def _InternalUnpackAny(msg):
982 """Unpacks Any message and returns the unpacked message.
983
984 This internal method is different from public Any Unpack method which takes
985 the target message as argument. _InternalUnpackAny method does not have
986 target message type and need to find the message type in descriptor pool.
987
988 Args:
989 msg: An Any message to be unpacked.
990
991 Returns:
992 The unpacked message.
993 """
994 # TODO: Don't use the factory of generated messages.
995 # To make Any work with custom factories, use the message factory of the
996 # parent message.
997 # pylint: disable=g-import-not-at-top
998 from google.protobuf import symbol_database
999 factory = symbol_database.Default()
1000
1001 type_url = msg.type_url
1002
1003 if not type_url:
1004 return None
1005
1006 # TODO: For now we just strip the hostname. Better logic will be
1007 # required.
1008 type_name = type_url.split('/')[-1]
1009 descriptor = factory.pool.FindMessageTypeByName(type_name)
1010
1011 if descriptor is None:
1012 return None
1013
1014 # Unable to import message_factory at top because of circular import.
1015 # pylint: disable=g-import-not-at-top
1016 from google.protobuf import message_factory
1017 message_class = message_factory.GetMessageClass(descriptor)
1018 message = message_class()
1019
1020 message.ParseFromString(msg.value)
1021 return message
1022
1023
1024def _AddEqualsMethod(message_descriptor, cls):
1025 """Helper for _AddMessageMethods()."""
1026 def __eq__(self, other):
1027 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1028 other, list
1029 ):
1030 return self._internal_compare(other)
1031 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1032 other, dict
1033 ):
1034 return self._internal_compare(other)
1035
1036 if (not isinstance(other, message_mod.Message) or
1037 other.DESCRIPTOR != self.DESCRIPTOR):
1038 return NotImplemented
1039
1040 if self is other:
1041 return True
1042
1043 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1044 any_a = _InternalUnpackAny(self)
1045 any_b = _InternalUnpackAny(other)
1046 if any_a and any_b:
1047 return any_a == any_b
1048
1049 if not self.ListFields() == other.ListFields():
1050 return False
1051
1052 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1053 # then use it for the comparison.
1054 unknown_fields = list(self._unknown_fields)
1055 unknown_fields.sort()
1056 other_unknown_fields = list(other._unknown_fields)
1057 other_unknown_fields.sort()
1058 return unknown_fields == other_unknown_fields
1059
1060 cls.__eq__ = __eq__
1061
1062
1063def _AddStrMethod(message_descriptor, cls):
1064 """Helper for _AddMessageMethods()."""
1065 def __str__(self):
1066 return text_format.MessageToString(self)
1067 cls.__str__ = __str__
1068
1069
1070def _AddReprMethod(message_descriptor, cls):
1071 """Helper for _AddMessageMethods()."""
1072 def __repr__(self):
1073 return text_format.MessageToString(self)
1074 cls.__repr__ = __repr__
1075
1076
1077def _AddUnicodeMethod(unused_message_descriptor, cls):
1078 """Helper for _AddMessageMethods()."""
1079
1080 def __unicode__(self):
1081 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1082 cls.__unicode__ = __unicode__
1083
1084
1085def _AddContainsMethod(message_descriptor, cls):
1086
1087 if message_descriptor.full_name == 'google.protobuf.Struct':
1088 def __contains__(self, key):
1089 return key in self.fields
1090 elif message_descriptor.full_name == 'google.protobuf.ListValue':
1091 def __contains__(self, value):
1092 return value in self.items()
1093 else:
1094 def __contains__(self, field):
1095 return self.HasField(field)
1096
1097 cls.__contains__ = __contains__
1098
1099
1100def _BytesForNonRepeatedElement(value, field_number, field_type):
1101 """Returns the number of bytes needed to serialize a non-repeated element.
1102 The returned byte count includes space for tag information and any
1103 other additional space associated with serializing value.
1104
1105 Args:
1106 value: Value we're serializing.
1107 field_number: Field number of this value. (Since the field number
1108 is stored as part of a varint-encoded tag, this has an impact
1109 on the total bytes required to serialize the value).
1110 field_type: The type of the field. One of the TYPE_* constants
1111 within FieldDescriptor.
1112 """
1113 try:
1114 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1115 return fn(field_number, value)
1116 except KeyError:
1117 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1118
1119
1120def _AddByteSizeMethod(message_descriptor, cls):
1121 """Helper for _AddMessageMethods()."""
1122
1123 def ByteSize(self):
1124 if not self._cached_byte_size_dirty:
1125 return self._cached_byte_size
1126
1127 size = 0
1128 descriptor = self.DESCRIPTOR
1129 if descriptor._is_map_entry:
1130 # Fields of map entry should always be serialized.
1131 key_field = descriptor.fields_by_name['key']
1132 _MaybeAddEncoder(cls, key_field)
1133 size = key_field._sizer(self.key)
1134 value_field = descriptor.fields_by_name['value']
1135 _MaybeAddEncoder(cls, value_field)
1136 size += value_field._sizer(self.value)
1137 else:
1138 for field_descriptor, field_value in self.ListFields():
1139 _MaybeAddEncoder(cls, field_descriptor)
1140 size += field_descriptor._sizer(field_value)
1141 for tag_bytes, value_bytes in self._unknown_fields:
1142 size += len(tag_bytes) + len(value_bytes)
1143
1144 self._cached_byte_size = size
1145 self._cached_byte_size_dirty = False
1146 self._listener_for_children.dirty = False
1147 return size
1148
1149 cls.ByteSize = ByteSize
1150
1151
1152def _AddSerializeToStringMethod(message_descriptor, cls):
1153 """Helper for _AddMessageMethods()."""
1154
1155 def SerializeToString(self, **kwargs):
1156 # Check if the message has all of its required fields set.
1157 if not self.IsInitialized():
1158 raise message_mod.EncodeError(
1159 'Message %s is missing required fields: %s' % (
1160 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1161 return self.SerializePartialToString(**kwargs)
1162 cls.SerializeToString = SerializeToString
1163
1164
1165def _AddSerializePartialToStringMethod(message_descriptor, cls):
1166 """Helper for _AddMessageMethods()."""
1167
1168 def SerializePartialToString(self, **kwargs):
1169 out = BytesIO()
1170 self._InternalSerialize(out.write, **kwargs)
1171 return out.getvalue()
1172 cls.SerializePartialToString = SerializePartialToString
1173
1174 def InternalSerialize(self, write_bytes, deterministic=None):
1175 if deterministic is None:
1176 deterministic = (
1177 api_implementation.IsPythonDefaultSerializationDeterministic())
1178 else:
1179 deterministic = bool(deterministic)
1180
1181 descriptor = self.DESCRIPTOR
1182 if descriptor._is_map_entry:
1183 # Fields of map entry should always be serialized.
1184 key_field = descriptor.fields_by_name['key']
1185 _MaybeAddEncoder(cls, key_field)
1186 key_field._encoder(write_bytes, self.key, deterministic)
1187 value_field = descriptor.fields_by_name['value']
1188 _MaybeAddEncoder(cls, value_field)
1189 value_field._encoder(write_bytes, self.value, deterministic)
1190 else:
1191 for field_descriptor, field_value in self.ListFields():
1192 _MaybeAddEncoder(cls, field_descriptor)
1193 field_descriptor._encoder(write_bytes, field_value, deterministic)
1194 for tag_bytes, value_bytes in self._unknown_fields:
1195 write_bytes(tag_bytes)
1196 write_bytes(value_bytes)
1197 cls._InternalSerialize = InternalSerialize
1198
1199
1200def _AddMergeFromStringMethod(message_descriptor, cls):
1201 """Helper for _AddMessageMethods()."""
1202 def MergeFromString(self, serialized):
1203 serialized = memoryview(serialized)
1204 length = len(serialized)
1205 try:
1206 if self._InternalParse(serialized, 0, length) != length:
1207 # The only reason _InternalParse would return early is if it
1208 # encountered an end-group tag.
1209 raise message_mod.DecodeError('Unexpected end-group tag.')
1210 except (IndexError, TypeError):
1211 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1212 raise message_mod.DecodeError('Truncated message.')
1213 except struct.error as e:
1214 raise message_mod.DecodeError(e)
1215 return length # Return this for legacy reasons.
1216 cls.MergeFromString = MergeFromString
1217
1218 fields_by_tag = cls._fields_by_tag
1219 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1220
1221 def InternalParse(self, buffer, pos, end, current_depth=0):
1222 """Create a message from serialized bytes.
1223
1224 Args:
1225 self: Message, instance of the proto message object.
1226 buffer: memoryview of the serialized data.
1227 pos: int, position to start in the serialized data.
1228 end: int, end position of the serialized data.
1229
1230 Returns:
1231 Message object.
1232 """
1233 # Guard against internal misuse, since this function is called internally
1234 # quite extensively, and its easy to accidentally pass bytes.
1235 assert isinstance(buffer, memoryview)
1236 self._Modified()
1237 field_dict = self._fields
1238 while pos != end:
1239 (tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
1240 field_decoder, field_des = message_set_decoders_by_tag.get(
1241 tag_bytes, (None, None)
1242 )
1243 if field_decoder:
1244 pos = field_decoder(
1245 buffer, new_pos, end, self, field_dict, current_depth
1246 )
1247 continue
1248 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1249 if field_des is None:
1250 if not self._unknown_fields: # pylint: disable=protected-access
1251 self._unknown_fields = [] # pylint: disable=protected-access
1252 field_number, wire_type = decoder.DecodeTag(tag_bytes)
1253 if field_number == 0:
1254 raise message_mod.DecodeError('Field number 0 is illegal.')
1255 (data, new_pos) = decoder._DecodeUnknownField(
1256 buffer, new_pos, end, field_number, wire_type
1257 ) # pylint: disable=protected-access
1258 if new_pos == -1:
1259 return pos
1260 self._unknown_fields.append(
1261 (tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
1262 )
1263 pos = new_pos
1264 else:
1265 _MaybeAddDecoder(cls, field_des)
1266 field_decoder = field_des._decoders[is_packed]
1267 pos = field_decoder(
1268 buffer, new_pos, end, self, field_dict, current_depth
1269 )
1270 if field_des.containing_oneof:
1271 self._UpdateOneofState(field_des)
1272 return pos
1273
1274 cls._InternalParse = InternalParse
1275
1276
1277def _AddIsInitializedMethod(message_descriptor, cls):
1278 """Adds the IsInitialized and FindInitializationError methods to the
1279 protocol message class."""
1280
1281 required_fields = [field for field in message_descriptor.fields
1282 if field.is_required]
1283
1284 def IsInitialized(self, errors=None):
1285 """Checks if all required fields of a message are set.
1286
1287 Args:
1288 errors: A list which, if provided, will be populated with the field
1289 paths of all missing required fields.
1290
1291 Returns:
1292 True iff the specified message has all required fields set.
1293 """
1294
1295 # Performance is critical so we avoid HasField() and ListFields().
1296
1297 for field in required_fields:
1298 if (field not in self._fields or
1299 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1300 not self._fields[field]._is_present_in_parent)):
1301 if errors is not None:
1302 errors.extend(self.FindInitializationErrors())
1303 return False
1304
1305 for field, value in list(self._fields.items()): # dict can change size!
1306 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1307 if field.is_repeated:
1308 if (field.message_type._is_map_entry):
1309 continue
1310 for element in value:
1311 if not element.IsInitialized():
1312 if errors is not None:
1313 errors.extend(self.FindInitializationErrors())
1314 return False
1315 elif value._is_present_in_parent and not value.IsInitialized():
1316 if errors is not None:
1317 errors.extend(self.FindInitializationErrors())
1318 return False
1319
1320 return True
1321
1322 cls.IsInitialized = IsInitialized
1323
1324 def FindInitializationErrors(self):
1325 """Finds required fields which are not initialized.
1326
1327 Returns:
1328 A list of strings. Each string is a path to an uninitialized field from
1329 the top-level message, e.g. "foo.bar[5].baz".
1330 """
1331
1332 errors = [] # simplify things
1333
1334 for field in required_fields:
1335 if not self.HasField(field.name):
1336 errors.append(field.name)
1337
1338 for field, value in self.ListFields():
1339 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1340 if field.is_extension:
1341 name = '(%s)' % field.full_name
1342 else:
1343 name = field.name
1344
1345 if _IsMapField(field):
1346 if _IsMessageMapField(field):
1347 for key in value:
1348 element = value[key]
1349 prefix = '%s[%s].' % (name, key)
1350 sub_errors = element.FindInitializationErrors()
1351 errors += [prefix + error for error in sub_errors]
1352 else:
1353 # ScalarMaps can't have any initialization errors.
1354 pass
1355 elif field.is_repeated:
1356 for i in range(len(value)):
1357 element = value[i]
1358 prefix = '%s[%d].' % (name, i)
1359 sub_errors = element.FindInitializationErrors()
1360 errors += [prefix + error for error in sub_errors]
1361 else:
1362 prefix = name + '.'
1363 sub_errors = value.FindInitializationErrors()
1364 errors += [prefix + error for error in sub_errors]
1365
1366 return errors
1367
1368 cls.FindInitializationErrors = FindInitializationErrors
1369
1370
1371def _FullyQualifiedClassName(klass):
1372 module = klass.__module__
1373 name = getattr(klass, '__qualname__', klass.__name__)
1374 if module in (None, 'builtins', '__builtin__'):
1375 return name
1376 return module + '.' + name
1377
1378
1379def _AddMergeFromMethod(cls):
1380 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1381
1382 def MergeFrom(self, msg):
1383 if not isinstance(msg, cls):
1384 raise TypeError(
1385 'Parameter to MergeFrom() must be instance of same class: '
1386 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1387 _FullyQualifiedClassName(msg.__class__)))
1388
1389 assert msg is not self
1390 self._Modified()
1391
1392 fields = self._fields
1393
1394 for field, value in msg._fields.items():
1395 if field.is_repeated:
1396 field_value = fields.get(field)
1397 if field_value is None:
1398 # Construct a new object to represent this field.
1399 field_value = field._default_constructor(self)
1400 fields[field] = field_value
1401 field_value.MergeFrom(value)
1402 elif field.cpp_type == CPPTYPE_MESSAGE:
1403 if value._is_present_in_parent:
1404 field_value = fields.get(field)
1405 if field_value is None:
1406 # Construct a new object to represent this field.
1407 field_value = field._default_constructor(self)
1408 fields[field] = field_value
1409 field_value.MergeFrom(value)
1410 else:
1411 self._fields[field] = value
1412 if field.containing_oneof:
1413 self._UpdateOneofState(field)
1414
1415 if msg._unknown_fields:
1416 if not self._unknown_fields:
1417 self._unknown_fields = []
1418 self._unknown_fields.extend(msg._unknown_fields)
1419
1420 cls.MergeFrom = MergeFrom
1421
1422
1423def _AddWhichOneofMethod(message_descriptor, cls):
1424 def WhichOneof(self, oneof_name):
1425 """Returns the name of the currently set field inside a oneof, or None."""
1426 try:
1427 field = message_descriptor.oneofs_by_name[oneof_name]
1428 except KeyError:
1429 raise ValueError(
1430 'Protocol message has no oneof "%s" field.' % oneof_name)
1431
1432 nested_field = self._oneofs.get(field, None)
1433 if nested_field is not None and self.HasField(nested_field.name):
1434 return nested_field.name
1435 else:
1436 return None
1437
1438 cls.WhichOneof = WhichOneof
1439
1440
1441def _Clear(self):
1442 # Clear fields.
1443 self._fields = {}
1444 self._unknown_fields = ()
1445
1446 self._oneofs = {}
1447 self._Modified()
1448
1449
1450def _UnknownFields(self):
1451 raise NotImplementedError('Please use the add-on feaure '
1452 'unknown_fields.UnknownFieldSet(message) in '
1453 'unknown_fields.py instead.')
1454
1455
1456def _DiscardUnknownFields(self):
1457 self._unknown_fields = []
1458 for field, value in self.ListFields():
1459 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1460 if _IsMapField(field):
1461 if _IsMessageMapField(field):
1462 for key in value:
1463 value[key].DiscardUnknownFields()
1464 elif field.is_repeated:
1465 for sub_message in value:
1466 sub_message.DiscardUnknownFields()
1467 else:
1468 value.DiscardUnknownFields()
1469
1470
1471def _SetListener(self, listener):
1472 if listener is None:
1473 self._listener = message_listener_mod.NullMessageListener()
1474 else:
1475 self._listener = listener
1476
1477
1478def _AddMessageMethods(message_descriptor, cls):
1479 """Adds implementations of all Message methods to cls."""
1480 _AddListFieldsMethod(message_descriptor, cls)
1481 _AddHasFieldMethod(message_descriptor, cls)
1482 _AddClearFieldMethod(message_descriptor, cls)
1483 if message_descriptor.is_extendable:
1484 _AddClearExtensionMethod(cls)
1485 _AddHasExtensionMethod(cls)
1486 _AddEqualsMethod(message_descriptor, cls)
1487 _AddStrMethod(message_descriptor, cls)
1488 _AddReprMethod(message_descriptor, cls)
1489 _AddUnicodeMethod(message_descriptor, cls)
1490 _AddContainsMethod(message_descriptor, cls)
1491 _AddByteSizeMethod(message_descriptor, cls)
1492 _AddSerializeToStringMethod(message_descriptor, cls)
1493 _AddSerializePartialToStringMethod(message_descriptor, cls)
1494 _AddMergeFromStringMethod(message_descriptor, cls)
1495 _AddIsInitializedMethod(message_descriptor, cls)
1496 _AddMergeFromMethod(cls)
1497 _AddWhichOneofMethod(message_descriptor, cls)
1498 # Adds methods which do not depend on cls.
1499 cls.Clear = _Clear
1500 cls.DiscardUnknownFields = _DiscardUnknownFields
1501 cls._SetListener = _SetListener
1502
1503
1504def _AddPrivateHelperMethods(message_descriptor, cls):
1505 """Adds implementation of private helper methods to cls."""
1506
1507 def Modified(self):
1508 """Sets the _cached_byte_size_dirty bit to true,
1509 and propagates this to our listener iff this was a state change.
1510 """
1511
1512 # Note: Some callers check _cached_byte_size_dirty before calling
1513 # _Modified() as an extra optimization. So, if this method is ever
1514 # changed such that it does stuff even when _cached_byte_size_dirty is
1515 # already true, the callers need to be updated.
1516 if not self._cached_byte_size_dirty:
1517 self._cached_byte_size_dirty = True
1518 self._listener_for_children.dirty = True
1519 self._is_present_in_parent = True
1520 self._listener.Modified()
1521
1522 def _UpdateOneofState(self, field):
1523 """Sets field as the active field in its containing oneof.
1524
1525 Will also delete currently active field in the oneof, if it is different
1526 from the argument. Does not mark the message as modified.
1527 """
1528 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1529 if other_field is not field:
1530 del self._fields[other_field]
1531 self._oneofs[field.containing_oneof] = field
1532
1533 cls._Modified = Modified
1534 cls.SetInParent = Modified
1535 cls._UpdateOneofState = _UpdateOneofState
1536
1537
1538class _Listener(object):
1539
1540 """MessageListener implementation that a parent message registers with its
1541 child message.
1542
1543 In order to support semantics like:
1544
1545 foo.bar.baz.moo = 23
1546 assert foo.HasField('bar')
1547
1548 ...child objects must have back references to their parents.
1549 This helper class is at the heart of this support.
1550 """
1551
1552 def __init__(self, parent_message):
1553 """Args:
1554 parent_message: The message whose _Modified() method we should call when
1555 we receive Modified() messages.
1556 """
1557 # This listener establishes a back reference from a child (contained) object
1558 # to its parent (containing) object. We make this a weak reference to avoid
1559 # creating cyclic garbage when the client finishes with the 'parent' object
1560 # in the tree.
1561 if isinstance(parent_message, weakref.ProxyType):
1562 self._parent_message_weakref = parent_message
1563 else:
1564 self._parent_message_weakref = weakref.proxy(parent_message)
1565
1566 # As an optimization, we also indicate directly on the listener whether
1567 # or not the parent message is dirty. This way we can avoid traversing
1568 # up the tree in the common case.
1569 self.dirty = False
1570
1571 def Modified(self):
1572 if self.dirty:
1573 return
1574 try:
1575 # Propagate the signal to our parents iff this is the first field set.
1576 self._parent_message_weakref._Modified()
1577 except ReferenceError:
1578 # We can get here if a client has kept a reference to a child object,
1579 # and is now setting a field on it, but the child's parent has been
1580 # garbage-collected. This is not an error.
1581 pass
1582
1583
1584class _OneofListener(_Listener):
1585 """Special listener implementation for setting composite oneof fields."""
1586
1587 def __init__(self, parent_message, field):
1588 """Args:
1589 parent_message: The message whose _Modified() method we should call when
1590 we receive Modified() messages.
1591 field: The descriptor of the field being set in the parent message.
1592 """
1593 super(_OneofListener, self).__init__(parent_message)
1594 self._field = field
1595
1596 def Modified(self):
1597 """Also updates the state of the containing oneof in the parent message."""
1598 try:
1599 self._parent_message_weakref._UpdateOneofState(self._field)
1600 super(_OneofListener, self).Modified()
1601 except ReferenceError:
1602 pass