1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import math
33import struct
34import sys
35import warnings
36import weakref
37
38from google.protobuf import descriptor as descriptor_mod
39from google.protobuf import message as message_mod
40from google.protobuf import text_format
41# We use "as" to avoid name collisions with variables.
42from google.protobuf.internal import api_implementation
43from google.protobuf.internal import containers
44from google.protobuf.internal import decoder
45from google.protobuf.internal import encoder
46from google.protobuf.internal import enum_type_wrapper
47from google.protobuf.internal import extension_dict
48from google.protobuf.internal import message_listener as message_listener_mod
49from google.protobuf.internal import type_checkers
50from google.protobuf.internal import well_known_types
51from google.protobuf.internal import wire_format
52
53_FieldDescriptor = descriptor_mod.FieldDescriptor
54_AnyFullTypeName = 'google.protobuf.Any'
55_StructFullTypeName = 'google.protobuf.Struct'
56_ListValueFullTypeName = 'google.protobuf.ListValue'
57_ExtensionDict = extension_dict._ExtensionDict
58
59class GeneratedProtocolMessageType(type):
60
61 """Metaclass for protocol message classes created at runtime from Descriptors.
62
63 We add implementations for all methods described in the Message class. We
64 also create properties to allow getting/setting all fields in the protocol
65 message. Finally, we create slots to prevent users from accidentally
66 "setting" nonexistent fields in the protocol message, which then wouldn't get
67 serialized / deserialized properly.
68
69 The protocol compiler currently uses this metaclass to create protocol
70 message classes at runtime. Clients can also manually create their own
71 classes at runtime, as in this example:
72
73 mydescriptor = Descriptor(.....)
74 factory = symbol_database.Default()
75 factory.pool.AddDescriptor(mydescriptor)
76 MyProtoClass = message_factory.GetMessageClass(mydescriptor)
77 myproto_instance = MyProtoClass()
78 myproto.foo_field = 23
79 ...
80 """
81
82 # Must be consistent with the protocol-compiler code in
83 # proto2/compiler/internal/generator.*.
84 _DESCRIPTOR_KEY = 'DESCRIPTOR'
85
86 def __new__(cls, name, bases, dictionary):
87 """Custom allocation for runtime-generated class types.
88
89 We override __new__ because this is apparently the only place
90 where we can meaningfully set __slots__ on the class we're creating(?).
91 (The interplay between metaclasses and slots is not very well-documented).
92
93 Args:
94 name: Name of the class (ignored, but required by the
95 metaclass protocol).
96 bases: Base classes of the class we're constructing.
97 (Should be message.Message). We ignore this field, but
98 it's required by the metaclass protocol
99 dictionary: The class dictionary of the class we're
100 constructing. dictionary[_DESCRIPTOR_KEY] must contain
101 a Descriptor object describing this protocol message
102 type.
103
104 Returns:
105 Newly-allocated class.
106
107 Raises:
108 RuntimeError: Generated code only work with python cpp extension.
109 """
110 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
111
112 if isinstance(descriptor, str):
113 raise RuntimeError('The generated code only work with python cpp '
114 'extension, but it is using pure python runtime.')
115
116 # If a concrete class already exists for this descriptor, don't try to
117 # create another. Doing so will break any messages that already exist with
118 # the existing class.
119 #
120 # The C++ implementation appears to have its own internal `PyMessageFactory`
121 # to achieve similar results.
122 #
123 # This most commonly happens in `text_format.py` when using descriptors from
124 # a custom pool; it calls message_factory.GetMessageClass() on a
125 # descriptor which already has an existing concrete class.
126 new_class = getattr(descriptor, '_concrete_class', None)
127 if new_class:
128 return new_class
129
130 if descriptor.full_name in well_known_types.WKTBASES:
131 bases += (well_known_types.WKTBASES[descriptor.full_name],)
132 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
133 _AddSlots(descriptor, dictionary)
134
135 superclass = super(GeneratedProtocolMessageType, cls)
136 new_class = superclass.__new__(cls, name, bases, dictionary)
137 return new_class
138
139 def __init__(cls, name, bases, dictionary):
140 """Here we perform the majority of our work on the class.
141 We add enum getters, an __init__ method, implementations
142 of all Message methods, and properties for all fields
143 in the protocol type.
144
145 Args:
146 name: Name of the class (ignored, but required by the
147 metaclass protocol).
148 bases: Base classes of the class we're constructing.
149 (Should be message.Message). We ignore this field, but
150 it's required by the metaclass protocol
151 dictionary: The class dictionary of the class we're
152 constructing. dictionary[_DESCRIPTOR_KEY] must contain
153 a Descriptor object describing this protocol message
154 type.
155 """
156 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
157
158 # If this is an _existing_ class looked up via `_concrete_class` in the
159 # __new__ method above, then we don't need to re-initialize anything.
160 existing_class = getattr(descriptor, '_concrete_class', None)
161 if existing_class:
162 assert existing_class is cls, (
163 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
164 % (descriptor.full_name))
165 return
166
167 cls._message_set_decoders_by_tag = {}
168 cls._fields_by_tag = {}
169 if (descriptor.has_options and
170 descriptor.GetOptions().message_set_wire_format):
171 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
172 decoder.MessageSetItemDecoder(descriptor),
173 None,
174 )
175
176 # Attach stuff to each FieldDescriptor for quick lookup later on.
177 for field in descriptor.fields:
178 _AttachFieldHelpers(cls, field)
179
180 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
181 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
182 for ext in extensions:
183 _AttachFieldHelpers(cls, ext)
184
185 descriptor._concrete_class = cls # pylint: disable=protected-access
186 _AddEnumValues(descriptor, cls)
187 _AddInitMethod(descriptor, cls)
188 _AddPropertiesForFields(descriptor, cls)
189 _AddPropertiesForExtensions(descriptor, cls)
190 _AddStaticMethods(cls)
191 _AddMessageMethods(descriptor, cls)
192 _AddPrivateHelperMethods(descriptor, cls)
193
194 superclass = super(GeneratedProtocolMessageType, cls)
195 superclass.__init__(name, bases, dictionary)
196
197
198# Stateless helpers for GeneratedProtocolMessageType below.
199# Outside clients should not access these directly.
200#
201# I opted not to make any of these methods on the metaclass, to make it more
202# clear that I'm not really using any state there and to keep clients from
203# thinking that they have direct access to these construction helpers.
204
205
206def _PropertyName(proto_field_name):
207 """Returns the name of the public property attribute which
208 clients can use to get and (in some cases) set the value
209 of a protocol message field.
210
211 Args:
212 proto_field_name: The protocol message field name, exactly
213 as it appears (or would appear) in a .proto file.
214 """
215 # TODO: Escape Python keywords (e.g., yield), and test this support.
216 # nnorwitz makes my day by writing:
217 # """
218 # FYI. See the keyword module in the stdlib. This could be as simple as:
219 #
220 # if keyword.iskeyword(proto_field_name):
221 # return proto_field_name + "_"
222 # return proto_field_name
223 # """
224 # Kenton says: The above is a BAD IDEA. People rely on being able to use
225 # getattr() and setattr() to reflectively manipulate field values. If we
226 # rename the properties, then every such user has to also make sure to apply
227 # the same transformation. Note that currently if you name a field "yield",
228 # you can still access it just fine using getattr/setattr -- it's not even
229 # that cumbersome to do so.
230 # TODO: Remove this method entirely if/when everyone agrees with my
231 # position.
232 return proto_field_name
233
234
235def _AddSlots(message_descriptor, dictionary):
236 """Adds a __slots__ entry to dictionary, containing the names of all valid
237 attributes for this message type.
238
239 Args:
240 message_descriptor: A Descriptor instance describing this message type.
241 dictionary: Class dictionary to which we'll add a '__slots__' entry.
242 """
243 dictionary['__slots__'] = ['_cached_byte_size',
244 '_cached_byte_size_dirty',
245 '_fields',
246 '_unknown_fields',
247 '_is_present_in_parent',
248 '_listener',
249 '_listener_for_children',
250 '__weakref__',
251 '_oneofs']
252
253
254def _IsMessageSetExtension(field):
255 return (field.is_extension and
256 field.containing_type.has_options and
257 field.containing_type.GetOptions().message_set_wire_format and
258 field.type == _FieldDescriptor.TYPE_MESSAGE and
259 not field.is_required and
260 not field.is_repeated)
261
262
263def _IsMapField(field):
264 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
265 field.message_type._is_map_entry)
266
267
268def _IsMessageMapField(field):
269 value_type = field.message_type.fields_by_name['value']
270 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
271
272def _AttachFieldHelpers(cls, field_descriptor):
273 field_descriptor._default_constructor = _DefaultValueConstructorForField(
274 field_descriptor
275 )
276
277 def AddFieldByTag(wiretype, is_packed):
278 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
279 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
280
281 AddFieldByTag(
282 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
283 )
284
285 if field_descriptor.is_repeated and wire_format.IsTypePackable(
286 field_descriptor.type
287 ):
288 # To support wire compatibility of adding packed = true, add a decoder for
289 # packed values regardless of the field's options.
290 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
291
292
293def _MaybeAddEncoder(cls, field_descriptor):
294 if hasattr(field_descriptor, '_encoder'):
295 return
296 is_repeated = field_descriptor.is_repeated
297 is_map_entry = _IsMapField(field_descriptor)
298 is_packed = field_descriptor.is_packed
299
300 if is_map_entry:
301 field_encoder = encoder.MapEncoder(field_descriptor)
302 sizer = encoder.MapSizer(field_descriptor,
303 _IsMessageMapField(field_descriptor))
304 elif _IsMessageSetExtension(field_descriptor):
305 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
306 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
307 else:
308 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
309 field_descriptor.number, is_repeated, is_packed)
310 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
311 field_descriptor.number, is_repeated, is_packed)
312
313 field_descriptor._sizer = sizer
314 field_descriptor._encoder = field_encoder
315
316
317def _MaybeAddDecoder(cls, field_descriptor):
318 if hasattr(field_descriptor, '_decoders'):
319 return
320
321 is_repeated = field_descriptor.is_repeated
322 is_map_entry = _IsMapField(field_descriptor)
323 helper_decoders = {}
324
325 def AddDecoder(is_packed):
326 decode_type = field_descriptor.type
327 if (decode_type == _FieldDescriptor.TYPE_ENUM and
328 not field_descriptor.enum_type.is_closed):
329 decode_type = _FieldDescriptor.TYPE_INT32
330
331 oneof_descriptor = None
332 if field_descriptor.containing_oneof is not None:
333 oneof_descriptor = field_descriptor
334
335 if is_map_entry:
336 is_message_map = _IsMessageMapField(field_descriptor)
337
338 field_decoder = decoder.MapDecoder(
339 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
340 is_message_map)
341 elif decode_type == _FieldDescriptor.TYPE_STRING:
342 field_decoder = decoder.StringDecoder(
343 field_descriptor.number, is_repeated, is_packed,
344 field_descriptor, field_descriptor._default_constructor,
345 not field_descriptor.has_presence)
346 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
347 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
348 field_descriptor.number, is_repeated, is_packed,
349 field_descriptor, field_descriptor._default_constructor)
350 else:
351 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
352 field_descriptor.number, is_repeated, is_packed,
353 # pylint: disable=protected-access
354 field_descriptor, field_descriptor._default_constructor,
355 not field_descriptor.has_presence)
356
357 helper_decoders[is_packed] = field_decoder
358
359 AddDecoder(False)
360
361 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
362 # To support wire compatibility of adding packed = true, add a decoder for
363 # packed values regardless of the field's options.
364 AddDecoder(True)
365
366 field_descriptor._decoders = helper_decoders
367
368
369def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
370 extensions = descriptor.extensions_by_name
371 for extension_name, extension_field in extensions.items():
372 assert extension_name not in dictionary
373 dictionary[extension_name] = extension_field
374
375
376def _AddEnumValues(descriptor, cls):
377 """Sets class-level attributes for all enum fields defined in this message.
378
379 Also exporting a class-level object that can name enum values.
380
381 Args:
382 descriptor: Descriptor object for this message type.
383 cls: Class we're constructing for this message type.
384 """
385 for enum_type in descriptor.enum_types:
386 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
387 for enum_value in enum_type.values:
388 setattr(cls, enum_value.name, enum_value.number)
389
390
391def _GetInitializeDefaultForMap(field):
392 if not field.is_repeated:
393 raise ValueError('map_entry set on non-repeated field %s' % (
394 field.name))
395 fields_by_name = field.message_type.fields_by_name
396 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
397
398 value_field = fields_by_name['value']
399 if _IsMessageMapField(field):
400 def MakeMessageMapDefault(message):
401 return containers.MessageMap(
402 message._listener_for_children, value_field.message_type, key_checker,
403 field.message_type)
404 return MakeMessageMapDefault
405 else:
406 value_checker = type_checkers.GetTypeChecker(value_field)
407 def MakePrimitiveMapDefault(message):
408 return containers.ScalarMap(
409 message._listener_for_children, key_checker, value_checker,
410 field.message_type)
411 return MakePrimitiveMapDefault
412
413def _DefaultValueConstructorForField(field):
414 """Returns a function which returns a default value for a field.
415
416 Args:
417 field: FieldDescriptor object for this field.
418
419 The returned function has one argument:
420 message: Message instance containing this field, or a weakref proxy
421 of same.
422
423 That function in turn returns a default value for this field. The default
424 value may refer back to |message| via a weak reference.
425 """
426
427 if _IsMapField(field):
428 return _GetInitializeDefaultForMap(field)
429
430 if field.is_repeated:
431 if field.has_default_value and field.default_value != []:
432 raise ValueError('Repeated field default value not empty list: %s' % (
433 field.default_value))
434 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
435 # We can't look at _concrete_class yet since it might not have
436 # been set. (Depends on order in which we initialize the classes).
437 message_type = field.message_type
438 def MakeRepeatedMessageDefault(message):
439 return containers.RepeatedCompositeFieldContainer(
440 message._listener_for_children, field.message_type)
441 return MakeRepeatedMessageDefault
442 else:
443 type_checker = type_checkers.GetTypeChecker(field)
444 def MakeRepeatedScalarDefault(message):
445 return containers.RepeatedScalarFieldContainer(
446 message._listener_for_children, type_checker)
447 return MakeRepeatedScalarDefault
448
449 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
450 message_type = field.message_type
451 def MakeSubMessageDefault(message):
452 # _concrete_class may not yet be initialized.
453 if not hasattr(message_type, '_concrete_class'):
454 from google.protobuf import message_factory
455 message_factory.GetMessageClass(message_type)
456 result = message_type._concrete_class()
457 result._SetListener(
458 _OneofListener(message, field)
459 if field.containing_oneof is not None
460 else message._listener_for_children)
461 return result
462 return MakeSubMessageDefault
463
464 def MakeScalarDefault(message):
465 # TODO: This may be broken since there may not be
466 # default_value. Combine with has_default_value somehow.
467 return field.default_value
468 return MakeScalarDefault
469
470
471def _ReraiseTypeErrorWithFieldName(message_name, field_name):
472 """Re-raise the currently-handled TypeError with the field name added."""
473 exc = sys.exc_info()[1]
474 if len(exc.args) == 1 and type(exc) is TypeError:
475 # simple TypeError; add field name to exception message
476 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
477
478 # re-raise possibly-amended exception with original traceback:
479 raise exc.with_traceback(sys.exc_info()[2])
480
481
482def _AddInitMethod(message_descriptor, cls):
483 """Adds an __init__ method to cls."""
484
485 def _GetIntegerEnumValue(enum_type, value):
486 """Convert a string or integer enum value to an integer.
487
488 If the value is a string, it is converted to the enum value in
489 enum_type with the same name. If the value is not a string, it's
490 returned as-is. (No conversion or bounds-checking is done.)
491 """
492 if isinstance(value, str):
493 try:
494 return enum_type.values_by_name[value].number
495 except KeyError:
496 raise ValueError('Enum type %s: unknown label "%s"' % (
497 enum_type.full_name, value))
498 return value
499
500 def init(self, **kwargs):
501
502 def init_wkt_or_merge(field, msg, value):
503 if isinstance(value, message_mod.Message):
504 msg.MergeFrom(value)
505 elif (
506 isinstance(value, dict)
507 and field.message_type.full_name == _StructFullTypeName
508 ):
509 msg.Clear()
510 if len(value) == 1 and 'fields' in value:
511 try:
512 msg.update(value)
513 except:
514 msg.Clear()
515 msg.__init__(**value)
516 else:
517 msg.update(value)
518 elif hasattr(msg, '_internal_assign'):
519 msg._internal_assign(value)
520 else:
521 raise TypeError(
522 'Message field {0}.{1} must be initialized with a '
523 'dict or instance of same class, got {2}.'.format(
524 message_descriptor.name,
525 field.name,
526 type(value).__name__,
527 )
528 )
529
530 self._cached_byte_size = 0
531 self._cached_byte_size_dirty = len(kwargs) > 0
532 self._fields = {}
533 # Contains a mapping from oneof field descriptors to the descriptor
534 # of the currently set field in that oneof field.
535 self._oneofs = {}
536
537 # _unknown_fields is () when empty for efficiency, and will be turned into
538 # a list if fields are added.
539 self._unknown_fields = ()
540 self._is_present_in_parent = False
541 self._listener = message_listener_mod.NullMessageListener()
542 self._listener_for_children = _Listener(self)
543 for field_name, field_value in kwargs.items():
544 field = _GetFieldByName(message_descriptor, field_name)
545 if field is None:
546 raise TypeError('%s() got an unexpected keyword argument "%s"' %
547 (message_descriptor.name, field_name))
548 if field_value is None:
549 # field=None is the same as no field at all.
550 continue
551 if field.is_repeated:
552 field_copy = field._default_constructor(self)
553 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
554 if _IsMapField(field):
555 if _IsMessageMapField(field):
556 for key in field_value:
557 item_value = field_value[key]
558 if isinstance(item_value, dict):
559 field_copy[key].__init__(**item_value)
560 else:
561 field_copy[key].MergeFrom(item_value)
562 else:
563 field_copy.update(field_value)
564 else:
565 for val in field_value:
566 if isinstance(val, dict) and (
567 field.message_type.full_name != _StructFullTypeName
568 ):
569 field_copy.add(**val)
570 else:
571 new_msg = field_copy.add()
572 init_wkt_or_merge(field, new_msg, val)
573 else: # Scalar
574 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
575 field_value = [_GetIntegerEnumValue(field.enum_type, val)
576 for val in field_value]
577 field_copy.extend(field_value)
578 self._fields[field] = field_copy
579 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
580 field_copy = field._default_constructor(self)
581 if isinstance(field_value, dict) and (
582 field.message_type.full_name != _StructFullTypeName
583 ):
584 new_val = field.message_type._concrete_class(**field_value)
585 field_copy.MergeFrom(new_val)
586 else:
587 try:
588 init_wkt_or_merge(field, field_copy, field_value)
589 except TypeError:
590 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
591 self._fields[field] = field_copy
592 else:
593 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
594 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
595 try:
596 setattr(self, field_name, field_value)
597 except TypeError:
598 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
599
600 init.__module__ = None
601 init.__doc__ = None
602 cls.__init__ = init
603
604
605def _GetFieldByName(message_descriptor, field_name):
606 """Returns a field descriptor by field name.
607
608 Args:
609 message_descriptor: A Descriptor describing all fields in message.
610 field_name: The name of the field to retrieve.
611 Returns:
612 The field descriptor associated with the field name.
613 """
614 try:
615 return message_descriptor.fields_by_name[field_name]
616 except KeyError:
617 raise ValueError('Protocol message %s has no "%s" field.' %
618 (message_descriptor.name, field_name))
619
620
621def _AddPropertiesForFields(descriptor, cls):
622 """Adds properties for all fields in this protocol message type."""
623 for field in descriptor.fields:
624 _AddPropertiesForField(field, cls)
625
626 if descriptor.is_extendable:
627 # _ExtensionDict is just an adaptor with no state so we allocate a new one
628 # every time it is accessed.
629 cls.Extensions = property(lambda self: _ExtensionDict(self))
630
631
632def _AddPropertiesForField(field, cls):
633 """Adds a public property for a protocol message field.
634 Clients can use this property to get and (in the case
635 of non-repeated scalar fields) directly set the value
636 of a protocol message field.
637
638 Args:
639 field: A FieldDescriptor for this field.
640 cls: The class we're constructing.
641 """
642 # Catch it if we add other types that we should
643 # handle specially here.
644 assert _FieldDescriptor.MAX_CPPTYPE == 10
645
646 constant_name = field.name.upper() + '_FIELD_NUMBER'
647 setattr(cls, constant_name, field.number)
648
649 if field.is_repeated:
650 _AddPropertiesForRepeatedField(field, cls)
651 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
652 _AddPropertiesForNonRepeatedCompositeField(field, cls)
653 else:
654 _AddPropertiesForNonRepeatedScalarField(field, cls)
655
656
657class _FieldProperty(property):
658 __slots__ = ('DESCRIPTOR',)
659
660 def __init__(self, descriptor, getter, setter, doc):
661 property.__init__(self, getter, setter, doc=doc)
662 self.DESCRIPTOR = descriptor
663
664
665def _AddPropertiesForRepeatedField(field, cls):
666 """Adds a public property for a "repeated" protocol message field. Clients
667 can use this property to get the value of the field, which will be either a
668 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
669 below).
670
671 Note that when clients add values to these containers, we perform
672 type-checking in the case of repeated scalar fields, and we also set any
673 necessary "has" bits as a side-effect.
674
675 Args:
676 field: A FieldDescriptor for this field.
677 cls: The class we're constructing.
678 """
679 proto_field_name = field.name
680 property_name = _PropertyName(proto_field_name)
681
682 def getter(self):
683 field_value = self._fields.get(field)
684 if field_value is None:
685 # Construct a new object to represent this field.
686 field_value = field._default_constructor(self)
687
688 # Atomically check if another thread has preempted us and, if not, swap
689 # in the new object we just created. If someone has preempted us, we
690 # take that object and discard ours.
691 # WARNING: We are relying on setdefault() being atomic. This is true
692 # in CPython but we haven't investigated others. This warning appears
693 # in several other locations in this file.
694 field_value = self._fields.setdefault(field, field_value)
695 return field_value
696 getter.__module__ = None
697 getter.__doc__ = 'Getter for %s.' % proto_field_name
698
699 # We define a setter just so we can throw an exception with a more
700 # helpful error message.
701 def setter(self, new_value):
702 raise AttributeError('Assignment not allowed to repeated field '
703 '"%s" in protocol message object.' % proto_field_name)
704
705 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
706 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
707
708
709def _AddPropertiesForNonRepeatedScalarField(field, cls):
710 """Adds a public property for a nonrepeated, scalar protocol message field.
711 Clients can use this property to get and directly set the value of the field.
712 Note that when the client sets the value of a field by using this property,
713 all necessary "has" bits are set as a side-effect, and we also perform
714 type-checking.
715
716 Args:
717 field: A FieldDescriptor for this field.
718 cls: The class we're constructing.
719 """
720 proto_field_name = field.name
721 property_name = _PropertyName(proto_field_name)
722 type_checker = type_checkers.GetTypeChecker(field)
723 default_value = field.default_value
724
725 def getter(self):
726 # TODO: This may be broken since there may not be
727 # default_value. Combine with has_default_value somehow.
728 return self._fields.get(field, default_value)
729 getter.__module__ = None
730 getter.__doc__ = 'Getter for %s.' % proto_field_name
731
732 def field_setter(self, new_value):
733 # pylint: disable=protected-access
734 # Testing the value for truthiness captures all of the implicit presence
735 # defaults (0, 0.0, enum 0, and False), except for -0.0.
736 try:
737 new_value = type_checker.CheckValue(new_value)
738 except TypeError as e:
739 raise TypeError(
740 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
741 if not field.has_presence and decoder.IsDefaultScalarValue(new_value):
742 self._fields.pop(field, None)
743 else:
744 self._fields[field] = new_value
745 # Check _cached_byte_size_dirty inline to improve performance, since scalar
746 # setters are called frequently.
747 if not self._cached_byte_size_dirty:
748 self._Modified()
749
750 if field.containing_oneof:
751 def setter(self, new_value):
752 field_setter(self, new_value)
753 self._UpdateOneofState(field)
754 else:
755 setter = field_setter
756
757 setter.__module__ = None
758 setter.__doc__ = 'Setter for %s.' % proto_field_name
759
760 # Add a property to encapsulate the getter/setter.
761 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
762 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
763
764
765def _AddPropertiesForNonRepeatedCompositeField(field, cls):
766 """Adds a public property for a nonrepeated, composite protocol message field.
767 A composite field is a "group" or "message" field.
768
769 Clients can use this property to get the value of the field, but cannot
770 assign to the property directly.
771
772 Args:
773 field: A FieldDescriptor for this field.
774 cls: The class we're constructing.
775 """
776 # TODO: Remove duplication with similar method
777 # for non-repeated scalars.
778 proto_field_name = field.name
779 property_name = _PropertyName(proto_field_name)
780
781 def getter(self):
782 field_value = self._fields.get(field)
783 if field_value is None:
784 # Construct a new object to represent this field.
785 field_value = field._default_constructor(self)
786
787 # Atomically check if another thread has preempted us and, if not, swap
788 # in the new object we just created. If someone has preempted us, we
789 # take that object and discard ours.
790 # WARNING: We are relying on setdefault() being atomic. This is true
791 # in CPython but we haven't investigated others. This warning appears
792 # in several other locations in this file.
793 field_value = self._fields.setdefault(field, field_value)
794 return field_value
795 getter.__module__ = None
796 getter.__doc__ = 'Getter for %s.' % proto_field_name
797
798 # We define a setter just so we can throw an exception with a more
799 # helpful error message.
800 def setter(self, new_value):
801 if field.message_type.full_name == 'google.protobuf.Timestamp':
802 getter(self)
803 self._fields[field].FromDatetime(new_value)
804 elif field.message_type.full_name == 'google.protobuf.Duration':
805 getter(self)
806 self._fields[field].FromTimedelta(new_value)
807 elif field.message_type.full_name == _StructFullTypeName:
808 getter(self)
809 self._fields[field].Clear()
810 self._fields[field].update(new_value)
811 elif field.message_type.full_name == _ListValueFullTypeName:
812 getter(self)
813 self._fields[field].Clear()
814 self._fields[field].extend(new_value)
815 else:
816 raise AttributeError(
817 'Assignment not allowed to composite field '
818 '"%s" in protocol message object.' % proto_field_name
819 )
820
821 # Add a property to encapsulate the getter.
822 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
823 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
824
825
826def _AddPropertiesForExtensions(descriptor, cls):
827 """Adds properties for all fields in this protocol message type."""
828 extensions = descriptor.extensions_by_name
829 for extension_name, extension_field in extensions.items():
830 constant_name = extension_name.upper() + '_FIELD_NUMBER'
831 setattr(cls, constant_name, extension_field.number)
832
833 # TODO: Migrate all users of these attributes to functions like
834 # pool.FindExtensionByNumber(descriptor).
835 if descriptor.file is not None:
836 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
837 pool = descriptor.file.pool
838
839def _AddStaticMethods(cls):
840
841 def RegisterExtension(_):
842 """no-op to keep generated code <=4.23 working with new runtimes."""
843 # This was originally removed in 5.26 (cl/595989309).
844 pass
845
846 cls.RegisterExtension = staticmethod(RegisterExtension)
847 def FromString(s):
848 message = cls()
849 message.MergeFromString(s)
850 return message
851 cls.FromString = staticmethod(FromString)
852
853
854def _IsPresent(item):
855 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
856 value should be included in the list returned by ListFields()."""
857
858 if item[0].is_repeated:
859 return bool(item[1])
860 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
861 return item[1]._is_present_in_parent
862 else:
863 return True
864
865
866def _AddListFieldsMethod(message_descriptor, cls):
867 """Helper for _AddMessageMethods()."""
868
869 def ListFields(self):
870 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
871 all_fields.sort(key = lambda item: item[0].number)
872 return all_fields
873
874 cls.ListFields = ListFields
875
876
877def _AddHasFieldMethod(message_descriptor, cls):
878 """Helper for _AddMessageMethods()."""
879
880 hassable_fields = {}
881 for field in message_descriptor.fields:
882 if field.is_repeated:
883 continue
884 # For proto3, only submessages and fields inside a oneof have presence.
885 if not field.has_presence:
886 continue
887 hassable_fields[field.name] = field
888
889 # Has methods are supported for oneof descriptors.
890 for oneof in message_descriptor.oneofs:
891 hassable_fields[oneof.name] = oneof
892
893 def HasField(self, field_name):
894 try:
895 field = hassable_fields[field_name]
896 except KeyError as exc:
897 raise ValueError('Protocol message %s has no non-repeated field "%s" '
898 'nor has presence is not available for this field.' % (
899 message_descriptor.full_name, field_name)) from exc
900
901 if isinstance(field, descriptor_mod.OneofDescriptor):
902 try:
903 return HasField(self, self._oneofs[field].name)
904 except KeyError:
905 return False
906 else:
907 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
908 value = self._fields.get(field)
909 return value is not None and value._is_present_in_parent
910 else:
911 return field in self._fields
912
913 cls.HasField = HasField
914
915
916def _AddClearFieldMethod(message_descriptor, cls):
917 """Helper for _AddMessageMethods()."""
918 def ClearField(self, field_name):
919 try:
920 field = message_descriptor.fields_by_name[field_name]
921 except KeyError:
922 try:
923 field = message_descriptor.oneofs_by_name[field_name]
924 if field in self._oneofs:
925 field = self._oneofs[field]
926 else:
927 return
928 except KeyError:
929 raise ValueError('Protocol message %s has no "%s" field.' %
930 (message_descriptor.name, field_name))
931
932 if field in self._fields:
933 # To match the C++ implementation, we need to invalidate iterators
934 # for map fields when ClearField() happens.
935 if hasattr(self._fields[field], 'InvalidateIterators'):
936 self._fields[field].InvalidateIterators()
937
938 # Note: If the field is a sub-message, its listener will still point
939 # at us. That's fine, because the worst than can happen is that it
940 # will call _Modified() and invalidate our byte size. Big deal.
941 del self._fields[field]
942
943 if self._oneofs.get(field.containing_oneof, None) is field:
944 del self._oneofs[field.containing_oneof]
945
946 # Always call _Modified() -- even if nothing was changed, this is
947 # a mutating method, and thus calling it should cause the field to become
948 # present in the parent message.
949 self._Modified()
950
951 cls.ClearField = ClearField
952
953
954def _AddClearExtensionMethod(cls):
955 """Helper for _AddMessageMethods()."""
956 def ClearExtension(self, field_descriptor):
957 extension_dict._VerifyExtensionHandle(self, field_descriptor)
958
959 # Similar to ClearField(), above.
960 if field_descriptor in self._fields:
961 del self._fields[field_descriptor]
962 self._Modified()
963 cls.ClearExtension = ClearExtension
964
965
966def _AddHasExtensionMethod(cls):
967 """Helper for _AddMessageMethods()."""
968 def HasExtension(self, field_descriptor):
969 extension_dict._VerifyExtensionHandle(self, field_descriptor)
970 if field_descriptor.is_repeated:
971 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
972
973 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
974 value = self._fields.get(field_descriptor)
975 return value is not None and value._is_present_in_parent
976 else:
977 return field_descriptor in self._fields
978 cls.HasExtension = HasExtension
979
980def _InternalUnpackAny(msg):
981 """Unpacks Any message and returns the unpacked message.
982
983 This internal method is different from public Any Unpack method which takes
984 the target message as argument. _InternalUnpackAny method does not have
985 target message type and need to find the message type in descriptor pool.
986
987 Args:
988 msg: An Any message to be unpacked.
989
990 Returns:
991 The unpacked message.
992 """
993 # TODO: Don't use the factory of generated messages.
994 # To make Any work with custom factories, use the message factory of the
995 # parent message.
996 # pylint: disable=g-import-not-at-top
997 from google.protobuf import symbol_database
998 factory = symbol_database.Default()
999
1000 type_url = msg.type_url
1001
1002 if not type_url:
1003 return None
1004
1005 # TODO: For now we just strip the hostname. Better logic will be
1006 # required.
1007 type_name = type_url.split('/')[-1]
1008 descriptor = factory.pool.FindMessageTypeByName(type_name)
1009
1010 if descriptor is None:
1011 return None
1012
1013 # Unable to import message_factory at top because of circular import.
1014 # pylint: disable=g-import-not-at-top
1015 from google.protobuf import message_factory
1016 message_class = message_factory.GetMessageClass(descriptor)
1017 message = message_class()
1018
1019 message.ParseFromString(msg.value)
1020 return message
1021
1022
1023def _AddEqualsMethod(message_descriptor, cls):
1024 """Helper for _AddMessageMethods()."""
1025 def __eq__(self, other):
1026 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1027 other, list
1028 ):
1029 return self._internal_compare(other)
1030 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1031 other, dict
1032 ):
1033 return self._internal_compare(other)
1034
1035 if (not isinstance(other, message_mod.Message) or
1036 other.DESCRIPTOR != self.DESCRIPTOR):
1037 return NotImplemented
1038
1039 if self is other:
1040 return True
1041
1042 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1043 any_a = _InternalUnpackAny(self)
1044 any_b = _InternalUnpackAny(other)
1045 if any_a and any_b:
1046 return any_a == any_b
1047
1048 if not self.ListFields() == other.ListFields():
1049 return False
1050
1051 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1052 # then use it for the comparison.
1053 unknown_fields = list(self._unknown_fields)
1054 unknown_fields.sort()
1055 other_unknown_fields = list(other._unknown_fields)
1056 other_unknown_fields.sort()
1057 return unknown_fields == other_unknown_fields
1058
1059 cls.__eq__ = __eq__
1060
1061
1062def _AddStrMethod(message_descriptor, cls):
1063 """Helper for _AddMessageMethods()."""
1064 def __str__(self):
1065 return text_format.MessageToString(self)
1066 cls.__str__ = __str__
1067
1068
1069def _AddReprMethod(message_descriptor, cls):
1070 """Helper for _AddMessageMethods()."""
1071 def __repr__(self):
1072 return text_format.MessageToString(self)
1073 cls.__repr__ = __repr__
1074
1075
1076def _AddUnicodeMethod(unused_message_descriptor, cls):
1077 """Helper for _AddMessageMethods()."""
1078
1079 def __unicode__(self):
1080 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1081 cls.__unicode__ = __unicode__
1082
1083
1084def _AddContainsMethod(message_descriptor, cls):
1085
1086 if message_descriptor.full_name == 'google.protobuf.Struct':
1087 def __contains__(self, key):
1088 return key in self.fields
1089 elif message_descriptor.full_name == 'google.protobuf.ListValue':
1090 def __contains__(self, value):
1091 return value in self.items()
1092 else:
1093 def __contains__(self, field):
1094 return self.HasField(field)
1095
1096 cls.__contains__ = __contains__
1097
1098
1099def _BytesForNonRepeatedElement(value, field_number, field_type):
1100 """Returns the number of bytes needed to serialize a non-repeated element.
1101 The returned byte count includes space for tag information and any
1102 other additional space associated with serializing value.
1103
1104 Args:
1105 value: Value we're serializing.
1106 field_number: Field number of this value. (Since the field number
1107 is stored as part of a varint-encoded tag, this has an impact
1108 on the total bytes required to serialize the value).
1109 field_type: The type of the field. One of the TYPE_* constants
1110 within FieldDescriptor.
1111 """
1112 try:
1113 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1114 return fn(field_number, value)
1115 except KeyError:
1116 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1117
1118
1119def _AddByteSizeMethod(message_descriptor, cls):
1120 """Helper for _AddMessageMethods()."""
1121
1122 def ByteSize(self):
1123 if not self._cached_byte_size_dirty:
1124 return self._cached_byte_size
1125
1126 size = 0
1127 descriptor = self.DESCRIPTOR
1128 if descriptor._is_map_entry:
1129 # Fields of map entry should always be serialized.
1130 key_field = descriptor.fields_by_name['key']
1131 _MaybeAddEncoder(cls, key_field)
1132 size = key_field._sizer(self.key)
1133 value_field = descriptor.fields_by_name['value']
1134 _MaybeAddEncoder(cls, value_field)
1135 size += value_field._sizer(self.value)
1136 else:
1137 for field_descriptor, field_value in self.ListFields():
1138 _MaybeAddEncoder(cls, field_descriptor)
1139 size += field_descriptor._sizer(field_value)
1140 for tag_bytes, value_bytes in self._unknown_fields:
1141 size += len(tag_bytes) + len(value_bytes)
1142
1143 self._cached_byte_size = size
1144 self._cached_byte_size_dirty = False
1145 self._listener_for_children.dirty = False
1146 return size
1147
1148 cls.ByteSize = ByteSize
1149
1150
1151def _AddSerializeToStringMethod(message_descriptor, cls):
1152 """Helper for _AddMessageMethods()."""
1153
1154 def SerializeToString(self, **kwargs):
1155 # Check if the message has all of its required fields set.
1156 if not self.IsInitialized():
1157 raise message_mod.EncodeError(
1158 'Message %s is missing required fields: %s' % (
1159 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1160 return self.SerializePartialToString(**kwargs)
1161 cls.SerializeToString = SerializeToString
1162
1163
1164def _AddSerializePartialToStringMethod(message_descriptor, cls):
1165 """Helper for _AddMessageMethods()."""
1166
1167 def SerializePartialToString(self, **kwargs):
1168 out = BytesIO()
1169 self._InternalSerialize(out.write, **kwargs)
1170 return out.getvalue()
1171 cls.SerializePartialToString = SerializePartialToString
1172
1173 def InternalSerialize(self, write_bytes, deterministic=None):
1174 if deterministic is None:
1175 deterministic = (
1176 api_implementation.IsPythonDefaultSerializationDeterministic())
1177 else:
1178 deterministic = bool(deterministic)
1179
1180 descriptor = self.DESCRIPTOR
1181 if descriptor._is_map_entry:
1182 # Fields of map entry should always be serialized.
1183 key_field = descriptor.fields_by_name['key']
1184 _MaybeAddEncoder(cls, key_field)
1185 key_field._encoder(write_bytes, self.key, deterministic)
1186 value_field = descriptor.fields_by_name['value']
1187 _MaybeAddEncoder(cls, value_field)
1188 value_field._encoder(write_bytes, self.value, deterministic)
1189 else:
1190 for field_descriptor, field_value in self.ListFields():
1191 _MaybeAddEncoder(cls, field_descriptor)
1192 field_descriptor._encoder(write_bytes, field_value, deterministic)
1193 for tag_bytes, value_bytes in self._unknown_fields:
1194 write_bytes(tag_bytes)
1195 write_bytes(value_bytes)
1196 cls._InternalSerialize = InternalSerialize
1197
1198
1199def _AddMergeFromStringMethod(message_descriptor, cls):
1200 """Helper for _AddMessageMethods()."""
1201 def MergeFromString(self, serialized):
1202 serialized = memoryview(serialized)
1203 length = len(serialized)
1204 try:
1205 if self._InternalParse(serialized, 0, length) != length:
1206 # The only reason _InternalParse would return early is if it
1207 # encountered an end-group tag.
1208 raise message_mod.DecodeError('Unexpected end-group tag.')
1209 except (IndexError, TypeError):
1210 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1211 raise message_mod.DecodeError('Truncated message.')
1212 except struct.error as e:
1213 raise message_mod.DecodeError(e)
1214 return length # Return this for legacy reasons.
1215 cls.MergeFromString = MergeFromString
1216
1217 fields_by_tag = cls._fields_by_tag
1218 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1219
1220 def InternalParse(self, buffer, pos, end, current_depth=0):
1221 """Create a message from serialized bytes.
1222
1223 Args:
1224 self: Message, instance of the proto message object.
1225 buffer: memoryview of the serialized data.
1226 pos: int, position to start in the serialized data.
1227 end: int, end position of the serialized data.
1228
1229 Returns:
1230 Message object.
1231 """
1232 # Guard against internal misuse, since this function is called internally
1233 # quite extensively, and its easy to accidentally pass bytes.
1234 assert isinstance(buffer, memoryview)
1235 self._Modified()
1236 field_dict = self._fields
1237 while pos != end:
1238 (tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
1239 field_decoder, field_des = message_set_decoders_by_tag.get(
1240 tag_bytes, (None, None)
1241 )
1242 if field_decoder:
1243 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1244 continue
1245 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1246 if field_des is None:
1247 if not self._unknown_fields: # pylint: disable=protected-access
1248 self._unknown_fields = [] # pylint: disable=protected-access
1249 field_number, wire_type = decoder.DecodeTag(tag_bytes)
1250 if field_number == 0:
1251 raise message_mod.DecodeError('Field number 0 is illegal.')
1252 (data, new_pos) = decoder._DecodeUnknownField(
1253 buffer, new_pos, end, field_number, wire_type
1254 ) # pylint: disable=protected-access
1255 if new_pos == -1:
1256 return pos
1257 self._unknown_fields.append(
1258 (tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
1259 )
1260 pos = new_pos
1261 else:
1262 _MaybeAddDecoder(cls, field_des)
1263 field_decoder = field_des._decoders[is_packed]
1264 pos = field_decoder(
1265 buffer, new_pos, end, self, field_dict, current_depth
1266 )
1267 if field_des.containing_oneof:
1268 self._UpdateOneofState(field_des)
1269 return pos
1270
1271 cls._InternalParse = InternalParse
1272
1273
1274def _AddIsInitializedMethod(message_descriptor, cls):
1275 """Adds the IsInitialized and FindInitializationError methods to the
1276 protocol message class."""
1277
1278 required_fields = [field for field in message_descriptor.fields
1279 if field.is_required]
1280
1281 def IsInitialized(self, errors=None):
1282 """Checks if all required fields of a message are set.
1283
1284 Args:
1285 errors: A list which, if provided, will be populated with the field
1286 paths of all missing required fields.
1287
1288 Returns:
1289 True iff the specified message has all required fields set.
1290 """
1291
1292 # Performance is critical so we avoid HasField() and ListFields().
1293
1294 for field in required_fields:
1295 if (field not in self._fields or
1296 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1297 not self._fields[field]._is_present_in_parent)):
1298 if errors is not None:
1299 errors.extend(self.FindInitializationErrors())
1300 return False
1301
1302 for field, value in list(self._fields.items()): # dict can change size!
1303 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1304 if field.is_repeated:
1305 if (field.message_type._is_map_entry):
1306 continue
1307 for element in value:
1308 if not element.IsInitialized():
1309 if errors is not None:
1310 errors.extend(self.FindInitializationErrors())
1311 return False
1312 elif value._is_present_in_parent and not value.IsInitialized():
1313 if errors is not None:
1314 errors.extend(self.FindInitializationErrors())
1315 return False
1316
1317 return True
1318
1319 cls.IsInitialized = IsInitialized
1320
1321 def FindInitializationErrors(self):
1322 """Finds required fields which are not initialized.
1323
1324 Returns:
1325 A list of strings. Each string is a path to an uninitialized field from
1326 the top-level message, e.g. "foo.bar[5].baz".
1327 """
1328
1329 errors = [] # simplify things
1330
1331 for field in required_fields:
1332 if not self.HasField(field.name):
1333 errors.append(field.name)
1334
1335 for field, value in self.ListFields():
1336 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1337 if field.is_extension:
1338 name = '(%s)' % field.full_name
1339 else:
1340 name = field.name
1341
1342 if _IsMapField(field):
1343 if _IsMessageMapField(field):
1344 for key in value:
1345 element = value[key]
1346 prefix = '%s[%s].' % (name, key)
1347 sub_errors = element.FindInitializationErrors()
1348 errors += [prefix + error for error in sub_errors]
1349 else:
1350 # ScalarMaps can't have any initialization errors.
1351 pass
1352 elif field.is_repeated:
1353 for i in range(len(value)):
1354 element = value[i]
1355 prefix = '%s[%d].' % (name, i)
1356 sub_errors = element.FindInitializationErrors()
1357 errors += [prefix + error for error in sub_errors]
1358 else:
1359 prefix = name + '.'
1360 sub_errors = value.FindInitializationErrors()
1361 errors += [prefix + error for error in sub_errors]
1362
1363 return errors
1364
1365 cls.FindInitializationErrors = FindInitializationErrors
1366
1367
1368def _FullyQualifiedClassName(klass):
1369 module = klass.__module__
1370 name = getattr(klass, '__qualname__', klass.__name__)
1371 if module in (None, 'builtins', '__builtin__'):
1372 return name
1373 return module + '.' + name
1374
1375
1376def _AddMergeFromMethod(cls):
1377 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1378
1379 def MergeFrom(self, msg):
1380 if not isinstance(msg, cls):
1381 raise TypeError(
1382 'Parameter to MergeFrom() must be instance of same class: '
1383 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1384 _FullyQualifiedClassName(msg.__class__)))
1385
1386 assert msg is not self
1387 self._Modified()
1388
1389 fields = self._fields
1390
1391 for field, value in msg._fields.items():
1392 if field.is_repeated:
1393 field_value = fields.get(field)
1394 if field_value is None:
1395 # Construct a new object to represent this field.
1396 field_value = field._default_constructor(self)
1397 fields[field] = field_value
1398 field_value.MergeFrom(value)
1399 elif field.cpp_type == CPPTYPE_MESSAGE:
1400 if value._is_present_in_parent:
1401 field_value = fields.get(field)
1402 if field_value is None:
1403 # Construct a new object to represent this field.
1404 field_value = field._default_constructor(self)
1405 fields[field] = field_value
1406 field_value.MergeFrom(value)
1407 else:
1408 self._fields[field] = value
1409 if field.containing_oneof:
1410 self._UpdateOneofState(field)
1411
1412 if msg._unknown_fields:
1413 if not self._unknown_fields:
1414 self._unknown_fields = []
1415 self._unknown_fields.extend(msg._unknown_fields)
1416
1417 cls.MergeFrom = MergeFrom
1418
1419
1420def _AddWhichOneofMethod(message_descriptor, cls):
1421 def WhichOneof(self, oneof_name):
1422 """Returns the name of the currently set field inside a oneof, or None."""
1423 try:
1424 field = message_descriptor.oneofs_by_name[oneof_name]
1425 except KeyError:
1426 raise ValueError(
1427 'Protocol message has no oneof "%s" field.' % oneof_name)
1428
1429 nested_field = self._oneofs.get(field, None)
1430 if nested_field is not None and self.HasField(nested_field.name):
1431 return nested_field.name
1432 else:
1433 return None
1434
1435 cls.WhichOneof = WhichOneof
1436
1437
1438def _Clear(self):
1439 # Clear fields.
1440 self._fields = {}
1441 self._unknown_fields = ()
1442
1443 self._oneofs = {}
1444 self._Modified()
1445
1446
1447def _UnknownFields(self):
1448 raise NotImplementedError('Please use the add-on feaure '
1449 'unknown_fields.UnknownFieldSet(message) in '
1450 'unknown_fields.py instead.')
1451
1452
1453def _DiscardUnknownFields(self):
1454 self._unknown_fields = []
1455 for field, value in self.ListFields():
1456 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1457 if _IsMapField(field):
1458 if _IsMessageMapField(field):
1459 for key in value:
1460 value[key].DiscardUnknownFields()
1461 elif field.is_repeated:
1462 for sub_message in value:
1463 sub_message.DiscardUnknownFields()
1464 else:
1465 value.DiscardUnknownFields()
1466
1467
1468def _SetListener(self, listener):
1469 if listener is None:
1470 self._listener = message_listener_mod.NullMessageListener()
1471 else:
1472 self._listener = listener
1473
1474
1475def _AddMessageMethods(message_descriptor, cls):
1476 """Adds implementations of all Message methods to cls."""
1477 _AddListFieldsMethod(message_descriptor, cls)
1478 _AddHasFieldMethod(message_descriptor, cls)
1479 _AddClearFieldMethod(message_descriptor, cls)
1480 if message_descriptor.is_extendable:
1481 _AddClearExtensionMethod(cls)
1482 _AddHasExtensionMethod(cls)
1483 _AddEqualsMethod(message_descriptor, cls)
1484 _AddStrMethod(message_descriptor, cls)
1485 _AddReprMethod(message_descriptor, cls)
1486 _AddUnicodeMethod(message_descriptor, cls)
1487 _AddContainsMethod(message_descriptor, cls)
1488 _AddByteSizeMethod(message_descriptor, cls)
1489 _AddSerializeToStringMethod(message_descriptor, cls)
1490 _AddSerializePartialToStringMethod(message_descriptor, cls)
1491 _AddMergeFromStringMethod(message_descriptor, cls)
1492 _AddIsInitializedMethod(message_descriptor, cls)
1493 _AddMergeFromMethod(cls)
1494 _AddWhichOneofMethod(message_descriptor, cls)
1495 # Adds methods which do not depend on cls.
1496 cls.Clear = _Clear
1497 cls.DiscardUnknownFields = _DiscardUnknownFields
1498 cls._SetListener = _SetListener
1499
1500
1501def _AddPrivateHelperMethods(message_descriptor, cls):
1502 """Adds implementation of private helper methods to cls."""
1503
1504 def Modified(self):
1505 """Sets the _cached_byte_size_dirty bit to true,
1506 and propagates this to our listener iff this was a state change.
1507 """
1508
1509 # Note: Some callers check _cached_byte_size_dirty before calling
1510 # _Modified() as an extra optimization. So, if this method is ever
1511 # changed such that it does stuff even when _cached_byte_size_dirty is
1512 # already true, the callers need to be updated.
1513 if not self._cached_byte_size_dirty:
1514 self._cached_byte_size_dirty = True
1515 self._listener_for_children.dirty = True
1516 self._is_present_in_parent = True
1517 self._listener.Modified()
1518
1519 def _UpdateOneofState(self, field):
1520 """Sets field as the active field in its containing oneof.
1521
1522 Will also delete currently active field in the oneof, if it is different
1523 from the argument. Does not mark the message as modified.
1524 """
1525 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1526 if other_field is not field:
1527 del self._fields[other_field]
1528 self._oneofs[field.containing_oneof] = field
1529
1530 cls._Modified = Modified
1531 cls.SetInParent = Modified
1532 cls._UpdateOneofState = _UpdateOneofState
1533
1534
1535class _Listener(object):
1536
1537 """MessageListener implementation that a parent message registers with its
1538 child message.
1539
1540 In order to support semantics like:
1541
1542 foo.bar.baz.moo = 23
1543 assert foo.HasField('bar')
1544
1545 ...child objects must have back references to their parents.
1546 This helper class is at the heart of this support.
1547 """
1548
1549 def __init__(self, parent_message):
1550 """Args:
1551 parent_message: The message whose _Modified() method we should call when
1552 we receive Modified() messages.
1553 """
1554 # This listener establishes a back reference from a child (contained) object
1555 # to its parent (containing) object. We make this a weak reference to avoid
1556 # creating cyclic garbage when the client finishes with the 'parent' object
1557 # in the tree.
1558 if isinstance(parent_message, weakref.ProxyType):
1559 self._parent_message_weakref = parent_message
1560 else:
1561 self._parent_message_weakref = weakref.proxy(parent_message)
1562
1563 # As an optimization, we also indicate directly on the listener whether
1564 # or not the parent message is dirty. This way we can avoid traversing
1565 # up the tree in the common case.
1566 self.dirty = False
1567
1568 def Modified(self):
1569 if self.dirty:
1570 return
1571 try:
1572 # Propagate the signal to our parents iff this is the first field set.
1573 self._parent_message_weakref._Modified()
1574 except ReferenceError:
1575 # We can get here if a client has kept a reference to a child object,
1576 # and is now setting a field on it, but the child's parent has been
1577 # garbage-collected. This is not an error.
1578 pass
1579
1580
1581class _OneofListener(_Listener):
1582 """Special listener implementation for setting composite oneof fields."""
1583
1584 def __init__(self, parent_message, field):
1585 """Args:
1586 parent_message: The message whose _Modified() method we should call when
1587 we receive Modified() messages.
1588 field: The descriptor of the field being set in the parent message.
1589 """
1590 super(_OneofListener, self).__init__(parent_message)
1591 self._field = field
1592
1593 def Modified(self):
1594 """Also updates the state of the containing oneof in the parent message."""
1595 try:
1596 self._parent_message_weakref._UpdateOneofState(self._field)
1597 super(_OneofListener, self).Modified()
1598 except ReferenceError:
1599 pass