1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides DescriptorPool to use as a container for proto2 descriptors.
9
10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11a collection of protocol buffer descriptors for use when dynamically creating
12message types at runtime.
13
14For most applications protocol buffers should be used via modules generated by
15the protocol buffer compiler tool. This should only be used when the type of
16protocol buffers used in an application or library cannot be predetermined.
17
18Below is a straightforward example on how to use this class::
19
20 pool = DescriptorPool()
21 file_descriptor_protos = [ ... ]
22 for file_descriptor_proto in file_descriptor_protos:
23 pool.Add(file_descriptor_proto)
24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
25
26The message descriptor can be used in conjunction with the message_factory
27module in order to create a protocol buffer class that can be encoded and
28decoded.
29
30If you want to get a Python class for the specified proto, use the
31helper functions inside google.protobuf.message_factory
32directly instead of this class.
33"""
34
35__author__ = 'matthewtoia@google.com (Matt Toia)'
36
37import collections
38import threading
39import warnings
40
41from google.protobuf import descriptor
42from google.protobuf import descriptor_database
43from google.protobuf import text_encoding
44from google.protobuf.internal import python_edition_defaults
45from google.protobuf.internal import python_message
46
47_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
48
49
50def _NormalizeFullyQualifiedName(name):
51 """Remove leading period from fully-qualified type name.
52
53 Due to b/13860351 in descriptor_database.py, types in the root namespace are
54 generated with a leading period. This function removes that prefix.
55
56 Args:
57 name (str): The fully-qualified symbol name.
58
59 Returns:
60 str: The normalized fully-qualified symbol name.
61 """
62 return name.lstrip('.')
63
64
65def _OptionsOrNone(descriptor_proto):
66 """Returns the value of the field `options`, or None if it is not set."""
67 if descriptor_proto.HasField('options'):
68 return descriptor_proto.options
69 else:
70 return None
71
72
73def _IsMessageSetExtension(field):
74 return (field.is_extension and
75 field.containing_type.has_options and
76 field.containing_type.GetOptions().message_set_wire_format and
77 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
78 not field.is_required and
79 not field.is_repeated)
80
81_edition_defaults_lock = threading.Lock()
82
83
84class DescriptorPool(object):
85 """A collection of protobufs dynamically constructed by descriptor protos."""
86
87 if _USE_C_DESCRIPTORS:
88
89 def __new__(cls, descriptor_db=None):
90 # pylint: disable=protected-access
91 return descriptor._message.DescriptorPool(descriptor_db)
92
93 def __init__(
94 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
95 ):
96 """Initializes a Pool of proto buffs.
97
98 The descriptor_db argument to the constructor is provided to allow
99 specialized file descriptor proto lookup code to be triggered on demand. An
100 example would be an implementation which will read and compile a file
101 specified in a call to FindFileByName() and not require the call to Add()
102 at all. Results from this database will be cached internally here as well.
103
104 Args:
105 descriptor_db: A secondary source of file descriptors.
106 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
107 C++.
108 """
109
110 self._internal_db = descriptor_database.DescriptorDatabase()
111 self._descriptor_db = descriptor_db
112 self._descriptors = {}
113 self._enum_descriptors = {}
114 self._service_descriptors = {}
115 self._file_descriptors = {}
116 self._toplevel_extensions = {}
117 self._top_enum_values = {}
118 # We store extensions in two two-level mappings: The first key is the
119 # descriptor of the message being extended, the second key is the extension
120 # full name or its tag number.
121 self._extensions_by_name = collections.defaultdict(dict)
122 self._extensions_by_number = collections.defaultdict(dict)
123 self._serialized_edition_defaults = (
124 python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
125 )
126 self._edition_defaults = None
127 self._feature_cache = dict()
128
129 def _CheckConflictRegister(self, desc, desc_name, file_name):
130 """Check if the descriptor name conflicts with another of the same name.
131
132 Args:
133 desc: Descriptor of a message, enum, service, extension or enum value.
134 desc_name (str): the full name of desc.
135 file_name (str): The file name of descriptor.
136 """
137 for register, descriptor_type in [
138 (self._descriptors, descriptor.Descriptor),
139 (self._enum_descriptors, descriptor.EnumDescriptor),
140 (self._service_descriptors, descriptor.ServiceDescriptor),
141 (self._toplevel_extensions, descriptor.FieldDescriptor),
142 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
143 if desc_name in register:
144 old_desc = register[desc_name]
145 if isinstance(old_desc, descriptor.EnumValueDescriptor):
146 old_file = old_desc.type.file.name
147 else:
148 old_file = old_desc.file.name
149
150 if not isinstance(desc, descriptor_type) or (
151 old_file != file_name):
152 error_msg = ('Conflict register for file "' + file_name +
153 '": ' + desc_name +
154 ' is already defined in file "' +
155 old_file + '". Please fix the conflict by adding '
156 'package name on the proto file, or use different '
157 'name for the duplication.')
158 if isinstance(desc, descriptor.EnumValueDescriptor):
159 error_msg += ('\nNote: enum values appear as '
160 'siblings of the enum type instead of '
161 'children of it.')
162
163 raise TypeError(error_msg)
164
165 return
166
167 def Add(self, file_desc_proto):
168 """Adds the FileDescriptorProto and its types to this pool.
169
170 Args:
171 file_desc_proto (FileDescriptorProto): The file descriptor to add.
172 """
173
174 self._internal_db.Add(file_desc_proto)
175
176 def AddSerializedFile(self, serialized_file_desc_proto):
177 """Adds the FileDescriptorProto and its types to this pool.
178
179 Args:
180 serialized_file_desc_proto (bytes): A bytes string, serialization of the
181 :class:`FileDescriptorProto` to add.
182
183 Returns:
184 FileDescriptor: Descriptor for the added file.
185 """
186
187 # pylint: disable=g-import-not-at-top
188 from google.protobuf import descriptor_pb2
189 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
190 serialized_file_desc_proto)
191 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
192 file_desc.serialized_pb = serialized_file_desc_proto
193 return file_desc
194
195 # Never call this method. It is for internal usage only.
196 def _AddDescriptor(self, desc):
197 """Adds a Descriptor to the pool, non-recursively.
198
199 If the Descriptor contains nested messages or enums, the caller must
200 explicitly register them. This method also registers the FileDescriptor
201 associated with the message.
202
203 Args:
204 desc: A Descriptor.
205 """
206 if not isinstance(desc, descriptor.Descriptor):
207 raise TypeError('Expected instance of descriptor.Descriptor.')
208
209 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
210
211 self._descriptors[desc.full_name] = desc
212 self._AddFileDescriptor(desc.file)
213
214 # Never call this method. It is for internal usage only.
215 def _AddEnumDescriptor(self, enum_desc):
216 """Adds an EnumDescriptor to the pool.
217
218 This method also registers the FileDescriptor associated with the enum.
219
220 Args:
221 enum_desc: An EnumDescriptor.
222 """
223
224 if not isinstance(enum_desc, descriptor.EnumDescriptor):
225 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
226
227 file_name = enum_desc.file.name
228 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
229 self._enum_descriptors[enum_desc.full_name] = enum_desc
230
231 # Top enum values need to be indexed.
232 # Count the number of dots to see whether the enum is toplevel or nested
233 # in a message. We cannot use enum_desc.containing_type at this stage.
234 if enum_desc.file.package:
235 top_level = (enum_desc.full_name.count('.')
236 - enum_desc.file.package.count('.') == 1)
237 else:
238 top_level = enum_desc.full_name.count('.') == 0
239 if top_level:
240 file_name = enum_desc.file.name
241 package = enum_desc.file.package
242 for enum_value in enum_desc.values:
243 full_name = _NormalizeFullyQualifiedName(
244 '.'.join((package, enum_value.name)))
245 self._CheckConflictRegister(enum_value, full_name, file_name)
246 self._top_enum_values[full_name] = enum_value
247 self._AddFileDescriptor(enum_desc.file)
248
249 # Never call this method. It is for internal usage only.
250 def _AddServiceDescriptor(self, service_desc):
251 """Adds a ServiceDescriptor to the pool.
252
253 Args:
254 service_desc: A ServiceDescriptor.
255 """
256
257 if not isinstance(service_desc, descriptor.ServiceDescriptor):
258 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
259
260 self._CheckConflictRegister(service_desc, service_desc.full_name,
261 service_desc.file.name)
262 self._service_descriptors[service_desc.full_name] = service_desc
263
264 # Never call this method. It is for internal usage only.
265 def _AddExtensionDescriptor(self, extension):
266 """Adds a FieldDescriptor describing an extension to the pool.
267
268 Args:
269 extension: A FieldDescriptor.
270
271 Raises:
272 AssertionError: when another extension with the same number extends the
273 same message.
274 TypeError: when the specified extension is not a
275 descriptor.FieldDescriptor.
276 """
277 if not (isinstance(extension, descriptor.FieldDescriptor) and
278 extension.is_extension):
279 raise TypeError('Expected an extension descriptor.')
280
281 if extension.extension_scope is None:
282 self._CheckConflictRegister(
283 extension, extension.full_name, extension.file.name)
284 self._toplevel_extensions[extension.full_name] = extension
285
286 try:
287 existing_desc = self._extensions_by_number[
288 extension.containing_type][extension.number]
289 except KeyError:
290 pass
291 else:
292 if extension is not existing_desc:
293 raise AssertionError(
294 'Extensions "%s" and "%s" both try to extend message type "%s" '
295 'with field number %d.' %
296 (extension.full_name, existing_desc.full_name,
297 extension.containing_type.full_name, extension.number))
298
299 self._extensions_by_number[extension.containing_type][
300 extension.number] = extension
301 self._extensions_by_name[extension.containing_type][
302 extension.full_name] = extension
303
304 # Also register MessageSet extensions with the type name.
305 if _IsMessageSetExtension(extension):
306 self._extensions_by_name[extension.containing_type][
307 extension.message_type.full_name] = extension
308
309 if hasattr(extension.containing_type, '_concrete_class'):
310 python_message._AttachFieldHelpers(
311 extension.containing_type._concrete_class, extension)
312
313 # Never call this method. It is for internal usage only.
314 def _InternalAddFileDescriptor(self, file_desc):
315 """Adds a FileDescriptor to the pool, non-recursively.
316
317 If the FileDescriptor contains messages or enums, the caller must explicitly
318 register them.
319
320 Args:
321 file_desc: A FileDescriptor.
322 """
323
324 self._AddFileDescriptor(file_desc)
325
326 def _AddFileDescriptor(self, file_desc):
327 """Adds a FileDescriptor to the pool, non-recursively.
328
329 If the FileDescriptor contains messages or enums, the caller must explicitly
330 register them.
331
332 Args:
333 file_desc: A FileDescriptor.
334 """
335
336 if not isinstance(file_desc, descriptor.FileDescriptor):
337 raise TypeError('Expected instance of descriptor.FileDescriptor.')
338 self._file_descriptors[file_desc.name] = file_desc
339
340 def FindFileByName(self, file_name):
341 """Gets a FileDescriptor by file name.
342
343 Args:
344 file_name (str): The path to the file to get a descriptor for.
345
346 Returns:
347 FileDescriptor: The descriptor for the named file.
348
349 Raises:
350 KeyError: if the file cannot be found in the pool.
351 """
352
353 try:
354 return self._file_descriptors[file_name]
355 except KeyError:
356 pass
357
358 try:
359 file_proto = self._internal_db.FindFileByName(file_name)
360 except KeyError as error:
361 if self._descriptor_db:
362 file_proto = self._descriptor_db.FindFileByName(file_name)
363 else:
364 raise error
365 if not file_proto:
366 raise KeyError('Cannot find a file named %s' % file_name)
367 return self._ConvertFileProtoToFileDescriptor(file_proto)
368
369 def FindFileContainingSymbol(self, symbol):
370 """Gets the FileDescriptor for the file containing the specified symbol.
371
372 Args:
373 symbol (str): The name of the symbol to search for.
374
375 Returns:
376 FileDescriptor: Descriptor for the file that contains the specified
377 symbol.
378
379 Raises:
380 KeyError: if the file cannot be found in the pool.
381 """
382
383 symbol = _NormalizeFullyQualifiedName(symbol)
384 try:
385 return self._InternalFindFileContainingSymbol(symbol)
386 except KeyError:
387 pass
388
389 try:
390 # Try fallback database. Build and find again if possible.
391 self._FindFileContainingSymbolInDb(symbol)
392 return self._InternalFindFileContainingSymbol(symbol)
393 except KeyError:
394 raise KeyError('Cannot find a file containing %s' % symbol)
395
396 def _InternalFindFileContainingSymbol(self, symbol):
397 """Gets the already built FileDescriptor containing the specified symbol.
398
399 Args:
400 symbol (str): The name of the symbol to search for.
401
402 Returns:
403 FileDescriptor: Descriptor for the file that contains the specified
404 symbol.
405
406 Raises:
407 KeyError: if the file cannot be found in the pool.
408 """
409 try:
410 return self._descriptors[symbol].file
411 except KeyError:
412 pass
413
414 try:
415 return self._enum_descriptors[symbol].file
416 except KeyError:
417 pass
418
419 try:
420 return self._service_descriptors[symbol].file
421 except KeyError:
422 pass
423
424 try:
425 return self._top_enum_values[symbol].type.file
426 except KeyError:
427 pass
428
429 try:
430 return self._toplevel_extensions[symbol].file
431 except KeyError:
432 pass
433
434 # Try fields, enum values and nested extensions inside a message.
435 top_name, _, sub_name = symbol.rpartition('.')
436 try:
437 message = self.FindMessageTypeByName(top_name)
438 assert (sub_name in message.extensions_by_name or
439 sub_name in message.fields_by_name or
440 sub_name in message.enum_values_by_name)
441 return message.file
442 except (KeyError, AssertionError):
443 raise KeyError('Cannot find a file containing %s' % symbol)
444
445 def FindMessageTypeByName(self, full_name):
446 """Loads the named descriptor from the pool.
447
448 Args:
449 full_name (str): The full name of the descriptor to load.
450
451 Returns:
452 Descriptor: The descriptor for the named type.
453
454 Raises:
455 KeyError: if the message cannot be found in the pool.
456 """
457
458 full_name = _NormalizeFullyQualifiedName(full_name)
459 if full_name not in self._descriptors:
460 self._FindFileContainingSymbolInDb(full_name)
461 return self._descriptors[full_name]
462
463 def FindEnumTypeByName(self, full_name):
464 """Loads the named enum descriptor from the pool.
465
466 Args:
467 full_name (str): The full name of the enum descriptor to load.
468
469 Returns:
470 EnumDescriptor: The enum descriptor for the named type.
471
472 Raises:
473 KeyError: if the enum cannot be found in the pool.
474 """
475
476 full_name = _NormalizeFullyQualifiedName(full_name)
477 if full_name not in self._enum_descriptors:
478 self._FindFileContainingSymbolInDb(full_name)
479 return self._enum_descriptors[full_name]
480
481 def FindFieldByName(self, full_name):
482 """Loads the named field descriptor from the pool.
483
484 Args:
485 full_name (str): The full name of the field descriptor to load.
486
487 Returns:
488 FieldDescriptor: The field descriptor for the named field.
489
490 Raises:
491 KeyError: if the field cannot be found in the pool.
492 """
493 full_name = _NormalizeFullyQualifiedName(full_name)
494 message_name, _, field_name = full_name.rpartition('.')
495 message_descriptor = self.FindMessageTypeByName(message_name)
496 return message_descriptor.fields_by_name[field_name]
497
498 def FindOneofByName(self, full_name):
499 """Loads the named oneof descriptor from the pool.
500
501 Args:
502 full_name (str): The full name of the oneof descriptor to load.
503
504 Returns:
505 OneofDescriptor: The oneof descriptor for the named oneof.
506
507 Raises:
508 KeyError: if the oneof cannot be found in the pool.
509 """
510 full_name = _NormalizeFullyQualifiedName(full_name)
511 message_name, _, oneof_name = full_name.rpartition('.')
512 message_descriptor = self.FindMessageTypeByName(message_name)
513 return message_descriptor.oneofs_by_name[oneof_name]
514
515 def FindExtensionByName(self, full_name):
516 """Loads the named extension descriptor from the pool.
517
518 Args:
519 full_name (str): The full name of the extension descriptor to load.
520
521 Returns:
522 FieldDescriptor: The field descriptor for the named extension.
523
524 Raises:
525 KeyError: if the extension cannot be found in the pool.
526 """
527 full_name = _NormalizeFullyQualifiedName(full_name)
528 try:
529 # The proto compiler does not give any link between the FileDescriptor
530 # and top-level extensions unless the FileDescriptorProto is added to
531 # the DescriptorDatabase, but this can impact memory usage.
532 # So we registered these extensions by name explicitly.
533 return self._toplevel_extensions[full_name]
534 except KeyError:
535 pass
536 message_name, _, extension_name = full_name.rpartition('.')
537 try:
538 # Most extensions are nested inside a message.
539 scope = self.FindMessageTypeByName(message_name)
540 except KeyError:
541 # Some extensions are defined at file scope.
542 scope = self._FindFileContainingSymbolInDb(full_name)
543 return scope.extensions_by_name[extension_name]
544
545 def FindExtensionByNumber(self, message_descriptor, number):
546 """Gets the extension of the specified message with the specified number.
547
548 Extensions have to be registered to this pool by calling :func:`Add` or
549 :func:`AddExtensionDescriptor`.
550
551 Args:
552 message_descriptor (Descriptor): descriptor of the extended message.
553 number (int): Number of the extension field.
554
555 Returns:
556 FieldDescriptor: The descriptor for the extension.
557
558 Raises:
559 KeyError: when no extension with the given number is known for the
560 specified message.
561 """
562 try:
563 return self._extensions_by_number[message_descriptor][number]
564 except KeyError:
565 self._TryLoadExtensionFromDB(message_descriptor, number)
566 return self._extensions_by_number[message_descriptor][number]
567
568 def FindAllExtensions(self, message_descriptor):
569 """Gets all the known extensions of a given message.
570
571 Extensions have to be registered to this pool by build related
572 :func:`Add` or :func:`AddExtensionDescriptor`.
573
574 Args:
575 message_descriptor (Descriptor): Descriptor of the extended message.
576
577 Returns:
578 list[FieldDescriptor]: Field descriptors describing the extensions.
579 """
580 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
581 if self._descriptor_db and hasattr(
582 self._descriptor_db, 'FindAllExtensionNumbers'):
583 full_name = message_descriptor.full_name
584 try:
585 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
586 except:
587 pass
588 else:
589 if isinstance(all_numbers, list):
590 for number in all_numbers:
591 if number in self._extensions_by_number[message_descriptor]:
592 continue
593 self._TryLoadExtensionFromDB(message_descriptor, number)
594 else:
595 warnings.warn(
596 'FindAllExtensionNumbers() on fall back DB must return a list,'
597 ' not {0}'.format(type(all_numbers))
598 )
599
600 return list(self._extensions_by_number[message_descriptor].values())
601
602 def _TryLoadExtensionFromDB(self, message_descriptor, number):
603 """Try to Load extensions from descriptor db.
604
605 Args:
606 message_descriptor: descriptor of the extended message.
607 number: the extension number that needs to be loaded.
608 """
609 if not self._descriptor_db:
610 return
611 # Only supported when FindFileContainingExtension is provided.
612 if not hasattr(
613 self._descriptor_db, 'FindFileContainingExtension'):
614 return
615
616 full_name = message_descriptor.full_name
617 file_proto = None
618 try:
619 file_proto = self._descriptor_db.FindFileContainingExtension(
620 full_name, number
621 )
622 except:
623 return
624
625 if file_proto is None:
626 return
627
628 try:
629 self._ConvertFileProtoToFileDescriptor(file_proto)
630 except:
631 warn_msg = ('Unable to load proto file %s for extension number %d.' %
632 (file_proto.name, number))
633 warnings.warn(warn_msg, RuntimeWarning)
634
635 def FindServiceByName(self, full_name):
636 """Loads the named service descriptor from the pool.
637
638 Args:
639 full_name (str): The full name of the service descriptor to load.
640
641 Returns:
642 ServiceDescriptor: The service descriptor for the named service.
643
644 Raises:
645 KeyError: if the service cannot be found in the pool.
646 """
647 full_name = _NormalizeFullyQualifiedName(full_name)
648 if full_name not in self._service_descriptors:
649 self._FindFileContainingSymbolInDb(full_name)
650 return self._service_descriptors[full_name]
651
652 def FindMethodByName(self, full_name):
653 """Loads the named service method descriptor from the pool.
654
655 Args:
656 full_name (str): The full name of the method descriptor to load.
657
658 Returns:
659 MethodDescriptor: The method descriptor for the service method.
660
661 Raises:
662 KeyError: if the method cannot be found in the pool.
663 """
664 full_name = _NormalizeFullyQualifiedName(full_name)
665 service_name, _, method_name = full_name.rpartition('.')
666 service_descriptor = self.FindServiceByName(service_name)
667 return service_descriptor.methods_by_name[method_name]
668
669 def SetFeatureSetDefaults(self, defaults):
670 """Sets the default feature mappings used during the build.
671
672 Args:
673 defaults: a FeatureSetDefaults message containing the new mappings.
674 """
675 if self._edition_defaults is not None:
676 raise ValueError(
677 "Feature set defaults can't be changed once the pool has started"
678 ' building!'
679 )
680
681 # pylint: disable=g-import-not-at-top
682 from google.protobuf import descriptor_pb2
683
684 if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
685 raise TypeError('SetFeatureSetDefaults called with invalid type')
686
687 if defaults.minimum_edition > defaults.maximum_edition:
688 raise ValueError(
689 'Invalid edition range %s to %s'
690 % (
691 descriptor_pb2.Edition.Name(defaults.minimum_edition),
692 descriptor_pb2.Edition.Name(defaults.maximum_edition),
693 )
694 )
695
696 prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
697 for d in defaults.defaults:
698 if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
699 raise ValueError('Invalid edition EDITION_UNKNOWN specified')
700 if prev_edition >= d.edition:
701 raise ValueError(
702 'Feature set defaults are not strictly increasing. %s is greater'
703 ' than or equal to %s'
704 % (
705 descriptor_pb2.Edition.Name(prev_edition),
706 descriptor_pb2.Edition.Name(d.edition),
707 )
708 )
709 prev_edition = d.edition
710 self._edition_defaults = defaults
711
712 def _CreateDefaultFeatures(self, edition):
713 """Creates a FeatureSet message with defaults for a specific edition.
714
715 Args:
716 edition: the edition to generate defaults for.
717
718 Returns:
719 A FeatureSet message with defaults for a specific edition.
720 """
721 # pylint: disable=g-import-not-at-top
722 from google.protobuf import descriptor_pb2
723
724 with _edition_defaults_lock:
725 if not self._edition_defaults:
726 self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
727 self._edition_defaults.ParseFromString(
728 self._serialized_edition_defaults
729 )
730
731 if edition < self._edition_defaults.minimum_edition:
732 raise TypeError(
733 'Edition %s is earlier than the minimum supported edition %s!'
734 % (
735 descriptor_pb2.Edition.Name(edition),
736 descriptor_pb2.Edition.Name(
737 self._edition_defaults.minimum_edition
738 ),
739 )
740 )
741 if edition > self._edition_defaults.maximum_edition:
742 raise TypeError(
743 'Edition %s is later than the maximum supported edition %s!'
744 % (
745 descriptor_pb2.Edition.Name(edition),
746 descriptor_pb2.Edition.Name(
747 self._edition_defaults.maximum_edition
748 ),
749 )
750 )
751 found = None
752 for d in self._edition_defaults.defaults:
753 if d.edition > edition:
754 break
755 found = d
756 if found is None:
757 raise TypeError(
758 'No valid default found for edition %s!'
759 % descriptor_pb2.Edition.Name(edition)
760 )
761
762 defaults = descriptor_pb2.FeatureSet()
763 defaults.CopyFrom(found.fixed_features)
764 defaults.MergeFrom(found.overridable_features)
765 return defaults
766
767 def _InternFeatures(self, features):
768 serialized = features.SerializeToString()
769 with _edition_defaults_lock:
770 cached = self._feature_cache.get(serialized)
771 if cached is None:
772 self._feature_cache[serialized] = features
773 cached = features
774 return cached
775
776 def _FindFileContainingSymbolInDb(self, symbol):
777 """Finds the file in descriptor DB containing the specified symbol.
778
779 Args:
780 symbol (str): The name of the symbol to search for.
781
782 Returns:
783 FileDescriptor: The file that contains the specified symbol.
784
785 Raises:
786 KeyError: if the file cannot be found in the descriptor database.
787 """
788 try:
789 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
790 except KeyError as error:
791 if self._descriptor_db:
792 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
793 else:
794 raise error
795 if not file_proto:
796 raise KeyError('Cannot find a file containing %s' % symbol)
797 return self._ConvertFileProtoToFileDescriptor(file_proto)
798
799 def _ConvertFileProtoToFileDescriptor(self, file_proto):
800 """Creates a FileDescriptor from a proto or returns a cached copy.
801
802 This method also has the side effect of loading all the symbols found in
803 the file into the appropriate dictionaries in the pool.
804
805 Args:
806 file_proto: The proto to convert.
807
808 Returns:
809 A FileDescriptor matching the passed in proto.
810 """
811 if file_proto.name not in self._file_descriptors:
812 built_deps = list(self._GetDeps(file_proto.dependency))
813 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
814 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
815
816 # pylint: disable=g-import-not-at-top
817 from google.protobuf import descriptor_pb2
818
819 file_descriptor = descriptor.FileDescriptor(
820 pool=self,
821 name=file_proto.name,
822 package=file_proto.package,
823 syntax=file_proto.syntax,
824 edition=descriptor_pb2.Edition.Name(file_proto.edition),
825 options=_OptionsOrNone(file_proto),
826 serialized_pb=file_proto.SerializeToString(),
827 dependencies=direct_deps,
828 public_dependencies=public_deps,
829 # pylint: disable=protected-access
830 create_key=descriptor._internal_create_key,
831 )
832 scope = {}
833
834 # This loop extracts all the message and enum types from all the
835 # dependencies of the file_proto. This is necessary to create the
836 # scope of available message types when defining the passed in
837 # file proto.
838 for dependency in built_deps:
839 scope.update(self._ExtractSymbols(
840 dependency.message_types_by_name.values()))
841 scope.update((_PrefixWithDot(enum.full_name), enum)
842 for enum in dependency.enum_types_by_name.values())
843
844 for message_type in file_proto.message_type:
845 message_desc = self._ConvertMessageDescriptor(
846 message_type, file_proto.package, file_descriptor, scope,
847 file_proto.syntax)
848 file_descriptor.message_types_by_name[message_desc.name] = (
849 message_desc)
850
851 for enum_type in file_proto.enum_type:
852 file_descriptor.enum_types_by_name[enum_type.name] = (
853 self._ConvertEnumDescriptor(enum_type, file_proto.package,
854 file_descriptor, None, scope, True))
855
856 for index, extension_proto in enumerate(file_proto.extension):
857 extension_desc = self._MakeFieldDescriptor(
858 extension_proto, file_proto.package, index, file_descriptor,
859 is_extension=True)
860 extension_desc.containing_type = self._GetTypeFromScope(
861 file_descriptor.package, extension_proto.extendee, scope)
862 self._SetFieldType(extension_proto, extension_desc,
863 file_descriptor.package, scope)
864 file_descriptor.extensions_by_name[extension_desc.name] = (
865 extension_desc)
866
867 for desc_proto in file_proto.message_type:
868 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
869
870 if file_proto.package:
871 desc_proto_prefix = _PrefixWithDot(file_proto.package)
872 else:
873 desc_proto_prefix = ''
874
875 for desc_proto in file_proto.message_type:
876 desc = self._GetTypeFromScope(
877 desc_proto_prefix, desc_proto.name, scope)
878 file_descriptor.message_types_by_name[desc_proto.name] = desc
879
880 for index, service_proto in enumerate(file_proto.service):
881 file_descriptor.services_by_name[service_proto.name] = (
882 self._MakeServiceDescriptor(service_proto, index, scope,
883 file_proto.package, file_descriptor))
884
885 self._file_descriptors[file_proto.name] = file_descriptor
886
887 # Add extensions to the pool
888 def AddExtensionForNested(message_type):
889 for nested in message_type.nested_types:
890 AddExtensionForNested(nested)
891 for extension in message_type.extensions:
892 self._AddExtensionDescriptor(extension)
893
894 file_desc = self._file_descriptors[file_proto.name]
895 for extension in file_desc.extensions_by_name.values():
896 self._AddExtensionDescriptor(extension)
897 for message_type in file_desc.message_types_by_name.values():
898 AddExtensionForNested(message_type)
899
900 return file_desc
901
902 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
903 scope=None, syntax=None):
904 """Adds the proto to the pool in the specified package.
905
906 Args:
907 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
908 package: The package the proto should be located in.
909 file_desc: The file containing this message.
910 scope: Dict mapping short and full symbols to message and enum types.
911 syntax: string indicating syntax of the file ("proto2" or "proto3")
912
913 Returns:
914 The added descriptor.
915 """
916
917 if package:
918 desc_name = '.'.join((package, desc_proto.name))
919 else:
920 desc_name = desc_proto.name
921
922 if file_desc is None:
923 file_name = None
924 else:
925 file_name = file_desc.name
926
927 if scope is None:
928 scope = {}
929
930 nested = [
931 self._ConvertMessageDescriptor(
932 nested, desc_name, file_desc, scope, syntax)
933 for nested in desc_proto.nested_type]
934 enums = [
935 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
936 scope, False)
937 for enum in desc_proto.enum_type]
938 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
939 for index, field in enumerate(desc_proto.field)]
940 extensions = [
941 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
942 is_extension=True)
943 for index, extension in enumerate(desc_proto.extension)]
944 oneofs = [
945 # pylint: disable=g-complex-comprehension
946 descriptor.OneofDescriptor(
947 desc.name,
948 '.'.join((desc_name, desc.name)),
949 index,
950 None,
951 [],
952 _OptionsOrNone(desc),
953 # pylint: disable=protected-access
954 create_key=descriptor._internal_create_key)
955 for index, desc in enumerate(desc_proto.oneof_decl)
956 ]
957 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
958 if extension_ranges:
959 is_extendable = True
960 else:
961 is_extendable = False
962 desc = descriptor.Descriptor(
963 name=desc_proto.name,
964 full_name=desc_name,
965 filename=file_name,
966 containing_type=None,
967 fields=fields,
968 oneofs=oneofs,
969 nested_types=nested,
970 enum_types=enums,
971 extensions=extensions,
972 options=_OptionsOrNone(desc_proto),
973 is_extendable=is_extendable,
974 extension_ranges=extension_ranges,
975 file=file_desc,
976 serialized_start=None,
977 serialized_end=None,
978 is_map_entry=desc_proto.options.map_entry,
979 # pylint: disable=protected-access
980 create_key=descriptor._internal_create_key,
981 )
982 for nested in desc.nested_types:
983 nested.containing_type = desc
984 for enum in desc.enum_types:
985 enum.containing_type = desc
986 for field_index, field_desc in enumerate(desc_proto.field):
987 if field_desc.HasField('oneof_index'):
988 oneof_index = field_desc.oneof_index
989 oneofs[oneof_index].fields.append(fields[field_index])
990 fields[field_index].containing_oneof = oneofs[oneof_index]
991
992 scope[_PrefixWithDot(desc_name)] = desc
993 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
994 self._descriptors[desc_name] = desc
995 return desc
996
997 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
998 containing_type=None, scope=None, top_level=False):
999 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
1000
1001 Args:
1002 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
1003 package: Optional package name for the new message EnumDescriptor.
1004 file_desc: The file containing the enum descriptor.
1005 containing_type: The type containing this enum.
1006 scope: Scope containing available types.
1007 top_level: If True, the enum is a top level symbol. If False, the enum
1008 is defined inside a message.
1009
1010 Returns:
1011 The added descriptor
1012 """
1013
1014 if package:
1015 enum_name = '.'.join((package, enum_proto.name))
1016 else:
1017 enum_name = enum_proto.name
1018
1019 if file_desc is None:
1020 file_name = None
1021 else:
1022 file_name = file_desc.name
1023
1024 values = [self._MakeEnumValueDescriptor(value, index)
1025 for index, value in enumerate(enum_proto.value)]
1026 desc = descriptor.EnumDescriptor(name=enum_proto.name,
1027 full_name=enum_name,
1028 filename=file_name,
1029 file=file_desc,
1030 values=values,
1031 containing_type=containing_type,
1032 options=_OptionsOrNone(enum_proto),
1033 # pylint: disable=protected-access
1034 create_key=descriptor._internal_create_key)
1035 scope['.%s' % enum_name] = desc
1036 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1037 self._enum_descriptors[enum_name] = desc
1038
1039 # Add top level enum values.
1040 if top_level:
1041 for value in values:
1042 full_name = _NormalizeFullyQualifiedName(
1043 '.'.join((package, value.name)))
1044 self._CheckConflictRegister(value, full_name, file_name)
1045 self._top_enum_values[full_name] = value
1046
1047 return desc
1048
1049 def _MakeFieldDescriptor(self, field_proto, message_name, index,
1050 file_desc, is_extension=False):
1051 """Creates a field descriptor from a FieldDescriptorProto.
1052
1053 For message and enum type fields, this method will do a look up
1054 in the pool for the appropriate descriptor for that type. If it
1055 is unavailable, it will fall back to the _source function to
1056 create it. If this type is still unavailable, construction will
1057 fail.
1058
1059 Args:
1060 field_proto: The proto describing the field.
1061 message_name: The name of the containing message.
1062 index: Index of the field
1063 file_desc: The file containing the field descriptor.
1064 is_extension: Indication that this field is for an extension.
1065
1066 Returns:
1067 An initialized FieldDescriptor object
1068 """
1069
1070 if message_name:
1071 full_name = '.'.join((message_name, field_proto.name))
1072 else:
1073 full_name = field_proto.name
1074
1075 if field_proto.json_name:
1076 json_name = field_proto.json_name
1077 else:
1078 json_name = None
1079
1080 return descriptor.FieldDescriptor(
1081 name=field_proto.name,
1082 full_name=full_name,
1083 index=index,
1084 number=field_proto.number,
1085 type=field_proto.type,
1086 cpp_type=None,
1087 message_type=None,
1088 enum_type=None,
1089 containing_type=None,
1090 label=field_proto.label,
1091 has_default_value=False,
1092 default_value=None,
1093 is_extension=is_extension,
1094 extension_scope=None,
1095 options=_OptionsOrNone(field_proto),
1096 json_name=json_name,
1097 file=file_desc,
1098 # pylint: disable=protected-access
1099 create_key=descriptor._internal_create_key)
1100
1101 def _SetAllFieldTypes(self, package, desc_proto, scope):
1102 """Sets all the descriptor's fields's types.
1103
1104 This method also sets the containing types on any extensions.
1105
1106 Args:
1107 package: The current package of desc_proto.
1108 desc_proto: The message descriptor to update.
1109 scope: Enclosing scope of available types.
1110 """
1111
1112 package = _PrefixWithDot(package)
1113
1114 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1115
1116 if package == '.':
1117 nested_package = _PrefixWithDot(desc_proto.name)
1118 else:
1119 nested_package = '.'.join([package, desc_proto.name])
1120
1121 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1122 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1123
1124 for extension_proto, extension_desc in (
1125 zip(desc_proto.extension, main_desc.extensions)):
1126 extension_desc.containing_type = self._GetTypeFromScope(
1127 nested_package, extension_proto.extendee, scope)
1128 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1129
1130 for nested_type in desc_proto.nested_type:
1131 self._SetAllFieldTypes(nested_package, nested_type, scope)
1132
1133 def _SetFieldType(self, field_proto, field_desc, package, scope):
1134 """Sets the field's type, cpp_type, message_type and enum_type.
1135
1136 Args:
1137 field_proto: Data about the field in proto format.
1138 field_desc: The descriptor to modify.
1139 package: The package the field's container is in.
1140 scope: Enclosing scope of available types.
1141 """
1142 if field_proto.type_name:
1143 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1144 else:
1145 desc = None
1146
1147 if not field_proto.HasField('type'):
1148 if isinstance(desc, descriptor.Descriptor):
1149 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1150 else:
1151 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1152
1153 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1154 field_proto.type)
1155
1156 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1157 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1158 field_desc.message_type = desc
1159
1160 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1161 field_desc.enum_type = desc
1162
1163 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1164 field_desc.has_default_value = False
1165 field_desc.default_value = []
1166 elif field_proto.HasField('default_value'):
1167 field_desc.has_default_value = True
1168 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1169 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1170 field_desc.default_value = float(field_proto.default_value)
1171 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1172 field_desc.default_value = field_proto.default_value
1173 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1174 field_desc.default_value = field_proto.default_value.lower() == 'true'
1175 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1176 field_desc.default_value = field_desc.enum_type.values_by_name[
1177 field_proto.default_value].number
1178 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1179 field_desc.default_value = text_encoding.CUnescape(
1180 field_proto.default_value)
1181 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1182 field_desc.default_value = None
1183 else:
1184 # All other types are of the "int" type.
1185 field_desc.default_value = int(field_proto.default_value)
1186 else:
1187 field_desc.has_default_value = False
1188 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1189 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1190 field_desc.default_value = 0.0
1191 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1192 field_desc.default_value = u''
1193 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1194 field_desc.default_value = False
1195 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1196 field_desc.default_value = field_desc.enum_type.values[0].number
1197 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1198 field_desc.default_value = b''
1199 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1200 field_desc.default_value = None
1201 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1202 field_desc.default_value = None
1203 else:
1204 # All other types are of the "int" type.
1205 field_desc.default_value = 0
1206
1207 field_desc.type = field_proto.type
1208
1209 def _MakeEnumValueDescriptor(self, value_proto, index):
1210 """Creates a enum value descriptor object from a enum value proto.
1211
1212 Args:
1213 value_proto: The proto describing the enum value.
1214 index: The index of the enum value.
1215
1216 Returns:
1217 An initialized EnumValueDescriptor object.
1218 """
1219
1220 return descriptor.EnumValueDescriptor(
1221 name=value_proto.name,
1222 index=index,
1223 number=value_proto.number,
1224 options=_OptionsOrNone(value_proto),
1225 type=None,
1226 # pylint: disable=protected-access
1227 create_key=descriptor._internal_create_key)
1228
1229 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1230 package, file_desc):
1231 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1232
1233 Args:
1234 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1235 service_index: The index of the service in the File.
1236 scope: Dict mapping short and full symbols to message and enum types.
1237 package: Optional package name for the new message EnumDescriptor.
1238 file_desc: The file containing the service descriptor.
1239
1240 Returns:
1241 The added descriptor.
1242 """
1243
1244 if package:
1245 service_name = '.'.join((package, service_proto.name))
1246 else:
1247 service_name = service_proto.name
1248
1249 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1250 scope, index)
1251 for index, method_proto in enumerate(service_proto.method)]
1252 desc = descriptor.ServiceDescriptor(
1253 name=service_proto.name,
1254 full_name=service_name,
1255 index=service_index,
1256 methods=methods,
1257 options=_OptionsOrNone(service_proto),
1258 file=file_desc,
1259 # pylint: disable=protected-access
1260 create_key=descriptor._internal_create_key)
1261 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1262 self._service_descriptors[service_name] = desc
1263 return desc
1264
1265 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1266 index):
1267 """Creates a method descriptor from a MethodDescriptorProto.
1268
1269 Args:
1270 method_proto: The proto describing the method.
1271 service_name: The name of the containing service.
1272 package: Optional package name to look up for types.
1273 scope: Scope containing available types.
1274 index: Index of the method in the service.
1275
1276 Returns:
1277 An initialized MethodDescriptor object.
1278 """
1279 full_name = '.'.join((service_name, method_proto.name))
1280 input_type = self._GetTypeFromScope(
1281 package, method_proto.input_type, scope)
1282 output_type = self._GetTypeFromScope(
1283 package, method_proto.output_type, scope)
1284 return descriptor.MethodDescriptor(
1285 name=method_proto.name,
1286 full_name=full_name,
1287 index=index,
1288 containing_service=None,
1289 input_type=input_type,
1290 output_type=output_type,
1291 client_streaming=method_proto.client_streaming,
1292 server_streaming=method_proto.server_streaming,
1293 options=_OptionsOrNone(method_proto),
1294 # pylint: disable=protected-access
1295 create_key=descriptor._internal_create_key)
1296
1297 def _ExtractSymbols(self, descriptors):
1298 """Pulls out all the symbols from descriptor protos.
1299
1300 Args:
1301 descriptors: The messages to extract descriptors from.
1302 Yields:
1303 A two element tuple of the type name and descriptor object.
1304 """
1305
1306 for desc in descriptors:
1307 yield (_PrefixWithDot(desc.full_name), desc)
1308 for symbol in self._ExtractSymbols(desc.nested_types):
1309 yield symbol
1310 for enum in desc.enum_types:
1311 yield (_PrefixWithDot(enum.full_name), enum)
1312
1313 def _GetDeps(self, dependencies, visited=None):
1314 """Recursively finds dependencies for file protos.
1315
1316 Args:
1317 dependencies: The names of the files being depended on.
1318 visited: The names of files already found.
1319
1320 Yields:
1321 Each direct and indirect dependency.
1322 """
1323
1324 visited = visited or set()
1325 for dependency in dependencies:
1326 if dependency not in visited:
1327 visited.add(dependency)
1328 dep_desc = self.FindFileByName(dependency)
1329 yield dep_desc
1330 public_files = [d.name for d in dep_desc.public_dependencies]
1331 yield from self._GetDeps(public_files, visited)
1332
1333 def _GetTypeFromScope(self, package, type_name, scope):
1334 """Finds a given type name in the current scope.
1335
1336 Args:
1337 package: The package the proto should be located in.
1338 type_name: The name of the type to be found in the scope.
1339 scope: Dict mapping short and full symbols to message and enum types.
1340
1341 Returns:
1342 The descriptor for the requested type.
1343 """
1344 if type_name not in scope:
1345 components = _PrefixWithDot(package).split('.')
1346 while components:
1347 possible_match = '.'.join(components + [type_name])
1348 if possible_match in scope:
1349 type_name = possible_match
1350 break
1351 else:
1352 components.pop(-1)
1353 return scope[type_name]
1354
1355
1356def _PrefixWithDot(name):
1357 return name if name.startswith('.') else '.%s' % name
1358
1359
1360if _USE_C_DESCRIPTORS:
1361 # TODO: This pool could be constructed from Python code, when we
1362 # support a flag like 'use_cpp_generated_pool=True'.
1363 # pylint: disable=protected-access
1364 _DEFAULT = descriptor._message.default_pool
1365else:
1366 _DEFAULT = DescriptorPool()
1367
1368
1369def Default():
1370 return _DEFAULT