1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides DescriptorPool to use as a container for proto2 descriptors.
9
10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11a collection of protocol buffer descriptors for use when dynamically creating
12message types at runtime.
13
14For most applications protocol buffers should be used via modules generated by
15the protocol buffer compiler tool. This should only be used when the type of
16protocol buffers used in an application or library cannot be predetermined.
17
18Below is a straightforward example on how to use this class::
19
20 pool = DescriptorPool()
21 file_descriptor_protos = [ ... ]
22 for file_descriptor_proto in file_descriptor_protos:
23 pool.Add(file_descriptor_proto)
24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
25
26The message descriptor can be used in conjunction with the message_factory
27module in order to create a protocol buffer class that can be encoded and
28decoded.
29
30If you want to get a Python class for the specified proto, use the
31helper functions inside google.protobuf.message_factory
32directly instead of this class.
33"""
34
35__author__ = 'matthewtoia@google.com (Matt Toia)'
36
37import collections
38import threading
39import warnings
40
41from google.protobuf import descriptor
42from google.protobuf import descriptor_database
43from google.protobuf import text_encoding
44from google.protobuf.internal import python_edition_defaults
45from google.protobuf.internal import python_message
46
47_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
48
49
50def _NormalizeFullyQualifiedName(name):
51 """Remove leading period from fully-qualified type name.
52
53 Due to b/13860351 in descriptor_database.py, types in the root namespace are
54 generated with a leading period. This function removes that prefix.
55
56 Args:
57 name (str): The fully-qualified symbol name.
58
59 Returns:
60 str: The normalized fully-qualified symbol name.
61 """
62 return name.lstrip('.')
63
64
65def _OptionsOrNone(descriptor_proto):
66 """Returns the value of the field `options`, or None if it is not set."""
67 if descriptor_proto.HasField('options'):
68 return descriptor_proto.options
69 else:
70 return None
71
72
73def _IsMessageSetExtension(field):
74 return (field.is_extension and
75 field.containing_type.has_options and
76 field.containing_type.GetOptions().message_set_wire_format and
77 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
78 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
79
80_edition_defaults_lock = threading.Lock()
81
82
83class DescriptorPool(object):
84 """A collection of protobufs dynamically constructed by descriptor protos."""
85
86 if _USE_C_DESCRIPTORS:
87
88 def __new__(cls, descriptor_db=None):
89 # pylint: disable=protected-access
90 return descriptor._message.DescriptorPool(descriptor_db)
91
92 def __init__(
93 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
94 ):
95 """Initializes a Pool of proto buffs.
96
97 The descriptor_db argument to the constructor is provided to allow
98 specialized file descriptor proto lookup code to be triggered on demand. An
99 example would be an implementation which will read and compile a file
100 specified in a call to FindFileByName() and not require the call to Add()
101 at all. Results from this database will be cached internally here as well.
102
103 Args:
104 descriptor_db: A secondary source of file descriptors.
105 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
106 C++.
107 """
108
109 self._internal_db = descriptor_database.DescriptorDatabase()
110 self._descriptor_db = descriptor_db
111 self._descriptors = {}
112 self._enum_descriptors = {}
113 self._service_descriptors = {}
114 self._file_descriptors = {}
115 self._toplevel_extensions = {}
116 self._top_enum_values = {}
117 # We store extensions in two two-level mappings: The first key is the
118 # descriptor of the message being extended, the second key is the extension
119 # full name or its tag number.
120 self._extensions_by_name = collections.defaultdict(dict)
121 self._extensions_by_number = collections.defaultdict(dict)
122 self._serialized_edition_defaults = (
123 python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
124 )
125 self._edition_defaults = None
126 self._feature_cache = dict()
127
128 def _CheckConflictRegister(self, desc, desc_name, file_name):
129 """Check if the descriptor name conflicts with another of the same name.
130
131 Args:
132 desc: Descriptor of a message, enum, service, extension or enum value.
133 desc_name (str): the full name of desc.
134 file_name (str): The file name of descriptor.
135 """
136 for register, descriptor_type in [
137 (self._descriptors, descriptor.Descriptor),
138 (self._enum_descriptors, descriptor.EnumDescriptor),
139 (self._service_descriptors, descriptor.ServiceDescriptor),
140 (self._toplevel_extensions, descriptor.FieldDescriptor),
141 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
142 if desc_name in register:
143 old_desc = register[desc_name]
144 if isinstance(old_desc, descriptor.EnumValueDescriptor):
145 old_file = old_desc.type.file.name
146 else:
147 old_file = old_desc.file.name
148
149 if not isinstance(desc, descriptor_type) or (
150 old_file != file_name):
151 error_msg = ('Conflict register for file "' + file_name +
152 '": ' + desc_name +
153 ' is already defined in file "' +
154 old_file + '". Please fix the conflict by adding '
155 'package name on the proto file, or use different '
156 'name for the duplication.')
157 if isinstance(desc, descriptor.EnumValueDescriptor):
158 error_msg += ('\nNote: enum values appear as '
159 'siblings of the enum type instead of '
160 'children of it.')
161
162 raise TypeError(error_msg)
163
164 return
165
166 def Add(self, file_desc_proto):
167 """Adds the FileDescriptorProto and its types to this pool.
168
169 Args:
170 file_desc_proto (FileDescriptorProto): The file descriptor to add.
171 """
172
173 self._internal_db.Add(file_desc_proto)
174
175 def AddSerializedFile(self, serialized_file_desc_proto):
176 """Adds the FileDescriptorProto and its types to this pool.
177
178 Args:
179 serialized_file_desc_proto (bytes): A bytes string, serialization of the
180 :class:`FileDescriptorProto` to add.
181
182 Returns:
183 FileDescriptor: Descriptor for the added file.
184 """
185
186 # pylint: disable=g-import-not-at-top
187 from google.protobuf import descriptor_pb2
188 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
189 serialized_file_desc_proto)
190 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
191 file_desc.serialized_pb = serialized_file_desc_proto
192 return file_desc
193
194 # Never call this method. It is for internal usage only.
195 def _AddDescriptor(self, desc):
196 """Adds a Descriptor to the pool, non-recursively.
197
198 If the Descriptor contains nested messages or enums, the caller must
199 explicitly register them. This method also registers the FileDescriptor
200 associated with the message.
201
202 Args:
203 desc: A Descriptor.
204 """
205 if not isinstance(desc, descriptor.Descriptor):
206 raise TypeError('Expected instance of descriptor.Descriptor.')
207
208 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
209
210 self._descriptors[desc.full_name] = desc
211 self._AddFileDescriptor(desc.file)
212
213 # Never call this method. It is for internal usage only.
214 def _AddEnumDescriptor(self, enum_desc):
215 """Adds an EnumDescriptor to the pool.
216
217 This method also registers the FileDescriptor associated with the enum.
218
219 Args:
220 enum_desc: An EnumDescriptor.
221 """
222
223 if not isinstance(enum_desc, descriptor.EnumDescriptor):
224 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
225
226 file_name = enum_desc.file.name
227 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
228 self._enum_descriptors[enum_desc.full_name] = enum_desc
229
230 # Top enum values need to be indexed.
231 # Count the number of dots to see whether the enum is toplevel or nested
232 # in a message. We cannot use enum_desc.containing_type at this stage.
233 if enum_desc.file.package:
234 top_level = (enum_desc.full_name.count('.')
235 - enum_desc.file.package.count('.') == 1)
236 else:
237 top_level = enum_desc.full_name.count('.') == 0
238 if top_level:
239 file_name = enum_desc.file.name
240 package = enum_desc.file.package
241 for enum_value in enum_desc.values:
242 full_name = _NormalizeFullyQualifiedName(
243 '.'.join((package, enum_value.name)))
244 self._CheckConflictRegister(enum_value, full_name, file_name)
245 self._top_enum_values[full_name] = enum_value
246 self._AddFileDescriptor(enum_desc.file)
247
248 # Never call this method. It is for internal usage only.
249 def _AddServiceDescriptor(self, service_desc):
250 """Adds a ServiceDescriptor to the pool.
251
252 Args:
253 service_desc: A ServiceDescriptor.
254 """
255
256 if not isinstance(service_desc, descriptor.ServiceDescriptor):
257 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
258
259 self._CheckConflictRegister(service_desc, service_desc.full_name,
260 service_desc.file.name)
261 self._service_descriptors[service_desc.full_name] = service_desc
262
263 # Never call this method. It is for internal usage only.
264 def _AddExtensionDescriptor(self, extension):
265 """Adds a FieldDescriptor describing an extension to the pool.
266
267 Args:
268 extension: A FieldDescriptor.
269
270 Raises:
271 AssertionError: when another extension with the same number extends the
272 same message.
273 TypeError: when the specified extension is not a
274 descriptor.FieldDescriptor.
275 """
276 if not (isinstance(extension, descriptor.FieldDescriptor) and
277 extension.is_extension):
278 raise TypeError('Expected an extension descriptor.')
279
280 if extension.extension_scope is None:
281 self._CheckConflictRegister(
282 extension, extension.full_name, extension.file.name)
283 self._toplevel_extensions[extension.full_name] = extension
284
285 try:
286 existing_desc = self._extensions_by_number[
287 extension.containing_type][extension.number]
288 except KeyError:
289 pass
290 else:
291 if extension is not existing_desc:
292 raise AssertionError(
293 'Extensions "%s" and "%s" both try to extend message type "%s" '
294 'with field number %d.' %
295 (extension.full_name, existing_desc.full_name,
296 extension.containing_type.full_name, extension.number))
297
298 self._extensions_by_number[extension.containing_type][
299 extension.number] = extension
300 self._extensions_by_name[extension.containing_type][
301 extension.full_name] = extension
302
303 # Also register MessageSet extensions with the type name.
304 if _IsMessageSetExtension(extension):
305 self._extensions_by_name[extension.containing_type][
306 extension.message_type.full_name] = extension
307
308 if hasattr(extension.containing_type, '_concrete_class'):
309 python_message._AttachFieldHelpers(
310 extension.containing_type._concrete_class, extension)
311
312 # Never call this method. It is for internal usage only.
313 def _InternalAddFileDescriptor(self, file_desc):
314 """Adds a FileDescriptor to the pool, non-recursively.
315
316 If the FileDescriptor contains messages or enums, the caller must explicitly
317 register them.
318
319 Args:
320 file_desc: A FileDescriptor.
321 """
322
323 self._AddFileDescriptor(file_desc)
324
325 def _AddFileDescriptor(self, file_desc):
326 """Adds a FileDescriptor to the pool, non-recursively.
327
328 If the FileDescriptor contains messages or enums, the caller must explicitly
329 register them.
330
331 Args:
332 file_desc: A FileDescriptor.
333 """
334
335 if not isinstance(file_desc, descriptor.FileDescriptor):
336 raise TypeError('Expected instance of descriptor.FileDescriptor.')
337 self._file_descriptors[file_desc.name] = file_desc
338
339 def FindFileByName(self, file_name):
340 """Gets a FileDescriptor by file name.
341
342 Args:
343 file_name (str): The path to the file to get a descriptor for.
344
345 Returns:
346 FileDescriptor: The descriptor for the named file.
347
348 Raises:
349 KeyError: if the file cannot be found in the pool.
350 """
351
352 try:
353 return self._file_descriptors[file_name]
354 except KeyError:
355 pass
356
357 try:
358 file_proto = self._internal_db.FindFileByName(file_name)
359 except KeyError as error:
360 if self._descriptor_db:
361 file_proto = self._descriptor_db.FindFileByName(file_name)
362 else:
363 raise error
364 if not file_proto:
365 raise KeyError('Cannot find a file named %s' % file_name)
366 return self._ConvertFileProtoToFileDescriptor(file_proto)
367
368 def FindFileContainingSymbol(self, symbol):
369 """Gets the FileDescriptor for the file containing the specified symbol.
370
371 Args:
372 symbol (str): The name of the symbol to search for.
373
374 Returns:
375 FileDescriptor: Descriptor for the file that contains the specified
376 symbol.
377
378 Raises:
379 KeyError: if the file cannot be found in the pool.
380 """
381
382 symbol = _NormalizeFullyQualifiedName(symbol)
383 try:
384 return self._InternalFindFileContainingSymbol(symbol)
385 except KeyError:
386 pass
387
388 try:
389 # Try fallback database. Build and find again if possible.
390 self._FindFileContainingSymbolInDb(symbol)
391 return self._InternalFindFileContainingSymbol(symbol)
392 except KeyError:
393 raise KeyError('Cannot find a file containing %s' % symbol)
394
395 def _InternalFindFileContainingSymbol(self, symbol):
396 """Gets the already built FileDescriptor containing the specified symbol.
397
398 Args:
399 symbol (str): The name of the symbol to search for.
400
401 Returns:
402 FileDescriptor: Descriptor for the file that contains the specified
403 symbol.
404
405 Raises:
406 KeyError: if the file cannot be found in the pool.
407 """
408 try:
409 return self._descriptors[symbol].file
410 except KeyError:
411 pass
412
413 try:
414 return self._enum_descriptors[symbol].file
415 except KeyError:
416 pass
417
418 try:
419 return self._service_descriptors[symbol].file
420 except KeyError:
421 pass
422
423 try:
424 return self._top_enum_values[symbol].type.file
425 except KeyError:
426 pass
427
428 try:
429 return self._toplevel_extensions[symbol].file
430 except KeyError:
431 pass
432
433 # Try fields, enum values and nested extensions inside a message.
434 top_name, _, sub_name = symbol.rpartition('.')
435 try:
436 message = self.FindMessageTypeByName(top_name)
437 assert (sub_name in message.extensions_by_name or
438 sub_name in message.fields_by_name or
439 sub_name in message.enum_values_by_name)
440 return message.file
441 except (KeyError, AssertionError):
442 raise KeyError('Cannot find a file containing %s' % symbol)
443
444 def FindMessageTypeByName(self, full_name):
445 """Loads the named descriptor from the pool.
446
447 Args:
448 full_name (str): The full name of the descriptor to load.
449
450 Returns:
451 Descriptor: The descriptor for the named type.
452
453 Raises:
454 KeyError: if the message cannot be found in the pool.
455 """
456
457 full_name = _NormalizeFullyQualifiedName(full_name)
458 if full_name not in self._descriptors:
459 self._FindFileContainingSymbolInDb(full_name)
460 return self._descriptors[full_name]
461
462 def FindEnumTypeByName(self, full_name):
463 """Loads the named enum descriptor from the pool.
464
465 Args:
466 full_name (str): The full name of the enum descriptor to load.
467
468 Returns:
469 EnumDescriptor: The enum descriptor for the named type.
470
471 Raises:
472 KeyError: if the enum cannot be found in the pool.
473 """
474
475 full_name = _NormalizeFullyQualifiedName(full_name)
476 if full_name not in self._enum_descriptors:
477 self._FindFileContainingSymbolInDb(full_name)
478 return self._enum_descriptors[full_name]
479
480 def FindFieldByName(self, full_name):
481 """Loads the named field descriptor from the pool.
482
483 Args:
484 full_name (str): The full name of the field descriptor to load.
485
486 Returns:
487 FieldDescriptor: The field descriptor for the named field.
488
489 Raises:
490 KeyError: if the field cannot be found in the pool.
491 """
492 full_name = _NormalizeFullyQualifiedName(full_name)
493 message_name, _, field_name = full_name.rpartition('.')
494 message_descriptor = self.FindMessageTypeByName(message_name)
495 return message_descriptor.fields_by_name[field_name]
496
497 def FindOneofByName(self, full_name):
498 """Loads the named oneof descriptor from the pool.
499
500 Args:
501 full_name (str): The full name of the oneof descriptor to load.
502
503 Returns:
504 OneofDescriptor: The oneof descriptor for the named oneof.
505
506 Raises:
507 KeyError: if the oneof cannot be found in the pool.
508 """
509 full_name = _NormalizeFullyQualifiedName(full_name)
510 message_name, _, oneof_name = full_name.rpartition('.')
511 message_descriptor = self.FindMessageTypeByName(message_name)
512 return message_descriptor.oneofs_by_name[oneof_name]
513
514 def FindExtensionByName(self, full_name):
515 """Loads the named extension descriptor from the pool.
516
517 Args:
518 full_name (str): The full name of the extension descriptor to load.
519
520 Returns:
521 FieldDescriptor: The field descriptor for the named extension.
522
523 Raises:
524 KeyError: if the extension cannot be found in the pool.
525 """
526 full_name = _NormalizeFullyQualifiedName(full_name)
527 try:
528 # The proto compiler does not give any link between the FileDescriptor
529 # and top-level extensions unless the FileDescriptorProto is added to
530 # the DescriptorDatabase, but this can impact memory usage.
531 # So we registered these extensions by name explicitly.
532 return self._toplevel_extensions[full_name]
533 except KeyError:
534 pass
535 message_name, _, extension_name = full_name.rpartition('.')
536 try:
537 # Most extensions are nested inside a message.
538 scope = self.FindMessageTypeByName(message_name)
539 except KeyError:
540 # Some extensions are defined at file scope.
541 scope = self._FindFileContainingSymbolInDb(full_name)
542 return scope.extensions_by_name[extension_name]
543
544 def FindExtensionByNumber(self, message_descriptor, number):
545 """Gets the extension of the specified message with the specified number.
546
547 Extensions have to be registered to this pool by calling :func:`Add` or
548 :func:`AddExtensionDescriptor`.
549
550 Args:
551 message_descriptor (Descriptor): descriptor of the extended message.
552 number (int): Number of the extension field.
553
554 Returns:
555 FieldDescriptor: The descriptor for the extension.
556
557 Raises:
558 KeyError: when no extension with the given number is known for the
559 specified message.
560 """
561 try:
562 return self._extensions_by_number[message_descriptor][number]
563 except KeyError:
564 self._TryLoadExtensionFromDB(message_descriptor, number)
565 return self._extensions_by_number[message_descriptor][number]
566
567 def FindAllExtensions(self, message_descriptor):
568 """Gets all the known extensions of a given message.
569
570 Extensions have to be registered to this pool by build related
571 :func:`Add` or :func:`AddExtensionDescriptor`.
572
573 Args:
574 message_descriptor (Descriptor): Descriptor of the extended message.
575
576 Returns:
577 list[FieldDescriptor]: Field descriptors describing the extensions.
578 """
579 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
580 if self._descriptor_db and hasattr(
581 self._descriptor_db, 'FindAllExtensionNumbers'):
582 full_name = message_descriptor.full_name
583 try:
584 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
585 except:
586 pass
587 else:
588 if isinstance(all_numbers, list):
589 for number in all_numbers:
590 if number in self._extensions_by_number[message_descriptor]:
591 continue
592 self._TryLoadExtensionFromDB(message_descriptor, number)
593 else:
594 warnings.warn(
595 'FindAllExtensionNumbers() on fall back DB must return a list,'
596 ' not {0}'.format(type(all_numbers))
597 )
598
599 return list(self._extensions_by_number[message_descriptor].values())
600
601 def _TryLoadExtensionFromDB(self, message_descriptor, number):
602 """Try to Load extensions from descriptor db.
603
604 Args:
605 message_descriptor: descriptor of the extended message.
606 number: the extension number that needs to be loaded.
607 """
608 if not self._descriptor_db:
609 return
610 # Only supported when FindFileContainingExtension is provided.
611 if not hasattr(
612 self._descriptor_db, 'FindFileContainingExtension'):
613 return
614
615 full_name = message_descriptor.full_name
616 file_proto = None
617 try:
618 file_proto = self._descriptor_db.FindFileContainingExtension(
619 full_name, number
620 )
621 except:
622 return
623
624 if file_proto is None:
625 return
626
627 try:
628 self._ConvertFileProtoToFileDescriptor(file_proto)
629 except:
630 warn_msg = ('Unable to load proto file %s for extension number %d.' %
631 (file_proto.name, number))
632 warnings.warn(warn_msg, RuntimeWarning)
633
634 def FindServiceByName(self, full_name):
635 """Loads the named service descriptor from the pool.
636
637 Args:
638 full_name (str): The full name of the service descriptor to load.
639
640 Returns:
641 ServiceDescriptor: The service descriptor for the named service.
642
643 Raises:
644 KeyError: if the service cannot be found in the pool.
645 """
646 full_name = _NormalizeFullyQualifiedName(full_name)
647 if full_name not in self._service_descriptors:
648 self._FindFileContainingSymbolInDb(full_name)
649 return self._service_descriptors[full_name]
650
651 def FindMethodByName(self, full_name):
652 """Loads the named service method descriptor from the pool.
653
654 Args:
655 full_name (str): The full name of the method descriptor to load.
656
657 Returns:
658 MethodDescriptor: The method descriptor for the service method.
659
660 Raises:
661 KeyError: if the method cannot be found in the pool.
662 """
663 full_name = _NormalizeFullyQualifiedName(full_name)
664 service_name, _, method_name = full_name.rpartition('.')
665 service_descriptor = self.FindServiceByName(service_name)
666 return service_descriptor.methods_by_name[method_name]
667
668 def SetFeatureSetDefaults(self, defaults):
669 """Sets the default feature mappings used during the build.
670
671 Args:
672 defaults: a FeatureSetDefaults message containing the new mappings.
673 """
674 if self._edition_defaults is not None:
675 raise ValueError(
676 "Feature set defaults can't be changed once the pool has started"
677 ' building!'
678 )
679
680 # pylint: disable=g-import-not-at-top
681 from google.protobuf import descriptor_pb2
682
683 if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
684 raise TypeError('SetFeatureSetDefaults called with invalid type')
685
686 if defaults.minimum_edition > defaults.maximum_edition:
687 raise ValueError(
688 'Invalid edition range %s to %s'
689 % (
690 descriptor_pb2.Edition.Name(defaults.minimum_edition),
691 descriptor_pb2.Edition.Name(defaults.maximum_edition),
692 )
693 )
694
695 prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
696 for d in defaults.defaults:
697 if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
698 raise ValueError('Invalid edition EDITION_UNKNOWN specified')
699 if prev_edition >= d.edition:
700 raise ValueError(
701 'Feature set defaults are not strictly increasing. %s is greater'
702 ' than or equal to %s'
703 % (
704 descriptor_pb2.Edition.Name(prev_edition),
705 descriptor_pb2.Edition.Name(d.edition),
706 )
707 )
708 prev_edition = d.edition
709 self._edition_defaults = defaults
710
711 def _CreateDefaultFeatures(self, edition):
712 """Creates a FeatureSet message with defaults for a specific edition.
713
714 Args:
715 edition: the edition to generate defaults for.
716
717 Returns:
718 A FeatureSet message with defaults for a specific edition.
719 """
720 # pylint: disable=g-import-not-at-top
721 from google.protobuf import descriptor_pb2
722
723 with _edition_defaults_lock:
724 if not self._edition_defaults:
725 self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
726 self._edition_defaults.ParseFromString(
727 self._serialized_edition_defaults
728 )
729
730 if edition < self._edition_defaults.minimum_edition:
731 raise TypeError(
732 'Edition %s is earlier than the minimum supported edition %s!'
733 % (
734 descriptor_pb2.Edition.Name(edition),
735 descriptor_pb2.Edition.Name(
736 self._edition_defaults.minimum_edition
737 ),
738 )
739 )
740 if edition > self._edition_defaults.maximum_edition:
741 raise TypeError(
742 'Edition %s is later than the maximum supported edition %s!'
743 % (
744 descriptor_pb2.Edition.Name(edition),
745 descriptor_pb2.Edition.Name(
746 self._edition_defaults.maximum_edition
747 ),
748 )
749 )
750 found = None
751 for d in self._edition_defaults.defaults:
752 if d.edition > edition:
753 break
754 found = d
755 if found is None:
756 raise TypeError(
757 'No valid default found for edition %s!'
758 % descriptor_pb2.Edition.Name(edition)
759 )
760
761 defaults = descriptor_pb2.FeatureSet()
762 defaults.CopyFrom(found.fixed_features)
763 defaults.MergeFrom(found.overridable_features)
764 return defaults
765
766 def _InternFeatures(self, features):
767 serialized = features.SerializeToString()
768 with _edition_defaults_lock:
769 cached = self._feature_cache.get(serialized)
770 if cached is None:
771 self._feature_cache[serialized] = features
772 cached = features
773 return cached
774
775 def _FindFileContainingSymbolInDb(self, symbol):
776 """Finds the file in descriptor DB containing the specified symbol.
777
778 Args:
779 symbol (str): The name of the symbol to search for.
780
781 Returns:
782 FileDescriptor: The file that contains the specified symbol.
783
784 Raises:
785 KeyError: if the file cannot be found in the descriptor database.
786 """
787 try:
788 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
789 except KeyError as error:
790 if self._descriptor_db:
791 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
792 else:
793 raise error
794 if not file_proto:
795 raise KeyError('Cannot find a file containing %s' % symbol)
796 return self._ConvertFileProtoToFileDescriptor(file_proto)
797
798 def _ConvertFileProtoToFileDescriptor(self, file_proto):
799 """Creates a FileDescriptor from a proto or returns a cached copy.
800
801 This method also has the side effect of loading all the symbols found in
802 the file into the appropriate dictionaries in the pool.
803
804 Args:
805 file_proto: The proto to convert.
806
807 Returns:
808 A FileDescriptor matching the passed in proto.
809 """
810 if file_proto.name not in self._file_descriptors:
811 built_deps = list(self._GetDeps(file_proto.dependency))
812 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
813 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
814
815 # pylint: disable=g-import-not-at-top
816 from google.protobuf import descriptor_pb2
817
818 file_descriptor = descriptor.FileDescriptor(
819 pool=self,
820 name=file_proto.name,
821 package=file_proto.package,
822 syntax=file_proto.syntax,
823 edition=descriptor_pb2.Edition.Name(file_proto.edition),
824 options=_OptionsOrNone(file_proto),
825 serialized_pb=file_proto.SerializeToString(),
826 dependencies=direct_deps,
827 public_dependencies=public_deps,
828 # pylint: disable=protected-access
829 create_key=descriptor._internal_create_key,
830 )
831 scope = {}
832
833 # This loop extracts all the message and enum types from all the
834 # dependencies of the file_proto. This is necessary to create the
835 # scope of available message types when defining the passed in
836 # file proto.
837 for dependency in built_deps:
838 scope.update(self._ExtractSymbols(
839 dependency.message_types_by_name.values()))
840 scope.update((_PrefixWithDot(enum.full_name), enum)
841 for enum in dependency.enum_types_by_name.values())
842
843 for message_type in file_proto.message_type:
844 message_desc = self._ConvertMessageDescriptor(
845 message_type, file_proto.package, file_descriptor, scope,
846 file_proto.syntax)
847 file_descriptor.message_types_by_name[message_desc.name] = (
848 message_desc)
849
850 for enum_type in file_proto.enum_type:
851 file_descriptor.enum_types_by_name[enum_type.name] = (
852 self._ConvertEnumDescriptor(enum_type, file_proto.package,
853 file_descriptor, None, scope, True))
854
855 for index, extension_proto in enumerate(file_proto.extension):
856 extension_desc = self._MakeFieldDescriptor(
857 extension_proto, file_proto.package, index, file_descriptor,
858 is_extension=True)
859 extension_desc.containing_type = self._GetTypeFromScope(
860 file_descriptor.package, extension_proto.extendee, scope)
861 self._SetFieldType(extension_proto, extension_desc,
862 file_descriptor.package, scope)
863 file_descriptor.extensions_by_name[extension_desc.name] = (
864 extension_desc)
865
866 for desc_proto in file_proto.message_type:
867 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
868
869 if file_proto.package:
870 desc_proto_prefix = _PrefixWithDot(file_proto.package)
871 else:
872 desc_proto_prefix = ''
873
874 for desc_proto in file_proto.message_type:
875 desc = self._GetTypeFromScope(
876 desc_proto_prefix, desc_proto.name, scope)
877 file_descriptor.message_types_by_name[desc_proto.name] = desc
878
879 for index, service_proto in enumerate(file_proto.service):
880 file_descriptor.services_by_name[service_proto.name] = (
881 self._MakeServiceDescriptor(service_proto, index, scope,
882 file_proto.package, file_descriptor))
883
884 self._file_descriptors[file_proto.name] = file_descriptor
885
886 # Add extensions to the pool
887 def AddExtensionForNested(message_type):
888 for nested in message_type.nested_types:
889 AddExtensionForNested(nested)
890 for extension in message_type.extensions:
891 self._AddExtensionDescriptor(extension)
892
893 file_desc = self._file_descriptors[file_proto.name]
894 for extension in file_desc.extensions_by_name.values():
895 self._AddExtensionDescriptor(extension)
896 for message_type in file_desc.message_types_by_name.values():
897 AddExtensionForNested(message_type)
898
899 return file_desc
900
901 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
902 scope=None, syntax=None):
903 """Adds the proto to the pool in the specified package.
904
905 Args:
906 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
907 package: The package the proto should be located in.
908 file_desc: The file containing this message.
909 scope: Dict mapping short and full symbols to message and enum types.
910 syntax: string indicating syntax of the file ("proto2" or "proto3")
911
912 Returns:
913 The added descriptor.
914 """
915
916 if package:
917 desc_name = '.'.join((package, desc_proto.name))
918 else:
919 desc_name = desc_proto.name
920
921 if file_desc is None:
922 file_name = None
923 else:
924 file_name = file_desc.name
925
926 if scope is None:
927 scope = {}
928
929 nested = [
930 self._ConvertMessageDescriptor(
931 nested, desc_name, file_desc, scope, syntax)
932 for nested in desc_proto.nested_type]
933 enums = [
934 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
935 scope, False)
936 for enum in desc_proto.enum_type]
937 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
938 for index, field in enumerate(desc_proto.field)]
939 extensions = [
940 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
941 is_extension=True)
942 for index, extension in enumerate(desc_proto.extension)]
943 oneofs = [
944 # pylint: disable=g-complex-comprehension
945 descriptor.OneofDescriptor(
946 desc.name,
947 '.'.join((desc_name, desc.name)),
948 index,
949 None,
950 [],
951 _OptionsOrNone(desc),
952 # pylint: disable=protected-access
953 create_key=descriptor._internal_create_key)
954 for index, desc in enumerate(desc_proto.oneof_decl)
955 ]
956 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
957 if extension_ranges:
958 is_extendable = True
959 else:
960 is_extendable = False
961 desc = descriptor.Descriptor(
962 name=desc_proto.name,
963 full_name=desc_name,
964 filename=file_name,
965 containing_type=None,
966 fields=fields,
967 oneofs=oneofs,
968 nested_types=nested,
969 enum_types=enums,
970 extensions=extensions,
971 options=_OptionsOrNone(desc_proto),
972 is_extendable=is_extendable,
973 extension_ranges=extension_ranges,
974 file=file_desc,
975 serialized_start=None,
976 serialized_end=None,
977 is_map_entry=desc_proto.options.map_entry,
978 # pylint: disable=protected-access
979 create_key=descriptor._internal_create_key,
980 )
981 for nested in desc.nested_types:
982 nested.containing_type = desc
983 for enum in desc.enum_types:
984 enum.containing_type = desc
985 for field_index, field_desc in enumerate(desc_proto.field):
986 if field_desc.HasField('oneof_index'):
987 oneof_index = field_desc.oneof_index
988 oneofs[oneof_index].fields.append(fields[field_index])
989 fields[field_index].containing_oneof = oneofs[oneof_index]
990
991 scope[_PrefixWithDot(desc_name)] = desc
992 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
993 self._descriptors[desc_name] = desc
994 return desc
995
996 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
997 containing_type=None, scope=None, top_level=False):
998 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
999
1000 Args:
1001 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
1002 package: Optional package name for the new message EnumDescriptor.
1003 file_desc: The file containing the enum descriptor.
1004 containing_type: The type containing this enum.
1005 scope: Scope containing available types.
1006 top_level: If True, the enum is a top level symbol. If False, the enum
1007 is defined inside a message.
1008
1009 Returns:
1010 The added descriptor
1011 """
1012
1013 if package:
1014 enum_name = '.'.join((package, enum_proto.name))
1015 else:
1016 enum_name = enum_proto.name
1017
1018 if file_desc is None:
1019 file_name = None
1020 else:
1021 file_name = file_desc.name
1022
1023 values = [self._MakeEnumValueDescriptor(value, index)
1024 for index, value in enumerate(enum_proto.value)]
1025 desc = descriptor.EnumDescriptor(name=enum_proto.name,
1026 full_name=enum_name,
1027 filename=file_name,
1028 file=file_desc,
1029 values=values,
1030 containing_type=containing_type,
1031 options=_OptionsOrNone(enum_proto),
1032 # pylint: disable=protected-access
1033 create_key=descriptor._internal_create_key)
1034 scope['.%s' % enum_name] = desc
1035 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1036 self._enum_descriptors[enum_name] = desc
1037
1038 # Add top level enum values.
1039 if top_level:
1040 for value in values:
1041 full_name = _NormalizeFullyQualifiedName(
1042 '.'.join((package, value.name)))
1043 self._CheckConflictRegister(value, full_name, file_name)
1044 self._top_enum_values[full_name] = value
1045
1046 return desc
1047
1048 def _MakeFieldDescriptor(self, field_proto, message_name, index,
1049 file_desc, is_extension=False):
1050 """Creates a field descriptor from a FieldDescriptorProto.
1051
1052 For message and enum type fields, this method will do a look up
1053 in the pool for the appropriate descriptor for that type. If it
1054 is unavailable, it will fall back to the _source function to
1055 create it. If this type is still unavailable, construction will
1056 fail.
1057
1058 Args:
1059 field_proto: The proto describing the field.
1060 message_name: The name of the containing message.
1061 index: Index of the field
1062 file_desc: The file containing the field descriptor.
1063 is_extension: Indication that this field is for an extension.
1064
1065 Returns:
1066 An initialized FieldDescriptor object
1067 """
1068
1069 if message_name:
1070 full_name = '.'.join((message_name, field_proto.name))
1071 else:
1072 full_name = field_proto.name
1073
1074 if field_proto.json_name:
1075 json_name = field_proto.json_name
1076 else:
1077 json_name = None
1078
1079 return descriptor.FieldDescriptor(
1080 name=field_proto.name,
1081 full_name=full_name,
1082 index=index,
1083 number=field_proto.number,
1084 type=field_proto.type,
1085 cpp_type=None,
1086 message_type=None,
1087 enum_type=None,
1088 containing_type=None,
1089 label=field_proto.label,
1090 has_default_value=False,
1091 default_value=None,
1092 is_extension=is_extension,
1093 extension_scope=None,
1094 options=_OptionsOrNone(field_proto),
1095 json_name=json_name,
1096 file=file_desc,
1097 # pylint: disable=protected-access
1098 create_key=descriptor._internal_create_key)
1099
1100 def _SetAllFieldTypes(self, package, desc_proto, scope):
1101 """Sets all the descriptor's fields's types.
1102
1103 This method also sets the containing types on any extensions.
1104
1105 Args:
1106 package: The current package of desc_proto.
1107 desc_proto: The message descriptor to update.
1108 scope: Enclosing scope of available types.
1109 """
1110
1111 package = _PrefixWithDot(package)
1112
1113 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1114
1115 if package == '.':
1116 nested_package = _PrefixWithDot(desc_proto.name)
1117 else:
1118 nested_package = '.'.join([package, desc_proto.name])
1119
1120 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1121 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1122
1123 for extension_proto, extension_desc in (
1124 zip(desc_proto.extension, main_desc.extensions)):
1125 extension_desc.containing_type = self._GetTypeFromScope(
1126 nested_package, extension_proto.extendee, scope)
1127 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1128
1129 for nested_type in desc_proto.nested_type:
1130 self._SetAllFieldTypes(nested_package, nested_type, scope)
1131
1132 def _SetFieldType(self, field_proto, field_desc, package, scope):
1133 """Sets the field's type, cpp_type, message_type and enum_type.
1134
1135 Args:
1136 field_proto: Data about the field in proto format.
1137 field_desc: The descriptor to modify.
1138 package: The package the field's container is in.
1139 scope: Enclosing scope of available types.
1140 """
1141 if field_proto.type_name:
1142 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1143 else:
1144 desc = None
1145
1146 if not field_proto.HasField('type'):
1147 if isinstance(desc, descriptor.Descriptor):
1148 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1149 else:
1150 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1151
1152 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1153 field_proto.type)
1154
1155 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1156 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1157 field_desc.message_type = desc
1158
1159 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1160 field_desc.enum_type = desc
1161
1162 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1163 field_desc.has_default_value = False
1164 field_desc.default_value = []
1165 elif field_proto.HasField('default_value'):
1166 field_desc.has_default_value = True
1167 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1168 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1169 field_desc.default_value = float(field_proto.default_value)
1170 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1171 field_desc.default_value = field_proto.default_value
1172 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1173 field_desc.default_value = field_proto.default_value.lower() == 'true'
1174 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1175 field_desc.default_value = field_desc.enum_type.values_by_name[
1176 field_proto.default_value].number
1177 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1178 field_desc.default_value = text_encoding.CUnescape(
1179 field_proto.default_value)
1180 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1181 field_desc.default_value = None
1182 else:
1183 # All other types are of the "int" type.
1184 field_desc.default_value = int(field_proto.default_value)
1185 else:
1186 field_desc.has_default_value = False
1187 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1188 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1189 field_desc.default_value = 0.0
1190 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1191 field_desc.default_value = u''
1192 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1193 field_desc.default_value = False
1194 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1195 field_desc.default_value = field_desc.enum_type.values[0].number
1196 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1197 field_desc.default_value = b''
1198 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1199 field_desc.default_value = None
1200 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1201 field_desc.default_value = None
1202 else:
1203 # All other types are of the "int" type.
1204 field_desc.default_value = 0
1205
1206 field_desc.type = field_proto.type
1207
1208 def _MakeEnumValueDescriptor(self, value_proto, index):
1209 """Creates a enum value descriptor object from a enum value proto.
1210
1211 Args:
1212 value_proto: The proto describing the enum value.
1213 index: The index of the enum value.
1214
1215 Returns:
1216 An initialized EnumValueDescriptor object.
1217 """
1218
1219 return descriptor.EnumValueDescriptor(
1220 name=value_proto.name,
1221 index=index,
1222 number=value_proto.number,
1223 options=_OptionsOrNone(value_proto),
1224 type=None,
1225 # pylint: disable=protected-access
1226 create_key=descriptor._internal_create_key)
1227
1228 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1229 package, file_desc):
1230 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1231
1232 Args:
1233 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1234 service_index: The index of the service in the File.
1235 scope: Dict mapping short and full symbols to message and enum types.
1236 package: Optional package name for the new message EnumDescriptor.
1237 file_desc: The file containing the service descriptor.
1238
1239 Returns:
1240 The added descriptor.
1241 """
1242
1243 if package:
1244 service_name = '.'.join((package, service_proto.name))
1245 else:
1246 service_name = service_proto.name
1247
1248 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1249 scope, index)
1250 for index, method_proto in enumerate(service_proto.method)]
1251 desc = descriptor.ServiceDescriptor(
1252 name=service_proto.name,
1253 full_name=service_name,
1254 index=service_index,
1255 methods=methods,
1256 options=_OptionsOrNone(service_proto),
1257 file=file_desc,
1258 # pylint: disable=protected-access
1259 create_key=descriptor._internal_create_key)
1260 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1261 self._service_descriptors[service_name] = desc
1262 return desc
1263
1264 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1265 index):
1266 """Creates a method descriptor from a MethodDescriptorProto.
1267
1268 Args:
1269 method_proto: The proto describing the method.
1270 service_name: The name of the containing service.
1271 package: Optional package name to look up for types.
1272 scope: Scope containing available types.
1273 index: Index of the method in the service.
1274
1275 Returns:
1276 An initialized MethodDescriptor object.
1277 """
1278 full_name = '.'.join((service_name, method_proto.name))
1279 input_type = self._GetTypeFromScope(
1280 package, method_proto.input_type, scope)
1281 output_type = self._GetTypeFromScope(
1282 package, method_proto.output_type, scope)
1283 return descriptor.MethodDescriptor(
1284 name=method_proto.name,
1285 full_name=full_name,
1286 index=index,
1287 containing_service=None,
1288 input_type=input_type,
1289 output_type=output_type,
1290 client_streaming=method_proto.client_streaming,
1291 server_streaming=method_proto.server_streaming,
1292 options=_OptionsOrNone(method_proto),
1293 # pylint: disable=protected-access
1294 create_key=descriptor._internal_create_key)
1295
1296 def _ExtractSymbols(self, descriptors):
1297 """Pulls out all the symbols from descriptor protos.
1298
1299 Args:
1300 descriptors: The messages to extract descriptors from.
1301 Yields:
1302 A two element tuple of the type name and descriptor object.
1303 """
1304
1305 for desc in descriptors:
1306 yield (_PrefixWithDot(desc.full_name), desc)
1307 for symbol in self._ExtractSymbols(desc.nested_types):
1308 yield symbol
1309 for enum in desc.enum_types:
1310 yield (_PrefixWithDot(enum.full_name), enum)
1311
1312 def _GetDeps(self, dependencies, visited=None):
1313 """Recursively finds dependencies for file protos.
1314
1315 Args:
1316 dependencies: The names of the files being depended on.
1317 visited: The names of files already found.
1318
1319 Yields:
1320 Each direct and indirect dependency.
1321 """
1322
1323 visited = visited or set()
1324 for dependency in dependencies:
1325 if dependency not in visited:
1326 visited.add(dependency)
1327 dep_desc = self.FindFileByName(dependency)
1328 yield dep_desc
1329 public_files = [d.name for d in dep_desc.public_dependencies]
1330 yield from self._GetDeps(public_files, visited)
1331
1332 def _GetTypeFromScope(self, package, type_name, scope):
1333 """Finds a given type name in the current scope.
1334
1335 Args:
1336 package: The package the proto should be located in.
1337 type_name: The name of the type to be found in the scope.
1338 scope: Dict mapping short and full symbols to message and enum types.
1339
1340 Returns:
1341 The descriptor for the requested type.
1342 """
1343 if type_name not in scope:
1344 components = _PrefixWithDot(package).split('.')
1345 while components:
1346 possible_match = '.'.join(components + [type_name])
1347 if possible_match in scope:
1348 type_name = possible_match
1349 break
1350 else:
1351 components.pop(-1)
1352 return scope[type_name]
1353
1354
1355def _PrefixWithDot(name):
1356 return name if name.startswith('.') else '.%s' % name
1357
1358
1359if _USE_C_DESCRIPTORS:
1360 # TODO: This pool could be constructed from Python code, when we
1361 # support a flag like 'use_cpp_generated_pool=True'.
1362 # pylint: disable=protected-access
1363 _DEFAULT = descriptor._message.default_pool
1364else:
1365 _DEFAULT = DescriptorPool()
1366
1367
1368def Default():
1369 return _DEFAULT