Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/descriptor_pool.py: 15%
465 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:30 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:30 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
41Below is a straightforward example on how to use this class::
43 pool = DescriptorPool()
44 file_descriptor_protos = [ ... ]
45 for file_descriptor_proto in file_descriptor_protos:
46 pool.Add(file_descriptor_proto)
47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
58__author__ = 'matthewtoia@google.com (Matt Toia)'
60import collections
61import warnings
63from google.protobuf import descriptor
64from google.protobuf import descriptor_database
65from google.protobuf import text_encoding
68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
71def _Deprecated(func):
72 """Mark functions as deprecated."""
74 def NewFunc(*args, **kwargs):
75 warnings.warn(
76 'Call to deprecated function %s(). Note: Do add unlinked descriptors '
77 'to descriptor_pool is wrong. Use Add() or AddSerializedFile() '
78 'instead.' % func.__name__,
79 category=DeprecationWarning)
80 return func(*args, **kwargs)
81 NewFunc.__name__ = func.__name__
82 NewFunc.__doc__ = func.__doc__
83 NewFunc.__dict__.update(func.__dict__)
84 return NewFunc
87def _NormalizeFullyQualifiedName(name):
88 """Remove leading period from fully-qualified type name.
90 Due to b/13860351 in descriptor_database.py, types in the root namespace are
91 generated with a leading period. This function removes that prefix.
93 Args:
94 name (str): The fully-qualified symbol name.
96 Returns:
97 str: The normalized fully-qualified symbol name.
98 """
99 return name.lstrip('.')
102def _OptionsOrNone(descriptor_proto):
103 """Returns the value of the field `options`, or None if it is not set."""
104 if descriptor_proto.HasField('options'):
105 return descriptor_proto.options
106 else:
107 return None
110def _IsMessageSetExtension(field):
111 return (field.is_extension and
112 field.containing_type.has_options and
113 field.containing_type.GetOptions().message_set_wire_format and
114 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
115 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
118class DescriptorPool(object):
119 """A collection of protobufs dynamically constructed by descriptor protos."""
121 if _USE_C_DESCRIPTORS:
123 def __new__(cls, descriptor_db=None):
124 # pylint: disable=protected-access
125 return descriptor._message.DescriptorPool(descriptor_db)
127 def __init__(
128 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
129 ):
130 """Initializes a Pool of proto buffs.
132 The descriptor_db argument to the constructor is provided to allow
133 specialized file descriptor proto lookup code to be triggered on demand. An
134 example would be an implementation which will read and compile a file
135 specified in a call to FindFileByName() and not require the call to Add()
136 at all. Results from this database will be cached internally here as well.
138 Args:
139 descriptor_db: A secondary source of file descriptors.
140 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
141 C++.
142 """
144 self._internal_db = descriptor_database.DescriptorDatabase()
145 self._descriptor_db = descriptor_db
146 self._descriptors = {}
147 self._enum_descriptors = {}
148 self._service_descriptors = {}
149 self._file_descriptors = {}
150 self._toplevel_extensions = {}
151 self._top_enum_values = {}
152 # We store extensions in two two-level mappings: The first key is the
153 # descriptor of the message being extended, the second key is the extension
154 # full name or its tag number.
155 self._extensions_by_name = collections.defaultdict(dict)
156 self._extensions_by_number = collections.defaultdict(dict)
158 def _CheckConflictRegister(self, desc, desc_name, file_name):
159 """Check if the descriptor name conflicts with another of the same name.
161 Args:
162 desc: Descriptor of a message, enum, service, extension or enum value.
163 desc_name (str): the full name of desc.
164 file_name (str): The file name of descriptor.
165 """
166 for register, descriptor_type in [
167 (self._descriptors, descriptor.Descriptor),
168 (self._enum_descriptors, descriptor.EnumDescriptor),
169 (self._service_descriptors, descriptor.ServiceDescriptor),
170 (self._toplevel_extensions, descriptor.FieldDescriptor),
171 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
172 if desc_name in register:
173 old_desc = register[desc_name]
174 if isinstance(old_desc, descriptor.EnumValueDescriptor):
175 old_file = old_desc.type.file.name
176 else:
177 old_file = old_desc.file.name
179 if not isinstance(desc, descriptor_type) or (
180 old_file != file_name):
181 error_msg = ('Conflict register for file "' + file_name +
182 '": ' + desc_name +
183 ' is already defined in file "' +
184 old_file + '". Please fix the conflict by adding '
185 'package name on the proto file, or use different '
186 'name for the duplication.')
187 if isinstance(desc, descriptor.EnumValueDescriptor):
188 error_msg += ('\nNote: enum values appear as '
189 'siblings of the enum type instead of '
190 'children of it.')
192 raise TypeError(error_msg)
194 return
196 def Add(self, file_desc_proto):
197 """Adds the FileDescriptorProto and its types to this pool.
199 Args:
200 file_desc_proto (FileDescriptorProto): The file descriptor to add.
201 """
203 self._internal_db.Add(file_desc_proto)
205 def AddSerializedFile(self, serialized_file_desc_proto):
206 """Adds the FileDescriptorProto and its types to this pool.
208 Args:
209 serialized_file_desc_proto (bytes): A bytes string, serialization of the
210 :class:`FileDescriptorProto` to add.
212 Returns:
213 FileDescriptor: Descriptor for the added file.
214 """
216 # pylint: disable=g-import-not-at-top
217 from google.protobuf import descriptor_pb2
218 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
219 serialized_file_desc_proto)
220 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
221 file_desc.serialized_pb = serialized_file_desc_proto
222 return file_desc
224 # Add Descriptor to descriptor pool is deprecated. Please use Add()
225 # or AddSerializedFile() to add a FileDescriptorProto instead.
226 @_Deprecated
227 def AddDescriptor(self, desc):
228 self._AddDescriptor(desc)
230 # Never call this method. It is for internal usage only.
231 def _AddDescriptor(self, desc):
232 """Adds a Descriptor to the pool, non-recursively.
234 If the Descriptor contains nested messages or enums, the caller must
235 explicitly register them. This method also registers the FileDescriptor
236 associated with the message.
238 Args:
239 desc: A Descriptor.
240 """
241 if not isinstance(desc, descriptor.Descriptor):
242 raise TypeError('Expected instance of descriptor.Descriptor.')
244 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
246 self._descriptors[desc.full_name] = desc
247 self._AddFileDescriptor(desc.file)
249 # Add EnumDescriptor to descriptor pool is deprecated. Please use Add()
250 # or AddSerializedFile() to add a FileDescriptorProto instead.
251 @_Deprecated
252 def AddEnumDescriptor(self, enum_desc):
253 self._AddEnumDescriptor(enum_desc)
255 # Never call this method. It is for internal usage only.
256 def _AddEnumDescriptor(self, enum_desc):
257 """Adds an EnumDescriptor to the pool.
259 This method also registers the FileDescriptor associated with the enum.
261 Args:
262 enum_desc: An EnumDescriptor.
263 """
265 if not isinstance(enum_desc, descriptor.EnumDescriptor):
266 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
268 file_name = enum_desc.file.name
269 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
270 self._enum_descriptors[enum_desc.full_name] = enum_desc
272 # Top enum values need to be indexed.
273 # Count the number of dots to see whether the enum is toplevel or nested
274 # in a message. We cannot use enum_desc.containing_type at this stage.
275 if enum_desc.file.package:
276 top_level = (enum_desc.full_name.count('.')
277 - enum_desc.file.package.count('.') == 1)
278 else:
279 top_level = enum_desc.full_name.count('.') == 0
280 if top_level:
281 file_name = enum_desc.file.name
282 package = enum_desc.file.package
283 for enum_value in enum_desc.values:
284 full_name = _NormalizeFullyQualifiedName(
285 '.'.join((package, enum_value.name)))
286 self._CheckConflictRegister(enum_value, full_name, file_name)
287 self._top_enum_values[full_name] = enum_value
288 self._AddFileDescriptor(enum_desc.file)
290 # Add ServiceDescriptor to descriptor pool is deprecated. Please use Add()
291 # or AddSerializedFile() to add a FileDescriptorProto instead.
292 @_Deprecated
293 def AddServiceDescriptor(self, service_desc):
294 self._AddServiceDescriptor(service_desc)
296 # Never call this method. It is for internal usage only.
297 def _AddServiceDescriptor(self, service_desc):
298 """Adds a ServiceDescriptor to the pool.
300 Args:
301 service_desc: A ServiceDescriptor.
302 """
304 if not isinstance(service_desc, descriptor.ServiceDescriptor):
305 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
307 self._CheckConflictRegister(service_desc, service_desc.full_name,
308 service_desc.file.name)
309 self._service_descriptors[service_desc.full_name] = service_desc
311 # Add ExtensionDescriptor to descriptor pool is deprecated. Please use Add()
312 # or AddSerializedFile() to add a FileDescriptorProto instead.
313 @_Deprecated
314 def AddExtensionDescriptor(self, extension):
315 self._AddExtensionDescriptor(extension)
317 # Never call this method. It is for internal usage only.
318 def _AddExtensionDescriptor(self, extension):
319 """Adds a FieldDescriptor describing an extension to the pool.
321 Args:
322 extension: A FieldDescriptor.
324 Raises:
325 AssertionError: when another extension with the same number extends the
326 same message.
327 TypeError: when the specified extension is not a
328 descriptor.FieldDescriptor.
329 """
330 if not (isinstance(extension, descriptor.FieldDescriptor) and
331 extension.is_extension):
332 raise TypeError('Expected an extension descriptor.')
334 if extension.extension_scope is None:
335 self._CheckConflictRegister(
336 extension, extension.full_name, extension.file.name)
337 self._toplevel_extensions[extension.full_name] = extension
339 try:
340 existing_desc = self._extensions_by_number[
341 extension.containing_type][extension.number]
342 except KeyError:
343 pass
344 else:
345 if extension is not existing_desc:
346 raise AssertionError(
347 'Extensions "%s" and "%s" both try to extend message type "%s" '
348 'with field number %d.' %
349 (extension.full_name, existing_desc.full_name,
350 extension.containing_type.full_name, extension.number))
352 self._extensions_by_number[extension.containing_type][
353 extension.number] = extension
354 self._extensions_by_name[extension.containing_type][
355 extension.full_name] = extension
357 # Also register MessageSet extensions with the type name.
358 if _IsMessageSetExtension(extension):
359 self._extensions_by_name[extension.containing_type][
360 extension.message_type.full_name] = extension
362 @_Deprecated
363 def AddFileDescriptor(self, file_desc):
364 self._InternalAddFileDescriptor(file_desc)
366 # Never call this method. It is for internal usage only.
367 def _InternalAddFileDescriptor(self, file_desc):
368 """Adds a FileDescriptor to the pool, non-recursively.
370 If the FileDescriptor contains messages or enums, the caller must explicitly
371 register them.
373 Args:
374 file_desc: A FileDescriptor.
375 """
377 self._AddFileDescriptor(file_desc)
379 def _AddFileDescriptor(self, file_desc):
380 """Adds a FileDescriptor to the pool, non-recursively.
382 If the FileDescriptor contains messages or enums, the caller must explicitly
383 register them.
385 Args:
386 file_desc: A FileDescriptor.
387 """
389 if not isinstance(file_desc, descriptor.FileDescriptor):
390 raise TypeError('Expected instance of descriptor.FileDescriptor.')
391 self._file_descriptors[file_desc.name] = file_desc
393 def FindFileByName(self, file_name):
394 """Gets a FileDescriptor by file name.
396 Args:
397 file_name (str): The path to the file to get a descriptor for.
399 Returns:
400 FileDescriptor: The descriptor for the named file.
402 Raises:
403 KeyError: if the file cannot be found in the pool.
404 """
406 try:
407 return self._file_descriptors[file_name]
408 except KeyError:
409 pass
411 try:
412 file_proto = self._internal_db.FindFileByName(file_name)
413 except KeyError as error:
414 if self._descriptor_db:
415 file_proto = self._descriptor_db.FindFileByName(file_name)
416 else:
417 raise error
418 if not file_proto:
419 raise KeyError('Cannot find a file named %s' % file_name)
420 return self._ConvertFileProtoToFileDescriptor(file_proto)
422 def FindFileContainingSymbol(self, symbol):
423 """Gets the FileDescriptor for the file containing the specified symbol.
425 Args:
426 symbol (str): The name of the symbol to search for.
428 Returns:
429 FileDescriptor: Descriptor for the file that contains the specified
430 symbol.
432 Raises:
433 KeyError: if the file cannot be found in the pool.
434 """
436 symbol = _NormalizeFullyQualifiedName(symbol)
437 try:
438 return self._InternalFindFileContainingSymbol(symbol)
439 except KeyError:
440 pass
442 try:
443 # Try fallback database. Build and find again if possible.
444 self._FindFileContainingSymbolInDb(symbol)
445 return self._InternalFindFileContainingSymbol(symbol)
446 except KeyError:
447 raise KeyError('Cannot find a file containing %s' % symbol)
449 def _InternalFindFileContainingSymbol(self, symbol):
450 """Gets the already built FileDescriptor containing the specified symbol.
452 Args:
453 symbol (str): The name of the symbol to search for.
455 Returns:
456 FileDescriptor: Descriptor for the file that contains the specified
457 symbol.
459 Raises:
460 KeyError: if the file cannot be found in the pool.
461 """
462 try:
463 return self._descriptors[symbol].file
464 except KeyError:
465 pass
467 try:
468 return self._enum_descriptors[symbol].file
469 except KeyError:
470 pass
472 try:
473 return self._service_descriptors[symbol].file
474 except KeyError:
475 pass
477 try:
478 return self._top_enum_values[symbol].type.file
479 except KeyError:
480 pass
482 try:
483 return self._toplevel_extensions[symbol].file
484 except KeyError:
485 pass
487 # Try fields, enum values and nested extensions inside a message.
488 top_name, _, sub_name = symbol.rpartition('.')
489 try:
490 message = self.FindMessageTypeByName(top_name)
491 assert (sub_name in message.extensions_by_name or
492 sub_name in message.fields_by_name or
493 sub_name in message.enum_values_by_name)
494 return message.file
495 except (KeyError, AssertionError):
496 raise KeyError('Cannot find a file containing %s' % symbol)
498 def FindMessageTypeByName(self, full_name):
499 """Loads the named descriptor from the pool.
501 Args:
502 full_name (str): The full name of the descriptor to load.
504 Returns:
505 Descriptor: The descriptor for the named type.
507 Raises:
508 KeyError: if the message cannot be found in the pool.
509 """
511 full_name = _NormalizeFullyQualifiedName(full_name)
512 if full_name not in self._descriptors:
513 self._FindFileContainingSymbolInDb(full_name)
514 return self._descriptors[full_name]
516 def FindEnumTypeByName(self, full_name):
517 """Loads the named enum descriptor from the pool.
519 Args:
520 full_name (str): The full name of the enum descriptor to load.
522 Returns:
523 EnumDescriptor: The enum descriptor for the named type.
525 Raises:
526 KeyError: if the enum cannot be found in the pool.
527 """
529 full_name = _NormalizeFullyQualifiedName(full_name)
530 if full_name not in self._enum_descriptors:
531 self._FindFileContainingSymbolInDb(full_name)
532 return self._enum_descriptors[full_name]
534 def FindFieldByName(self, full_name):
535 """Loads the named field descriptor from the pool.
537 Args:
538 full_name (str): The full name of the field descriptor to load.
540 Returns:
541 FieldDescriptor: The field descriptor for the named field.
543 Raises:
544 KeyError: if the field cannot be found in the pool.
545 """
546 full_name = _NormalizeFullyQualifiedName(full_name)
547 message_name, _, field_name = full_name.rpartition('.')
548 message_descriptor = self.FindMessageTypeByName(message_name)
549 return message_descriptor.fields_by_name[field_name]
551 def FindOneofByName(self, full_name):
552 """Loads the named oneof descriptor from the pool.
554 Args:
555 full_name (str): The full name of the oneof descriptor to load.
557 Returns:
558 OneofDescriptor: The oneof descriptor for the named oneof.
560 Raises:
561 KeyError: if the oneof cannot be found in the pool.
562 """
563 full_name = _NormalizeFullyQualifiedName(full_name)
564 message_name, _, oneof_name = full_name.rpartition('.')
565 message_descriptor = self.FindMessageTypeByName(message_name)
566 return message_descriptor.oneofs_by_name[oneof_name]
568 def FindExtensionByName(self, full_name):
569 """Loads the named extension descriptor from the pool.
571 Args:
572 full_name (str): The full name of the extension descriptor to load.
574 Returns:
575 FieldDescriptor: The field descriptor for the named extension.
577 Raises:
578 KeyError: if the extension cannot be found in the pool.
579 """
580 full_name = _NormalizeFullyQualifiedName(full_name)
581 try:
582 # The proto compiler does not give any link between the FileDescriptor
583 # and top-level extensions unless the FileDescriptorProto is added to
584 # the DescriptorDatabase, but this can impact memory usage.
585 # So we registered these extensions by name explicitly.
586 return self._toplevel_extensions[full_name]
587 except KeyError:
588 pass
589 message_name, _, extension_name = full_name.rpartition('.')
590 try:
591 # Most extensions are nested inside a message.
592 scope = self.FindMessageTypeByName(message_name)
593 except KeyError:
594 # Some extensions are defined at file scope.
595 scope = self._FindFileContainingSymbolInDb(full_name)
596 return scope.extensions_by_name[extension_name]
598 def FindExtensionByNumber(self, message_descriptor, number):
599 """Gets the extension of the specified message with the specified number.
601 Extensions have to be registered to this pool by calling :func:`Add` or
602 :func:`AddExtensionDescriptor`.
604 Args:
605 message_descriptor (Descriptor): descriptor of the extended message.
606 number (int): Number of the extension field.
608 Returns:
609 FieldDescriptor: The descriptor for the extension.
611 Raises:
612 KeyError: when no extension with the given number is known for the
613 specified message.
614 """
615 try:
616 return self._extensions_by_number[message_descriptor][number]
617 except KeyError:
618 self._TryLoadExtensionFromDB(message_descriptor, number)
619 return self._extensions_by_number[message_descriptor][number]
621 def FindAllExtensions(self, message_descriptor):
622 """Gets all the known extensions of a given message.
624 Extensions have to be registered to this pool by build related
625 :func:`Add` or :func:`AddExtensionDescriptor`.
627 Args:
628 message_descriptor (Descriptor): Descriptor of the extended message.
630 Returns:
631 list[FieldDescriptor]: Field descriptors describing the extensions.
632 """
633 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
634 if self._descriptor_db and hasattr(
635 self._descriptor_db, 'FindAllExtensionNumbers'):
636 full_name = message_descriptor.full_name
637 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
638 for number in all_numbers:
639 if number in self._extensions_by_number[message_descriptor]:
640 continue
641 self._TryLoadExtensionFromDB(message_descriptor, number)
643 return list(self._extensions_by_number[message_descriptor].values())
645 def _TryLoadExtensionFromDB(self, message_descriptor, number):
646 """Try to Load extensions from descriptor db.
648 Args:
649 message_descriptor: descriptor of the extended message.
650 number: the extension number that needs to be loaded.
651 """
652 if not self._descriptor_db:
653 return
654 # Only supported when FindFileContainingExtension is provided.
655 if not hasattr(
656 self._descriptor_db, 'FindFileContainingExtension'):
657 return
659 full_name = message_descriptor.full_name
660 file_proto = self._descriptor_db.FindFileContainingExtension(
661 full_name, number)
663 if file_proto is None:
664 return
666 try:
667 self._ConvertFileProtoToFileDescriptor(file_proto)
668 except:
669 warn_msg = ('Unable to load proto file %s for extension number %d.' %
670 (file_proto.name, number))
671 warnings.warn(warn_msg, RuntimeWarning)
673 def FindServiceByName(self, full_name):
674 """Loads the named service descriptor from the pool.
676 Args:
677 full_name (str): The full name of the service descriptor to load.
679 Returns:
680 ServiceDescriptor: The service descriptor for the named service.
682 Raises:
683 KeyError: if the service cannot be found in the pool.
684 """
685 full_name = _NormalizeFullyQualifiedName(full_name)
686 if full_name not in self._service_descriptors:
687 self._FindFileContainingSymbolInDb(full_name)
688 return self._service_descriptors[full_name]
690 def FindMethodByName(self, full_name):
691 """Loads the named service method descriptor from the pool.
693 Args:
694 full_name (str): The full name of the method descriptor to load.
696 Returns:
697 MethodDescriptor: The method descriptor for the service method.
699 Raises:
700 KeyError: if the method cannot be found in the pool.
701 """
702 full_name = _NormalizeFullyQualifiedName(full_name)
703 service_name, _, method_name = full_name.rpartition('.')
704 service_descriptor = self.FindServiceByName(service_name)
705 return service_descriptor.methods_by_name[method_name]
707 def _FindFileContainingSymbolInDb(self, symbol):
708 """Finds the file in descriptor DB containing the specified symbol.
710 Args:
711 symbol (str): The name of the symbol to search for.
713 Returns:
714 FileDescriptor: The file that contains the specified symbol.
716 Raises:
717 KeyError: if the file cannot be found in the descriptor database.
718 """
719 try:
720 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
721 except KeyError as error:
722 if self._descriptor_db:
723 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
724 else:
725 raise error
726 if not file_proto:
727 raise KeyError('Cannot find a file containing %s' % symbol)
728 return self._ConvertFileProtoToFileDescriptor(file_proto)
730 def _ConvertFileProtoToFileDescriptor(self, file_proto):
731 """Creates a FileDescriptor from a proto or returns a cached copy.
733 This method also has the side effect of loading all the symbols found in
734 the file into the appropriate dictionaries in the pool.
736 Args:
737 file_proto: The proto to convert.
739 Returns:
740 A FileDescriptor matching the passed in proto.
741 """
742 if file_proto.name not in self._file_descriptors:
743 built_deps = list(self._GetDeps(file_proto.dependency))
744 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
745 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
747 file_descriptor = descriptor.FileDescriptor(
748 pool=self,
749 name=file_proto.name,
750 package=file_proto.package,
751 syntax=file_proto.syntax,
752 options=_OptionsOrNone(file_proto),
753 serialized_pb=file_proto.SerializeToString(),
754 dependencies=direct_deps,
755 public_dependencies=public_deps,
756 # pylint: disable=protected-access
757 create_key=descriptor._internal_create_key)
758 scope = {}
760 # This loop extracts all the message and enum types from all the
761 # dependencies of the file_proto. This is necessary to create the
762 # scope of available message types when defining the passed in
763 # file proto.
764 for dependency in built_deps:
765 scope.update(self._ExtractSymbols(
766 dependency.message_types_by_name.values()))
767 scope.update((_PrefixWithDot(enum.full_name), enum)
768 for enum in dependency.enum_types_by_name.values())
770 for message_type in file_proto.message_type:
771 message_desc = self._ConvertMessageDescriptor(
772 message_type, file_proto.package, file_descriptor, scope,
773 file_proto.syntax)
774 file_descriptor.message_types_by_name[message_desc.name] = (
775 message_desc)
777 for enum_type in file_proto.enum_type:
778 file_descriptor.enum_types_by_name[enum_type.name] = (
779 self._ConvertEnumDescriptor(enum_type, file_proto.package,
780 file_descriptor, None, scope, True))
782 for index, extension_proto in enumerate(file_proto.extension):
783 extension_desc = self._MakeFieldDescriptor(
784 extension_proto, file_proto.package, index, file_descriptor,
785 is_extension=True)
786 extension_desc.containing_type = self._GetTypeFromScope(
787 file_descriptor.package, extension_proto.extendee, scope)
788 self._SetFieldType(extension_proto, extension_desc,
789 file_descriptor.package, scope)
790 file_descriptor.extensions_by_name[extension_desc.name] = (
791 extension_desc)
793 for desc_proto in file_proto.message_type:
794 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
796 if file_proto.package:
797 desc_proto_prefix = _PrefixWithDot(file_proto.package)
798 else:
799 desc_proto_prefix = ''
801 for desc_proto in file_proto.message_type:
802 desc = self._GetTypeFromScope(
803 desc_proto_prefix, desc_proto.name, scope)
804 file_descriptor.message_types_by_name[desc_proto.name] = desc
806 for index, service_proto in enumerate(file_proto.service):
807 file_descriptor.services_by_name[service_proto.name] = (
808 self._MakeServiceDescriptor(service_proto, index, scope,
809 file_proto.package, file_descriptor))
811 self._file_descriptors[file_proto.name] = file_descriptor
813 # Add extensions to the pool
814 file_desc = self._file_descriptors[file_proto.name]
815 for extension in file_desc.extensions_by_name.values():
816 self._AddExtensionDescriptor(extension)
817 for message_type in file_desc.message_types_by_name.values():
818 for extension in message_type.extensions:
819 self._AddExtensionDescriptor(extension)
821 return file_desc
823 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
824 scope=None, syntax=None):
825 """Adds the proto to the pool in the specified package.
827 Args:
828 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
829 package: The package the proto should be located in.
830 file_desc: The file containing this message.
831 scope: Dict mapping short and full symbols to message and enum types.
832 syntax: string indicating syntax of the file ("proto2" or "proto3")
834 Returns:
835 The added descriptor.
836 """
838 if package:
839 desc_name = '.'.join((package, desc_proto.name))
840 else:
841 desc_name = desc_proto.name
843 if file_desc is None:
844 file_name = None
845 else:
846 file_name = file_desc.name
848 if scope is None:
849 scope = {}
851 nested = [
852 self._ConvertMessageDescriptor(
853 nested, desc_name, file_desc, scope, syntax)
854 for nested in desc_proto.nested_type]
855 enums = [
856 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
857 scope, False)
858 for enum in desc_proto.enum_type]
859 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
860 for index, field in enumerate(desc_proto.field)]
861 extensions = [
862 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
863 is_extension=True)
864 for index, extension in enumerate(desc_proto.extension)]
865 oneofs = [
866 # pylint: disable=g-complex-comprehension
867 descriptor.OneofDescriptor(
868 desc.name,
869 '.'.join((desc_name, desc.name)),
870 index,
871 None,
872 [],
873 _OptionsOrNone(desc),
874 # pylint: disable=protected-access
875 create_key=descriptor._internal_create_key)
876 for index, desc in enumerate(desc_proto.oneof_decl)
877 ]
878 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
879 if extension_ranges:
880 is_extendable = True
881 else:
882 is_extendable = False
883 desc = descriptor.Descriptor(
884 name=desc_proto.name,
885 full_name=desc_name,
886 filename=file_name,
887 containing_type=None,
888 fields=fields,
889 oneofs=oneofs,
890 nested_types=nested,
891 enum_types=enums,
892 extensions=extensions,
893 options=_OptionsOrNone(desc_proto),
894 is_extendable=is_extendable,
895 extension_ranges=extension_ranges,
896 file=file_desc,
897 serialized_start=None,
898 serialized_end=None,
899 syntax=syntax,
900 # pylint: disable=protected-access
901 create_key=descriptor._internal_create_key)
902 for nested in desc.nested_types:
903 nested.containing_type = desc
904 for enum in desc.enum_types:
905 enum.containing_type = desc
906 for field_index, field_desc in enumerate(desc_proto.field):
907 if field_desc.HasField('oneof_index'):
908 oneof_index = field_desc.oneof_index
909 oneofs[oneof_index].fields.append(fields[field_index])
910 fields[field_index].containing_oneof = oneofs[oneof_index]
912 scope[_PrefixWithDot(desc_name)] = desc
913 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
914 self._descriptors[desc_name] = desc
915 return desc
917 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
918 containing_type=None, scope=None, top_level=False):
919 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
921 Args:
922 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
923 package: Optional package name for the new message EnumDescriptor.
924 file_desc: The file containing the enum descriptor.
925 containing_type: The type containing this enum.
926 scope: Scope containing available types.
927 top_level: If True, the enum is a top level symbol. If False, the enum
928 is defined inside a message.
930 Returns:
931 The added descriptor
932 """
934 if package:
935 enum_name = '.'.join((package, enum_proto.name))
936 else:
937 enum_name = enum_proto.name
939 if file_desc is None:
940 file_name = None
941 else:
942 file_name = file_desc.name
944 values = [self._MakeEnumValueDescriptor(value, index)
945 for index, value in enumerate(enum_proto.value)]
946 desc = descriptor.EnumDescriptor(name=enum_proto.name,
947 full_name=enum_name,
948 filename=file_name,
949 file=file_desc,
950 values=values,
951 containing_type=containing_type,
952 options=_OptionsOrNone(enum_proto),
953 # pylint: disable=protected-access
954 create_key=descriptor._internal_create_key)
955 scope['.%s' % enum_name] = desc
956 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
957 self._enum_descriptors[enum_name] = desc
959 # Add top level enum values.
960 if top_level:
961 for value in values:
962 full_name = _NormalizeFullyQualifiedName(
963 '.'.join((package, value.name)))
964 self._CheckConflictRegister(value, full_name, file_name)
965 self._top_enum_values[full_name] = value
967 return desc
969 def _MakeFieldDescriptor(self, field_proto, message_name, index,
970 file_desc, is_extension=False):
971 """Creates a field descriptor from a FieldDescriptorProto.
973 For message and enum type fields, this method will do a look up
974 in the pool for the appropriate descriptor for that type. If it
975 is unavailable, it will fall back to the _source function to
976 create it. If this type is still unavailable, construction will
977 fail.
979 Args:
980 field_proto: The proto describing the field.
981 message_name: The name of the containing message.
982 index: Index of the field
983 file_desc: The file containing the field descriptor.
984 is_extension: Indication that this field is for an extension.
986 Returns:
987 An initialized FieldDescriptor object
988 """
990 if message_name:
991 full_name = '.'.join((message_name, field_proto.name))
992 else:
993 full_name = field_proto.name
995 if field_proto.json_name:
996 json_name = field_proto.json_name
997 else:
998 json_name = None
1000 return descriptor.FieldDescriptor(
1001 name=field_proto.name,
1002 full_name=full_name,
1003 index=index,
1004 number=field_proto.number,
1005 type=field_proto.type,
1006 cpp_type=None,
1007 message_type=None,
1008 enum_type=None,
1009 containing_type=None,
1010 label=field_proto.label,
1011 has_default_value=False,
1012 default_value=None,
1013 is_extension=is_extension,
1014 extension_scope=None,
1015 options=_OptionsOrNone(field_proto),
1016 json_name=json_name,
1017 file=file_desc,
1018 # pylint: disable=protected-access
1019 create_key=descriptor._internal_create_key)
1021 def _SetAllFieldTypes(self, package, desc_proto, scope):
1022 """Sets all the descriptor's fields's types.
1024 This method also sets the containing types on any extensions.
1026 Args:
1027 package: The current package of desc_proto.
1028 desc_proto: The message descriptor to update.
1029 scope: Enclosing scope of available types.
1030 """
1032 package = _PrefixWithDot(package)
1034 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1036 if package == '.':
1037 nested_package = _PrefixWithDot(desc_proto.name)
1038 else:
1039 nested_package = '.'.join([package, desc_proto.name])
1041 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1042 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1044 for extension_proto, extension_desc in (
1045 zip(desc_proto.extension, main_desc.extensions)):
1046 extension_desc.containing_type = self._GetTypeFromScope(
1047 nested_package, extension_proto.extendee, scope)
1048 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1050 for nested_type in desc_proto.nested_type:
1051 self._SetAllFieldTypes(nested_package, nested_type, scope)
1053 def _SetFieldType(self, field_proto, field_desc, package, scope):
1054 """Sets the field's type, cpp_type, message_type and enum_type.
1056 Args:
1057 field_proto: Data about the field in proto format.
1058 field_desc: The descriptor to modify.
1059 package: The package the field's container is in.
1060 scope: Enclosing scope of available types.
1061 """
1062 if field_proto.type_name:
1063 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1064 else:
1065 desc = None
1067 if not field_proto.HasField('type'):
1068 if isinstance(desc, descriptor.Descriptor):
1069 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1070 else:
1071 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1073 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1074 field_proto.type)
1076 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1077 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1078 field_desc.message_type = desc
1080 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1081 field_desc.enum_type = desc
1083 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1084 field_desc.has_default_value = False
1085 field_desc.default_value = []
1086 elif field_proto.HasField('default_value'):
1087 field_desc.has_default_value = True
1088 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1089 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1090 field_desc.default_value = float(field_proto.default_value)
1091 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1092 field_desc.default_value = field_proto.default_value
1093 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1094 field_desc.default_value = field_proto.default_value.lower() == 'true'
1095 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1096 field_desc.default_value = field_desc.enum_type.values_by_name[
1097 field_proto.default_value].number
1098 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1099 field_desc.default_value = text_encoding.CUnescape(
1100 field_proto.default_value)
1101 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1102 field_desc.default_value = None
1103 else:
1104 # All other types are of the "int" type.
1105 field_desc.default_value = int(field_proto.default_value)
1106 else:
1107 field_desc.has_default_value = False
1108 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1109 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1110 field_desc.default_value = 0.0
1111 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1112 field_desc.default_value = u''
1113 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1114 field_desc.default_value = False
1115 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1116 field_desc.default_value = field_desc.enum_type.values[0].number
1117 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1118 field_desc.default_value = b''
1119 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1120 field_desc.default_value = None
1121 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1122 field_desc.default_value = None
1123 else:
1124 # All other types are of the "int" type.
1125 field_desc.default_value = 0
1127 field_desc.type = field_proto.type
1129 def _MakeEnumValueDescriptor(self, value_proto, index):
1130 """Creates a enum value descriptor object from a enum value proto.
1132 Args:
1133 value_proto: The proto describing the enum value.
1134 index: The index of the enum value.
1136 Returns:
1137 An initialized EnumValueDescriptor object.
1138 """
1140 return descriptor.EnumValueDescriptor(
1141 name=value_proto.name,
1142 index=index,
1143 number=value_proto.number,
1144 options=_OptionsOrNone(value_proto),
1145 type=None,
1146 # pylint: disable=protected-access
1147 create_key=descriptor._internal_create_key)
1149 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1150 package, file_desc):
1151 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1153 Args:
1154 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1155 service_index: The index of the service in the File.
1156 scope: Dict mapping short and full symbols to message and enum types.
1157 package: Optional package name for the new message EnumDescriptor.
1158 file_desc: The file containing the service descriptor.
1160 Returns:
1161 The added descriptor.
1162 """
1164 if package:
1165 service_name = '.'.join((package, service_proto.name))
1166 else:
1167 service_name = service_proto.name
1169 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1170 scope, index)
1171 for index, method_proto in enumerate(service_proto.method)]
1172 desc = descriptor.ServiceDescriptor(
1173 name=service_proto.name,
1174 full_name=service_name,
1175 index=service_index,
1176 methods=methods,
1177 options=_OptionsOrNone(service_proto),
1178 file=file_desc,
1179 # pylint: disable=protected-access
1180 create_key=descriptor._internal_create_key)
1181 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1182 self._service_descriptors[service_name] = desc
1183 return desc
1185 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1186 index):
1187 """Creates a method descriptor from a MethodDescriptorProto.
1189 Args:
1190 method_proto: The proto describing the method.
1191 service_name: The name of the containing service.
1192 package: Optional package name to look up for types.
1193 scope: Scope containing available types.
1194 index: Index of the method in the service.
1196 Returns:
1197 An initialized MethodDescriptor object.
1198 """
1199 full_name = '.'.join((service_name, method_proto.name))
1200 input_type = self._GetTypeFromScope(
1201 package, method_proto.input_type, scope)
1202 output_type = self._GetTypeFromScope(
1203 package, method_proto.output_type, scope)
1204 return descriptor.MethodDescriptor(
1205 name=method_proto.name,
1206 full_name=full_name,
1207 index=index,
1208 containing_service=None,
1209 input_type=input_type,
1210 output_type=output_type,
1211 client_streaming=method_proto.client_streaming,
1212 server_streaming=method_proto.server_streaming,
1213 options=_OptionsOrNone(method_proto),
1214 # pylint: disable=protected-access
1215 create_key=descriptor._internal_create_key)
1217 def _ExtractSymbols(self, descriptors):
1218 """Pulls out all the symbols from descriptor protos.
1220 Args:
1221 descriptors: The messages to extract descriptors from.
1222 Yields:
1223 A two element tuple of the type name and descriptor object.
1224 """
1226 for desc in descriptors:
1227 yield (_PrefixWithDot(desc.full_name), desc)
1228 for symbol in self._ExtractSymbols(desc.nested_types):
1229 yield symbol
1230 for enum in desc.enum_types:
1231 yield (_PrefixWithDot(enum.full_name), enum)
1233 def _GetDeps(self, dependencies, visited=None):
1234 """Recursively finds dependencies for file protos.
1236 Args:
1237 dependencies: The names of the files being depended on.
1238 visited: The names of files already found.
1240 Yields:
1241 Each direct and indirect dependency.
1242 """
1244 visited = visited or set()
1245 for dependency in dependencies:
1246 if dependency not in visited:
1247 visited.add(dependency)
1248 dep_desc = self.FindFileByName(dependency)
1249 yield dep_desc
1250 public_files = [d.name for d in dep_desc.public_dependencies]
1251 yield from self._GetDeps(public_files, visited)
1253 def _GetTypeFromScope(self, package, type_name, scope):
1254 """Finds a given type name in the current scope.
1256 Args:
1257 package: The package the proto should be located in.
1258 type_name: The name of the type to be found in the scope.
1259 scope: Dict mapping short and full symbols to message and enum types.
1261 Returns:
1262 The descriptor for the requested type.
1263 """
1264 if type_name not in scope:
1265 components = _PrefixWithDot(package).split('.')
1266 while components:
1267 possible_match = '.'.join(components + [type_name])
1268 if possible_match in scope:
1269 type_name = possible_match
1270 break
1271 else:
1272 components.pop(-1)
1273 return scope[type_name]
1276def _PrefixWithDot(name):
1277 return name if name.startswith('.') else '.%s' % name
1280if _USE_C_DESCRIPTORS:
1281 # TODO(amauryfa): This pool could be constructed from Python code, when we
1282 # support a flag like 'use_cpp_generated_pool=True'.
1283 # pylint: disable=protected-access
1284 _DEFAULT = descriptor._message.default_pool
1285else:
1286 _DEFAULT = DescriptorPool()
1289def Default():
1290 return _DEFAULT