1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import struct
33import sys
34import warnings
35import weakref
36
37from google.protobuf import descriptor as descriptor_mod
38from google.protobuf import message as message_mod
39from google.protobuf import text_format
40# We use "as" to avoid name collisions with variables.
41from google.protobuf.internal import api_implementation
42from google.protobuf.internal import containers
43from google.protobuf.internal import decoder
44from google.protobuf.internal import encoder
45from google.protobuf.internal import enum_type_wrapper
46from google.protobuf.internal import extension_dict
47from google.protobuf.internal import message_listener as message_listener_mod
48from google.protobuf.internal import type_checkers
49from google.protobuf.internal import well_known_types
50from google.protobuf.internal import wire_format
51
52_FieldDescriptor = descriptor_mod.FieldDescriptor
53_AnyFullTypeName = 'google.protobuf.Any'
54_StructFullTypeName = 'google.protobuf.Struct'
55_ListValueFullTypeName = 'google.protobuf.ListValue'
56_ExtensionDict = extension_dict._ExtensionDict
57
58class GeneratedProtocolMessageType(type):
59
60 """Metaclass for protocol message classes created at runtime from Descriptors.
61
62 We add implementations for all methods described in the Message class. We
63 also create properties to allow getting/setting all fields in the protocol
64 message. Finally, we create slots to prevent users from accidentally
65 "setting" nonexistent fields in the protocol message, which then wouldn't get
66 serialized / deserialized properly.
67
68 The protocol compiler currently uses this metaclass to create protocol
69 message classes at runtime. Clients can also manually create their own
70 classes at runtime, as in this example:
71
72 mydescriptor = Descriptor(.....)
73 factory = symbol_database.Default()
74 factory.pool.AddDescriptor(mydescriptor)
75 MyProtoClass = factory.GetPrototype(mydescriptor)
76 myproto_instance = MyProtoClass()
77 myproto.foo_field = 23
78 ...
79 """
80
81 # Must be consistent with the protocol-compiler code in
82 # proto2/compiler/internal/generator.*.
83 _DESCRIPTOR_KEY = 'DESCRIPTOR'
84
85 def __new__(cls, name, bases, dictionary):
86 """Custom allocation for runtime-generated class types.
87
88 We override __new__ because this is apparently the only place
89 where we can meaningfully set __slots__ on the class we're creating(?).
90 (The interplay between metaclasses and slots is not very well-documented).
91
92 Args:
93 name: Name of the class (ignored, but required by the
94 metaclass protocol).
95 bases: Base classes of the class we're constructing.
96 (Should be message.Message). We ignore this field, but
97 it's required by the metaclass protocol
98 dictionary: The class dictionary of the class we're
99 constructing. dictionary[_DESCRIPTOR_KEY] must contain
100 a Descriptor object describing this protocol message
101 type.
102
103 Returns:
104 Newly-allocated class.
105
106 Raises:
107 RuntimeError: Generated code only work with python cpp extension.
108 """
109 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
110
111 if isinstance(descriptor, str):
112 raise RuntimeError('The generated code only work with python cpp '
113 'extension, but it is using pure python runtime.')
114
115 # If a concrete class already exists for this descriptor, don't try to
116 # create another. Doing so will break any messages that already exist with
117 # the existing class.
118 #
119 # The C++ implementation appears to have its own internal `PyMessageFactory`
120 # to achieve similar results.
121 #
122 # This most commonly happens in `text_format.py` when using descriptors from
123 # a custom pool; it calls symbol_database.Global().getPrototype() on a
124 # descriptor which already has an existing concrete class.
125 new_class = getattr(descriptor, '_concrete_class', None)
126 if new_class:
127 return new_class
128
129 if descriptor.full_name in well_known_types.WKTBASES:
130 bases += (well_known_types.WKTBASES[descriptor.full_name],)
131 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
132 _AddSlots(descriptor, dictionary)
133
134 superclass = super(GeneratedProtocolMessageType, cls)
135 new_class = superclass.__new__(cls, name, bases, dictionary)
136 return new_class
137
138 def __init__(cls, name, bases, dictionary):
139 """Here we perform the majority of our work on the class.
140 We add enum getters, an __init__ method, implementations
141 of all Message methods, and properties for all fields
142 in the protocol type.
143
144 Args:
145 name: Name of the class (ignored, but required by the
146 metaclass protocol).
147 bases: Base classes of the class we're constructing.
148 (Should be message.Message). We ignore this field, but
149 it's required by the metaclass protocol
150 dictionary: The class dictionary of the class we're
151 constructing. dictionary[_DESCRIPTOR_KEY] must contain
152 a Descriptor object describing this protocol message
153 type.
154 """
155 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
156
157 # If this is an _existing_ class looked up via `_concrete_class` in the
158 # __new__ method above, then we don't need to re-initialize anything.
159 existing_class = getattr(descriptor, '_concrete_class', None)
160 if existing_class:
161 assert existing_class is cls, (
162 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
163 % (descriptor.full_name))
164 return
165
166 cls._message_set_decoders_by_tag = {}
167 cls._fields_by_tag = {}
168 if (descriptor.has_options and
169 descriptor.GetOptions().message_set_wire_format):
170 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
171 decoder.MessageSetItemDecoder(descriptor),
172 None,
173 )
174
175 # Attach stuff to each FieldDescriptor for quick lookup later on.
176 for field in descriptor.fields:
177 _AttachFieldHelpers(cls, field)
178
179 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
180 extensions = descriptor.file.pool.FindAllExtensions(descriptor)
181 for ext in extensions:
182 _AttachFieldHelpers(cls, ext)
183
184 descriptor._concrete_class = cls # pylint: disable=protected-access
185 _AddEnumValues(descriptor, cls)
186 _AddInitMethod(descriptor, cls)
187 _AddPropertiesForFields(descriptor, cls)
188 _AddPropertiesForExtensions(descriptor, cls)
189 _AddStaticMethods(cls)
190 _AddMessageMethods(descriptor, cls)
191 _AddPrivateHelperMethods(descriptor, cls)
192
193 superclass = super(GeneratedProtocolMessageType, cls)
194 superclass.__init__(name, bases, dictionary)
195
196
197# Stateless helpers for GeneratedProtocolMessageType below.
198# Outside clients should not access these directly.
199#
200# I opted not to make any of these methods on the metaclass, to make it more
201# clear that I'm not really using any state there and to keep clients from
202# thinking that they have direct access to these construction helpers.
203
204
205def _PropertyName(proto_field_name):
206 """Returns the name of the public property attribute which
207 clients can use to get and (in some cases) set the value
208 of a protocol message field.
209
210 Args:
211 proto_field_name: The protocol message field name, exactly
212 as it appears (or would appear) in a .proto file.
213 """
214 # TODO: Escape Python keywords (e.g., yield), and test this support.
215 # nnorwitz makes my day by writing:
216 # """
217 # FYI. See the keyword module in the stdlib. This could be as simple as:
218 #
219 # if keyword.iskeyword(proto_field_name):
220 # return proto_field_name + "_"
221 # return proto_field_name
222 # """
223 # Kenton says: The above is a BAD IDEA. People rely on being able to use
224 # getattr() and setattr() to reflectively manipulate field values. If we
225 # rename the properties, then every such user has to also make sure to apply
226 # the same transformation. Note that currently if you name a field "yield",
227 # you can still access it just fine using getattr/setattr -- it's not even
228 # that cumbersome to do so.
229 # TODO: Remove this method entirely if/when everyone agrees with my
230 # position.
231 return proto_field_name
232
233
234def _AddSlots(message_descriptor, dictionary):
235 """Adds a __slots__ entry to dictionary, containing the names of all valid
236 attributes for this message type.
237
238 Args:
239 message_descriptor: A Descriptor instance describing this message type.
240 dictionary: Class dictionary to which we'll add a '__slots__' entry.
241 """
242 dictionary['__slots__'] = ['_cached_byte_size',
243 '_cached_byte_size_dirty',
244 '_fields',
245 '_unknown_fields',
246 '_is_present_in_parent',
247 '_listener',
248 '_listener_for_children',
249 '__weakref__',
250 '_oneofs']
251
252
253def _IsMessageSetExtension(field):
254 return (field.is_extension and
255 field.containing_type.has_options and
256 field.containing_type.GetOptions().message_set_wire_format and
257 field.type == _FieldDescriptor.TYPE_MESSAGE and
258 field.label == _FieldDescriptor.LABEL_OPTIONAL)
259
260
261def _IsMapField(field):
262 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
263 field.message_type._is_map_entry)
264
265
266def _IsMessageMapField(field):
267 value_type = field.message_type.fields_by_name['value']
268 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
269
270def _AttachFieldHelpers(cls, field_descriptor):
271 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
272 field_descriptor._default_constructor = _DefaultValueConstructorForField(
273 field_descriptor
274 )
275
276 def AddFieldByTag(wiretype, is_packed):
277 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
278 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
279
280 AddFieldByTag(
281 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
282 )
283
284 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
285 # To support wire compatibility of adding packed = true, add a decoder for
286 # packed values regardless of the field's options.
287 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
288
289
290def _MaybeAddEncoder(cls, field_descriptor):
291 if hasattr(field_descriptor, '_encoder'):
292 return
293 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
294 is_map_entry = _IsMapField(field_descriptor)
295 is_packed = field_descriptor.is_packed
296
297 if is_map_entry:
298 field_encoder = encoder.MapEncoder(field_descriptor)
299 sizer = encoder.MapSizer(field_descriptor,
300 _IsMessageMapField(field_descriptor))
301 elif _IsMessageSetExtension(field_descriptor):
302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
303 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
304 else:
305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
306 field_descriptor.number, is_repeated, is_packed)
307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
308 field_descriptor.number, is_repeated, is_packed)
309
310 field_descriptor._sizer = sizer
311 field_descriptor._encoder = field_encoder
312
313
314def _MaybeAddDecoder(cls, field_descriptor):
315 if hasattr(field_descriptor, '_decoders'):
316 return
317
318 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
319 is_map_entry = _IsMapField(field_descriptor)
320 helper_decoders = {}
321
322 def AddDecoder(is_packed):
323 decode_type = field_descriptor.type
324 if (decode_type == _FieldDescriptor.TYPE_ENUM and
325 not field_descriptor.enum_type.is_closed):
326 decode_type = _FieldDescriptor.TYPE_INT32
327
328 oneof_descriptor = None
329 if field_descriptor.containing_oneof is not None:
330 oneof_descriptor = field_descriptor
331
332 if is_map_entry:
333 is_message_map = _IsMessageMapField(field_descriptor)
334
335 field_decoder = decoder.MapDecoder(
336 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
337 is_message_map)
338 elif decode_type == _FieldDescriptor.TYPE_STRING:
339 field_decoder = decoder.StringDecoder(
340 field_descriptor.number, is_repeated, is_packed,
341 field_descriptor, field_descriptor._default_constructor,
342 not field_descriptor.has_presence)
343 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
344 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
345 field_descriptor.number, is_repeated, is_packed,
346 field_descriptor, field_descriptor._default_constructor)
347 else:
348 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
349 field_descriptor.number, is_repeated, is_packed,
350 # pylint: disable=protected-access
351 field_descriptor, field_descriptor._default_constructor,
352 not field_descriptor.has_presence)
353
354 helper_decoders[is_packed] = field_decoder
355
356 AddDecoder(False)
357
358 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
359 # To support wire compatibility of adding packed = true, add a decoder for
360 # packed values regardless of the field's options.
361 AddDecoder(True)
362
363 field_descriptor._decoders = helper_decoders
364
365
366def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
367 extensions = descriptor.extensions_by_name
368 for extension_name, extension_field in extensions.items():
369 assert extension_name not in dictionary
370 dictionary[extension_name] = extension_field
371
372
373def _AddEnumValues(descriptor, cls):
374 """Sets class-level attributes for all enum fields defined in this message.
375
376 Also exporting a class-level object that can name enum values.
377
378 Args:
379 descriptor: Descriptor object for this message type.
380 cls: Class we're constructing for this message type.
381 """
382 for enum_type in descriptor.enum_types:
383 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
384 for enum_value in enum_type.values:
385 setattr(cls, enum_value.name, enum_value.number)
386
387
388def _GetInitializeDefaultForMap(field):
389 if field.label != _FieldDescriptor.LABEL_REPEATED:
390 raise ValueError('map_entry set on non-repeated field %s' % (
391 field.name))
392 fields_by_name = field.message_type.fields_by_name
393 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
394
395 value_field = fields_by_name['value']
396 if _IsMessageMapField(field):
397 def MakeMessageMapDefault(message):
398 return containers.MessageMap(
399 message._listener_for_children, value_field.message_type, key_checker,
400 field.message_type)
401 return MakeMessageMapDefault
402 else:
403 value_checker = type_checkers.GetTypeChecker(value_field)
404 def MakePrimitiveMapDefault(message):
405 return containers.ScalarMap(
406 message._listener_for_children, key_checker, value_checker,
407 field.message_type)
408 return MakePrimitiveMapDefault
409
410def _DefaultValueConstructorForField(field):
411 """Returns a function which returns a default value for a field.
412
413 Args:
414 field: FieldDescriptor object for this field.
415
416 The returned function has one argument:
417 message: Message instance containing this field, or a weakref proxy
418 of same.
419
420 That function in turn returns a default value for this field. The default
421 value may refer back to |message| via a weak reference.
422 """
423
424 if _IsMapField(field):
425 return _GetInitializeDefaultForMap(field)
426
427 if field.label == _FieldDescriptor.LABEL_REPEATED:
428 if field.has_default_value and field.default_value != []:
429 raise ValueError('Repeated field default value not empty list: %s' % (
430 field.default_value))
431 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
432 # We can't look at _concrete_class yet since it might not have
433 # been set. (Depends on order in which we initialize the classes).
434 message_type = field.message_type
435 def MakeRepeatedMessageDefault(message):
436 return containers.RepeatedCompositeFieldContainer(
437 message._listener_for_children, field.message_type)
438 return MakeRepeatedMessageDefault
439 else:
440 type_checker = type_checkers.GetTypeChecker(field)
441 def MakeRepeatedScalarDefault(message):
442 return containers.RepeatedScalarFieldContainer(
443 message._listener_for_children, type_checker)
444 return MakeRepeatedScalarDefault
445
446 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
447 message_type = field.message_type
448 def MakeSubMessageDefault(message):
449 # _concrete_class may not yet be initialized.
450 if not hasattr(message_type, '_concrete_class'):
451 from google.protobuf import message_factory
452 message_factory.GetMessageClass(message_type)
453 result = message_type._concrete_class()
454 result._SetListener(
455 _OneofListener(message, field)
456 if field.containing_oneof is not None
457 else message._listener_for_children)
458 return result
459 return MakeSubMessageDefault
460
461 def MakeScalarDefault(message):
462 # TODO: This may be broken since there may not be
463 # default_value. Combine with has_default_value somehow.
464 return field.default_value
465 return MakeScalarDefault
466
467
468def _ReraiseTypeErrorWithFieldName(message_name, field_name):
469 """Re-raise the currently-handled TypeError with the field name added."""
470 exc = sys.exc_info()[1]
471 if len(exc.args) == 1 and type(exc) is TypeError:
472 # simple TypeError; add field name to exception message
473 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
474
475 # re-raise possibly-amended exception with original traceback:
476 raise exc.with_traceback(sys.exc_info()[2])
477
478
479def _AddInitMethod(message_descriptor, cls):
480 """Adds an __init__ method to cls."""
481
482 def _GetIntegerEnumValue(enum_type, value):
483 """Convert a string or integer enum value to an integer.
484
485 If the value is a string, it is converted to the enum value in
486 enum_type with the same name. If the value is not a string, it's
487 returned as-is. (No conversion or bounds-checking is done.)
488 """
489 if isinstance(value, str):
490 try:
491 return enum_type.values_by_name[value].number
492 except KeyError:
493 raise ValueError('Enum type %s: unknown label "%s"' % (
494 enum_type.full_name, value))
495 return value
496
497 def init(self, **kwargs):
498 self._cached_byte_size = 0
499 self._cached_byte_size_dirty = len(kwargs) > 0
500 self._fields = {}
501 # Contains a mapping from oneof field descriptors to the descriptor
502 # of the currently set field in that oneof field.
503 self._oneofs = {}
504
505 # _unknown_fields is () when empty for efficiency, and will be turned into
506 # a list if fields are added.
507 self._unknown_fields = ()
508 self._is_present_in_parent = False
509 self._listener = message_listener_mod.NullMessageListener()
510 self._listener_for_children = _Listener(self)
511 for field_name, field_value in kwargs.items():
512 field = _GetFieldByName(message_descriptor, field_name)
513 if field is None:
514 raise TypeError('%s() got an unexpected keyword argument "%s"' %
515 (message_descriptor.name, field_name))
516 if field_value is None:
517 # field=None is the same as no field at all.
518 continue
519 if field.label == _FieldDescriptor.LABEL_REPEATED:
520 field_copy = field._default_constructor(self)
521 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
522 if _IsMapField(field):
523 if _IsMessageMapField(field):
524 for key in field_value:
525 field_copy[key].MergeFrom(field_value[key])
526 else:
527 field_copy.update(field_value)
528 else:
529 for val in field_value:
530 if isinstance(val, dict):
531 field_copy.add(**val)
532 else:
533 field_copy.add().MergeFrom(val)
534 else: # Scalar
535 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
536 field_value = [_GetIntegerEnumValue(field.enum_type, val)
537 for val in field_value]
538 field_copy.extend(field_value)
539 self._fields[field] = field_copy
540 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
541 field_copy = field._default_constructor(self)
542 new_val = None
543 if isinstance(field_value, message_mod.Message):
544 new_val = field_value
545 elif isinstance(field_value, dict):
546 if field.message_type.full_name == _StructFullTypeName:
547 field_copy.Clear()
548 if len(field_value) == 1 and 'fields' in field_value:
549 try:
550 field_copy.update(field_value)
551 except:
552 # Fall back to init normal message field
553 field_copy.Clear()
554 new_val = field.message_type._concrete_class(**field_value)
555 else:
556 field_copy.update(field_value)
557 else:
558 new_val = field.message_type._concrete_class(**field_value)
559 elif hasattr(field_copy, '_internal_assign'):
560 field_copy._internal_assign(field_value)
561 else:
562 raise TypeError(
563 'Message field {0}.{1} must be initialized with a '
564 'dict or instance of same class, got {2}.'.format(
565 message_descriptor.name,
566 field_name,
567 type(field_value).__name__,
568 )
569 )
570
571 if new_val:
572 try:
573 field_copy.MergeFrom(new_val)
574 except TypeError:
575 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
576 self._fields[field] = field_copy
577 else:
578 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
579 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
580 try:
581 setattr(self, field_name, field_value)
582 except TypeError:
583 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
584
585 init.__module__ = None
586 init.__doc__ = None
587 cls.__init__ = init
588
589
590def _GetFieldByName(message_descriptor, field_name):
591 """Returns a field descriptor by field name.
592
593 Args:
594 message_descriptor: A Descriptor describing all fields in message.
595 field_name: The name of the field to retrieve.
596 Returns:
597 The field descriptor associated with the field name.
598 """
599 try:
600 return message_descriptor.fields_by_name[field_name]
601 except KeyError:
602 raise ValueError('Protocol message %s has no "%s" field.' %
603 (message_descriptor.name, field_name))
604
605
606def _AddPropertiesForFields(descriptor, cls):
607 """Adds properties for all fields in this protocol message type."""
608 for field in descriptor.fields:
609 _AddPropertiesForField(field, cls)
610
611 if descriptor.is_extendable:
612 # _ExtensionDict is just an adaptor with no state so we allocate a new one
613 # every time it is accessed.
614 cls.Extensions = property(lambda self: _ExtensionDict(self))
615
616
617def _AddPropertiesForField(field, cls):
618 """Adds a public property for a protocol message field.
619 Clients can use this property to get and (in the case
620 of non-repeated scalar fields) directly set the value
621 of a protocol message field.
622
623 Args:
624 field: A FieldDescriptor for this field.
625 cls: The class we're constructing.
626 """
627 # Catch it if we add other types that we should
628 # handle specially here.
629 assert _FieldDescriptor.MAX_CPPTYPE == 10
630
631 constant_name = field.name.upper() + '_FIELD_NUMBER'
632 setattr(cls, constant_name, field.number)
633
634 if field.label == _FieldDescriptor.LABEL_REPEATED:
635 _AddPropertiesForRepeatedField(field, cls)
636 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
637 _AddPropertiesForNonRepeatedCompositeField(field, cls)
638 else:
639 _AddPropertiesForNonRepeatedScalarField(field, cls)
640
641
642class _FieldProperty(property):
643 __slots__ = ('DESCRIPTOR',)
644
645 def __init__(self, descriptor, getter, setter, doc):
646 property.__init__(self, getter, setter, doc=doc)
647 self.DESCRIPTOR = descriptor
648
649
650def _AddPropertiesForRepeatedField(field, cls):
651 """Adds a public property for a "repeated" protocol message field. Clients
652 can use this property to get the value of the field, which will be either a
653 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
654 below).
655
656 Note that when clients add values to these containers, we perform
657 type-checking in the case of repeated scalar fields, and we also set any
658 necessary "has" bits as a side-effect.
659
660 Args:
661 field: A FieldDescriptor for this field.
662 cls: The class we're constructing.
663 """
664 proto_field_name = field.name
665 property_name = _PropertyName(proto_field_name)
666
667 def getter(self):
668 field_value = self._fields.get(field)
669 if field_value is None:
670 # Construct a new object to represent this field.
671 field_value = field._default_constructor(self)
672
673 # Atomically check if another thread has preempted us and, if not, swap
674 # in the new object we just created. If someone has preempted us, we
675 # take that object and discard ours.
676 # WARNING: We are relying on setdefault() being atomic. This is true
677 # in CPython but we haven't investigated others. This warning appears
678 # in several other locations in this file.
679 field_value = self._fields.setdefault(field, field_value)
680 return field_value
681 getter.__module__ = None
682 getter.__doc__ = 'Getter for %s.' % proto_field_name
683
684 # We define a setter just so we can throw an exception with a more
685 # helpful error message.
686 def setter(self, new_value):
687 raise AttributeError('Assignment not allowed to repeated field '
688 '"%s" in protocol message object.' % proto_field_name)
689
690 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
691 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
692
693
694def _AddPropertiesForNonRepeatedScalarField(field, cls):
695 """Adds a public property for a nonrepeated, scalar protocol message field.
696 Clients can use this property to get and directly set the value of the field.
697 Note that when the client sets the value of a field by using this property,
698 all necessary "has" bits are set as a side-effect, and we also perform
699 type-checking.
700
701 Args:
702 field: A FieldDescriptor for this field.
703 cls: The class we're constructing.
704 """
705 proto_field_name = field.name
706 property_name = _PropertyName(proto_field_name)
707 type_checker = type_checkers.GetTypeChecker(field)
708 default_value = field.default_value
709
710 def getter(self):
711 # TODO: This may be broken since there may not be
712 # default_value. Combine with has_default_value somehow.
713 return self._fields.get(field, default_value)
714 getter.__module__ = None
715 getter.__doc__ = 'Getter for %s.' % proto_field_name
716
717 def field_setter(self, new_value):
718 # pylint: disable=protected-access
719 # Testing the value for truthiness captures all of the proto3 defaults
720 # (0, 0.0, enum 0, and False).
721 try:
722 new_value = type_checker.CheckValue(new_value)
723 except TypeError as e:
724 raise TypeError(
725 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
726 if not field.has_presence and not new_value:
727 self._fields.pop(field, None)
728 else:
729 self._fields[field] = new_value
730 # Check _cached_byte_size_dirty inline to improve performance, since scalar
731 # setters are called frequently.
732 if not self._cached_byte_size_dirty:
733 self._Modified()
734
735 if field.containing_oneof:
736 def setter(self, new_value):
737 field_setter(self, new_value)
738 self._UpdateOneofState(field)
739 else:
740 setter = field_setter
741
742 setter.__module__ = None
743 setter.__doc__ = 'Setter for %s.' % proto_field_name
744
745 # Add a property to encapsulate the getter/setter.
746 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
747 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
748
749
750def _AddPropertiesForNonRepeatedCompositeField(field, cls):
751 """Adds a public property for a nonrepeated, composite protocol message field.
752 A composite field is a "group" or "message" field.
753
754 Clients can use this property to get the value of the field, but cannot
755 assign to the property directly.
756
757 Args:
758 field: A FieldDescriptor for this field.
759 cls: The class we're constructing.
760 """
761 # TODO: Remove duplication with similar method
762 # for non-repeated scalars.
763 proto_field_name = field.name
764 property_name = _PropertyName(proto_field_name)
765
766 def getter(self):
767 field_value = self._fields.get(field)
768 if field_value is None:
769 # Construct a new object to represent this field.
770 field_value = field._default_constructor(self)
771
772 # Atomically check if another thread has preempted us and, if not, swap
773 # in the new object we just created. If someone has preempted us, we
774 # take that object and discard ours.
775 # WARNING: We are relying on setdefault() being atomic. This is true
776 # in CPython but we haven't investigated others. This warning appears
777 # in several other locations in this file.
778 field_value = self._fields.setdefault(field, field_value)
779 return field_value
780 getter.__module__ = None
781 getter.__doc__ = 'Getter for %s.' % proto_field_name
782
783 # We define a setter just so we can throw an exception with a more
784 # helpful error message.
785 def setter(self, new_value):
786 if field.message_type.full_name == 'google.protobuf.Timestamp':
787 getter(self)
788 self._fields[field].FromDatetime(new_value)
789 elif field.message_type.full_name == 'google.protobuf.Duration':
790 getter(self)
791 self._fields[field].FromTimedelta(new_value)
792 elif field.message_type.full_name == _StructFullTypeName:
793 getter(self)
794 self._fields[field].Clear()
795 self._fields[field].update(new_value)
796 elif field.message_type.full_name == _ListValueFullTypeName:
797 getter(self)
798 self._fields[field].Clear()
799 self._fields[field].extend(new_value)
800 else:
801 raise AttributeError(
802 'Assignment not allowed to composite field '
803 '"%s" in protocol message object.' % proto_field_name
804 )
805
806 # Add a property to encapsulate the getter.
807 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
808 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
809
810
811def _AddPropertiesForExtensions(descriptor, cls):
812 """Adds properties for all fields in this protocol message type."""
813 extensions = descriptor.extensions_by_name
814 for extension_name, extension_field in extensions.items():
815 constant_name = extension_name.upper() + '_FIELD_NUMBER'
816 setattr(cls, constant_name, extension_field.number)
817
818 # TODO: Migrate all users of these attributes to functions like
819 # pool.FindExtensionByNumber(descriptor).
820 if descriptor.file is not None:
821 # TODO: Use cls.MESSAGE_FACTORY.pool when available.
822 pool = descriptor.file.pool
823
824def _AddStaticMethods(cls):
825 def FromString(s):
826 message = cls()
827 message.MergeFromString(s)
828 return message
829 cls.FromString = staticmethod(FromString)
830
831
832def _IsPresent(item):
833 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
834 value should be included in the list returned by ListFields()."""
835
836 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
837 return bool(item[1])
838 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
839 return item[1]._is_present_in_parent
840 else:
841 return True
842
843
844def _AddListFieldsMethod(message_descriptor, cls):
845 """Helper for _AddMessageMethods()."""
846
847 def ListFields(self):
848 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
849 all_fields.sort(key = lambda item: item[0].number)
850 return all_fields
851
852 cls.ListFields = ListFields
853
854
855def _AddHasFieldMethod(message_descriptor, cls):
856 """Helper for _AddMessageMethods()."""
857
858 hassable_fields = {}
859 for field in message_descriptor.fields:
860 if field.label == _FieldDescriptor.LABEL_REPEATED:
861 continue
862 # For proto3, only submessages and fields inside a oneof have presence.
863 if not field.has_presence:
864 continue
865 hassable_fields[field.name] = field
866
867 # Has methods are supported for oneof descriptors.
868 for oneof in message_descriptor.oneofs:
869 hassable_fields[oneof.name] = oneof
870
871 def HasField(self, field_name):
872 try:
873 field = hassable_fields[field_name]
874 except KeyError as exc:
875 raise ValueError('Protocol message %s has no non-repeated field "%s" '
876 'nor has presence is not available for this field.' % (
877 message_descriptor.full_name, field_name)) from exc
878
879 if isinstance(field, descriptor_mod.OneofDescriptor):
880 try:
881 return HasField(self, self._oneofs[field].name)
882 except KeyError:
883 return False
884 else:
885 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
886 value = self._fields.get(field)
887 return value is not None and value._is_present_in_parent
888 else:
889 return field in self._fields
890
891 cls.HasField = HasField
892
893
894def _AddClearFieldMethod(message_descriptor, cls):
895 """Helper for _AddMessageMethods()."""
896 def ClearField(self, field_name):
897 try:
898 field = message_descriptor.fields_by_name[field_name]
899 except KeyError:
900 try:
901 field = message_descriptor.oneofs_by_name[field_name]
902 if field in self._oneofs:
903 field = self._oneofs[field]
904 else:
905 return
906 except KeyError:
907 raise ValueError('Protocol message %s has no "%s" field.' %
908 (message_descriptor.name, field_name))
909
910 if field in self._fields:
911 # To match the C++ implementation, we need to invalidate iterators
912 # for map fields when ClearField() happens.
913 if hasattr(self._fields[field], 'InvalidateIterators'):
914 self._fields[field].InvalidateIterators()
915
916 # Note: If the field is a sub-message, its listener will still point
917 # at us. That's fine, because the worst than can happen is that it
918 # will call _Modified() and invalidate our byte size. Big deal.
919 del self._fields[field]
920
921 if self._oneofs.get(field.containing_oneof, None) is field:
922 del self._oneofs[field.containing_oneof]
923
924 # Always call _Modified() -- even if nothing was changed, this is
925 # a mutating method, and thus calling it should cause the field to become
926 # present in the parent message.
927 self._Modified()
928
929 cls.ClearField = ClearField
930
931
932def _AddClearExtensionMethod(cls):
933 """Helper for _AddMessageMethods()."""
934 def ClearExtension(self, field_descriptor):
935 extension_dict._VerifyExtensionHandle(self, field_descriptor)
936
937 # Similar to ClearField(), above.
938 if field_descriptor in self._fields:
939 del self._fields[field_descriptor]
940 self._Modified()
941 cls.ClearExtension = ClearExtension
942
943
944def _AddHasExtensionMethod(cls):
945 """Helper for _AddMessageMethods()."""
946 def HasExtension(self, field_descriptor):
947 extension_dict._VerifyExtensionHandle(self, field_descriptor)
948 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
949 raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
950
951 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
952 value = self._fields.get(field_descriptor)
953 return value is not None and value._is_present_in_parent
954 else:
955 return field_descriptor in self._fields
956 cls.HasExtension = HasExtension
957
958def _InternalUnpackAny(msg):
959 """Unpacks Any message and returns the unpacked message.
960
961 This internal method is different from public Any Unpack method which takes
962 the target message as argument. _InternalUnpackAny method does not have
963 target message type and need to find the message type in descriptor pool.
964
965 Args:
966 msg: An Any message to be unpacked.
967
968 Returns:
969 The unpacked message.
970 """
971 # TODO: Don't use the factory of generated messages.
972 # To make Any work with custom factories, use the message factory of the
973 # parent message.
974 # pylint: disable=g-import-not-at-top
975 from google.protobuf import symbol_database
976 factory = symbol_database.Default()
977
978 type_url = msg.type_url
979
980 if not type_url:
981 return None
982
983 # TODO: For now we just strip the hostname. Better logic will be
984 # required.
985 type_name = type_url.split('/')[-1]
986 descriptor = factory.pool.FindMessageTypeByName(type_name)
987
988 if descriptor is None:
989 return None
990
991 message_class = factory.GetPrototype(descriptor)
992 message = message_class()
993
994 message.ParseFromString(msg.value)
995 return message
996
997
998def _AddEqualsMethod(message_descriptor, cls):
999 """Helper for _AddMessageMethods()."""
1000 def __eq__(self, other):
1001 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1002 other, list
1003 ):
1004 return self._internal_compare(other)
1005 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1006 other, dict
1007 ):
1008 return self._internal_compare(other)
1009
1010 if (not isinstance(other, message_mod.Message) or
1011 other.DESCRIPTOR != self.DESCRIPTOR):
1012 return NotImplemented
1013
1014 if self is other:
1015 return True
1016
1017 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1018 any_a = _InternalUnpackAny(self)
1019 any_b = _InternalUnpackAny(other)
1020 if any_a and any_b:
1021 return any_a == any_b
1022
1023 if not self.ListFields() == other.ListFields():
1024 return False
1025
1026 # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1027 # then use it for the comparison.
1028 unknown_fields = list(self._unknown_fields)
1029 unknown_fields.sort()
1030 other_unknown_fields = list(other._unknown_fields)
1031 other_unknown_fields.sort()
1032 return unknown_fields == other_unknown_fields
1033
1034 cls.__eq__ = __eq__
1035
1036
1037def _AddStrMethod(message_descriptor, cls):
1038 """Helper for _AddMessageMethods()."""
1039 def __str__(self):
1040 return text_format.MessageToString(self)
1041 cls.__str__ = __str__
1042
1043
1044def _AddReprMethod(message_descriptor, cls):
1045 """Helper for _AddMessageMethods()."""
1046 def __repr__(self):
1047 return text_format.MessageToString(self)
1048 cls.__repr__ = __repr__
1049
1050
1051def _AddUnicodeMethod(unused_message_descriptor, cls):
1052 """Helper for _AddMessageMethods()."""
1053
1054 def __unicode__(self):
1055 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1056 cls.__unicode__ = __unicode__
1057
1058
1059def _AddContainsMethod(message_descriptor, cls):
1060
1061 if message_descriptor.full_name == 'google.protobuf.Struct':
1062 def __contains__(self, key):
1063 return key in self.fields
1064 elif message_descriptor.full_name == 'google.protobuf.ListValue':
1065 def __contains__(self, value):
1066 return value in self.items()
1067 else:
1068 def __contains__(self, field):
1069 return self.HasField(field)
1070
1071 cls.__contains__ = __contains__
1072
1073
1074def _BytesForNonRepeatedElement(value, field_number, field_type):
1075 """Returns the number of bytes needed to serialize a non-repeated element.
1076 The returned byte count includes space for tag information and any
1077 other additional space associated with serializing value.
1078
1079 Args:
1080 value: Value we're serializing.
1081 field_number: Field number of this value. (Since the field number
1082 is stored as part of a varint-encoded tag, this has an impact
1083 on the total bytes required to serialize the value).
1084 field_type: The type of the field. One of the TYPE_* constants
1085 within FieldDescriptor.
1086 """
1087 try:
1088 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1089 return fn(field_number, value)
1090 except KeyError:
1091 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1092
1093
1094def _AddByteSizeMethod(message_descriptor, cls):
1095 """Helper for _AddMessageMethods()."""
1096
1097 def ByteSize(self):
1098 if not self._cached_byte_size_dirty:
1099 return self._cached_byte_size
1100
1101 size = 0
1102 descriptor = self.DESCRIPTOR
1103 if descriptor._is_map_entry:
1104 # Fields of map entry should always be serialized.
1105 key_field = descriptor.fields_by_name['key']
1106 _MaybeAddEncoder(cls, key_field)
1107 size = key_field._sizer(self.key)
1108 value_field = descriptor.fields_by_name['value']
1109 _MaybeAddEncoder(cls, value_field)
1110 size += value_field._sizer(self.value)
1111 else:
1112 for field_descriptor, field_value in self.ListFields():
1113 _MaybeAddEncoder(cls, field_descriptor)
1114 size += field_descriptor._sizer(field_value)
1115 for tag_bytes, value_bytes in self._unknown_fields:
1116 size += len(tag_bytes) + len(value_bytes)
1117
1118 self._cached_byte_size = size
1119 self._cached_byte_size_dirty = False
1120 self._listener_for_children.dirty = False
1121 return size
1122
1123 cls.ByteSize = ByteSize
1124
1125
1126def _AddSerializeToStringMethod(message_descriptor, cls):
1127 """Helper for _AddMessageMethods()."""
1128
1129 def SerializeToString(self, **kwargs):
1130 # Check if the message has all of its required fields set.
1131 if not self.IsInitialized():
1132 raise message_mod.EncodeError(
1133 'Message %s is missing required fields: %s' % (
1134 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1135 return self.SerializePartialToString(**kwargs)
1136 cls.SerializeToString = SerializeToString
1137
1138
1139def _AddSerializePartialToStringMethod(message_descriptor, cls):
1140 """Helper for _AddMessageMethods()."""
1141
1142 def SerializePartialToString(self, **kwargs):
1143 out = BytesIO()
1144 self._InternalSerialize(out.write, **kwargs)
1145 return out.getvalue()
1146 cls.SerializePartialToString = SerializePartialToString
1147
1148 def InternalSerialize(self, write_bytes, deterministic=None):
1149 if deterministic is None:
1150 deterministic = (
1151 api_implementation.IsPythonDefaultSerializationDeterministic())
1152 else:
1153 deterministic = bool(deterministic)
1154
1155 descriptor = self.DESCRIPTOR
1156 if descriptor._is_map_entry:
1157 # Fields of map entry should always be serialized.
1158 key_field = descriptor.fields_by_name['key']
1159 _MaybeAddEncoder(cls, key_field)
1160 key_field._encoder(write_bytes, self.key, deterministic)
1161 value_field = descriptor.fields_by_name['value']
1162 _MaybeAddEncoder(cls, value_field)
1163 value_field._encoder(write_bytes, self.value, deterministic)
1164 else:
1165 for field_descriptor, field_value in self.ListFields():
1166 _MaybeAddEncoder(cls, field_descriptor)
1167 field_descriptor._encoder(write_bytes, field_value, deterministic)
1168 for tag_bytes, value_bytes in self._unknown_fields:
1169 write_bytes(tag_bytes)
1170 write_bytes(value_bytes)
1171 cls._InternalSerialize = InternalSerialize
1172
1173
1174def _AddMergeFromStringMethod(message_descriptor, cls):
1175 """Helper for _AddMessageMethods()."""
1176 def MergeFromString(self, serialized):
1177 serialized = memoryview(serialized)
1178 length = len(serialized)
1179 try:
1180 if self._InternalParse(serialized, 0, length) != length:
1181 # The only reason _InternalParse would return early is if it
1182 # encountered an end-group tag.
1183 raise message_mod.DecodeError('Unexpected end-group tag.')
1184 except (IndexError, TypeError):
1185 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1186 raise message_mod.DecodeError('Truncated message.')
1187 except struct.error as e:
1188 raise message_mod.DecodeError(e)
1189 return length # Return this for legacy reasons.
1190 cls.MergeFromString = MergeFromString
1191
1192 local_ReadTag = decoder.ReadTag
1193 local_SkipField = decoder.SkipField
1194 fields_by_tag = cls._fields_by_tag
1195 message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1196
1197 def InternalParse(self, buffer, pos, end):
1198 """Create a message from serialized bytes.
1199
1200 Args:
1201 self: Message, instance of the proto message object.
1202 buffer: memoryview of the serialized data.
1203 pos: int, position to start in the serialized data.
1204 end: int, end position of the serialized data.
1205
1206 Returns:
1207 Message object.
1208 """
1209 # Guard against internal misuse, since this function is called internally
1210 # quite extensively, and its easy to accidentally pass bytes.
1211 assert isinstance(buffer, memoryview)
1212 self._Modified()
1213 field_dict = self._fields
1214 while pos != end:
1215 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1216 field_decoder, field_des = message_set_decoders_by_tag.get(
1217 tag_bytes, (None, None)
1218 )
1219 if field_decoder:
1220 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1221 continue
1222 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1223 if field_des is None:
1224 if not self._unknown_fields: # pylint: disable=protected-access
1225 self._unknown_fields = [] # pylint: disable=protected-access
1226 # pylint: disable=protected-access
1227 (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1228 field_number, wire_type = wire_format.UnpackTag(tag)
1229 if field_number == 0:
1230 raise message_mod.DecodeError('Field number 0 is illegal.')
1231 # TODO: remove old_pos.
1232 old_pos = new_pos
1233 (data, new_pos) = decoder._DecodeUnknownField(
1234 buffer, new_pos, wire_type) # pylint: disable=protected-access
1235 if new_pos == -1:
1236 return pos
1237 # TODO: remove _unknown_fields.
1238 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1239 if new_pos == -1:
1240 return pos
1241 self._unknown_fields.append(
1242 (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1243 pos = new_pos
1244 else:
1245 _MaybeAddDecoder(cls, field_des)
1246 field_decoder = field_des._decoders[is_packed]
1247 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1248 if field_des.containing_oneof:
1249 self._UpdateOneofState(field_des)
1250 return pos
1251 cls._InternalParse = InternalParse
1252
1253
1254def _AddIsInitializedMethod(message_descriptor, cls):
1255 """Adds the IsInitialized and FindInitializationError methods to the
1256 protocol message class."""
1257
1258 required_fields = [field for field in message_descriptor.fields
1259 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1260
1261 def IsInitialized(self, errors=None):
1262 """Checks if all required fields of a message are set.
1263
1264 Args:
1265 errors: A list which, if provided, will be populated with the field
1266 paths of all missing required fields.
1267
1268 Returns:
1269 True iff the specified message has all required fields set.
1270 """
1271
1272 # Performance is critical so we avoid HasField() and ListFields().
1273
1274 for field in required_fields:
1275 if (field not in self._fields or
1276 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1277 not self._fields[field]._is_present_in_parent)):
1278 if errors is not None:
1279 errors.extend(self.FindInitializationErrors())
1280 return False
1281
1282 for field, value in list(self._fields.items()): # dict can change size!
1283 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1284 if field.label == _FieldDescriptor.LABEL_REPEATED:
1285 if (field.message_type._is_map_entry):
1286 continue
1287 for element in value:
1288 if not element.IsInitialized():
1289 if errors is not None:
1290 errors.extend(self.FindInitializationErrors())
1291 return False
1292 elif value._is_present_in_parent and not value.IsInitialized():
1293 if errors is not None:
1294 errors.extend(self.FindInitializationErrors())
1295 return False
1296
1297 return True
1298
1299 cls.IsInitialized = IsInitialized
1300
1301 def FindInitializationErrors(self):
1302 """Finds required fields which are not initialized.
1303
1304 Returns:
1305 A list of strings. Each string is a path to an uninitialized field from
1306 the top-level message, e.g. "foo.bar[5].baz".
1307 """
1308
1309 errors = [] # simplify things
1310
1311 for field in required_fields:
1312 if not self.HasField(field.name):
1313 errors.append(field.name)
1314
1315 for field, value in self.ListFields():
1316 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1317 if field.is_extension:
1318 name = '(%s)' % field.full_name
1319 else:
1320 name = field.name
1321
1322 if _IsMapField(field):
1323 if _IsMessageMapField(field):
1324 for key in value:
1325 element = value[key]
1326 prefix = '%s[%s].' % (name, key)
1327 sub_errors = element.FindInitializationErrors()
1328 errors += [prefix + error for error in sub_errors]
1329 else:
1330 # ScalarMaps can't have any initialization errors.
1331 pass
1332 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1333 for i in range(len(value)):
1334 element = value[i]
1335 prefix = '%s[%d].' % (name, i)
1336 sub_errors = element.FindInitializationErrors()
1337 errors += [prefix + error for error in sub_errors]
1338 else:
1339 prefix = name + '.'
1340 sub_errors = value.FindInitializationErrors()
1341 errors += [prefix + error for error in sub_errors]
1342
1343 return errors
1344
1345 cls.FindInitializationErrors = FindInitializationErrors
1346
1347
1348def _FullyQualifiedClassName(klass):
1349 module = klass.__module__
1350 name = getattr(klass, '__qualname__', klass.__name__)
1351 if module in (None, 'builtins', '__builtin__'):
1352 return name
1353 return module + '.' + name
1354
1355
1356def _AddMergeFromMethod(cls):
1357 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1358 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1359
1360 def MergeFrom(self, msg):
1361 if not isinstance(msg, cls):
1362 raise TypeError(
1363 'Parameter to MergeFrom() must be instance of same class: '
1364 'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1365 _FullyQualifiedClassName(msg.__class__)))
1366
1367 assert msg is not self
1368 self._Modified()
1369
1370 fields = self._fields
1371
1372 for field, value in msg._fields.items():
1373 if field.label == LABEL_REPEATED:
1374 field_value = fields.get(field)
1375 if field_value is None:
1376 # Construct a new object to represent this field.
1377 field_value = field._default_constructor(self)
1378 fields[field] = field_value
1379 field_value.MergeFrom(value)
1380 elif field.cpp_type == CPPTYPE_MESSAGE:
1381 if value._is_present_in_parent:
1382 field_value = fields.get(field)
1383 if field_value is None:
1384 # Construct a new object to represent this field.
1385 field_value = field._default_constructor(self)
1386 fields[field] = field_value
1387 field_value.MergeFrom(value)
1388 else:
1389 self._fields[field] = value
1390 if field.containing_oneof:
1391 self._UpdateOneofState(field)
1392
1393 if msg._unknown_fields:
1394 if not self._unknown_fields:
1395 self._unknown_fields = []
1396 self._unknown_fields.extend(msg._unknown_fields)
1397
1398 cls.MergeFrom = MergeFrom
1399
1400
1401def _AddWhichOneofMethod(message_descriptor, cls):
1402 def WhichOneof(self, oneof_name):
1403 """Returns the name of the currently set field inside a oneof, or None."""
1404 try:
1405 field = message_descriptor.oneofs_by_name[oneof_name]
1406 except KeyError:
1407 raise ValueError(
1408 'Protocol message has no oneof "%s" field.' % oneof_name)
1409
1410 nested_field = self._oneofs.get(field, None)
1411 if nested_field is not None and self.HasField(nested_field.name):
1412 return nested_field.name
1413 else:
1414 return None
1415
1416 cls.WhichOneof = WhichOneof
1417
1418
1419def _Clear(self):
1420 # Clear fields.
1421 self._fields = {}
1422 self._unknown_fields = ()
1423
1424 self._oneofs = {}
1425 self._Modified()
1426
1427
1428def _UnknownFields(self):
1429 raise NotImplementedError('Please use the add-on feaure '
1430 'unknown_fields.UnknownFieldSet(message) in '
1431 'unknown_fields.py instead.')
1432
1433
1434def _DiscardUnknownFields(self):
1435 self._unknown_fields = []
1436 for field, value in self.ListFields():
1437 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1438 if _IsMapField(field):
1439 if _IsMessageMapField(field):
1440 for key in value:
1441 value[key].DiscardUnknownFields()
1442 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1443 for sub_message in value:
1444 sub_message.DiscardUnknownFields()
1445 else:
1446 value.DiscardUnknownFields()
1447
1448
1449def _SetListener(self, listener):
1450 if listener is None:
1451 self._listener = message_listener_mod.NullMessageListener()
1452 else:
1453 self._listener = listener
1454
1455
1456def _AddMessageMethods(message_descriptor, cls):
1457 """Adds implementations of all Message methods to cls."""
1458 _AddListFieldsMethod(message_descriptor, cls)
1459 _AddHasFieldMethod(message_descriptor, cls)
1460 _AddClearFieldMethod(message_descriptor, cls)
1461 if message_descriptor.is_extendable:
1462 _AddClearExtensionMethod(cls)
1463 _AddHasExtensionMethod(cls)
1464 _AddEqualsMethod(message_descriptor, cls)
1465 _AddStrMethod(message_descriptor, cls)
1466 _AddReprMethod(message_descriptor, cls)
1467 _AddUnicodeMethod(message_descriptor, cls)
1468 _AddContainsMethod(message_descriptor, cls)
1469 _AddByteSizeMethod(message_descriptor, cls)
1470 _AddSerializeToStringMethod(message_descriptor, cls)
1471 _AddSerializePartialToStringMethod(message_descriptor, cls)
1472 _AddMergeFromStringMethod(message_descriptor, cls)
1473 _AddIsInitializedMethod(message_descriptor, cls)
1474 _AddMergeFromMethod(cls)
1475 _AddWhichOneofMethod(message_descriptor, cls)
1476 # Adds methods which do not depend on cls.
1477 cls.Clear = _Clear
1478 cls.DiscardUnknownFields = _DiscardUnknownFields
1479 cls._SetListener = _SetListener
1480
1481
1482def _AddPrivateHelperMethods(message_descriptor, cls):
1483 """Adds implementation of private helper methods to cls."""
1484
1485 def Modified(self):
1486 """Sets the _cached_byte_size_dirty bit to true,
1487 and propagates this to our listener iff this was a state change.
1488 """
1489
1490 # Note: Some callers check _cached_byte_size_dirty before calling
1491 # _Modified() as an extra optimization. So, if this method is ever
1492 # changed such that it does stuff even when _cached_byte_size_dirty is
1493 # already true, the callers need to be updated.
1494 if not self._cached_byte_size_dirty:
1495 self._cached_byte_size_dirty = True
1496 self._listener_for_children.dirty = True
1497 self._is_present_in_parent = True
1498 self._listener.Modified()
1499
1500 def _UpdateOneofState(self, field):
1501 """Sets field as the active field in its containing oneof.
1502
1503 Will also delete currently active field in the oneof, if it is different
1504 from the argument. Does not mark the message as modified.
1505 """
1506 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1507 if other_field is not field:
1508 del self._fields[other_field]
1509 self._oneofs[field.containing_oneof] = field
1510
1511 cls._Modified = Modified
1512 cls.SetInParent = Modified
1513 cls._UpdateOneofState = _UpdateOneofState
1514
1515
1516class _Listener(object):
1517
1518 """MessageListener implementation that a parent message registers with its
1519 child message.
1520
1521 In order to support semantics like:
1522
1523 foo.bar.baz.moo = 23
1524 assert foo.HasField('bar')
1525
1526 ...child objects must have back references to their parents.
1527 This helper class is at the heart of this support.
1528 """
1529
1530 def __init__(self, parent_message):
1531 """Args:
1532 parent_message: The message whose _Modified() method we should call when
1533 we receive Modified() messages.
1534 """
1535 # This listener establishes a back reference from a child (contained) object
1536 # to its parent (containing) object. We make this a weak reference to avoid
1537 # creating cyclic garbage when the client finishes with the 'parent' object
1538 # in the tree.
1539 if isinstance(parent_message, weakref.ProxyType):
1540 self._parent_message_weakref = parent_message
1541 else:
1542 self._parent_message_weakref = weakref.proxy(parent_message)
1543
1544 # As an optimization, we also indicate directly on the listener whether
1545 # or not the parent message is dirty. This way we can avoid traversing
1546 # up the tree in the common case.
1547 self.dirty = False
1548
1549 def Modified(self):
1550 if self.dirty:
1551 return
1552 try:
1553 # Propagate the signal to our parents iff this is the first field set.
1554 self._parent_message_weakref._Modified()
1555 except ReferenceError:
1556 # We can get here if a client has kept a reference to a child object,
1557 # and is now setting a field on it, but the child's parent has been
1558 # garbage-collected. This is not an error.
1559 pass
1560
1561
1562class _OneofListener(_Listener):
1563 """Special listener implementation for setting composite oneof fields."""
1564
1565 def __init__(self, parent_message, field):
1566 """Args:
1567 parent_message: The message whose _Modified() method we should call when
1568 we receive Modified() messages.
1569 field: The descriptor of the field being set in the parent message.
1570 """
1571 super(_OneofListener, self).__init__(parent_message)
1572 self._field = field
1573
1574 def Modified(self):
1575 """Also updates the state of the containing oneof in the parent message."""
1576 try:
1577 self._parent_message_weakref._UpdateOneofState(self._field)
1578 super(_OneofListener, self).Modified()
1579 except ReferenceError:
1580 pass