1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30from io import BytesIO
31import struct
32import sys
33import warnings
34import weakref
35
36from google.protobuf import descriptor as descriptor_mod
37from google.protobuf import message as message_mod
38from google.protobuf import text_format
39# We use "as" to avoid name collisions with variables.
40from google.protobuf.internal import api_implementation
41from google.protobuf.internal import containers
42from google.protobuf.internal import decoder
43from google.protobuf.internal import encoder
44from google.protobuf.internal import enum_type_wrapper
45from google.protobuf.internal import extension_dict
46from google.protobuf.internal import message_listener as message_listener_mod
47from google.protobuf.internal import type_checkers
48from google.protobuf.internal import well_known_types
49from google.protobuf.internal import wire_format
50
51_FieldDescriptor = descriptor_mod.FieldDescriptor
52_AnyFullTypeName = 'google.protobuf.Any'
53_ExtensionDict = extension_dict._ExtensionDict
54
55class GeneratedProtocolMessageType(type):
56
57 """Metaclass for protocol message classes created at runtime from Descriptors.
58
59 We add implementations for all methods described in the Message class. We
60 also create properties to allow getting/setting all fields in the protocol
61 message. Finally, we create slots to prevent users from accidentally
62 "setting" nonexistent fields in the protocol message, which then wouldn't get
63 serialized / deserialized properly.
64
65 The protocol compiler currently uses this metaclass to create protocol
66 message classes at runtime. Clients can also manually create their own
67 classes at runtime, as in this example:
68
69 mydescriptor = Descriptor(.....)
70 factory = symbol_database.Default()
71 factory.pool.AddDescriptor(mydescriptor)
72 MyProtoClass = factory.GetPrototype(mydescriptor)
73 myproto_instance = MyProtoClass()
74 myproto.foo_field = 23
75 ...
76 """
77
78 # Must be consistent with the protocol-compiler code in
79 # proto2/compiler/internal/generator.*.
80 _DESCRIPTOR_KEY = 'DESCRIPTOR'
81
82 def __new__(cls, name, bases, dictionary):
83 """Custom allocation for runtime-generated class types.
84
85 We override __new__ because this is apparently the only place
86 where we can meaningfully set __slots__ on the class we're creating(?).
87 (The interplay between metaclasses and slots is not very well-documented).
88
89 Args:
90 name: Name of the class (ignored, but required by the
91 metaclass protocol).
92 bases: Base classes of the class we're constructing.
93 (Should be message.Message). We ignore this field, but
94 it's required by the metaclass protocol
95 dictionary: The class dictionary of the class we're
96 constructing. dictionary[_DESCRIPTOR_KEY] must contain
97 a Descriptor object describing this protocol message
98 type.
99
100 Returns:
101 Newly-allocated class.
102
103 Raises:
104 RuntimeError: Generated code only work with python cpp extension.
105 """
106 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
107
108 if isinstance(descriptor, str):
109 raise RuntimeError('The generated code only work with python cpp '
110 'extension, but it is using pure python runtime.')
111
112 # If a concrete class already exists for this descriptor, don't try to
113 # create another. Doing so will break any messages that already exist with
114 # the existing class.
115 #
116 # The C++ implementation appears to have its own internal `PyMessageFactory`
117 # to achieve similar results.
118 #
119 # This most commonly happens in `text_format.py` when using descriptors from
120 # a custom pool; it calls symbol_database.Global().getPrototype() on a
121 # descriptor which already has an existing concrete class.
122 new_class = getattr(descriptor, '_concrete_class', None)
123 if new_class:
124 return new_class
125
126 if descriptor.full_name in well_known_types.WKTBASES:
127 bases += (well_known_types.WKTBASES[descriptor.full_name],)
128 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
129 _AddSlots(descriptor, dictionary)
130
131 superclass = super(GeneratedProtocolMessageType, cls)
132 new_class = superclass.__new__(cls, name, bases, dictionary)
133 return new_class
134
135 def __init__(cls, name, bases, dictionary):
136 """Here we perform the majority of our work on the class.
137 We add enum getters, an __init__ method, implementations
138 of all Message methods, and properties for all fields
139 in the protocol type.
140
141 Args:
142 name: Name of the class (ignored, but required by the
143 metaclass protocol).
144 bases: Base classes of the class we're constructing.
145 (Should be message.Message). We ignore this field, but
146 it's required by the metaclass protocol
147 dictionary: The class dictionary of the class we're
148 constructing. dictionary[_DESCRIPTOR_KEY] must contain
149 a Descriptor object describing this protocol message
150 type.
151 """
152 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
153
154 # If this is an _existing_ class looked up via `_concrete_class` in the
155 # __new__ method above, then we don't need to re-initialize anything.
156 existing_class = getattr(descriptor, '_concrete_class', None)
157 if existing_class:
158 assert existing_class is cls, (
159 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
160 % (descriptor.full_name))
161 return
162
163 cls._message_set_decoders_by_tag = {}
164 cls._fields_by_tag = {}
165 if (descriptor.has_options and
166 descriptor.GetOptions().message_set_wire_format):
167 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
168 decoder.MessageSetItemDecoder(descriptor),
169 None,
170 )
171
172 # Attach stuff to each FieldDescriptor for quick lookup later on.
173 for field in descriptor.fields:
174 _AttachFieldHelpers(cls, field)
175
176 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
177 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
178 for ext in extensions:
179 _AttachFieldHelpers(cls, ext)
180
181 descriptor._concrete_class = cls # pylint: disable=protected-access
182 _AddEnumValues(descriptor, cls)
183 _AddInitMethod(descriptor, cls)
184 _AddPropertiesForFields(descriptor, cls)
185 _AddPropertiesForExtensions(descriptor, cls)
186 _AddStaticMethods(cls)
187 _AddMessageMethods(descriptor, cls)
188 _AddPrivateHelperMethods(descriptor, cls)
189
190 superclass = super(GeneratedProtocolMessageType, cls)
191 superclass.__init__(name, bases, dictionary)
192
193
194# Stateless helpers for GeneratedProtocolMessageType below.
195# Outside clients should not access these directly.
196#
197# I opted not to make any of these methods on the metaclass, to make it more
198# clear that I'm not really using any state there and to keep clients from
199# thinking that they have direct access to these construction helpers.
200
201
202def _PropertyName(proto_field_name):
203 """Returns the name of the public property attribute which
204 clients can use to get and (in some cases) set the value
205 of a protocol message field.
206
207 Args:
208 proto_field_name: The protocol message field name, exactly
209 as it appears (or would appear) in a .proto file.
210 """
211 # TODO: Escape Python keywords (e.g., yield), and test this support.
212 # nnorwitz makes my day by writing:
213 # """
214 # FYI. See the keyword module in the stdlib. This could be as simple as:
215 #
216 # if keyword.iskeyword(proto_field_name):
217 # return proto_field_name + "_"
218 # return proto_field_name
219 # """
220 # Kenton says: The above is a BAD IDEA. People rely on being able to use
221 # getattr() and setattr() to reflectively manipulate field values. If we
222 # rename the properties, then every such user has to also make sure to apply
223 # the same transformation. Note that currently if you name a field "yield",
224 # you can still access it just fine using getattr/setattr -- it's not even
225 # that cumbersome to do so.
226 # TODO: Remove this method entirely if/when everyone agrees with my
227 # position.
228 return proto_field_name
229
230
231def _AddSlots(message_descriptor, dictionary):
232 """Adds a __slots__ entry to dictionary, containing the names of all valid
233 attributes for this message type.
234
235 Args:
236 message_descriptor: A Descriptor instance describing this message type.
237 dictionary: Class dictionary to which we'll add a '__slots__' entry.
238 """
239 dictionary['__slots__'] = ['_cached_byte_size',
240 '_cached_byte_size_dirty',
241 '_fields',
242 '_unknown_fields',
243 '_is_present_in_parent',
244 '_listener',
245 '_listener_for_children',
246 '__weakref__',
247 '_oneofs']
248
249
250def _IsMessageSetExtension(field):
251 return (field.is_extension and
252 field.containing_type.has_options and
253 field.containing_type.GetOptions().message_set_wire_format and
254 field.type == _FieldDescriptor.TYPE_MESSAGE and
255 field.label == _FieldDescriptor.LABEL_OPTIONAL)
256
257
258def _IsMapField(field):
259 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
260 field.message_type._is_map_entry)
261
262
263def _IsMessageMapField(field):
264 value_type = field.message_type.fields_by_name['value']
265 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
266
267def _AttachFieldHelpers(cls, field_descriptor):
268 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
269 field_descriptor._default_constructor = _DefaultValueConstructorForField(
270 field_descriptor
271 )
272
273 def AddFieldByTag(wiretype, is_packed):
274 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
275 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
276
277 AddFieldByTag(
278 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
279 )
280
281 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
282 # To support wire compatibility of adding packed = true, add a decoder for
283 # packed values regardless of the field's options.
284 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
285
286
287def _MaybeAddEncoder(cls, field_descriptor):
288 if hasattr(field_descriptor, '_encoder'):
289 return
290 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
291 is_map_entry = _IsMapField(field_descriptor)
292 is_packed = field_descriptor.is_packed
293
294 if is_map_entry:
295 field_encoder = encoder.MapEncoder(field_descriptor)
296 sizer = encoder.MapSizer(field_descriptor,
297 _IsMessageMapField(field_descriptor))
298 elif _IsMessageSetExtension(field_descriptor):
299 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
300 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
301 else:
302 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
303 field_descriptor.number, is_repeated, is_packed)
304 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
305 field_descriptor.number, is_repeated, is_packed)
306
307 field_descriptor._sizer = sizer
308 field_descriptor._encoder = field_encoder
309
310
311def _MaybeAddDecoder(cls, field_descriptor):
312 if hasattr(field_descriptor, '_decoders'):
313 return
314
315 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
316 is_map_entry = _IsMapField(field_descriptor)
317 helper_decoders = {}
318
319 def AddDecoder(is_packed):
320 decode_type = field_descriptor.type
321 if (decode_type == _FieldDescriptor.TYPE_ENUM and
322 not field_descriptor.enum_type.is_closed):
323 decode_type = _FieldDescriptor.TYPE_INT32
324
325 oneof_descriptor = None
326 if field_descriptor.containing_oneof is not None:
327 oneof_descriptor = field_descriptor
328
329 if is_map_entry:
330 is_message_map = _IsMessageMapField(field_descriptor)
331
332 field_decoder = decoder.MapDecoder(
333 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
334 is_message_map)
335 elif decode_type == _FieldDescriptor.TYPE_STRING:
336 field_decoder = decoder.StringDecoder(
337 field_descriptor.number, is_repeated, is_packed,
338 field_descriptor, field_descriptor._default_constructor,
339 not field_descriptor.has_presence)
340 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
341 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
342 field_descriptor.number, is_repeated, is_packed,
343 field_descriptor, field_descriptor._default_constructor)
344 else:
345 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
346 field_descriptor.number, is_repeated, is_packed,
347 # pylint: disable=protected-access
348 field_descriptor, field_descriptor._default_constructor,
349 not field_descriptor.has_presence)
350
351 helper_decoders[is_packed] = field_decoder
352
353 AddDecoder(False)
354
355 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
356 # To support wire compatibility of adding packed = true, add a decoder for
357 # packed values regardless of the field's options.
358 AddDecoder(True)
359
360 field_descriptor._decoders = helper_decoders
361
362
363def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
364 extensions = descriptor.extensions_by_name
365 for extension_name, extension_field in extensions.items():
366 assert extension_name not in dictionary
367 dictionary[extension_name] = extension_field
368
369
370def _AddEnumValues(descriptor, cls):
371 """Sets class-level attributes for all enum fields defined in this message.
372
373 Also exporting a class-level object that can name enum values.
374
375 Args:
376 descriptor: Descriptor object for this message type.
377 cls: Class we're constructing for this message type.
378 """
379 for enum_type in descriptor.enum_types:
380 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
381 for enum_value in enum_type.values:
382 setattr(cls, enum_value.name, enum_value.number)
383
384
385def _GetInitializeDefaultForMap(field):
386 if field.label != _FieldDescriptor.LABEL_REPEATED:
387 raise ValueError('map_entry set on non-repeated field %s' % (
388 field.name))
389 fields_by_name = field.message_type.fields_by_name
390 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
391
392 value_field = fields_by_name['value']
393 if _IsMessageMapField(field):
394 def MakeMessageMapDefault(message):
395 return containers.MessageMap(
396 message._listener_for_children, value_field.message_type, key_checker,
397 field.message_type)
398 return MakeMessageMapDefault
399 else:
400 value_checker = type_checkers.GetTypeChecker(value_field)
401 def MakePrimitiveMapDefault(message):
402 return containers.ScalarMap(
403 message._listener_for_children, key_checker, value_checker,
404 field.message_type)
405 return MakePrimitiveMapDefault
406
407def _DefaultValueConstructorForField(field):
408 """Returns a function which returns a default value for a field.
409
410 Args:
411 field: FieldDescriptor object for this field.
412
413 The returned function has one argument:
414 message: Message instance containing this field, or a weakref proxy
415 of same.
416
417 That function in turn returns a default value for this field. The default
418 value may refer back to |message| via a weak reference.
419 """
420
421 if _IsMapField(field):
422 return _GetInitializeDefaultForMap(field)
423
424 if field.label == _FieldDescriptor.LABEL_REPEATED:
425 if field.has_default_value and field.default_value != []:
426 raise ValueError('Repeated field default value not empty list: %s' % (
427 field.default_value))
428 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
429 # We can't look at _concrete_class yet since it might not have
430 # been set. (Depends on order in which we initialize the classes).
431 message_type = field.message_type
432 def MakeRepeatedMessageDefault(message):
433 return containers.RepeatedCompositeFieldContainer(
434 message._listener_for_children, field.message_type)
435 return MakeRepeatedMessageDefault
436 else:
437 type_checker = type_checkers.GetTypeChecker(field)
438 def MakeRepeatedScalarDefault(message):
439 return containers.RepeatedScalarFieldContainer(
440 message._listener_for_children, type_checker)
441 return MakeRepeatedScalarDefault
442
443 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
444 message_type = field.message_type
445 def MakeSubMessageDefault(message):
446 # _concrete_class may not yet be initialized.
447 if not hasattr(message_type, '_concrete_class'):
448 from google.protobuf import message_factory
449 message_factory.GetMessageClass(message_type)
450 result = message_type._concrete_class()
451 result._SetListener(
452 _OneofListener(message, field)
453 if field.containing_oneof is not None
454 else message._listener_for_children)
455 return result
456 return MakeSubMessageDefault
457
458 def MakeScalarDefault(message):
459 # TODO: This may be broken since there may not be
460 # default_value. Combine with has_default_value somehow.
461 return field.default_value
462 return MakeScalarDefault
463
464
465def _ReraiseTypeErrorWithFieldName(message_name, field_name):
466 """Re-raise the currently-handled TypeError with the field name added."""
467 exc = sys.exc_info()[1]
468 if len(exc.args) == 1 and type(exc) is TypeError:
469 # simple TypeError; add field name to exception message
470 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
471
472 # re-raise possibly-amended exception with original traceback:
473 raise exc.with_traceback(sys.exc_info()[2])
474
475
476def _AddInitMethod(message_descriptor, cls):
477 """Adds an __init__ method to cls."""
478
479 def _GetIntegerEnumValue(enum_type, value):
480 """Convert a string or integer enum value to an integer.
481
482 If the value is a string, it is converted to the enum value in
483 enum_type with the same name. If the value is not a string, it's
484 returned as-is. (No conversion or bounds-checking is done.)
485 """
486 if isinstance(value, str):
487 try:
488 return enum_type.values_by_name[value].number
489 except KeyError:
490 raise ValueError('Enum type %s: unknown label "%s"' % (
491 enum_type.full_name, value))
492 return value
493
494 def init(self, **kwargs):
495 self._cached_byte_size = 0
496 self._cached_byte_size_dirty = len(kwargs) > 0
497 self._fields = {}
498 # Contains a mapping from oneof field descriptors to the descriptor
499 # of the currently set field in that oneof field.
500 self._oneofs = {}
501
502 # _unknown_fields is () when empty for efficiency, and will be turned into
503 # a list if fields are added.
504 self._unknown_fields = ()
505 self._is_present_in_parent = False
506 self._listener = message_listener_mod.NullMessageListener()
507 self._listener_for_children = _Listener(self)
508 for field_name, field_value in kwargs.items():
509 field = _GetFieldByName(message_descriptor, field_name)
510 if field is None:
511 raise TypeError('%s() got an unexpected keyword argument "%s"' %
512 (message_descriptor.name, field_name))
513 if field_value is None:
514 # field=None is the same as no field at all.
515 continue
516 if field.label == _FieldDescriptor.LABEL_REPEATED:
517 copy = field._default_constructor(self)
518 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
519 if _IsMapField(field):
520 if _IsMessageMapField(field):
521 for key in field_value:
522 copy[key].MergeFrom(field_value[key])
523 else:
524 copy.update(field_value)
525 else:
526 for val in field_value:
527 if isinstance(val, dict):
528 copy.add(**val)
529 else:
530 copy.add().MergeFrom(val)
531 else: # Scalar
532 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
533 field_value = [_GetIntegerEnumValue(field.enum_type, val)
534 for val in field_value]
535 copy.extend(field_value)
536 self._fields[field] = copy
537 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
538 copy = field._default_constructor(self)
539 new_val = field_value
540 if isinstance(field_value, dict):
541 new_val = field.message_type._concrete_class(**field_value)
542 try:
543 copy.MergeFrom(new_val)
544 except TypeError:
545 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
546 self._fields[field] = copy
547 else:
548 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
549 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
550 try:
551 setattr(self, field_name, field_value)
552 except TypeError:
553 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
554
555 init.__module__ = None
556 init.__doc__ = None
557 cls.__init__ = init
558
559
560def _GetFieldByName(message_descriptor, field_name):
561 """Returns a field descriptor by field name.
562
563 Args:
564 message_descriptor: A Descriptor describing all fields in message.
565 field_name: The name of the field to retrieve.
566 Returns:
567 The field descriptor associated with the field name.
568 """
569 try:
570 return message_descriptor.fields_by_name[field_name]
571 except KeyError:
572 raise ValueError('Protocol message %s has no "%s" field.' %
573 (message_descriptor.name, field_name))
574
575
576def _AddPropertiesForFields(descriptor, cls):
577 """Adds properties for all fields in this protocol message type."""
578 for field in descriptor.fields:
579 _AddPropertiesForField(field, cls)
580
581 if descriptor.is_extendable:
582 # _ExtensionDict is just an adaptor with no state so we allocate a new one
583 # every time it is accessed.
584 cls.Extensions = property(lambda self: _ExtensionDict(self))
585
586
587def _AddPropertiesForField(field, cls):
588 """Adds a public property for a protocol message field.
589 Clients can use this property to get and (in the case
590 of non-repeated scalar fields) directly set the value
591 of a protocol message field.
592
593 Args:
594 field: A FieldDescriptor for this field.
595 cls: The class we're constructing.
596 """
597 # Catch it if we add other types that we should
598 # handle specially here.
599 assert _FieldDescriptor.MAX_CPPTYPE == 10
600
601 constant_name = field.name.upper() + '_FIELD_NUMBER'
602 setattr(cls, constant_name, field.number)
603
604 if field.label == _FieldDescriptor.LABEL_REPEATED:
605 _AddPropertiesForRepeatedField(field, cls)
606 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
607 _AddPropertiesForNonRepeatedCompositeField(field, cls)
608 else:
609 _AddPropertiesForNonRepeatedScalarField(field, cls)
610
611
612class _FieldProperty(property):
613 __slots__ = ('DESCRIPTOR',)
614
615 def __init__(self, descriptor, getter, setter, doc):
616 property.__init__(self, getter, setter, doc=doc)
617 self.DESCRIPTOR = descriptor
618
619
620def _AddPropertiesForRepeatedField(field, cls):
621 """Adds a public property for a "repeated" protocol message field. Clients
622 can use this property to get the value of the field, which will be either a
623 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
624 below).
625
626 Note that when clients add values to these containers, we perform
627 type-checking in the case of repeated scalar fields, and we also set any
628 necessary "has" bits as a side-effect.
629
630 Args:
631 field: A FieldDescriptor for this field.
632 cls: The class we're constructing.
633 """
634 proto_field_name = field.name
635 property_name = _PropertyName(proto_field_name)
636
637 def getter(self):
638 field_value = self._fields.get(field)
639 if field_value is None:
640 # Construct a new object to represent this field.
641 field_value = field._default_constructor(self)
642
643 # Atomically check if another thread has preempted us and, if not, swap
644 # in the new object we just created. If someone has preempted us, we
645 # take that object and discard ours.
646 # WARNING: We are relying on setdefault() being atomic. This is true
647 # in CPython but we haven't investigated others. This warning appears
648 # in several other locations in this file.
649 field_value = self._fields.setdefault(field, field_value)
650 return field_value
651 getter.__module__ = None
652 getter.__doc__ = 'Getter for %s.' % proto_field_name
653
654 # We define a setter just so we can throw an exception with a more
655 # helpful error message.
656 def setter(self, new_value):
657 raise AttributeError('Assignment not allowed to repeated field '
658 '"%s" in protocol message object.' % proto_field_name)
659
660 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
661 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
662
663
664def _AddPropertiesForNonRepeatedScalarField(field, cls):
665 """Adds a public property for a nonrepeated, scalar protocol message field.
666 Clients can use this property to get and directly set the value of the field.
667 Note that when the client sets the value of a field by using this property,
668 all necessary "has" bits are set as a side-effect, and we also perform
669 type-checking.
670
671 Args:
672 field: A FieldDescriptor for this field.
673 cls: The class we're constructing.
674 """
675 proto_field_name = field.name
676 property_name = _PropertyName(proto_field_name)
677 type_checker = type_checkers.GetTypeChecker(field)
678 default_value = field.default_value
679
680 def getter(self):
681 # TODO: This may be broken since there may not be
682 # default_value. Combine with has_default_value somehow.
683 return self._fields.get(field, default_value)
684 getter.__module__ = None
685 getter.__doc__ = 'Getter for %s.' % proto_field_name
686
687 def field_setter(self, new_value):
688 # pylint: disable=protected-access
689 # Testing the value for truthiness captures all of the proto3 defaults
690 # (0, 0.0, enum 0, and False).
691 try:
692 new_value = type_checker.CheckValue(new_value)
693 except TypeError as e:
694 raise TypeError(
695 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
696 if not field.has_presence and not new_value:
697 self._fields.pop(field, None)
698 else:
699 self._fields[field] = new_value
700 # Check _cached_byte_size_dirty inline to improve performance, since scalar
701 # setters are called frequently.
702 if not self._cached_byte_size_dirty:
703 self._Modified()
704
705 if field.containing_oneof:
706 def setter(self, new_value):
707 field_setter(self, new_value)
708 self._UpdateOneofState(field)
709 else:
710 setter = field_setter
711
712 setter.__module__ = None
713 setter.__doc__ = 'Setter for %s.' % proto_field_name
714
715 # Add a property to encapsulate the getter/setter.
716 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
717 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
718
719
720def _AddPropertiesForNonRepeatedCompositeField(field, cls):
721 """Adds a public property for a nonrepeated, composite protocol message field.
722 A composite field is a "group" or "message" field.
723
724 Clients can use this property to get the value of the field, but cannot
725 assign to the property directly.
726
727 Args:
728 field: A FieldDescriptor for this field.
729 cls: The class we're constructing.
730 """
731 # TODO: Remove duplication with similar method
732 # for non-repeated scalars.
733 proto_field_name = field.name
734 property_name = _PropertyName(proto_field_name)
735
736 def getter(self):
737 field_value = self._fields.get(field)
738 if field_value is None:
739 # Construct a new object to represent this field.
740 field_value = field._default_constructor(self)
741
742 # Atomically check if another thread has preempted us and, if not, swap
743 # in the new object we just created. If someone has preempted us, we
744 # take that object and discard ours.
745 # WARNING: We are relying on setdefault() being atomic. This is true
746 # in CPython but we haven't investigated others. This warning appears
747 # in several other locations in this file.
748 field_value = self._fields.setdefault(field, field_value)
749 return field_value
750 getter.__module__ = None
751 getter.__doc__ = 'Getter for %s.' % proto_field_name
752
753 # We define a setter just so we can throw an exception with a more
754 # helpful error message.
755 def setter(self, new_value):
756 raise AttributeError('Assignment not allowed to composite field '
757 '"%s" in protocol message object.' % proto_field_name)
758
759 # Add a property to encapsulate the getter.
760 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
761 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
762
763
764def _AddPropertiesForExtensions(descriptor, cls):
765 """Adds properties for all fields in this protocol message type."""
766 extensions = descriptor.extensions_by_name
767 for extension_name, extension_field in extensions.items():
768 constant_name = extension_name.upper() + '_FIELD_NUMBER'
769 setattr(cls, constant_name, extension_field.number)
770
771 # TODO: Migrate all users of these attributes to functions like
772 # pool.FindExtensionByNumber(descriptor).
773 if descriptor.file is not None:
774 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
775 pool = descriptor.file.pool
776
777def _AddStaticMethods(cls):
778 def FromString(s):
779 message = cls()
780 message.MergeFromString(s)
781 return message
782 cls.FromString = staticmethod(FromString)
783
784
785def _IsPresent(item):
786 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
787 value should be included in the list returned by ListFields()."""
788
789 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
790 return bool(item[1])
791 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
792 return item[1]._is_present_in_parent
793 else:
794 return True
795
796
797def _AddListFieldsMethod(message_descriptor, cls):
798 """Helper for _AddMessageMethods()."""
799
800 def ListFields(self):
801 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
802 all_fields.sort(key = lambda item: item[0].number)
803 return all_fields
804
805 cls.ListFields = ListFields
806
807
808def _AddHasFieldMethod(message_descriptor, cls):
809 """Helper for _AddMessageMethods()."""
810
811 hassable_fields = {}
812 for field in message_descriptor.fields:
813 if field.label == _FieldDescriptor.LABEL_REPEATED:
814 continue
815 # For proto3, only submessages and fields inside a oneof have presence.
816 if not field.has_presence:
817 continue
818 hassable_fields[field.name] = field
819
820 # Has methods are supported for oneof descriptors.
821 for oneof in message_descriptor.oneofs:
822 hassable_fields[oneof.name] = oneof
823
824 def HasField(self, field_name):
825 try:
826 field = hassable_fields[field_name]
827 except KeyError as exc:
828 raise ValueError('Protocol message %s has no non-repeated field "%s" '
829 'nor has presence is not available for this field.' % (
830 message_descriptor.full_name, field_name)) from exc
831
832 if isinstance(field, descriptor_mod.OneofDescriptor):
833 try:
834 return HasField(self, self._oneofs[field].name)
835 except KeyError:
836 return False
837 else:
838 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
839 value = self._fields.get(field)
840 return value is not None and value._is_present_in_parent
841 else:
842 return field in self._fields
843
844 cls.HasField = HasField
845
846
847def _AddClearFieldMethod(message_descriptor, cls):
848 """Helper for _AddMessageMethods()."""
849 def ClearField(self, field_name):
850 try:
851 field = message_descriptor.fields_by_name[field_name]
852 except KeyError:
853 try:
854 field = message_descriptor.oneofs_by_name[field_name]
855 if field in self._oneofs:
856 field = self._oneofs[field]
857 else:
858 return
859 except KeyError:
860 raise ValueError('Protocol message %s has no "%s" field.' %
861 (message_descriptor.name, field_name))
862
863 if field in self._fields:
864 # To match the C++ implementation, we need to invalidate iterators
865 # for map fields when ClearField() happens.
866 if hasattr(self._fields[field], 'InvalidateIterators'):
867 self._fields[field].InvalidateIterators()
868
869 # Note: If the field is a sub-message, its listener will still point
870 # at us. That's fine, because the worst than can happen is that it
871 # will call _Modified() and invalidate our byte size. Big deal.
872 del self._fields[field]
873
874 if self._oneofs.get(field.containing_oneof, None) is field:
875 del self._oneofs[field.containing_oneof]
876
877 # Always call _Modified() -- even if nothing was changed, this is
878 # a mutating method, and thus calling it should cause the field to become
879 # present in the parent message.
880 self._Modified()
881
882 cls.ClearField = ClearField
883
884
885def _AddClearExtensionMethod(cls):
886 """Helper for _AddMessageMethods()."""
887 def ClearExtension(self, field_descriptor):
888 extension_dict._VerifyExtensionHandle(self, field_descriptor)
889
890 # Similar to ClearField(), above.
891 if field_descriptor in self._fields:
892 del self._fields[field_descriptor]
893 self._Modified()
894 cls.ClearExtension = ClearExtension
895
896
897def _AddHasExtensionMethod(cls):
898 """Helper for _AddMessageMethods()."""
899 def HasExtension(self, field_descriptor):
900 extension_dict._VerifyExtensionHandle(self, field_descriptor)
901 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
902 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
903
904 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
905 value = self._fields.get(field_descriptor)
906 return value is not None and value._is_present_in_parent
907 else:
908 return field_descriptor in self._fields
909 cls.HasExtension = HasExtension
910
911def _InternalUnpackAny(msg):
912 """Unpacks Any message and returns the unpacked message.
913
914 This internal method is different from public Any Unpack method which takes
915 the target message as argument. _InternalUnpackAny method does not have
916 target message type and need to find the message type in descriptor pool.
917
918 Args:
919 msg: An Any message to be unpacked.
920
921 Returns:
922 The unpacked message.
923 """
924 # TODO: Don't use the factory of generated messages.
925 # To make Any work with custom factories, use the message factory of the
926 # parent message.
927 # pylint: disable=g-import-not-at-top
928 from google.protobuf import symbol_database
929 factory = symbol_database.Default()
930
931 type_url = msg.type_url
932
933 if not type_url:
934 return None
935
936 # TODO: For now we just strip the hostname. Better logic will be
937 # required.
938 type_name = type_url.split('/')[-1]
939 descriptor = factory.pool.FindMessageTypeByName(type_name)
940
941 if descriptor is None:
942 return None
943
944 message_class = factory.GetPrototype(descriptor)
945 message = message_class()
946
947 message.ParseFromString(msg.value)
948 return message
949
950
951def _AddEqualsMethod(message_descriptor, cls):
952 """Helper for _AddMessageMethods()."""
953 def __eq__(self, other):
954 if (not isinstance(other, message_mod.Message) or
955 other.DESCRIPTOR != self.DESCRIPTOR):
956 return NotImplemented
957
958 if self is other:
959 return True
960
961 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
962 any_a = _InternalUnpackAny(self)
963 any_b = _InternalUnpackAny(other)
964 if any_a and any_b:
965 return any_a == any_b
966
967 if not self.ListFields() == other.ListFields():
968 return False
969
970 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
971 # then use it for the comparison.
972 unknown_fields = list(self._unknown_fields)
973 unknown_fields.sort()
974 other_unknown_fields = list(other._unknown_fields)
975 other_unknown_fields.sort()
976 return unknown_fields == other_unknown_fields
977
978 cls.__eq__ = __eq__
979
980
981def _AddStrMethod(message_descriptor, cls):
982 """Helper for _AddMessageMethods()."""
983 def __str__(self):
984 return text_format.MessageToString(self)
985 cls.__str__ = __str__
986
987
988def _AddReprMethod(message_descriptor, cls):
989 """Helper for _AddMessageMethods()."""
990 def __repr__(self):
991 return text_format.MessageToString(self)
992 cls.__repr__ = __repr__
993
994
995def _AddUnicodeMethod(unused_message_descriptor, cls):
996 """Helper for _AddMessageMethods()."""
997
998 def __unicode__(self):
999 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1000 cls.__unicode__ = __unicode__
1001
1002
1003def _BytesForNonRepeatedElement(value, field_number, field_type):
1004 """Returns the number of bytes needed to serialize a non-repeated element.
1005 The returned byte count includes space for tag information and any
1006 other additional space associated with serializing value.
1007
1008 Args:
1009 value: Value we're serializing.
1010 field_number: Field number of this value. (Since the field number
1011 is stored as part of a varint-encoded tag, this has an impact
1012 on the total bytes required to serialize the value).
1013 field_type: The type of the field. One of the TYPE_* constants
1014 within FieldDescriptor.
1015 """
1016 try:
1017 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1018 return fn(field_number, value)
1019 except KeyError:
1020 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1021
1022
1023def _AddByteSizeMethod(message_descriptor, cls):
1024 """Helper for _AddMessageMethods()."""
1025
1026 def ByteSize(self):
1027 if not self._cached_byte_size_dirty:
1028 return self._cached_byte_size
1029
1030 size = 0
1031 descriptor = self.DESCRIPTOR
1032 if descriptor._is_map_entry:
1033 # Fields of map entry should always be serialized.
1034 key_field = descriptor.fields_by_name['key']
1035 _MaybeAddEncoder(cls, key_field)
1036 size = key_field._sizer(self.key)
1037 value_field = descriptor.fields_by_name['value']
1038 _MaybeAddEncoder(cls, value_field)
1039 size += value_field._sizer(self.value)
1040 else:
1041 for field_descriptor, field_value in self.ListFields():
1042 _MaybeAddEncoder(cls, field_descriptor)
1043 size += field_descriptor._sizer(field_value)
1044 for tag_bytes, value_bytes in self._unknown_fields:
1045 size += len(tag_bytes) + len(value_bytes)
1046
1047 self._cached_byte_size = size
1048 self._cached_byte_size_dirty = False
1049 self._listener_for_children.dirty = False
1050 return size
1051
1052 cls.ByteSize = ByteSize
1053
1054
1055def _AddSerializeToStringMethod(message_descriptor, cls):
1056 """Helper for _AddMessageMethods()."""
1057
1058 def SerializeToString(self, **kwargs):
1059 # Check if the message has all of its required fields set.
1060 if not self.IsInitialized():
1061 raise message_mod.EncodeError(
1062 'Message %s is missing required fields: %s' % (
1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1064 return self.SerializePartialToString(**kwargs)
1065 cls.SerializeToString = SerializeToString
1066
1067
1068def _AddSerializePartialToStringMethod(message_descriptor, cls):
1069 """Helper for _AddMessageMethods()."""
1070
1071 def SerializePartialToString(self, **kwargs):
1072 out = BytesIO()
1073 self._InternalSerialize(out.write, **kwargs)
1074 return out.getvalue()
1075 cls.SerializePartialToString = SerializePartialToString
1076
1077 def InternalSerialize(self, write_bytes, deterministic=None):
1078 if deterministic is None:
1079 deterministic = (
1080 api_implementation.IsPythonDefaultSerializationDeterministic())
1081 else:
1082 deterministic = bool(deterministic)
1083
1084 descriptor = self.DESCRIPTOR
1085 if descriptor._is_map_entry:
1086 # Fields of map entry should always be serialized.
1087 key_field = descriptor.fields_by_name['key']
1088 _MaybeAddEncoder(cls, key_field)
1089 key_field._encoder(write_bytes, self.key, deterministic)
1090 value_field = descriptor.fields_by_name['value']
1091 _MaybeAddEncoder(cls, value_field)
1092 value_field._encoder(write_bytes, self.value, deterministic)
1093 else:
1094 for field_descriptor, field_value in self.ListFields():
1095 _MaybeAddEncoder(cls, field_descriptor)
1096 field_descriptor._encoder(write_bytes, field_value, deterministic)
1097 for tag_bytes, value_bytes in self._unknown_fields:
1098 write_bytes(tag_bytes)
1099 write_bytes(value_bytes)
1100 cls._InternalSerialize = InternalSerialize
1101
1102
1103def _AddMergeFromStringMethod(message_descriptor, cls):
1104 """Helper for _AddMessageMethods()."""
1105 def MergeFromString(self, serialized):
1106 serialized = memoryview(serialized)
1107 length = len(serialized)
1108 try:
1109 if self._InternalParse(serialized, 0, length) != length:
1110 # The only reason _InternalParse would return early is if it
1111 # encountered an end-group tag.
1112 raise message_mod.DecodeError('Unexpected end-group tag.')
1113 except (IndexError, TypeError):
1114 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1115 raise message_mod.DecodeError('Truncated message.')
1116 except struct.error as e:
1117 raise message_mod.DecodeError(e)
1118 return length # Return this for legacy reasons.
1119 cls.MergeFromString = MergeFromString
1120
1121 local_ReadTag = decoder.ReadTag
1122 local_SkipField = decoder.SkipField
1123 fields_by_tag = cls._fields_by_tag
1124 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1125
1126 def InternalParse(self, buffer, pos, end):
1127 """Create a message from serialized bytes.
1128
1129 Args:
1130 self: Message, instance of the proto message object.
1131 buffer: memoryview of the serialized data.
1132 pos: int, position to start in the serialized data.
1133 end: int, end position of the serialized data.
1134
1135 Returns:
1136 Message object.
1137 """
1138 # Guard against internal misuse, since this function is called internally
1139 # quite extensively, and its easy to accidentally pass bytes.
1140 assert isinstance(buffer, memoryview)
1141 self._Modified()
1142 field_dict = self._fields
1143 while pos != end:
1144 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1145 field_decoder, field_des = message_set_decoders_by_tag.get(
1146 tag_bytes, (None, None)
1147 )
1148 if field_decoder:
1149 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1150 continue
1151 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1152 if field_des is None:
1153 if not self._unknown_fields: # pylint: disable=protected-access
1154 self._unknown_fields = [] # pylint: disable=protected-access
1155 # pylint: disable=protected-access
1156 (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1157 field_number, wire_type = wire_format.UnpackTag(tag)
1158 if field_number == 0:
1159 raise message_mod.DecodeError('Field number 0 is illegal.')
1160 # TODO: remove old_pos.
1161 old_pos = new_pos
1162 (data, new_pos) = decoder._DecodeUnknownField(
1163 buffer, new_pos, wire_type) # pylint: disable=protected-access
1164 if new_pos == -1:
1165 return pos
1166 # TODO: remove _unknown_fields.
1167 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1168 if new_pos == -1:
1169 return pos
1170 self._unknown_fields.append(
1171 (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1172 pos = new_pos
1173 else:
1174 _MaybeAddDecoder(cls, field_des)
1175 field_decoder = field_des._decoders[is_packed]
1176 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1177 if field_des.containing_oneof:
1178 self._UpdateOneofState(field_des)
1179 return pos
1180 cls._InternalParse = InternalParse
1181
1182
1183def _AddIsInitializedMethod(message_descriptor, cls):
1184 """Adds the IsInitialized and FindInitializationError methods to the
1185 protocol message class."""
1186
1187 required_fields = [field for field in message_descriptor.fields
1188 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1189
1190 def IsInitialized(self, errors=None):
1191 """Checks if all required fields of a message are set.
1192
1193 Args:
1194 errors: A list which, if provided, will be populated with the field
1195 paths of all missing required fields.
1196
1197 Returns:
1198 True iff the specified message has all required fields set.
1199 """
1200
1201 # Performance is critical so we avoid HasField() and ListFields().
1202
1203 for field in required_fields:
1204 if (field not in self._fields or
1205 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1206 not self._fields[field]._is_present_in_parent)):
1207 if errors is not None:
1208 errors.extend(self.FindInitializationErrors())
1209 return False
1210
1211 for field, value in list(self._fields.items()): # dict can change size!
1212 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1213 if field.label == _FieldDescriptor.LABEL_REPEATED:
1214 if (field.message_type._is_map_entry):
1215 continue
1216 for element in value:
1217 if not element.IsInitialized():
1218 if errors is not None:
1219 errors.extend(self.FindInitializationErrors())
1220 return False
1221 elif value._is_present_in_parent and not value.IsInitialized():
1222 if errors is not None:
1223 errors.extend(self.FindInitializationErrors())
1224 return False
1225
1226 return True
1227
1228 cls.IsInitialized = IsInitialized
1229
1230 def FindInitializationErrors(self):
1231 """Finds required fields which are not initialized.
1232
1233 Returns:
1234 A list of strings. Each string is a path to an uninitialized field from
1235 the top-level message, e.g. "foo.bar[5].baz".
1236 """
1237
1238 errors = [] # simplify things
1239
1240 for field in required_fields:
1241 if not self.HasField(field.name):
1242 errors.append(field.name)
1243
1244 for field, value in self.ListFields():
1245 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1246 if field.is_extension:
1247 name = '(%s)' % field.full_name
1248 else:
1249 name = field.name
1250
1251 if _IsMapField(field):
1252 if _IsMessageMapField(field):
1253 for key in value:
1254 element = value[key]
1255 prefix = '%s[%s].' % (name, key)
1256 sub_errors = element.FindInitializationErrors()
1257 errors += [prefix + error for error in sub_errors]
1258 else:
1259 # ScalarMaps can't have any initialization errors.
1260 pass
1261 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1262 for i in range(len(value)):
1263 element = value[i]
1264 prefix = '%s[%d].' % (name, i)
1265 sub_errors = element.FindInitializationErrors()
1266 errors += [prefix + error for error in sub_errors]
1267 else:
1268 prefix = name + '.'
1269 sub_errors = value.FindInitializationErrors()
1270 errors += [prefix + error for error in sub_errors]
1271
1272 return errors
1273
1274 cls.FindInitializationErrors = FindInitializationErrors
1275
1276
1277def _FullyQualifiedClassName(klass):
1278 module = klass.__module__
1279 name = getattr(klass, '__qualname__', klass.__name__)
1280 if module in (None, 'builtins', '__builtin__'):
1281 return name
1282 return module + '.' + name
1283
1284
1285def _AddMergeFromMethod(cls):
1286 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1287 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1288
1289 def MergeFrom(self, msg):
1290 if not isinstance(msg, cls):
1291 raise TypeError(
1292 'Parameter to MergeFrom() must be instance of same class: '
1293 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1294 _FullyQualifiedClassName(msg.__class__)))
1295
1296 assert msg is not self
1297 self._Modified()
1298
1299 fields = self._fields
1300
1301 for field, value in msg._fields.items():
1302 if field.label == LABEL_REPEATED:
1303 field_value = fields.get(field)
1304 if field_value is None:
1305 # Construct a new object to represent this field.
1306 field_value = field._default_constructor(self)
1307 fields[field] = field_value
1308 field_value.MergeFrom(value)
1309 elif field.cpp_type == CPPTYPE_MESSAGE:
1310 if value._is_present_in_parent:
1311 field_value = fields.get(field)
1312 if field_value is None:
1313 # Construct a new object to represent this field.
1314 field_value = field._default_constructor(self)
1315 fields[field] = field_value
1316 field_value.MergeFrom(value)
1317 else:
1318 self._fields[field] = value
1319 if field.containing_oneof:
1320 self._UpdateOneofState(field)
1321
1322 if msg._unknown_fields:
1323 if not self._unknown_fields:
1324 self._unknown_fields = []
1325 self._unknown_fields.extend(msg._unknown_fields)
1326
1327 cls.MergeFrom = MergeFrom
1328
1329
1330def _AddWhichOneofMethod(message_descriptor, cls):
1331 def WhichOneof(self, oneof_name):
1332 """Returns the name of the currently set field inside a oneof, or None."""
1333 try:
1334 field = message_descriptor.oneofs_by_name[oneof_name]
1335 except KeyError:
1336 raise ValueError(
1337 'Protocol message has no oneof "%s" field.' % oneof_name)
1338
1339 nested_field = self._oneofs.get(field, None)
1340 if nested_field is not None and self.HasField(nested_field.name):
1341 return nested_field.name
1342 else:
1343 return None
1344
1345 cls.WhichOneof = WhichOneof
1346
1347
1348def _Clear(self):
1349 # Clear fields.
1350 self._fields = {}
1351 self._unknown_fields = ()
1352
1353 self._oneofs = {}
1354 self._Modified()
1355
1356
1357def _UnknownFields(self):
1358 raise NotImplementedError('Please use the add-on feaure '
1359 'unknown_fields.UnknownFieldSet(message) in '
1360 'unknown_fields.py instead.')
1361
1362
1363def _DiscardUnknownFields(self):
1364 self._unknown_fields = []
1365 for field, value in self.ListFields():
1366 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1367 if _IsMapField(field):
1368 if _IsMessageMapField(field):
1369 for key in value:
1370 value[key].DiscardUnknownFields()
1371 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1372 for sub_message in value:
1373 sub_message.DiscardUnknownFields()
1374 else:
1375 value.DiscardUnknownFields()
1376
1377
1378def _SetListener(self, listener):
1379 if listener is None:
1380 self._listener = message_listener_mod.NullMessageListener()
1381 else:
1382 self._listener = listener
1383
1384
1385def _AddMessageMethods(message_descriptor, cls):
1386 """Adds implementations of all Message methods to cls."""
1387 _AddListFieldsMethod(message_descriptor, cls)
1388 _AddHasFieldMethod(message_descriptor, cls)
1389 _AddClearFieldMethod(message_descriptor, cls)
1390 if message_descriptor.is_extendable:
1391 _AddClearExtensionMethod(cls)
1392 _AddHasExtensionMethod(cls)
1393 _AddEqualsMethod(message_descriptor, cls)
1394 _AddStrMethod(message_descriptor, cls)
1395 _AddReprMethod(message_descriptor, cls)
1396 _AddUnicodeMethod(message_descriptor, cls)
1397 _AddByteSizeMethod(message_descriptor, cls)
1398 _AddSerializeToStringMethod(message_descriptor, cls)
1399 _AddSerializePartialToStringMethod(message_descriptor, cls)
1400 _AddMergeFromStringMethod(message_descriptor, cls)
1401 _AddIsInitializedMethod(message_descriptor, cls)
1402 _AddMergeFromMethod(cls)
1403 _AddWhichOneofMethod(message_descriptor, cls)
1404 # Adds methods which do not depend on cls.
1405 cls.Clear = _Clear
1406 cls.DiscardUnknownFields = _DiscardUnknownFields
1407 cls._SetListener = _SetListener
1408
1409
1410def _AddPrivateHelperMethods(message_descriptor, cls):
1411 """Adds implementation of private helper methods to cls."""
1412
1413 def Modified(self):
1414 """Sets the _cached_byte_size_dirty bit to true,
1415 and propagates this to our listener iff this was a state change.
1416 """
1417
1418 # Note: Some callers check _cached_byte_size_dirty before calling
1419 # _Modified() as an extra optimization. So, if this method is ever
1420 # changed such that it does stuff even when _cached_byte_size_dirty is
1421 # already true, the callers need to be updated.
1422 if not self._cached_byte_size_dirty:
1423 self._cached_byte_size_dirty = True
1424 self._listener_for_children.dirty = True
1425 self._is_present_in_parent = True
1426 self._listener.Modified()
1427
1428 def _UpdateOneofState(self, field):
1429 """Sets field as the active field in its containing oneof.
1430
1431 Will also delete currently active field in the oneof, if it is different
1432 from the argument. Does not mark the message as modified.
1433 """
1434 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1435 if other_field is not field:
1436 del self._fields[other_field]
1437 self._oneofs[field.containing_oneof] = field
1438
1439 cls._Modified = Modified
1440 cls.SetInParent = Modified
1441 cls._UpdateOneofState = _UpdateOneofState
1442
1443
1444class _Listener(object):
1445
1446 """MessageListener implementation that a parent message registers with its
1447 child message.
1448
1449 In order to support semantics like:
1450
1451 foo.bar.baz.moo = 23
1452 assert foo.HasField('bar')
1453
1454 ...child objects must have back references to their parents.
1455 This helper class is at the heart of this support.
1456 """
1457
1458 def __init__(self, parent_message):
1459 """Args:
1460 parent_message: The message whose _Modified() method we should call when
1461 we receive Modified() messages.
1462 """
1463 # This listener establishes a back reference from a child (contained) object
1464 # to its parent (containing) object. We make this a weak reference to avoid
1465 # creating cyclic garbage when the client finishes with the 'parent' object
1466 # in the tree.
1467 if isinstance(parent_message, weakref.ProxyType):
1468 self._parent_message_weakref = parent_message
1469 else:
1470 self._parent_message_weakref = weakref.proxy(parent_message)
1471
1472 # As an optimization, we also indicate directly on the listener whether
1473 # or not the parent message is dirty. This way we can avoid traversing
1474 # up the tree in the common case.
1475 self.dirty = False
1476
1477 def Modified(self):
1478 if self.dirty:
1479 return
1480 try:
1481 # Propagate the signal to our parents iff this is the first field set.
1482 self._parent_message_weakref._Modified()
1483 except ReferenceError:
1484 # We can get here if a client has kept a reference to a child object,
1485 # and is now setting a field on it, but the child's parent has been
1486 # garbage-collected. This is not an error.
1487 pass
1488
1489
1490class _OneofListener(_Listener):
1491 """Special listener implementation for setting composite oneof fields."""
1492
1493 def __init__(self, parent_message, field):
1494 """Args:
1495 parent_message: The message whose _Modified() method we should call when
1496 we receive Modified() messages.
1497 field: The descriptor of the field being set in the parent message.
1498 """
1499 super(_OneofListener, self).__init__(parent_message)
1500 self._field = field
1501
1502 def Modified(self):
1503 """Also updates the state of the containing oneof in the parent message."""
1504 try:
1505 self._parent_message_weakref._UpdateOneofState(self._field)
1506 super(_OneofListener, self).Modified()
1507 except ReferenceError:
1508 pass