1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides DescriptorPool to use as a container for proto2 descriptors.
9
10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11a collection of protocol buffer descriptors for use when dynamically creating
12message types at runtime.
13
14For most applications protocol buffers should be used via modules generated by
15the protocol buffer compiler tool. This should only be used when the type of
16protocol buffers used in an application or library cannot be predetermined.
17
18Below is a straightforward example on how to use this class::
19
20 pool = DescriptorPool()
21 file_descriptor_protos = [ ... ]
22 for file_descriptor_proto in file_descriptor_protos:
23 pool.Add(file_descriptor_proto)
24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
25
26The message descriptor can be used in conjunction with the message_factory
27module in order to create a protocol buffer class that can be encoded and
28decoded.
29
30If you want to get a Python class for the specified proto, use the
31helper functions inside google.protobuf.message_factory
32directly instead of this class.
33"""
34
35__author__ = 'matthewtoia@google.com (Matt Toia)'
36
37import collections
38import threading
39import warnings
40
41from google.protobuf import descriptor
42from google.protobuf import descriptor_database
43from google.protobuf import text_encoding
44from google.protobuf.internal import python_edition_defaults
45from google.protobuf.internal import python_message
46
47_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
48
49
50def _NormalizeFullyQualifiedName(name):
51 """Remove leading period from fully-qualified type name.
52
53 Due to b/13860351 in descriptor_database.py, types in the root namespace are
54 generated with a leading period. This function removes that prefix.
55
56 Args:
57 name (str): The fully-qualified symbol name.
58
59 Returns:
60 str: The normalized fully-qualified symbol name.
61 """
62 return name.lstrip('.')
63
64
65def _OptionsOrNone(descriptor_proto):
66 """Returns the value of the field `options`, or None if it is not set."""
67 if descriptor_proto.HasField('options'):
68 return descriptor_proto.options
69 else:
70 return None
71
72
73def _IsMessageSetExtension(field):
74 return (field.is_extension and
75 field.containing_type.has_options and
76 field.containing_type.GetOptions().message_set_wire_format and
77 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
78 not field.is_required and
79 not field.is_repeated)
80
81_edition_defaults_lock = threading.Lock()
82
83
84class DescriptorPool(object):
85 """A collection of protobufs dynamically constructed by descriptor protos."""
86
87 if _USE_C_DESCRIPTORS:
88
89 def __new__(cls, descriptor_db=None):
90 # pylint: disable=protected-access
91 return descriptor._message.DescriptorPool(descriptor_db)
92
93 def __init__(
94 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
95 ):
96 """Initializes a Pool of proto buffs.
97
98 The descriptor_db argument to the constructor is provided to allow
99 specialized file descriptor proto lookup code to be triggered on demand. An
100 example would be an implementation which will read and compile a file
101 specified in a call to FindFileByName() and not require the call to Add()
102 at all. Results from this database will be cached internally here as well.
103
104 Args:
105 descriptor_db: A secondary source of file descriptors.
106 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
107 C++.
108 """
109
110 self._internal_db = descriptor_database.DescriptorDatabase()
111 self._descriptor_db = descriptor_db
112 self._descriptors = {}
113 self._enum_descriptors = {}
114 self._service_descriptors = {}
115 self._file_descriptors = {}
116 self._toplevel_extensions = {}
117 self._top_enum_values = {}
118 # We store extensions in two two-level mappings: The first key is the
119 # descriptor of the message being extended, the second key is the extension
120 # full name or its tag number.
121 self._extensions_by_name = collections.defaultdict(dict)
122 self._extensions_by_number = collections.defaultdict(dict)
123 self._serialized_edition_defaults = (
124 python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
125 )
126 self._edition_defaults = None
127 self._feature_cache = dict()
128
129 def _CheckConflictRegister(self, desc, desc_name, file_name):
130 """Check if the descriptor name conflicts with another of the same name.
131
132 Args:
133 desc: Descriptor of a message, enum, service, extension or enum value.
134 desc_name (str): the full name of desc.
135 file_name (str): The file name of descriptor.
136 """
137 for register, descriptor_type in [
138 (self._descriptors, descriptor.Descriptor),
139 (self._enum_descriptors, descriptor.EnumDescriptor),
140 (self._service_descriptors, descriptor.ServiceDescriptor),
141 (self._toplevel_extensions, descriptor.FieldDescriptor),
142 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
143 if desc_name in register:
144 old_desc = register[desc_name]
145 if isinstance(old_desc, descriptor.EnumValueDescriptor):
146 old_file = old_desc.type.file.name
147 else:
148 old_file = old_desc.file.name
149
150 if not isinstance(desc, descriptor_type) or (
151 old_file != file_name):
152 error_msg = ('Conflict register for file "' + file_name +
153 '": ' + desc_name +
154 ' is already defined in file "' +
155 old_file + '". Please fix the conflict by adding '
156 'package name on the proto file, or use different '
157 'name for the duplication.')
158 if isinstance(desc, descriptor.EnumValueDescriptor):
159 error_msg += ('\nNote: enum values appear as '
160 'siblings of the enum type instead of '
161 'children of it.')
162
163 raise TypeError(error_msg)
164
165 return
166
167 def Add(self, file_desc_proto):
168 """Adds the FileDescriptorProto and its types to this pool.
169
170 Args:
171 file_desc_proto (FileDescriptorProto): The file descriptor to add.
172 """
173
174 self._internal_db.Add(file_desc_proto)
175
176 def AddSerializedFile(self, serialized_file_desc_proto):
177 """Adds the FileDescriptorProto and its types to this pool.
178
179 Args:
180 serialized_file_desc_proto (bytes): A bytes string, serialization of the
181 :class:`FileDescriptorProto` to add.
182
183 Returns:
184 FileDescriptor: Descriptor for the added file.
185 """
186
187 # pylint: disable=g-import-not-at-top
188 from google.protobuf import descriptor_pb2
189 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
190 serialized_file_desc_proto)
191 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
192 file_desc.serialized_pb = serialized_file_desc_proto
193 return file_desc
194
195 # Never call this method. It is for internal usage only.
196 def _AddDescriptor(self, desc):
197 """Adds a Descriptor to the pool, non-recursively.
198
199 If the Descriptor contains nested messages or enums, the caller must
200 explicitly register them. This method also registers the FileDescriptor
201 associated with the message.
202
203 Args:
204 desc: A Descriptor.
205 """
206 if not isinstance(desc, descriptor.Descriptor):
207 raise TypeError('Expected instance of descriptor.Descriptor.')
208
209 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
210
211 self._descriptors[desc.full_name] = desc
212 self._AddFileDescriptor(desc.file)
213
214 # Never call this method. It is for internal usage only.
215 def _AddEnumDescriptor(self, enum_desc):
216 """Adds an EnumDescriptor to the pool.
217
218 This method also registers the FileDescriptor associated with the enum.
219
220 Args:
221 enum_desc: An EnumDescriptor.
222 """
223
224 if not isinstance(enum_desc, descriptor.EnumDescriptor):
225 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
226
227 file_name = enum_desc.file.name
228 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
229 self._enum_descriptors[enum_desc.full_name] = enum_desc
230
231 # Top enum values need to be indexed.
232 # Count the number of dots to see whether the enum is toplevel or nested
233 # in a message. We cannot use enum_desc.containing_type at this stage.
234 if enum_desc.file.package:
235 top_level = (enum_desc.full_name.count('.')
236 - enum_desc.file.package.count('.') == 1)
237 else:
238 top_level = enum_desc.full_name.count('.') == 0
239 if top_level:
240 file_name = enum_desc.file.name
241 package = enum_desc.file.package
242 for enum_value in enum_desc.values:
243 full_name = _NormalizeFullyQualifiedName(
244 '.'.join((package, enum_value.name)))
245 self._CheckConflictRegister(enum_value, full_name, file_name)
246 self._top_enum_values[full_name] = enum_value
247 self._AddFileDescriptor(enum_desc.file)
248
249 # Never call this method. It is for internal usage only.
250 def _AddServiceDescriptor(self, service_desc):
251 """Adds a ServiceDescriptor to the pool.
252
253 Args:
254 service_desc: A ServiceDescriptor.
255 """
256
257 if not isinstance(service_desc, descriptor.ServiceDescriptor):
258 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
259
260 self._CheckConflictRegister(service_desc, service_desc.full_name,
261 service_desc.file.name)
262 self._service_descriptors[service_desc.full_name] = service_desc
263
264 # Never call this method. It is for internal usage only.
265 def _AddExtensionDescriptor(self, extension):
266 """Adds a FieldDescriptor describing an extension to the pool.
267
268 Args:
269 extension: A FieldDescriptor.
270
271 Raises:
272 AssertionError: when another extension with the same number extends the
273 same message.
274 TypeError: when the specified extension is not a
275 descriptor.FieldDescriptor.
276 """
277 if not (isinstance(extension, descriptor.FieldDescriptor) and
278 extension.is_extension):
279 raise TypeError('Expected an extension descriptor.')
280
281 if extension.extension_scope is None:
282 self._CheckConflictRegister(
283 extension, extension.full_name, extension.file.name)
284 self._toplevel_extensions[extension.full_name] = extension
285
286 try:
287 existing_desc = self._extensions_by_number[
288 extension.containing_type][extension.number]
289 except KeyError:
290 pass
291 else:
292 if extension is not existing_desc:
293 raise AssertionError(
294 'Extensions "%s" and "%s" both try to extend message type "%s" '
295 'with field number %d.' %
296 (extension.full_name, existing_desc.full_name,
297 extension.containing_type.full_name, extension.number))
298
299 self._extensions_by_number[extension.containing_type][
300 extension.number] = extension
301 self._extensions_by_name[extension.containing_type][
302 extension.full_name] = extension
303
304 # Also register MessageSet extensions with the type name.
305 if _IsMessageSetExtension(extension):
306 self._extensions_by_name[extension.containing_type][
307 extension.message_type.full_name] = extension
308
309 if hasattr(extension.containing_type, '_concrete_class'):
310 python_message._AttachFieldHelpers(
311 extension.containing_type._concrete_class, extension)
312
313 # Never call this method. It is for internal usage only.
314 def _InternalAddFileDescriptor(self, file_desc):
315 """Adds a FileDescriptor to the pool, non-recursively.
316
317 If the FileDescriptor contains messages or enums, the caller must explicitly
318 register them.
319
320 Args:
321 file_desc: A FileDescriptor.
322 """
323
324 self._AddFileDescriptor(file_desc)
325
326 def _AddFileDescriptor(self, file_desc):
327 """Adds a FileDescriptor to the pool, non-recursively.
328
329 If the FileDescriptor contains messages or enums, the caller must explicitly
330 register them.
331
332 Args:
333 file_desc: A FileDescriptor.
334 """
335
336 if not isinstance(file_desc, descriptor.FileDescriptor):
337 raise TypeError('Expected instance of descriptor.FileDescriptor.')
338 self._file_descriptors[file_desc.name] = file_desc
339
340 def FindFileByName(self, file_name):
341 """Gets a FileDescriptor by file name.
342
343 Args:
344 file_name (str): The path to the file to get a descriptor for.
345
346 Returns:
347 FileDescriptor: The descriptor for the named file.
348
349 Raises:
350 KeyError: if the file cannot be found in the pool.
351 """
352
353 try:
354 return self._file_descriptors[file_name]
355 except KeyError:
356 pass
357
358 try:
359 file_proto = self._internal_db.FindFileByName(file_name)
360 except KeyError as error:
361 if self._descriptor_db:
362 file_proto = self._descriptor_db.FindFileByName(file_name)
363 else:
364 raise error
365 if not file_proto:
366 raise KeyError('Cannot find a file named %s' % file_name)
367 return self._ConvertFileProtoToFileDescriptor(file_proto)
368
369 def FindFileContainingSymbol(self, symbol):
370 """Gets the FileDescriptor for the file containing the specified symbol.
371
372 Args:
373 symbol (str): The name of the symbol to search for.
374
375 Returns:
376 FileDescriptor: Descriptor for the file that contains the specified
377 symbol.
378
379 Raises:
380 KeyError: if the file cannot be found in the pool.
381 """
382
383 symbol = _NormalizeFullyQualifiedName(symbol)
384 try:
385 return self._InternalFindFileContainingSymbol(symbol)
386 except KeyError:
387 pass
388
389 try:
390 # Try fallback database. Build and find again if possible.
391 self._FindFileContainingSymbolInDb(symbol)
392 return self._InternalFindFileContainingSymbol(symbol)
393 except KeyError:
394 raise KeyError('Cannot find a file containing %s' % symbol)
395
396 def _InternalFindFileContainingSymbol(self, symbol):
397 """Gets the already built FileDescriptor containing the specified symbol.
398
399 Args:
400 symbol (str): The name of the symbol to search for.
401
402 Returns:
403 FileDescriptor: Descriptor for the file that contains the specified
404 symbol.
405
406 Raises:
407 KeyError: if the file cannot be found in the pool.
408 """
409 try:
410 return self._descriptors[symbol].file
411 except KeyError:
412 pass
413
414 try:
415 return self._enum_descriptors[symbol].file
416 except KeyError:
417 pass
418
419 try:
420 return self._service_descriptors[symbol].file
421 except KeyError:
422 pass
423
424 try:
425 return self._top_enum_values[symbol].type.file
426 except KeyError:
427 pass
428
429 try:
430 return self._toplevel_extensions[symbol].file
431 except KeyError:
432 pass
433
434 # Try fields, enum values and nested extensions inside a message.
435 top_name, _, sub_name = symbol.rpartition('.')
436 try:
437 message = self.FindMessageTypeByName(top_name)
438 assert (sub_name in message.extensions_by_name or
439 sub_name in message.fields_by_name or
440 sub_name in message.enum_values_by_name)
441 return message.file
442 except (KeyError, AssertionError):
443 raise KeyError('Cannot find a file containing %s' % symbol)
444
445 def FindMessageTypeByName(self, full_name):
446 """Loads the named descriptor from the pool.
447
448 Args:
449 full_name (str): The full name of the descriptor to load.
450
451 Returns:
452 Descriptor: The descriptor for the named type.
453
454 Raises:
455 KeyError: if the message cannot be found in the pool.
456 """
457
458 full_name = _NormalizeFullyQualifiedName(full_name)
459 if full_name not in self._descriptors:
460 self._FindFileContainingSymbolInDb(full_name)
461 return self._descriptors[full_name]
462
463 def FindEnumTypeByName(self, full_name):
464 """Loads the named enum descriptor from the pool.
465
466 Args:
467 full_name (str): The full name of the enum descriptor to load.
468
469 Returns:
470 EnumDescriptor: The enum descriptor for the named type.
471
472 Raises:
473 KeyError: if the enum cannot be found in the pool.
474 """
475
476 full_name = _NormalizeFullyQualifiedName(full_name)
477 if full_name not in self._enum_descriptors:
478 self._FindFileContainingSymbolInDb(full_name)
479 return self._enum_descriptors[full_name]
480
481 def FindFieldByName(self, full_name):
482 """Loads the named field descriptor from the pool.
483
484 Args:
485 full_name (str): The full name of the field descriptor to load.
486
487 Returns:
488 FieldDescriptor: The field descriptor for the named field.
489
490 Raises:
491 KeyError: if the field cannot be found in the pool.
492 """
493 full_name = _NormalizeFullyQualifiedName(full_name)
494 message_name, _, field_name = full_name.rpartition('.')
495 message_descriptor = self.FindMessageTypeByName(message_name)
496 return message_descriptor.fields_by_name[field_name]
497
498 def FindOneofByName(self, full_name):
499 """Loads the named oneof descriptor from the pool.
500
501 Args:
502 full_name (str): The full name of the oneof descriptor to load.
503
504 Returns:
505 OneofDescriptor: The oneof descriptor for the named oneof.
506
507 Raises:
508 KeyError: if the oneof cannot be found in the pool.
509 """
510 full_name = _NormalizeFullyQualifiedName(full_name)
511 message_name, _, oneof_name = full_name.rpartition('.')
512 message_descriptor = self.FindMessageTypeByName(message_name)
513 return message_descriptor.oneofs_by_name[oneof_name]
514
515 def FindExtensionByName(self, full_name):
516 """Loads the named extension descriptor from the pool.
517
518 Args:
519 full_name (str): The full name of the extension descriptor to load.
520
521 Returns:
522 FieldDescriptor: The field descriptor for the named extension.
523
524 Raises:
525 KeyError: if the extension cannot be found in the pool.
526 """
527 full_name = _NormalizeFullyQualifiedName(full_name)
528 try:
529 # The proto compiler does not give any link between the FileDescriptor
530 # and top-level extensions unless the FileDescriptorProto is added to
531 # the DescriptorDatabase, but this can impact memory usage.
532 # So we registered these extensions by name explicitly.
533 return self._toplevel_extensions[full_name]
534 except KeyError:
535 pass
536 message_name, _, extension_name = full_name.rpartition('.')
537 try:
538 # Most extensions are nested inside a message.
539 scope = self.FindMessageTypeByName(message_name)
540 except KeyError:
541 # Some extensions are defined at file scope.
542 scope = self._FindFileContainingSymbolInDb(full_name)
543 return scope.extensions_by_name[extension_name]
544
545 def FindExtensionByNumber(self, message_descriptor, number):
546 """Gets the extension of the specified message with the specified number.
547
548 Extensions have to be registered to this pool by calling :func:`Add` or
549 :func:`AddExtensionDescriptor`.
550
551 Args:
552 message_descriptor (Descriptor): descriptor of the extended message.
553 number (int): Number of the extension field.
554
555 Returns:
556 FieldDescriptor: The descriptor for the extension.
557
558 Raises:
559 KeyError: when no extension with the given number is known for the
560 specified message.
561 """
562 try:
563 return self._extensions_by_number[message_descriptor][number]
564 except KeyError:
565 self._TryLoadExtensionFromDB(message_descriptor, number)
566 return self._extensions_by_number[message_descriptor][number]
567
568 def FindAllExtensions(self, message_descriptor):
569 """Gets all the known extensions of a given message.
570
571 Extensions have to be registered to this pool by build related
572 :func:`Add` or :func:`AddExtensionDescriptor`.
573
574 Args:
575 message_descriptor (Descriptor): Descriptor of the extended message.
576
577 Returns:
578 list[FieldDescriptor]: Field descriptors describing the extensions.
579 """
580 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
581 if self._descriptor_db and hasattr(
582 self._descriptor_db, 'FindAllExtensionNumbers'):
583 full_name = message_descriptor.full_name
584 try:
585 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
586 except:
587 pass
588 else:
589 if isinstance(all_numbers, list):
590 for number in all_numbers:
591 if number in self._extensions_by_number[message_descriptor]:
592 continue
593 self._TryLoadExtensionFromDB(message_descriptor, number)
594 else:
595 warnings.warn(
596 'FindAllExtensionNumbers() on fall back DB must return a list,'
597 ' not {0}'.format(type(all_numbers))
598 )
599
600 return list(self._extensions_by_number[message_descriptor].values())
601
602 def _TryLoadExtensionFromDB(self, message_descriptor, number):
603 """Try to Load extensions from descriptor db.
604
605 Args:
606 message_descriptor: descriptor of the extended message.
607 number: the extension number that needs to be loaded.
608 """
609 if not self._descriptor_db:
610 return
611 # Only supported when FindFileContainingExtension is provided.
612 if not hasattr(
613 self._descriptor_db, 'FindFileContainingExtension'):
614 return
615
616 full_name = message_descriptor.full_name
617 file_proto = None
618 try:
619 file_proto = self._descriptor_db.FindFileContainingExtension(
620 full_name, number
621 )
622 except:
623 return
624
625 if file_proto is None:
626 return
627
628 try:
629 self._ConvertFileProtoToFileDescriptor(file_proto)
630 except:
631 warn_msg = ('Unable to load proto file %s for extension number %d.' %
632 (file_proto.name, number))
633 warnings.warn(warn_msg, RuntimeWarning)
634
635 def FindServiceByName(self, full_name):
636 """Loads the named service descriptor from the pool.
637
638 Args:
639 full_name (str): The full name of the service descriptor to load.
640
641 Returns:
642 ServiceDescriptor: The service descriptor for the named service.
643
644 Raises:
645 KeyError: if the service cannot be found in the pool.
646 """
647 full_name = _NormalizeFullyQualifiedName(full_name)
648 if full_name not in self._service_descriptors:
649 self._FindFileContainingSymbolInDb(full_name)
650 return self._service_descriptors[full_name]
651
652 def FindMethodByName(self, full_name):
653 """Loads the named service method descriptor from the pool.
654
655 Args:
656 full_name (str): The full name of the method descriptor to load.
657
658 Returns:
659 MethodDescriptor: The method descriptor for the service method.
660
661 Raises:
662 KeyError: if the method cannot be found in the pool.
663 """
664 full_name = _NormalizeFullyQualifiedName(full_name)
665 service_name, _, method_name = full_name.rpartition('.')
666 service_descriptor = self.FindServiceByName(service_name)
667 return service_descriptor.methods_by_name[method_name]
668
669 def SetFeatureSetDefaults(self, defaults):
670 """Sets the default feature mappings used during the build.
671
672 Args:
673 defaults: a FeatureSetDefaults message containing the new mappings.
674 """
675 if self._edition_defaults is not None:
676 raise ValueError(
677 "Feature set defaults can't be changed once the pool has started"
678 ' building!'
679 )
680
681 # pylint: disable=g-import-not-at-top
682 from google.protobuf import descriptor_pb2
683
684 if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
685 raise TypeError('SetFeatureSetDefaults called with invalid type')
686
687 if defaults.minimum_edition > defaults.maximum_edition:
688 raise ValueError(
689 'Invalid edition range %s to %s'
690 % (
691 descriptor_pb2.Edition.Name(defaults.minimum_edition),
692 descriptor_pb2.Edition.Name(defaults.maximum_edition),
693 )
694 )
695
696 prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
697 for d in defaults.defaults:
698 if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
699 raise ValueError('Invalid edition EDITION_UNKNOWN specified')
700 if prev_edition >= d.edition:
701 raise ValueError(
702 'Feature set defaults are not strictly increasing. %s is greater'
703 ' than or equal to %s'
704 % (
705 descriptor_pb2.Edition.Name(prev_edition),
706 descriptor_pb2.Edition.Name(d.edition),
707 )
708 )
709 prev_edition = d.edition
710 self._edition_defaults = defaults
711
712 def _CreateDefaultFeatures(self, edition):
713 """Creates a FeatureSet message with defaults for a specific edition.
714
715 Args:
716 edition: the edition to generate defaults for.
717
718 Returns:
719 A FeatureSet message with defaults for a specific edition.
720 """
721 # pylint: disable=g-import-not-at-top
722 from google.protobuf import descriptor_pb2
723
724 with _edition_defaults_lock:
725 if not self._edition_defaults:
726 self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
727 self._edition_defaults.ParseFromString(
728 self._serialized_edition_defaults
729 )
730
731 if edition < self._edition_defaults.minimum_edition:
732 raise TypeError(
733 'Edition %s is earlier than the minimum supported edition %s!'
734 % (
735 descriptor_pb2.Edition.Name(edition),
736 descriptor_pb2.Edition.Name(
737 self._edition_defaults.minimum_edition
738 ),
739 )
740 )
741 if (
742 edition > self._edition_defaults.maximum_edition
743 and edition != descriptor_pb2.EDITION_UNSTABLE
744 ):
745 raise TypeError(
746 'Edition %s is later than the maximum supported edition %s!'
747 % (
748 descriptor_pb2.Edition.Name(edition),
749 descriptor_pb2.Edition.Name(
750 self._edition_defaults.maximum_edition
751 ),
752 )
753 )
754 found = None
755 for d in self._edition_defaults.defaults:
756 if d.edition > edition:
757 break
758 found = d
759 if found is None:
760 raise TypeError(
761 'No valid default found for edition %s!'
762 % descriptor_pb2.Edition.Name(edition)
763 )
764
765 defaults = descriptor_pb2.FeatureSet()
766 defaults.CopyFrom(found.fixed_features)
767 defaults.MergeFrom(found.overridable_features)
768 return defaults
769
770 def _InternFeatures(self, features):
771 serialized = features.SerializeToString()
772 with _edition_defaults_lock:
773 cached = self._feature_cache.get(serialized)
774 if cached is None:
775 self._feature_cache[serialized] = features
776 cached = features
777 return cached
778
779 def _FindFileContainingSymbolInDb(self, symbol):
780 """Finds the file in descriptor DB containing the specified symbol.
781
782 Args:
783 symbol (str): The name of the symbol to search for.
784
785 Returns:
786 FileDescriptor: The file that contains the specified symbol.
787
788 Raises:
789 KeyError: if the file cannot be found in the descriptor database.
790 """
791 try:
792 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
793 except KeyError as error:
794 if self._descriptor_db:
795 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
796 else:
797 raise error
798 if not file_proto:
799 raise KeyError('Cannot find a file containing %s' % symbol)
800 return self._ConvertFileProtoToFileDescriptor(file_proto)
801
802 def _ConvertFileProtoToFileDescriptor(self, file_proto):
803 """Creates a FileDescriptor from a proto or returns a cached copy.
804
805 This method also has the side effect of loading all the symbols found in
806 the file into the appropriate dictionaries in the pool.
807
808 Args:
809 file_proto: The proto to convert.
810
811 Returns:
812 A FileDescriptor matching the passed in proto.
813 """
814 if file_proto.name not in self._file_descriptors:
815 built_deps = list(self._GetDeps(file_proto.dependency))
816 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
817 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
818
819 # pylint: disable=g-import-not-at-top
820 from google.protobuf import descriptor_pb2
821
822 file_descriptor = descriptor.FileDescriptor(
823 pool=self,
824 name=file_proto.name,
825 package=file_proto.package,
826 syntax=file_proto.syntax,
827 edition=descriptor_pb2.Edition.Name(file_proto.edition),
828 options=_OptionsOrNone(file_proto),
829 serialized_pb=file_proto.SerializeToString(),
830 dependencies=direct_deps,
831 public_dependencies=public_deps,
832 # pylint: disable=protected-access
833 create_key=descriptor._internal_create_key,
834 )
835 scope = {}
836
837 # This loop extracts all the message and enum types from all the
838 # dependencies of the file_proto. This is necessary to create the
839 # scope of available message types when defining the passed in
840 # file proto.
841 for dependency in built_deps:
842 scope.update(self._ExtractSymbols(
843 dependency.message_types_by_name.values()))
844 scope.update((_PrefixWithDot(enum.full_name), enum)
845 for enum in dependency.enum_types_by_name.values())
846
847 for message_type in file_proto.message_type:
848 message_desc = self._ConvertMessageDescriptor(
849 message_type, file_proto.package, file_descriptor, scope,
850 file_proto.syntax)
851 file_descriptor.message_types_by_name[message_desc.name] = (
852 message_desc)
853
854 for enum_type in file_proto.enum_type:
855 file_descriptor.enum_types_by_name[enum_type.name] = (
856 self._ConvertEnumDescriptor(enum_type, file_proto.package,
857 file_descriptor, None, scope, True))
858
859 for index, extension_proto in enumerate(file_proto.extension):
860 extension_desc = self._MakeFieldDescriptor(
861 extension_proto, file_proto.package, index, file_descriptor,
862 is_extension=True)
863 extension_desc.containing_type = self._GetTypeFromScope(
864 file_descriptor.package, extension_proto.extendee, scope)
865 self._SetFieldType(extension_proto, extension_desc,
866 file_descriptor.package, scope)
867 file_descriptor.extensions_by_name[extension_desc.name] = (
868 extension_desc)
869
870 for desc_proto in file_proto.message_type:
871 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
872
873 if file_proto.package:
874 desc_proto_prefix = _PrefixWithDot(file_proto.package)
875 else:
876 desc_proto_prefix = ''
877
878 for desc_proto in file_proto.message_type:
879 desc = self._GetTypeFromScope(
880 desc_proto_prefix, desc_proto.name, scope)
881 file_descriptor.message_types_by_name[desc_proto.name] = desc
882
883 for index, service_proto in enumerate(file_proto.service):
884 file_descriptor.services_by_name[service_proto.name] = (
885 self._MakeServiceDescriptor(service_proto, index, scope,
886 file_proto.package, file_descriptor))
887
888 self._file_descriptors[file_proto.name] = file_descriptor
889
890 # Add extensions to the pool
891 def AddExtensionForNested(message_type):
892 for nested in message_type.nested_types:
893 AddExtensionForNested(nested)
894 for extension in message_type.extensions:
895 self._AddExtensionDescriptor(extension)
896
897 file_desc = self._file_descriptors[file_proto.name]
898 for extension in file_desc.extensions_by_name.values():
899 self._AddExtensionDescriptor(extension)
900 for message_type in file_desc.message_types_by_name.values():
901 AddExtensionForNested(message_type)
902
903 return file_desc
904
905 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
906 scope=None, syntax=None):
907 """Adds the proto to the pool in the specified package.
908
909 Args:
910 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
911 package: The package the proto should be located in.
912 file_desc: The file containing this message.
913 scope: Dict mapping short and full symbols to message and enum types.
914 syntax: string indicating syntax of the file ("proto2" or "proto3")
915
916 Returns:
917 The added descriptor.
918 """
919
920 if package:
921 desc_name = '.'.join((package, desc_proto.name))
922 else:
923 desc_name = desc_proto.name
924
925 if file_desc is None:
926 file_name = None
927 else:
928 file_name = file_desc.name
929
930 if scope is None:
931 scope = {}
932
933 nested = [
934 self._ConvertMessageDescriptor(
935 nested, desc_name, file_desc, scope, syntax)
936 for nested in desc_proto.nested_type]
937 enums = [
938 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
939 scope, False)
940 for enum in desc_proto.enum_type]
941 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
942 for index, field in enumerate(desc_proto.field)]
943 extensions = [
944 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
945 is_extension=True)
946 for index, extension in enumerate(desc_proto.extension)]
947 oneofs = [
948 # pylint: disable=g-complex-comprehension
949 descriptor.OneofDescriptor(
950 desc.name,
951 '.'.join((desc_name, desc.name)),
952 index,
953 None,
954 [],
955 _OptionsOrNone(desc),
956 # pylint: disable=protected-access
957 create_key=descriptor._internal_create_key)
958 for index, desc in enumerate(desc_proto.oneof_decl)
959 ]
960 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
961 if extension_ranges:
962 is_extendable = True
963 else:
964 is_extendable = False
965 desc = descriptor.Descriptor(
966 name=desc_proto.name,
967 full_name=desc_name,
968 filename=file_name,
969 containing_type=None,
970 fields=fields,
971 oneofs=oneofs,
972 nested_types=nested,
973 enum_types=enums,
974 extensions=extensions,
975 options=_OptionsOrNone(desc_proto),
976 is_extendable=is_extendable,
977 extension_ranges=extension_ranges,
978 file=file_desc,
979 serialized_start=None,
980 serialized_end=None,
981 is_map_entry=desc_proto.options.map_entry,
982 # pylint: disable=protected-access
983 create_key=descriptor._internal_create_key,
984 )
985 for nested in desc.nested_types:
986 nested.containing_type = desc
987 for enum in desc.enum_types:
988 enum.containing_type = desc
989 for field_index, field_desc in enumerate(desc_proto.field):
990 if field_desc.HasField('oneof_index'):
991 oneof_index = field_desc.oneof_index
992 oneofs[oneof_index].fields.append(fields[field_index])
993 fields[field_index].containing_oneof = oneofs[oneof_index]
994
995 scope[_PrefixWithDot(desc_name)] = desc
996 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
997 self._descriptors[desc_name] = desc
998 return desc
999
1000 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
1001 containing_type=None, scope=None, top_level=False):
1002 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
1003
1004 Args:
1005 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
1006 package: Optional package name for the new message EnumDescriptor.
1007 file_desc: The file containing the enum descriptor.
1008 containing_type: The type containing this enum.
1009 scope: Scope containing available types.
1010 top_level: If True, the enum is a top level symbol. If False, the enum
1011 is defined inside a message.
1012
1013 Returns:
1014 The added descriptor
1015 """
1016
1017 if package:
1018 enum_name = '.'.join((package, enum_proto.name))
1019 else:
1020 enum_name = enum_proto.name
1021
1022 if file_desc is None:
1023 file_name = None
1024 else:
1025 file_name = file_desc.name
1026
1027 values = [self._MakeEnumValueDescriptor(value, index)
1028 for index, value in enumerate(enum_proto.value)]
1029 desc = descriptor.EnumDescriptor(name=enum_proto.name,
1030 full_name=enum_name,
1031 filename=file_name,
1032 file=file_desc,
1033 values=values,
1034 containing_type=containing_type,
1035 options=_OptionsOrNone(enum_proto),
1036 # pylint: disable=protected-access
1037 create_key=descriptor._internal_create_key)
1038 scope['.%s' % enum_name] = desc
1039 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1040 self._enum_descriptors[enum_name] = desc
1041
1042 # Add top level enum values.
1043 if top_level:
1044 for value in values:
1045 full_name = _NormalizeFullyQualifiedName(
1046 '.'.join((package, value.name)))
1047 self._CheckConflictRegister(value, full_name, file_name)
1048 self._top_enum_values[full_name] = value
1049
1050 return desc
1051
1052 def _MakeFieldDescriptor(self, field_proto, message_name, index,
1053 file_desc, is_extension=False):
1054 """Creates a field descriptor from a FieldDescriptorProto.
1055
1056 For message and enum type fields, this method will do a look up
1057 in the pool for the appropriate descriptor for that type. If it
1058 is unavailable, it will fall back to the _source function to
1059 create it. If this type is still unavailable, construction will
1060 fail.
1061
1062 Args:
1063 field_proto: The proto describing the field.
1064 message_name: The name of the containing message.
1065 index: Index of the field
1066 file_desc: The file containing the field descriptor.
1067 is_extension: Indication that this field is for an extension.
1068
1069 Returns:
1070 An initialized FieldDescriptor object
1071 """
1072
1073 if message_name:
1074 full_name = '.'.join((message_name, field_proto.name))
1075 else:
1076 full_name = field_proto.name
1077
1078 if field_proto.json_name:
1079 json_name = field_proto.json_name
1080 else:
1081 json_name = None
1082
1083 return descriptor.FieldDescriptor(
1084 name=field_proto.name,
1085 full_name=full_name,
1086 index=index,
1087 number=field_proto.number,
1088 type=field_proto.type,
1089 cpp_type=None,
1090 message_type=None,
1091 enum_type=None,
1092 containing_type=None,
1093 label=field_proto.label,
1094 has_default_value=False,
1095 default_value=None,
1096 is_extension=is_extension,
1097 extension_scope=None,
1098 options=_OptionsOrNone(field_proto),
1099 json_name=json_name,
1100 file=file_desc,
1101 # pylint: disable=protected-access
1102 create_key=descriptor._internal_create_key)
1103
1104 def _SetAllFieldTypes(self, package, desc_proto, scope):
1105 """Sets all the descriptor's fields's types.
1106
1107 This method also sets the containing types on any extensions.
1108
1109 Args:
1110 package: The current package of desc_proto.
1111 desc_proto: The message descriptor to update.
1112 scope: Enclosing scope of available types.
1113 """
1114
1115 package = _PrefixWithDot(package)
1116
1117 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1118
1119 if package == '.':
1120 nested_package = _PrefixWithDot(desc_proto.name)
1121 else:
1122 nested_package = '.'.join([package, desc_proto.name])
1123
1124 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1125 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1126
1127 for extension_proto, extension_desc in (
1128 zip(desc_proto.extension, main_desc.extensions)):
1129 extension_desc.containing_type = self._GetTypeFromScope(
1130 nested_package, extension_proto.extendee, scope)
1131 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1132
1133 for nested_type in desc_proto.nested_type:
1134 self._SetAllFieldTypes(nested_package, nested_type, scope)
1135
1136 def _SetFieldType(self, field_proto, field_desc, package, scope):
1137 """Sets the field's type, cpp_type, message_type and enum_type.
1138
1139 Args:
1140 field_proto: Data about the field in proto format.
1141 field_desc: The descriptor to modify.
1142 package: The package the field's container is in.
1143 scope: Enclosing scope of available types.
1144 """
1145 if field_proto.type_name:
1146 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1147 else:
1148 desc = None
1149
1150 if not field_proto.HasField('type'):
1151 if isinstance(desc, descriptor.Descriptor):
1152 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1153 else:
1154 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1155
1156 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1157 field_proto.type)
1158
1159 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1160 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1161 field_desc.message_type = desc
1162
1163 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1164 field_desc.enum_type = desc
1165
1166 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1167 field_desc.has_default_value = False
1168 field_desc.default_value = []
1169 elif field_proto.HasField('default_value'):
1170 field_desc.has_default_value = True
1171 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1172 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1173 field_desc.default_value = float(field_proto.default_value)
1174 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1175 field_desc.default_value = field_proto.default_value
1176 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1177 field_desc.default_value = field_proto.default_value.lower() == 'true'
1178 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1179 field_desc.default_value = field_desc.enum_type.values_by_name[
1180 field_proto.default_value].number
1181 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1182 field_desc.default_value = text_encoding.CUnescape(
1183 field_proto.default_value)
1184 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1185 field_desc.default_value = None
1186 else:
1187 # All other types are of the "int" type.
1188 field_desc.default_value = int(field_proto.default_value)
1189 else:
1190 field_desc.has_default_value = False
1191 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1192 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1193 field_desc.default_value = 0.0
1194 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1195 field_desc.default_value = u''
1196 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1197 field_desc.default_value = False
1198 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1199 field_desc.default_value = field_desc.enum_type.values[0].number
1200 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1201 field_desc.default_value = b''
1202 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1203 field_desc.default_value = None
1204 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1205 field_desc.default_value = None
1206 else:
1207 # All other types are of the "int" type.
1208 field_desc.default_value = 0
1209
1210 field_desc.type = field_proto.type
1211
1212 def _MakeEnumValueDescriptor(self, value_proto, index):
1213 """Creates a enum value descriptor object from a enum value proto.
1214
1215 Args:
1216 value_proto: The proto describing the enum value.
1217 index: The index of the enum value.
1218
1219 Returns:
1220 An initialized EnumValueDescriptor object.
1221 """
1222
1223 return descriptor.EnumValueDescriptor(
1224 name=value_proto.name,
1225 index=index,
1226 number=value_proto.number,
1227 options=_OptionsOrNone(value_proto),
1228 type=None,
1229 # pylint: disable=protected-access
1230 create_key=descriptor._internal_create_key)
1231
1232 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1233 package, file_desc):
1234 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1235
1236 Args:
1237 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1238 service_index: The index of the service in the File.
1239 scope: Dict mapping short and full symbols to message and enum types.
1240 package: Optional package name for the new message EnumDescriptor.
1241 file_desc: The file containing the service descriptor.
1242
1243 Returns:
1244 The added descriptor.
1245 """
1246
1247 if package:
1248 service_name = '.'.join((package, service_proto.name))
1249 else:
1250 service_name = service_proto.name
1251
1252 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1253 scope, index)
1254 for index, method_proto in enumerate(service_proto.method)]
1255 desc = descriptor.ServiceDescriptor(
1256 name=service_proto.name,
1257 full_name=service_name,
1258 index=service_index,
1259 methods=methods,
1260 options=_OptionsOrNone(service_proto),
1261 file=file_desc,
1262 # pylint: disable=protected-access
1263 create_key=descriptor._internal_create_key)
1264 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1265 self._service_descriptors[service_name] = desc
1266 return desc
1267
1268 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1269 index):
1270 """Creates a method descriptor from a MethodDescriptorProto.
1271
1272 Args:
1273 method_proto: The proto describing the method.
1274 service_name: The name of the containing service.
1275 package: Optional package name to look up for types.
1276 scope: Scope containing available types.
1277 index: Index of the method in the service.
1278
1279 Returns:
1280 An initialized MethodDescriptor object.
1281 """
1282 full_name = '.'.join((service_name, method_proto.name))
1283 input_type = self._GetTypeFromScope(
1284 package, method_proto.input_type, scope)
1285 output_type = self._GetTypeFromScope(
1286 package, method_proto.output_type, scope)
1287 return descriptor.MethodDescriptor(
1288 name=method_proto.name,
1289 full_name=full_name,
1290 index=index,
1291 containing_service=None,
1292 input_type=input_type,
1293 output_type=output_type,
1294 client_streaming=method_proto.client_streaming,
1295 server_streaming=method_proto.server_streaming,
1296 options=_OptionsOrNone(method_proto),
1297 # pylint: disable=protected-access
1298 create_key=descriptor._internal_create_key)
1299
1300 def _ExtractSymbols(self, descriptors):
1301 """Pulls out all the symbols from descriptor protos.
1302
1303 Args:
1304 descriptors: The messages to extract descriptors from.
1305 Yields:
1306 A two element tuple of the type name and descriptor object.
1307 """
1308
1309 for desc in descriptors:
1310 yield (_PrefixWithDot(desc.full_name), desc)
1311 for symbol in self._ExtractSymbols(desc.nested_types):
1312 yield symbol
1313 for enum in desc.enum_types:
1314 yield (_PrefixWithDot(enum.full_name), enum)
1315
1316 def _GetDeps(self, dependencies, visited=None):
1317 """Recursively finds dependencies for file protos.
1318
1319 Args:
1320 dependencies: The names of the files being depended on.
1321 visited: The names of files already found.
1322
1323 Yields:
1324 Each direct and indirect dependency.
1325 """
1326
1327 visited = visited or set()
1328 for dependency in dependencies:
1329 if dependency not in visited:
1330 visited.add(dependency)
1331 dep_desc = self.FindFileByName(dependency)
1332 yield dep_desc
1333 public_files = [d.name for d in dep_desc.public_dependencies]
1334 yield from self._GetDeps(public_files, visited)
1335
1336 def _GetTypeFromScope(self, package, type_name, scope):
1337 """Finds a given type name in the current scope.
1338
1339 Args:
1340 package: The package the proto should be located in.
1341 type_name: The name of the type to be found in the scope.
1342 scope: Dict mapping short and full symbols to message and enum types.
1343
1344 Returns:
1345 The descriptor for the requested type.
1346 """
1347 if type_name not in scope:
1348 components = _PrefixWithDot(package).split('.')
1349 while components:
1350 possible_match = '.'.join(components + [type_name])
1351 if possible_match in scope:
1352 type_name = possible_match
1353 break
1354 else:
1355 components.pop(-1)
1356 return scope[type_name]
1357
1358
1359def _PrefixWithDot(name):
1360 return name if name.startswith('.') else '.%s' % name
1361
1362
1363if _USE_C_DESCRIPTORS:
1364 # TODO: This pool could be constructed from Python code, when we
1365 # support a flag like 'use_cpp_generated_pool=True'.
1366 # pylint: disable=protected-access
1367 _DEFAULT = descriptor._message.default_pool
1368else:
1369 _DEFAULT = DescriptorPool()
1370
1371
1372def Default():
1373 return _DEFAULT