Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/descriptor_pool.py: 15%
469 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:37 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:37 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
41Below is a straightforward example on how to use this class::
43 pool = DescriptorPool()
44 file_descriptor_protos = [ ... ]
45 for file_descriptor_proto in file_descriptor_protos:
46 pool.Add(file_descriptor_proto)
47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
58__author__ = 'matthewtoia@google.com (Matt Toia)'
60import collections
61import warnings
63from google.protobuf import descriptor
64from google.protobuf import descriptor_database
65from google.protobuf import text_encoding
66from google.protobuf.internal import python_message
68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
71def _Deprecated(func):
72 """Mark functions as deprecated."""
74 def NewFunc(*args, **kwargs):
75 warnings.warn(
76 'Call to deprecated function %s(). Note: Do add unlinked descriptors '
77 'to descriptor_pool is wrong. Please use Add() or AddSerializedFile() '
78 'instead. This function will be removed soon.' % func.__name__,
79 category=DeprecationWarning)
80 return func(*args, **kwargs)
81 NewFunc.__name__ = func.__name__
82 NewFunc.__doc__ = func.__doc__
83 NewFunc.__dict__.update(func.__dict__)
84 return NewFunc
87def _NormalizeFullyQualifiedName(name):
88 """Remove leading period from fully-qualified type name.
90 Due to b/13860351 in descriptor_database.py, types in the root namespace are
91 generated with a leading period. This function removes that prefix.
93 Args:
94 name (str): The fully-qualified symbol name.
96 Returns:
97 str: The normalized fully-qualified symbol name.
98 """
99 return name.lstrip('.')
102def _OptionsOrNone(descriptor_proto):
103 """Returns the value of the field `options`, or None if it is not set."""
104 if descriptor_proto.HasField('options'):
105 return descriptor_proto.options
106 else:
107 return None
110def _IsMessageSetExtension(field):
111 return (field.is_extension and
112 field.containing_type.has_options and
113 field.containing_type.GetOptions().message_set_wire_format and
114 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
115 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
118class DescriptorPool(object):
119 """A collection of protobufs dynamically constructed by descriptor protos."""
121 if _USE_C_DESCRIPTORS:
123 def __new__(cls, descriptor_db=None):
124 # pylint: disable=protected-access
125 return descriptor._message.DescriptorPool(descriptor_db)
127 def __init__(
128 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
129 ):
130 """Initializes a Pool of proto buffs.
132 The descriptor_db argument to the constructor is provided to allow
133 specialized file descriptor proto lookup code to be triggered on demand. An
134 example would be an implementation which will read and compile a file
135 specified in a call to FindFileByName() and not require the call to Add()
136 at all. Results from this database will be cached internally here as well.
138 Args:
139 descriptor_db: A secondary source of file descriptors.
140 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
141 C++.
142 """
144 self._internal_db = descriptor_database.DescriptorDatabase()
145 self._descriptor_db = descriptor_db
146 self._descriptors = {}
147 self._enum_descriptors = {}
148 self._service_descriptors = {}
149 self._file_descriptors = {}
150 self._toplevel_extensions = {}
151 self._top_enum_values = {}
152 # We store extensions in two two-level mappings: The first key is the
153 # descriptor of the message being extended, the second key is the extension
154 # full name or its tag number.
155 self._extensions_by_name = collections.defaultdict(dict)
156 self._extensions_by_number = collections.defaultdict(dict)
158 def _CheckConflictRegister(self, desc, desc_name, file_name):
159 """Check if the descriptor name conflicts with another of the same name.
161 Args:
162 desc: Descriptor of a message, enum, service, extension or enum value.
163 desc_name (str): the full name of desc.
164 file_name (str): The file name of descriptor.
165 """
166 for register, descriptor_type in [
167 (self._descriptors, descriptor.Descriptor),
168 (self._enum_descriptors, descriptor.EnumDescriptor),
169 (self._service_descriptors, descriptor.ServiceDescriptor),
170 (self._toplevel_extensions, descriptor.FieldDescriptor),
171 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
172 if desc_name in register:
173 old_desc = register[desc_name]
174 if isinstance(old_desc, descriptor.EnumValueDescriptor):
175 old_file = old_desc.type.file.name
176 else:
177 old_file = old_desc.file.name
179 if not isinstance(desc, descriptor_type) or (
180 old_file != file_name):
181 error_msg = ('Conflict register for file "' + file_name +
182 '": ' + desc_name +
183 ' is already defined in file "' +
184 old_file + '". Please fix the conflict by adding '
185 'package name on the proto file, or use different '
186 'name for the duplication.')
187 if isinstance(desc, descriptor.EnumValueDescriptor):
188 error_msg += ('\nNote: enum values appear as '
189 'siblings of the enum type instead of '
190 'children of it.')
192 raise TypeError(error_msg)
194 return
196 def Add(self, file_desc_proto):
197 """Adds the FileDescriptorProto and its types to this pool.
199 Args:
200 file_desc_proto (FileDescriptorProto): The file descriptor to add.
201 """
203 self._internal_db.Add(file_desc_proto)
205 def AddSerializedFile(self, serialized_file_desc_proto):
206 """Adds the FileDescriptorProto and its types to this pool.
208 Args:
209 serialized_file_desc_proto (bytes): A bytes string, serialization of the
210 :class:`FileDescriptorProto` to add.
212 Returns:
213 FileDescriptor: Descriptor for the added file.
214 """
216 # pylint: disable=g-import-not-at-top
217 from google.protobuf import descriptor_pb2
218 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
219 serialized_file_desc_proto)
220 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
221 file_desc.serialized_pb = serialized_file_desc_proto
222 return file_desc
224 # Add Descriptor to descriptor pool is deprecated. Please use Add()
225 # or AddSerializedFile() to add a FileDescriptorProto instead.
226 @_Deprecated
227 def AddDescriptor(self, desc):
228 self._AddDescriptor(desc)
230 # Never call this method. It is for internal usage only.
231 def _AddDescriptor(self, desc):
232 """Adds a Descriptor to the pool, non-recursively.
234 If the Descriptor contains nested messages or enums, the caller must
235 explicitly register them. This method also registers the FileDescriptor
236 associated with the message.
238 Args:
239 desc: A Descriptor.
240 """
241 if not isinstance(desc, descriptor.Descriptor):
242 raise TypeError('Expected instance of descriptor.Descriptor.')
244 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
246 self._descriptors[desc.full_name] = desc
247 self._AddFileDescriptor(desc.file)
249 # Never call this method. It is for internal usage only.
250 def _AddEnumDescriptor(self, enum_desc):
251 """Adds an EnumDescriptor to the pool.
253 This method also registers the FileDescriptor associated with the enum.
255 Args:
256 enum_desc: An EnumDescriptor.
257 """
259 if not isinstance(enum_desc, descriptor.EnumDescriptor):
260 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
262 file_name = enum_desc.file.name
263 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
264 self._enum_descriptors[enum_desc.full_name] = enum_desc
266 # Top enum values need to be indexed.
267 # Count the number of dots to see whether the enum is toplevel or nested
268 # in a message. We cannot use enum_desc.containing_type at this stage.
269 if enum_desc.file.package:
270 top_level = (enum_desc.full_name.count('.')
271 - enum_desc.file.package.count('.') == 1)
272 else:
273 top_level = enum_desc.full_name.count('.') == 0
274 if top_level:
275 file_name = enum_desc.file.name
276 package = enum_desc.file.package
277 for enum_value in enum_desc.values:
278 full_name = _NormalizeFullyQualifiedName(
279 '.'.join((package, enum_value.name)))
280 self._CheckConflictRegister(enum_value, full_name, file_name)
281 self._top_enum_values[full_name] = enum_value
282 self._AddFileDescriptor(enum_desc.file)
284 # Add ServiceDescriptor to descriptor pool is deprecated. Please use Add()
285 # or AddSerializedFile() to add a FileDescriptorProto instead.
286 @_Deprecated
287 def AddServiceDescriptor(self, service_desc):
288 self._AddServiceDescriptor(service_desc)
290 # Never call this method. It is for internal usage only.
291 def _AddServiceDescriptor(self, service_desc):
292 """Adds a ServiceDescriptor to the pool.
294 Args:
295 service_desc: A ServiceDescriptor.
296 """
298 if not isinstance(service_desc, descriptor.ServiceDescriptor):
299 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
301 self._CheckConflictRegister(service_desc, service_desc.full_name,
302 service_desc.file.name)
303 self._service_descriptors[service_desc.full_name] = service_desc
305 # Add ExtensionDescriptor to descriptor pool is deprecated. Please use Add()
306 # or AddSerializedFile() to add a FileDescriptorProto instead.
307 @_Deprecated
308 def AddExtensionDescriptor(self, extension):
309 self._AddExtensionDescriptor(extension)
311 # Never call this method. It is for internal usage only.
312 def _AddExtensionDescriptor(self, extension):
313 """Adds a FieldDescriptor describing an extension to the pool.
315 Args:
316 extension: A FieldDescriptor.
318 Raises:
319 AssertionError: when another extension with the same number extends the
320 same message.
321 TypeError: when the specified extension is not a
322 descriptor.FieldDescriptor.
323 """
324 if not (isinstance(extension, descriptor.FieldDescriptor) and
325 extension.is_extension):
326 raise TypeError('Expected an extension descriptor.')
328 if extension.extension_scope is None:
329 self._CheckConflictRegister(
330 extension, extension.full_name, extension.file.name)
331 self._toplevel_extensions[extension.full_name] = extension
333 try:
334 existing_desc = self._extensions_by_number[
335 extension.containing_type][extension.number]
336 except KeyError:
337 pass
338 else:
339 if extension is not existing_desc:
340 raise AssertionError(
341 'Extensions "%s" and "%s" both try to extend message type "%s" '
342 'with field number %d.' %
343 (extension.full_name, existing_desc.full_name,
344 extension.containing_type.full_name, extension.number))
346 self._extensions_by_number[extension.containing_type][
347 extension.number] = extension
348 self._extensions_by_name[extension.containing_type][
349 extension.full_name] = extension
351 # Also register MessageSet extensions with the type name.
352 if _IsMessageSetExtension(extension):
353 self._extensions_by_name[extension.containing_type][
354 extension.message_type.full_name] = extension
356 if hasattr(extension.containing_type, '_concrete_class'):
357 python_message._AttachFieldHelpers(
358 extension.containing_type._concrete_class, extension)
360 @_Deprecated
361 def AddFileDescriptor(self, file_desc):
362 self._InternalAddFileDescriptor(file_desc)
364 # Never call this method. It is for internal usage only.
365 def _InternalAddFileDescriptor(self, file_desc):
366 """Adds a FileDescriptor to the pool, non-recursively.
368 If the FileDescriptor contains messages or enums, the caller must explicitly
369 register them.
371 Args:
372 file_desc: A FileDescriptor.
373 """
375 self._AddFileDescriptor(file_desc)
377 def _AddFileDescriptor(self, file_desc):
378 """Adds a FileDescriptor to the pool, non-recursively.
380 If the FileDescriptor contains messages or enums, the caller must explicitly
381 register them.
383 Args:
384 file_desc: A FileDescriptor.
385 """
387 if not isinstance(file_desc, descriptor.FileDescriptor):
388 raise TypeError('Expected instance of descriptor.FileDescriptor.')
389 self._file_descriptors[file_desc.name] = file_desc
391 def FindFileByName(self, file_name):
392 """Gets a FileDescriptor by file name.
394 Args:
395 file_name (str): The path to the file to get a descriptor for.
397 Returns:
398 FileDescriptor: The descriptor for the named file.
400 Raises:
401 KeyError: if the file cannot be found in the pool.
402 """
404 try:
405 return self._file_descriptors[file_name]
406 except KeyError:
407 pass
409 try:
410 file_proto = self._internal_db.FindFileByName(file_name)
411 except KeyError as error:
412 if self._descriptor_db:
413 file_proto = self._descriptor_db.FindFileByName(file_name)
414 else:
415 raise error
416 if not file_proto:
417 raise KeyError('Cannot find a file named %s' % file_name)
418 return self._ConvertFileProtoToFileDescriptor(file_proto)
420 def FindFileContainingSymbol(self, symbol):
421 """Gets the FileDescriptor for the file containing the specified symbol.
423 Args:
424 symbol (str): The name of the symbol to search for.
426 Returns:
427 FileDescriptor: Descriptor for the file that contains the specified
428 symbol.
430 Raises:
431 KeyError: if the file cannot be found in the pool.
432 """
434 symbol = _NormalizeFullyQualifiedName(symbol)
435 try:
436 return self._InternalFindFileContainingSymbol(symbol)
437 except KeyError:
438 pass
440 try:
441 # Try fallback database. Build and find again if possible.
442 self._FindFileContainingSymbolInDb(symbol)
443 return self._InternalFindFileContainingSymbol(symbol)
444 except KeyError:
445 raise KeyError('Cannot find a file containing %s' % symbol)
447 def _InternalFindFileContainingSymbol(self, symbol):
448 """Gets the already built FileDescriptor containing the specified symbol.
450 Args:
451 symbol (str): The name of the symbol to search for.
453 Returns:
454 FileDescriptor: Descriptor for the file that contains the specified
455 symbol.
457 Raises:
458 KeyError: if the file cannot be found in the pool.
459 """
460 try:
461 return self._descriptors[symbol].file
462 except KeyError:
463 pass
465 try:
466 return self._enum_descriptors[symbol].file
467 except KeyError:
468 pass
470 try:
471 return self._service_descriptors[symbol].file
472 except KeyError:
473 pass
475 try:
476 return self._top_enum_values[symbol].type.file
477 except KeyError:
478 pass
480 try:
481 return self._toplevel_extensions[symbol].file
482 except KeyError:
483 pass
485 # Try fields, enum values and nested extensions inside a message.
486 top_name, _, sub_name = symbol.rpartition('.')
487 try:
488 message = self.FindMessageTypeByName(top_name)
489 assert (sub_name in message.extensions_by_name or
490 sub_name in message.fields_by_name or
491 sub_name in message.enum_values_by_name)
492 return message.file
493 except (KeyError, AssertionError):
494 raise KeyError('Cannot find a file containing %s' % symbol)
496 def FindMessageTypeByName(self, full_name):
497 """Loads the named descriptor from the pool.
499 Args:
500 full_name (str): The full name of the descriptor to load.
502 Returns:
503 Descriptor: The descriptor for the named type.
505 Raises:
506 KeyError: if the message cannot be found in the pool.
507 """
509 full_name = _NormalizeFullyQualifiedName(full_name)
510 if full_name not in self._descriptors:
511 self._FindFileContainingSymbolInDb(full_name)
512 return self._descriptors[full_name]
514 def FindEnumTypeByName(self, full_name):
515 """Loads the named enum descriptor from the pool.
517 Args:
518 full_name (str): The full name of the enum descriptor to load.
520 Returns:
521 EnumDescriptor: The enum descriptor for the named type.
523 Raises:
524 KeyError: if the enum cannot be found in the pool.
525 """
527 full_name = _NormalizeFullyQualifiedName(full_name)
528 if full_name not in self._enum_descriptors:
529 self._FindFileContainingSymbolInDb(full_name)
530 return self._enum_descriptors[full_name]
532 def FindFieldByName(self, full_name):
533 """Loads the named field descriptor from the pool.
535 Args:
536 full_name (str): The full name of the field descriptor to load.
538 Returns:
539 FieldDescriptor: The field descriptor for the named field.
541 Raises:
542 KeyError: if the field cannot be found in the pool.
543 """
544 full_name = _NormalizeFullyQualifiedName(full_name)
545 message_name, _, field_name = full_name.rpartition('.')
546 message_descriptor = self.FindMessageTypeByName(message_name)
547 return message_descriptor.fields_by_name[field_name]
549 def FindOneofByName(self, full_name):
550 """Loads the named oneof descriptor from the pool.
552 Args:
553 full_name (str): The full name of the oneof descriptor to load.
555 Returns:
556 OneofDescriptor: The oneof descriptor for the named oneof.
558 Raises:
559 KeyError: if the oneof cannot be found in the pool.
560 """
561 full_name = _NormalizeFullyQualifiedName(full_name)
562 message_name, _, oneof_name = full_name.rpartition('.')
563 message_descriptor = self.FindMessageTypeByName(message_name)
564 return message_descriptor.oneofs_by_name[oneof_name]
566 def FindExtensionByName(self, full_name):
567 """Loads the named extension descriptor from the pool.
569 Args:
570 full_name (str): The full name of the extension descriptor to load.
572 Returns:
573 FieldDescriptor: The field descriptor for the named extension.
575 Raises:
576 KeyError: if the extension cannot be found in the pool.
577 """
578 full_name = _NormalizeFullyQualifiedName(full_name)
579 try:
580 # The proto compiler does not give any link between the FileDescriptor
581 # and top-level extensions unless the FileDescriptorProto is added to
582 # the DescriptorDatabase, but this can impact memory usage.
583 # So we registered these extensions by name explicitly.
584 return self._toplevel_extensions[full_name]
585 except KeyError:
586 pass
587 message_name, _, extension_name = full_name.rpartition('.')
588 try:
589 # Most extensions are nested inside a message.
590 scope = self.FindMessageTypeByName(message_name)
591 except KeyError:
592 # Some extensions are defined at file scope.
593 scope = self._FindFileContainingSymbolInDb(full_name)
594 return scope.extensions_by_name[extension_name]
596 def FindExtensionByNumber(self, message_descriptor, number):
597 """Gets the extension of the specified message with the specified number.
599 Extensions have to be registered to this pool by calling :func:`Add` or
600 :func:`AddExtensionDescriptor`.
602 Args:
603 message_descriptor (Descriptor): descriptor of the extended message.
604 number (int): Number of the extension field.
606 Returns:
607 FieldDescriptor: The descriptor for the extension.
609 Raises:
610 KeyError: when no extension with the given number is known for the
611 specified message.
612 """
613 try:
614 return self._extensions_by_number[message_descriptor][number]
615 except KeyError:
616 self._TryLoadExtensionFromDB(message_descriptor, number)
617 return self._extensions_by_number[message_descriptor][number]
619 def FindAllExtensions(self, message_descriptor):
620 """Gets all the known extensions of a given message.
622 Extensions have to be registered to this pool by build related
623 :func:`Add` or :func:`AddExtensionDescriptor`.
625 Args:
626 message_descriptor (Descriptor): Descriptor of the extended message.
628 Returns:
629 list[FieldDescriptor]: Field descriptors describing the extensions.
630 """
631 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
632 if self._descriptor_db and hasattr(
633 self._descriptor_db, 'FindAllExtensionNumbers'):
634 full_name = message_descriptor.full_name
635 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
636 for number in all_numbers:
637 if number in self._extensions_by_number[message_descriptor]:
638 continue
639 self._TryLoadExtensionFromDB(message_descriptor, number)
641 return list(self._extensions_by_number[message_descriptor].values())
643 def _TryLoadExtensionFromDB(self, message_descriptor, number):
644 """Try to Load extensions from descriptor db.
646 Args:
647 message_descriptor: descriptor of the extended message.
648 number: the extension number that needs to be loaded.
649 """
650 if not self._descriptor_db:
651 return
652 # Only supported when FindFileContainingExtension is provided.
653 if not hasattr(
654 self._descriptor_db, 'FindFileContainingExtension'):
655 return
657 full_name = message_descriptor.full_name
658 file_proto = self._descriptor_db.FindFileContainingExtension(
659 full_name, number)
661 if file_proto is None:
662 return
664 try:
665 self._ConvertFileProtoToFileDescriptor(file_proto)
666 except:
667 warn_msg = ('Unable to load proto file %s for extension number %d.' %
668 (file_proto.name, number))
669 warnings.warn(warn_msg, RuntimeWarning)
671 def FindServiceByName(self, full_name):
672 """Loads the named service descriptor from the pool.
674 Args:
675 full_name (str): The full name of the service descriptor to load.
677 Returns:
678 ServiceDescriptor: The service descriptor for the named service.
680 Raises:
681 KeyError: if the service cannot be found in the pool.
682 """
683 full_name = _NormalizeFullyQualifiedName(full_name)
684 if full_name not in self._service_descriptors:
685 self._FindFileContainingSymbolInDb(full_name)
686 return self._service_descriptors[full_name]
688 def FindMethodByName(self, full_name):
689 """Loads the named service method descriptor from the pool.
691 Args:
692 full_name (str): The full name of the method descriptor to load.
694 Returns:
695 MethodDescriptor: The method descriptor for the service method.
697 Raises:
698 KeyError: if the method cannot be found in the pool.
699 """
700 full_name = _NormalizeFullyQualifiedName(full_name)
701 service_name, _, method_name = full_name.rpartition('.')
702 service_descriptor = self.FindServiceByName(service_name)
703 return service_descriptor.methods_by_name[method_name]
705 def _FindFileContainingSymbolInDb(self, symbol):
706 """Finds the file in descriptor DB containing the specified symbol.
708 Args:
709 symbol (str): The name of the symbol to search for.
711 Returns:
712 FileDescriptor: The file that contains the specified symbol.
714 Raises:
715 KeyError: if the file cannot be found in the descriptor database.
716 """
717 try:
718 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
719 except KeyError as error:
720 if self._descriptor_db:
721 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
722 else:
723 raise error
724 if not file_proto:
725 raise KeyError('Cannot find a file containing %s' % symbol)
726 return self._ConvertFileProtoToFileDescriptor(file_proto)
728 def _ConvertFileProtoToFileDescriptor(self, file_proto):
729 """Creates a FileDescriptor from a proto or returns a cached copy.
731 This method also has the side effect of loading all the symbols found in
732 the file into the appropriate dictionaries in the pool.
734 Args:
735 file_proto: The proto to convert.
737 Returns:
738 A FileDescriptor matching the passed in proto.
739 """
740 if file_proto.name not in self._file_descriptors:
741 built_deps = list(self._GetDeps(file_proto.dependency))
742 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
743 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
745 file_descriptor = descriptor.FileDescriptor(
746 pool=self,
747 name=file_proto.name,
748 package=file_proto.package,
749 syntax=file_proto.syntax,
750 options=_OptionsOrNone(file_proto),
751 serialized_pb=file_proto.SerializeToString(),
752 dependencies=direct_deps,
753 public_dependencies=public_deps,
754 # pylint: disable=protected-access
755 create_key=descriptor._internal_create_key)
756 scope = {}
758 # This loop extracts all the message and enum types from all the
759 # dependencies of the file_proto. This is necessary to create the
760 # scope of available message types when defining the passed in
761 # file proto.
762 for dependency in built_deps:
763 scope.update(self._ExtractSymbols(
764 dependency.message_types_by_name.values()))
765 scope.update((_PrefixWithDot(enum.full_name), enum)
766 for enum in dependency.enum_types_by_name.values())
768 for message_type in file_proto.message_type:
769 message_desc = self._ConvertMessageDescriptor(
770 message_type, file_proto.package, file_descriptor, scope,
771 file_proto.syntax)
772 file_descriptor.message_types_by_name[message_desc.name] = (
773 message_desc)
775 for enum_type in file_proto.enum_type:
776 file_descriptor.enum_types_by_name[enum_type.name] = (
777 self._ConvertEnumDescriptor(enum_type, file_proto.package,
778 file_descriptor, None, scope, True))
780 for index, extension_proto in enumerate(file_proto.extension):
781 extension_desc = self._MakeFieldDescriptor(
782 extension_proto, file_proto.package, index, file_descriptor,
783 is_extension=True)
784 extension_desc.containing_type = self._GetTypeFromScope(
785 file_descriptor.package, extension_proto.extendee, scope)
786 self._SetFieldType(extension_proto, extension_desc,
787 file_descriptor.package, scope)
788 file_descriptor.extensions_by_name[extension_desc.name] = (
789 extension_desc)
791 for desc_proto in file_proto.message_type:
792 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
794 if file_proto.package:
795 desc_proto_prefix = _PrefixWithDot(file_proto.package)
796 else:
797 desc_proto_prefix = ''
799 for desc_proto in file_proto.message_type:
800 desc = self._GetTypeFromScope(
801 desc_proto_prefix, desc_proto.name, scope)
802 file_descriptor.message_types_by_name[desc_proto.name] = desc
804 for index, service_proto in enumerate(file_proto.service):
805 file_descriptor.services_by_name[service_proto.name] = (
806 self._MakeServiceDescriptor(service_proto, index, scope,
807 file_proto.package, file_descriptor))
809 self._file_descriptors[file_proto.name] = file_descriptor
811 # Add extensions to the pool
812 def AddExtensionForNested(message_type):
813 for nested in message_type.nested_types:
814 AddExtensionForNested(nested)
815 for extension in message_type.extensions:
816 self._AddExtensionDescriptor(extension)
818 file_desc = self._file_descriptors[file_proto.name]
819 for extension in file_desc.extensions_by_name.values():
820 self._AddExtensionDescriptor(extension)
821 for message_type in file_desc.message_types_by_name.values():
822 AddExtensionForNested(message_type)
824 return file_desc
826 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
827 scope=None, syntax=None):
828 """Adds the proto to the pool in the specified package.
830 Args:
831 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
832 package: The package the proto should be located in.
833 file_desc: The file containing this message.
834 scope: Dict mapping short and full symbols to message and enum types.
835 syntax: string indicating syntax of the file ("proto2" or "proto3")
837 Returns:
838 The added descriptor.
839 """
841 if package:
842 desc_name = '.'.join((package, desc_proto.name))
843 else:
844 desc_name = desc_proto.name
846 if file_desc is None:
847 file_name = None
848 else:
849 file_name = file_desc.name
851 if scope is None:
852 scope = {}
854 nested = [
855 self._ConvertMessageDescriptor(
856 nested, desc_name, file_desc, scope, syntax)
857 for nested in desc_proto.nested_type]
858 enums = [
859 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
860 scope, False)
861 for enum in desc_proto.enum_type]
862 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
863 for index, field in enumerate(desc_proto.field)]
864 extensions = [
865 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
866 is_extension=True)
867 for index, extension in enumerate(desc_proto.extension)]
868 oneofs = [
869 # pylint: disable=g-complex-comprehension
870 descriptor.OneofDescriptor(
871 desc.name,
872 '.'.join((desc_name, desc.name)),
873 index,
874 None,
875 [],
876 _OptionsOrNone(desc),
877 # pylint: disable=protected-access
878 create_key=descriptor._internal_create_key)
879 for index, desc in enumerate(desc_proto.oneof_decl)
880 ]
881 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
882 if extension_ranges:
883 is_extendable = True
884 else:
885 is_extendable = False
886 desc = descriptor.Descriptor(
887 name=desc_proto.name,
888 full_name=desc_name,
889 filename=file_name,
890 containing_type=None,
891 fields=fields,
892 oneofs=oneofs,
893 nested_types=nested,
894 enum_types=enums,
895 extensions=extensions,
896 options=_OptionsOrNone(desc_proto),
897 is_extendable=is_extendable,
898 extension_ranges=extension_ranges,
899 file=file_desc,
900 serialized_start=None,
901 serialized_end=None,
902 syntax=syntax,
903 # pylint: disable=protected-access
904 create_key=descriptor._internal_create_key)
905 for nested in desc.nested_types:
906 nested.containing_type = desc
907 for enum in desc.enum_types:
908 enum.containing_type = desc
909 for field_index, field_desc in enumerate(desc_proto.field):
910 if field_desc.HasField('oneof_index'):
911 oneof_index = field_desc.oneof_index
912 oneofs[oneof_index].fields.append(fields[field_index])
913 fields[field_index].containing_oneof = oneofs[oneof_index]
915 scope[_PrefixWithDot(desc_name)] = desc
916 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
917 self._descriptors[desc_name] = desc
918 return desc
920 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
921 containing_type=None, scope=None, top_level=False):
922 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
924 Args:
925 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
926 package: Optional package name for the new message EnumDescriptor.
927 file_desc: The file containing the enum descriptor.
928 containing_type: The type containing this enum.
929 scope: Scope containing available types.
930 top_level: If True, the enum is a top level symbol. If False, the enum
931 is defined inside a message.
933 Returns:
934 The added descriptor
935 """
937 if package:
938 enum_name = '.'.join((package, enum_proto.name))
939 else:
940 enum_name = enum_proto.name
942 if file_desc is None:
943 file_name = None
944 else:
945 file_name = file_desc.name
947 values = [self._MakeEnumValueDescriptor(value, index)
948 for index, value in enumerate(enum_proto.value)]
949 desc = descriptor.EnumDescriptor(name=enum_proto.name,
950 full_name=enum_name,
951 filename=file_name,
952 file=file_desc,
953 values=values,
954 containing_type=containing_type,
955 options=_OptionsOrNone(enum_proto),
956 # pylint: disable=protected-access
957 create_key=descriptor._internal_create_key)
958 scope['.%s' % enum_name] = desc
959 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
960 self._enum_descriptors[enum_name] = desc
962 # Add top level enum values.
963 if top_level:
964 for value in values:
965 full_name = _NormalizeFullyQualifiedName(
966 '.'.join((package, value.name)))
967 self._CheckConflictRegister(value, full_name, file_name)
968 self._top_enum_values[full_name] = value
970 return desc
972 def _MakeFieldDescriptor(self, field_proto, message_name, index,
973 file_desc, is_extension=False):
974 """Creates a field descriptor from a FieldDescriptorProto.
976 For message and enum type fields, this method will do a look up
977 in the pool for the appropriate descriptor for that type. If it
978 is unavailable, it will fall back to the _source function to
979 create it. If this type is still unavailable, construction will
980 fail.
982 Args:
983 field_proto: The proto describing the field.
984 message_name: The name of the containing message.
985 index: Index of the field
986 file_desc: The file containing the field descriptor.
987 is_extension: Indication that this field is for an extension.
989 Returns:
990 An initialized FieldDescriptor object
991 """
993 if message_name:
994 full_name = '.'.join((message_name, field_proto.name))
995 else:
996 full_name = field_proto.name
998 if field_proto.json_name:
999 json_name = field_proto.json_name
1000 else:
1001 json_name = None
1003 return descriptor.FieldDescriptor(
1004 name=field_proto.name,
1005 full_name=full_name,
1006 index=index,
1007 number=field_proto.number,
1008 type=field_proto.type,
1009 cpp_type=None,
1010 message_type=None,
1011 enum_type=None,
1012 containing_type=None,
1013 label=field_proto.label,
1014 has_default_value=False,
1015 default_value=None,
1016 is_extension=is_extension,
1017 extension_scope=None,
1018 options=_OptionsOrNone(field_proto),
1019 json_name=json_name,
1020 file=file_desc,
1021 # pylint: disable=protected-access
1022 create_key=descriptor._internal_create_key)
1024 def _SetAllFieldTypes(self, package, desc_proto, scope):
1025 """Sets all the descriptor's fields's types.
1027 This method also sets the containing types on any extensions.
1029 Args:
1030 package: The current package of desc_proto.
1031 desc_proto: The message descriptor to update.
1032 scope: Enclosing scope of available types.
1033 """
1035 package = _PrefixWithDot(package)
1037 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1039 if package == '.':
1040 nested_package = _PrefixWithDot(desc_proto.name)
1041 else:
1042 nested_package = '.'.join([package, desc_proto.name])
1044 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1045 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1047 for extension_proto, extension_desc in (
1048 zip(desc_proto.extension, main_desc.extensions)):
1049 extension_desc.containing_type = self._GetTypeFromScope(
1050 nested_package, extension_proto.extendee, scope)
1051 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1053 for nested_type in desc_proto.nested_type:
1054 self._SetAllFieldTypes(nested_package, nested_type, scope)
1056 def _SetFieldType(self, field_proto, field_desc, package, scope):
1057 """Sets the field's type, cpp_type, message_type and enum_type.
1059 Args:
1060 field_proto: Data about the field in proto format.
1061 field_desc: The descriptor to modify.
1062 package: The package the field's container is in.
1063 scope: Enclosing scope of available types.
1064 """
1065 if field_proto.type_name:
1066 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1067 else:
1068 desc = None
1070 if not field_proto.HasField('type'):
1071 if isinstance(desc, descriptor.Descriptor):
1072 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1073 else:
1074 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1076 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1077 field_proto.type)
1079 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1080 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1081 field_desc.message_type = desc
1083 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1084 field_desc.enum_type = desc
1086 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1087 field_desc.has_default_value = False
1088 field_desc.default_value = []
1089 elif field_proto.HasField('default_value'):
1090 field_desc.has_default_value = True
1091 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1092 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1093 field_desc.default_value = float(field_proto.default_value)
1094 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1095 field_desc.default_value = field_proto.default_value
1096 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1097 field_desc.default_value = field_proto.default_value.lower() == 'true'
1098 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1099 field_desc.default_value = field_desc.enum_type.values_by_name[
1100 field_proto.default_value].number
1101 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1102 field_desc.default_value = text_encoding.CUnescape(
1103 field_proto.default_value)
1104 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1105 field_desc.default_value = None
1106 else:
1107 # All other types are of the "int" type.
1108 field_desc.default_value = int(field_proto.default_value)
1109 else:
1110 field_desc.has_default_value = False
1111 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1112 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1113 field_desc.default_value = 0.0
1114 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1115 field_desc.default_value = u''
1116 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1117 field_desc.default_value = False
1118 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1119 field_desc.default_value = field_desc.enum_type.values[0].number
1120 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1121 field_desc.default_value = b''
1122 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1123 field_desc.default_value = None
1124 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1125 field_desc.default_value = None
1126 else:
1127 # All other types are of the "int" type.
1128 field_desc.default_value = 0
1130 field_desc.type = field_proto.type
1132 def _MakeEnumValueDescriptor(self, value_proto, index):
1133 """Creates a enum value descriptor object from a enum value proto.
1135 Args:
1136 value_proto: The proto describing the enum value.
1137 index: The index of the enum value.
1139 Returns:
1140 An initialized EnumValueDescriptor object.
1141 """
1143 return descriptor.EnumValueDescriptor(
1144 name=value_proto.name,
1145 index=index,
1146 number=value_proto.number,
1147 options=_OptionsOrNone(value_proto),
1148 type=None,
1149 # pylint: disable=protected-access
1150 create_key=descriptor._internal_create_key)
1152 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1153 package, file_desc):
1154 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1156 Args:
1157 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1158 service_index: The index of the service in the File.
1159 scope: Dict mapping short and full symbols to message and enum types.
1160 package: Optional package name for the new message EnumDescriptor.
1161 file_desc: The file containing the service descriptor.
1163 Returns:
1164 The added descriptor.
1165 """
1167 if package:
1168 service_name = '.'.join((package, service_proto.name))
1169 else:
1170 service_name = service_proto.name
1172 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1173 scope, index)
1174 for index, method_proto in enumerate(service_proto.method)]
1175 desc = descriptor.ServiceDescriptor(
1176 name=service_proto.name,
1177 full_name=service_name,
1178 index=service_index,
1179 methods=methods,
1180 options=_OptionsOrNone(service_proto),
1181 file=file_desc,
1182 # pylint: disable=protected-access
1183 create_key=descriptor._internal_create_key)
1184 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1185 self._service_descriptors[service_name] = desc
1186 return desc
1188 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1189 index):
1190 """Creates a method descriptor from a MethodDescriptorProto.
1192 Args:
1193 method_proto: The proto describing the method.
1194 service_name: The name of the containing service.
1195 package: Optional package name to look up for types.
1196 scope: Scope containing available types.
1197 index: Index of the method in the service.
1199 Returns:
1200 An initialized MethodDescriptor object.
1201 """
1202 full_name = '.'.join((service_name, method_proto.name))
1203 input_type = self._GetTypeFromScope(
1204 package, method_proto.input_type, scope)
1205 output_type = self._GetTypeFromScope(
1206 package, method_proto.output_type, scope)
1207 return descriptor.MethodDescriptor(
1208 name=method_proto.name,
1209 full_name=full_name,
1210 index=index,
1211 containing_service=None,
1212 input_type=input_type,
1213 output_type=output_type,
1214 client_streaming=method_proto.client_streaming,
1215 server_streaming=method_proto.server_streaming,
1216 options=_OptionsOrNone(method_proto),
1217 # pylint: disable=protected-access
1218 create_key=descriptor._internal_create_key)
1220 def _ExtractSymbols(self, descriptors):
1221 """Pulls out all the symbols from descriptor protos.
1223 Args:
1224 descriptors: The messages to extract descriptors from.
1225 Yields:
1226 A two element tuple of the type name and descriptor object.
1227 """
1229 for desc in descriptors:
1230 yield (_PrefixWithDot(desc.full_name), desc)
1231 for symbol in self._ExtractSymbols(desc.nested_types):
1232 yield symbol
1233 for enum in desc.enum_types:
1234 yield (_PrefixWithDot(enum.full_name), enum)
1236 def _GetDeps(self, dependencies, visited=None):
1237 """Recursively finds dependencies for file protos.
1239 Args:
1240 dependencies: The names of the files being depended on.
1241 visited: The names of files already found.
1243 Yields:
1244 Each direct and indirect dependency.
1245 """
1247 visited = visited or set()
1248 for dependency in dependencies:
1249 if dependency not in visited:
1250 visited.add(dependency)
1251 dep_desc = self.FindFileByName(dependency)
1252 yield dep_desc
1253 public_files = [d.name for d in dep_desc.public_dependencies]
1254 yield from self._GetDeps(public_files, visited)
1256 def _GetTypeFromScope(self, package, type_name, scope):
1257 """Finds a given type name in the current scope.
1259 Args:
1260 package: The package the proto should be located in.
1261 type_name: The name of the type to be found in the scope.
1262 scope: Dict mapping short and full symbols to message and enum types.
1264 Returns:
1265 The descriptor for the requested type.
1266 """
1267 if type_name not in scope:
1268 components = _PrefixWithDot(package).split('.')
1269 while components:
1270 possible_match = '.'.join(components + [type_name])
1271 if possible_match in scope:
1272 type_name = possible_match
1273 break
1274 else:
1275 components.pop(-1)
1276 return scope[type_name]
1279def _PrefixWithDot(name):
1280 return name if name.startswith('.') else '.%s' % name
1283if _USE_C_DESCRIPTORS:
1284 # TODO(amauryfa): This pool could be constructed from Python code, when we
1285 # support a flag like 'use_cpp_generated_pool=True'.
1286 # pylint: disable=protected-access
1287 _DEFAULT = descriptor._message.default_pool
1288else:
1289 _DEFAULT = DescriptorPool()
1292def Default():
1293 return _DEFAULT