Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/descriptor_pool.py: 15%
469 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:45 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:45 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
8"""Provides DescriptorPool to use as a container for proto2 descriptors.
10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11a collection of protocol buffer descriptors for use when dynamically creating
12message types at runtime.
14For most applications protocol buffers should be used via modules generated by
15the protocol buffer compiler tool. This should only be used when the type of
16protocol buffers used in an application or library cannot be predetermined.
18Below is a straightforward example on how to use this class::
20 pool = DescriptorPool()
21 file_descriptor_protos = [ ... ]
22 for file_descriptor_proto in file_descriptor_protos:
23 pool.Add(file_descriptor_proto)
24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
26The message descriptor can be used in conjunction with the message_factory
27module in order to create a protocol buffer class that can be encoded and
28decoded.
30If you want to get a Python class for the specified proto, use the
31helper functions inside google.protobuf.message_factory
32directly instead of this class.
33"""
35__author__ = 'matthewtoia@google.com (Matt Toia)'
37import collections
38import warnings
40from google.protobuf import descriptor
41from google.protobuf import descriptor_database
42from google.protobuf import text_encoding
43from google.protobuf.internal import python_message
45_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
48def _Deprecated(func):
49 """Mark functions as deprecated."""
51 def NewFunc(*args, **kwargs):
52 warnings.warn(
53 'Call to deprecated function %s(). Note: Do add unlinked descriptors '
54 'to descriptor_pool is wrong. Please use Add() or AddSerializedFile() '
55 'instead. This function will be removed soon.' % func.__name__,
56 category=DeprecationWarning)
57 return func(*args, **kwargs)
58 NewFunc.__name__ = func.__name__
59 NewFunc.__doc__ = func.__doc__
60 NewFunc.__dict__.update(func.__dict__)
61 return NewFunc
64def _NormalizeFullyQualifiedName(name):
65 """Remove leading period from fully-qualified type name.
67 Due to b/13860351 in descriptor_database.py, types in the root namespace are
68 generated with a leading period. This function removes that prefix.
70 Args:
71 name (str): The fully-qualified symbol name.
73 Returns:
74 str: The normalized fully-qualified symbol name.
75 """
76 return name.lstrip('.')
79def _OptionsOrNone(descriptor_proto):
80 """Returns the value of the field `options`, or None if it is not set."""
81 if descriptor_proto.HasField('options'):
82 return descriptor_proto.options
83 else:
84 return None
87def _IsMessageSetExtension(field):
88 return (field.is_extension and
89 field.containing_type.has_options and
90 field.containing_type.GetOptions().message_set_wire_format and
91 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
92 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
95class DescriptorPool(object):
96 """A collection of protobufs dynamically constructed by descriptor protos."""
98 if _USE_C_DESCRIPTORS:
100 def __new__(cls, descriptor_db=None):
101 # pylint: disable=protected-access
102 return descriptor._message.DescriptorPool(descriptor_db)
104 def __init__(
105 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
106 ):
107 """Initializes a Pool of proto buffs.
109 The descriptor_db argument to the constructor is provided to allow
110 specialized file descriptor proto lookup code to be triggered on demand. An
111 example would be an implementation which will read and compile a file
112 specified in a call to FindFileByName() and not require the call to Add()
113 at all. Results from this database will be cached internally here as well.
115 Args:
116 descriptor_db: A secondary source of file descriptors.
117 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
118 C++.
119 """
121 self._internal_db = descriptor_database.DescriptorDatabase()
122 self._descriptor_db = descriptor_db
123 self._descriptors = {}
124 self._enum_descriptors = {}
125 self._service_descriptors = {}
126 self._file_descriptors = {}
127 self._toplevel_extensions = {}
128 self._top_enum_values = {}
129 # We store extensions in two two-level mappings: The first key is the
130 # descriptor of the message being extended, the second key is the extension
131 # full name or its tag number.
132 self._extensions_by_name = collections.defaultdict(dict)
133 self._extensions_by_number = collections.defaultdict(dict)
135 def _CheckConflictRegister(self, desc, desc_name, file_name):
136 """Check if the descriptor name conflicts with another of the same name.
138 Args:
139 desc: Descriptor of a message, enum, service, extension or enum value.
140 desc_name (str): the full name of desc.
141 file_name (str): The file name of descriptor.
142 """
143 for register, descriptor_type in [
144 (self._descriptors, descriptor.Descriptor),
145 (self._enum_descriptors, descriptor.EnumDescriptor),
146 (self._service_descriptors, descriptor.ServiceDescriptor),
147 (self._toplevel_extensions, descriptor.FieldDescriptor),
148 (self._top_enum_values, descriptor.EnumValueDescriptor)]:
149 if desc_name in register:
150 old_desc = register[desc_name]
151 if isinstance(old_desc, descriptor.EnumValueDescriptor):
152 old_file = old_desc.type.file.name
153 else:
154 old_file = old_desc.file.name
156 if not isinstance(desc, descriptor_type) or (
157 old_file != file_name):
158 error_msg = ('Conflict register for file "' + file_name +
159 '": ' + desc_name +
160 ' is already defined in file "' +
161 old_file + '". Please fix the conflict by adding '
162 'package name on the proto file, or use different '
163 'name for the duplication.')
164 if isinstance(desc, descriptor.EnumValueDescriptor):
165 error_msg += ('\nNote: enum values appear as '
166 'siblings of the enum type instead of '
167 'children of it.')
169 raise TypeError(error_msg)
171 return
173 def Add(self, file_desc_proto):
174 """Adds the FileDescriptorProto and its types to this pool.
176 Args:
177 file_desc_proto (FileDescriptorProto): The file descriptor to add.
178 """
180 self._internal_db.Add(file_desc_proto)
182 def AddSerializedFile(self, serialized_file_desc_proto):
183 """Adds the FileDescriptorProto and its types to this pool.
185 Args:
186 serialized_file_desc_proto (bytes): A bytes string, serialization of the
187 :class:`FileDescriptorProto` to add.
189 Returns:
190 FileDescriptor: Descriptor for the added file.
191 """
193 # pylint: disable=g-import-not-at-top
194 from google.protobuf import descriptor_pb2
195 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
196 serialized_file_desc_proto)
197 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
198 file_desc.serialized_pb = serialized_file_desc_proto
199 return file_desc
201 # Add Descriptor to descriptor pool is deprecated. Please use Add()
202 # or AddSerializedFile() to add a FileDescriptorProto instead.
203 @_Deprecated
204 def AddDescriptor(self, desc):
205 self._AddDescriptor(desc)
207 # Never call this method. It is for internal usage only.
208 def _AddDescriptor(self, desc):
209 """Adds a Descriptor to the pool, non-recursively.
211 If the Descriptor contains nested messages or enums, the caller must
212 explicitly register them. This method also registers the FileDescriptor
213 associated with the message.
215 Args:
216 desc: A Descriptor.
217 """
218 if not isinstance(desc, descriptor.Descriptor):
219 raise TypeError('Expected instance of descriptor.Descriptor.')
221 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
223 self._descriptors[desc.full_name] = desc
224 self._AddFileDescriptor(desc.file)
226 # Never call this method. It is for internal usage only.
227 def _AddEnumDescriptor(self, enum_desc):
228 """Adds an EnumDescriptor to the pool.
230 This method also registers the FileDescriptor associated with the enum.
232 Args:
233 enum_desc: An EnumDescriptor.
234 """
236 if not isinstance(enum_desc, descriptor.EnumDescriptor):
237 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
239 file_name = enum_desc.file.name
240 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
241 self._enum_descriptors[enum_desc.full_name] = enum_desc
243 # Top enum values need to be indexed.
244 # Count the number of dots to see whether the enum is toplevel or nested
245 # in a message. We cannot use enum_desc.containing_type at this stage.
246 if enum_desc.file.package:
247 top_level = (enum_desc.full_name.count('.')
248 - enum_desc.file.package.count('.') == 1)
249 else:
250 top_level = enum_desc.full_name.count('.') == 0
251 if top_level:
252 file_name = enum_desc.file.name
253 package = enum_desc.file.package
254 for enum_value in enum_desc.values:
255 full_name = _NormalizeFullyQualifiedName(
256 '.'.join((package, enum_value.name)))
257 self._CheckConflictRegister(enum_value, full_name, file_name)
258 self._top_enum_values[full_name] = enum_value
259 self._AddFileDescriptor(enum_desc.file)
261 # Add ServiceDescriptor to descriptor pool is deprecated. Please use Add()
262 # or AddSerializedFile() to add a FileDescriptorProto instead.
263 @_Deprecated
264 def AddServiceDescriptor(self, service_desc):
265 self._AddServiceDescriptor(service_desc)
267 # Never call this method. It is for internal usage only.
268 def _AddServiceDescriptor(self, service_desc):
269 """Adds a ServiceDescriptor to the pool.
271 Args:
272 service_desc: A ServiceDescriptor.
273 """
275 if not isinstance(service_desc, descriptor.ServiceDescriptor):
276 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
278 self._CheckConflictRegister(service_desc, service_desc.full_name,
279 service_desc.file.name)
280 self._service_descriptors[service_desc.full_name] = service_desc
282 # Add ExtensionDescriptor to descriptor pool is deprecated. Please use Add()
283 # or AddSerializedFile() to add a FileDescriptorProto instead.
284 @_Deprecated
285 def AddExtensionDescriptor(self, extension):
286 self._AddExtensionDescriptor(extension)
288 # Never call this method. It is for internal usage only.
289 def _AddExtensionDescriptor(self, extension):
290 """Adds a FieldDescriptor describing an extension to the pool.
292 Args:
293 extension: A FieldDescriptor.
295 Raises:
296 AssertionError: when another extension with the same number extends the
297 same message.
298 TypeError: when the specified extension is not a
299 descriptor.FieldDescriptor.
300 """
301 if not (isinstance(extension, descriptor.FieldDescriptor) and
302 extension.is_extension):
303 raise TypeError('Expected an extension descriptor.')
305 if extension.extension_scope is None:
306 self._CheckConflictRegister(
307 extension, extension.full_name, extension.file.name)
308 self._toplevel_extensions[extension.full_name] = extension
310 try:
311 existing_desc = self._extensions_by_number[
312 extension.containing_type][extension.number]
313 except KeyError:
314 pass
315 else:
316 if extension is not existing_desc:
317 raise AssertionError(
318 'Extensions "%s" and "%s" both try to extend message type "%s" '
319 'with field number %d.' %
320 (extension.full_name, existing_desc.full_name,
321 extension.containing_type.full_name, extension.number))
323 self._extensions_by_number[extension.containing_type][
324 extension.number] = extension
325 self._extensions_by_name[extension.containing_type][
326 extension.full_name] = extension
328 # Also register MessageSet extensions with the type name.
329 if _IsMessageSetExtension(extension):
330 self._extensions_by_name[extension.containing_type][
331 extension.message_type.full_name] = extension
333 if hasattr(extension.containing_type, '_concrete_class'):
334 python_message._AttachFieldHelpers(
335 extension.containing_type._concrete_class, extension)
337 @_Deprecated
338 def AddFileDescriptor(self, file_desc):
339 self._InternalAddFileDescriptor(file_desc)
341 # Never call this method. It is for internal usage only.
342 def _InternalAddFileDescriptor(self, file_desc):
343 """Adds a FileDescriptor to the pool, non-recursively.
345 If the FileDescriptor contains messages or enums, the caller must explicitly
346 register them.
348 Args:
349 file_desc: A FileDescriptor.
350 """
352 self._AddFileDescriptor(file_desc)
354 def _AddFileDescriptor(self, file_desc):
355 """Adds a FileDescriptor to the pool, non-recursively.
357 If the FileDescriptor contains messages or enums, the caller must explicitly
358 register them.
360 Args:
361 file_desc: A FileDescriptor.
362 """
364 if not isinstance(file_desc, descriptor.FileDescriptor):
365 raise TypeError('Expected instance of descriptor.FileDescriptor.')
366 self._file_descriptors[file_desc.name] = file_desc
368 def FindFileByName(self, file_name):
369 """Gets a FileDescriptor by file name.
371 Args:
372 file_name (str): The path to the file to get a descriptor for.
374 Returns:
375 FileDescriptor: The descriptor for the named file.
377 Raises:
378 KeyError: if the file cannot be found in the pool.
379 """
381 try:
382 return self._file_descriptors[file_name]
383 except KeyError:
384 pass
386 try:
387 file_proto = self._internal_db.FindFileByName(file_name)
388 except KeyError as error:
389 if self._descriptor_db:
390 file_proto = self._descriptor_db.FindFileByName(file_name)
391 else:
392 raise error
393 if not file_proto:
394 raise KeyError('Cannot find a file named %s' % file_name)
395 return self._ConvertFileProtoToFileDescriptor(file_proto)
397 def FindFileContainingSymbol(self, symbol):
398 """Gets the FileDescriptor for the file containing the specified symbol.
400 Args:
401 symbol (str): The name of the symbol to search for.
403 Returns:
404 FileDescriptor: Descriptor for the file that contains the specified
405 symbol.
407 Raises:
408 KeyError: if the file cannot be found in the pool.
409 """
411 symbol = _NormalizeFullyQualifiedName(symbol)
412 try:
413 return self._InternalFindFileContainingSymbol(symbol)
414 except KeyError:
415 pass
417 try:
418 # Try fallback database. Build and find again if possible.
419 self._FindFileContainingSymbolInDb(symbol)
420 return self._InternalFindFileContainingSymbol(symbol)
421 except KeyError:
422 raise KeyError('Cannot find a file containing %s' % symbol)
424 def _InternalFindFileContainingSymbol(self, symbol):
425 """Gets the already built FileDescriptor containing the specified symbol.
427 Args:
428 symbol (str): The name of the symbol to search for.
430 Returns:
431 FileDescriptor: Descriptor for the file that contains the specified
432 symbol.
434 Raises:
435 KeyError: if the file cannot be found in the pool.
436 """
437 try:
438 return self._descriptors[symbol].file
439 except KeyError:
440 pass
442 try:
443 return self._enum_descriptors[symbol].file
444 except KeyError:
445 pass
447 try:
448 return self._service_descriptors[symbol].file
449 except KeyError:
450 pass
452 try:
453 return self._top_enum_values[symbol].type.file
454 except KeyError:
455 pass
457 try:
458 return self._toplevel_extensions[symbol].file
459 except KeyError:
460 pass
462 # Try fields, enum values and nested extensions inside a message.
463 top_name, _, sub_name = symbol.rpartition('.')
464 try:
465 message = self.FindMessageTypeByName(top_name)
466 assert (sub_name in message.extensions_by_name or
467 sub_name in message.fields_by_name or
468 sub_name in message.enum_values_by_name)
469 return message.file
470 except (KeyError, AssertionError):
471 raise KeyError('Cannot find a file containing %s' % symbol)
473 def FindMessageTypeByName(self, full_name):
474 """Loads the named descriptor from the pool.
476 Args:
477 full_name (str): The full name of the descriptor to load.
479 Returns:
480 Descriptor: The descriptor for the named type.
482 Raises:
483 KeyError: if the message cannot be found in the pool.
484 """
486 full_name = _NormalizeFullyQualifiedName(full_name)
487 if full_name not in self._descriptors:
488 self._FindFileContainingSymbolInDb(full_name)
489 return self._descriptors[full_name]
491 def FindEnumTypeByName(self, full_name):
492 """Loads the named enum descriptor from the pool.
494 Args:
495 full_name (str): The full name of the enum descriptor to load.
497 Returns:
498 EnumDescriptor: The enum descriptor for the named type.
500 Raises:
501 KeyError: if the enum cannot be found in the pool.
502 """
504 full_name = _NormalizeFullyQualifiedName(full_name)
505 if full_name not in self._enum_descriptors:
506 self._FindFileContainingSymbolInDb(full_name)
507 return self._enum_descriptors[full_name]
509 def FindFieldByName(self, full_name):
510 """Loads the named field descriptor from the pool.
512 Args:
513 full_name (str): The full name of the field descriptor to load.
515 Returns:
516 FieldDescriptor: The field descriptor for the named field.
518 Raises:
519 KeyError: if the field cannot be found in the pool.
520 """
521 full_name = _NormalizeFullyQualifiedName(full_name)
522 message_name, _, field_name = full_name.rpartition('.')
523 message_descriptor = self.FindMessageTypeByName(message_name)
524 return message_descriptor.fields_by_name[field_name]
526 def FindOneofByName(self, full_name):
527 """Loads the named oneof descriptor from the pool.
529 Args:
530 full_name (str): The full name of the oneof descriptor to load.
532 Returns:
533 OneofDescriptor: The oneof descriptor for the named oneof.
535 Raises:
536 KeyError: if the oneof cannot be found in the pool.
537 """
538 full_name = _NormalizeFullyQualifiedName(full_name)
539 message_name, _, oneof_name = full_name.rpartition('.')
540 message_descriptor = self.FindMessageTypeByName(message_name)
541 return message_descriptor.oneofs_by_name[oneof_name]
543 def FindExtensionByName(self, full_name):
544 """Loads the named extension descriptor from the pool.
546 Args:
547 full_name (str): The full name of the extension descriptor to load.
549 Returns:
550 FieldDescriptor: The field descriptor for the named extension.
552 Raises:
553 KeyError: if the extension cannot be found in the pool.
554 """
555 full_name = _NormalizeFullyQualifiedName(full_name)
556 try:
557 # The proto compiler does not give any link between the FileDescriptor
558 # and top-level extensions unless the FileDescriptorProto is added to
559 # the DescriptorDatabase, but this can impact memory usage.
560 # So we registered these extensions by name explicitly.
561 return self._toplevel_extensions[full_name]
562 except KeyError:
563 pass
564 message_name, _, extension_name = full_name.rpartition('.')
565 try:
566 # Most extensions are nested inside a message.
567 scope = self.FindMessageTypeByName(message_name)
568 except KeyError:
569 # Some extensions are defined at file scope.
570 scope = self._FindFileContainingSymbolInDb(full_name)
571 return scope.extensions_by_name[extension_name]
573 def FindExtensionByNumber(self, message_descriptor, number):
574 """Gets the extension of the specified message with the specified number.
576 Extensions have to be registered to this pool by calling :func:`Add` or
577 :func:`AddExtensionDescriptor`.
579 Args:
580 message_descriptor (Descriptor): descriptor of the extended message.
581 number (int): Number of the extension field.
583 Returns:
584 FieldDescriptor: The descriptor for the extension.
586 Raises:
587 KeyError: when no extension with the given number is known for the
588 specified message.
589 """
590 try:
591 return self._extensions_by_number[message_descriptor][number]
592 except KeyError:
593 self._TryLoadExtensionFromDB(message_descriptor, number)
594 return self._extensions_by_number[message_descriptor][number]
596 def FindAllExtensions(self, message_descriptor):
597 """Gets all the known extensions of a given message.
599 Extensions have to be registered to this pool by build related
600 :func:`Add` or :func:`AddExtensionDescriptor`.
602 Args:
603 message_descriptor (Descriptor): Descriptor of the extended message.
605 Returns:
606 list[FieldDescriptor]: Field descriptors describing the extensions.
607 """
608 # Fallback to descriptor db if FindAllExtensionNumbers is provided.
609 if self._descriptor_db and hasattr(
610 self._descriptor_db, 'FindAllExtensionNumbers'):
611 full_name = message_descriptor.full_name
612 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
613 for number in all_numbers:
614 if number in self._extensions_by_number[message_descriptor]:
615 continue
616 self._TryLoadExtensionFromDB(message_descriptor, number)
618 return list(self._extensions_by_number[message_descriptor].values())
620 def _TryLoadExtensionFromDB(self, message_descriptor, number):
621 """Try to Load extensions from descriptor db.
623 Args:
624 message_descriptor: descriptor of the extended message.
625 number: the extension number that needs to be loaded.
626 """
627 if not self._descriptor_db:
628 return
629 # Only supported when FindFileContainingExtension is provided.
630 if not hasattr(
631 self._descriptor_db, 'FindFileContainingExtension'):
632 return
634 full_name = message_descriptor.full_name
635 file_proto = self._descriptor_db.FindFileContainingExtension(
636 full_name, number)
638 if file_proto is None:
639 return
641 try:
642 self._ConvertFileProtoToFileDescriptor(file_proto)
643 except:
644 warn_msg = ('Unable to load proto file %s for extension number %d.' %
645 (file_proto.name, number))
646 warnings.warn(warn_msg, RuntimeWarning)
648 def FindServiceByName(self, full_name):
649 """Loads the named service descriptor from the pool.
651 Args:
652 full_name (str): The full name of the service descriptor to load.
654 Returns:
655 ServiceDescriptor: The service descriptor for the named service.
657 Raises:
658 KeyError: if the service cannot be found in the pool.
659 """
660 full_name = _NormalizeFullyQualifiedName(full_name)
661 if full_name not in self._service_descriptors:
662 self._FindFileContainingSymbolInDb(full_name)
663 return self._service_descriptors[full_name]
665 def FindMethodByName(self, full_name):
666 """Loads the named service method descriptor from the pool.
668 Args:
669 full_name (str): The full name of the method descriptor to load.
671 Returns:
672 MethodDescriptor: The method descriptor for the service method.
674 Raises:
675 KeyError: if the method cannot be found in the pool.
676 """
677 full_name = _NormalizeFullyQualifiedName(full_name)
678 service_name, _, method_name = full_name.rpartition('.')
679 service_descriptor = self.FindServiceByName(service_name)
680 return service_descriptor.methods_by_name[method_name]
682 def _FindFileContainingSymbolInDb(self, symbol):
683 """Finds the file in descriptor DB containing the specified symbol.
685 Args:
686 symbol (str): The name of the symbol to search for.
688 Returns:
689 FileDescriptor: The file that contains the specified symbol.
691 Raises:
692 KeyError: if the file cannot be found in the descriptor database.
693 """
694 try:
695 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
696 except KeyError as error:
697 if self._descriptor_db:
698 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
699 else:
700 raise error
701 if not file_proto:
702 raise KeyError('Cannot find a file containing %s' % symbol)
703 return self._ConvertFileProtoToFileDescriptor(file_proto)
705 def _ConvertFileProtoToFileDescriptor(self, file_proto):
706 """Creates a FileDescriptor from a proto or returns a cached copy.
708 This method also has the side effect of loading all the symbols found in
709 the file into the appropriate dictionaries in the pool.
711 Args:
712 file_proto: The proto to convert.
714 Returns:
715 A FileDescriptor matching the passed in proto.
716 """
717 if file_proto.name not in self._file_descriptors:
718 built_deps = list(self._GetDeps(file_proto.dependency))
719 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
720 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
722 file_descriptor = descriptor.FileDescriptor(
723 pool=self,
724 name=file_proto.name,
725 package=file_proto.package,
726 syntax=file_proto.syntax,
727 options=_OptionsOrNone(file_proto),
728 serialized_pb=file_proto.SerializeToString(),
729 dependencies=direct_deps,
730 public_dependencies=public_deps,
731 # pylint: disable=protected-access
732 create_key=descriptor._internal_create_key)
733 scope = {}
735 # This loop extracts all the message and enum types from all the
736 # dependencies of the file_proto. This is necessary to create the
737 # scope of available message types when defining the passed in
738 # file proto.
739 for dependency in built_deps:
740 scope.update(self._ExtractSymbols(
741 dependency.message_types_by_name.values()))
742 scope.update((_PrefixWithDot(enum.full_name), enum)
743 for enum in dependency.enum_types_by_name.values())
745 for message_type in file_proto.message_type:
746 message_desc = self._ConvertMessageDescriptor(
747 message_type, file_proto.package, file_descriptor, scope,
748 file_proto.syntax)
749 file_descriptor.message_types_by_name[message_desc.name] = (
750 message_desc)
752 for enum_type in file_proto.enum_type:
753 file_descriptor.enum_types_by_name[enum_type.name] = (
754 self._ConvertEnumDescriptor(enum_type, file_proto.package,
755 file_descriptor, None, scope, True))
757 for index, extension_proto in enumerate(file_proto.extension):
758 extension_desc = self._MakeFieldDescriptor(
759 extension_proto, file_proto.package, index, file_descriptor,
760 is_extension=True)
761 extension_desc.containing_type = self._GetTypeFromScope(
762 file_descriptor.package, extension_proto.extendee, scope)
763 self._SetFieldType(extension_proto, extension_desc,
764 file_descriptor.package, scope)
765 file_descriptor.extensions_by_name[extension_desc.name] = (
766 extension_desc)
768 for desc_proto in file_proto.message_type:
769 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
771 if file_proto.package:
772 desc_proto_prefix = _PrefixWithDot(file_proto.package)
773 else:
774 desc_proto_prefix = ''
776 for desc_proto in file_proto.message_type:
777 desc = self._GetTypeFromScope(
778 desc_proto_prefix, desc_proto.name, scope)
779 file_descriptor.message_types_by_name[desc_proto.name] = desc
781 for index, service_proto in enumerate(file_proto.service):
782 file_descriptor.services_by_name[service_proto.name] = (
783 self._MakeServiceDescriptor(service_proto, index, scope,
784 file_proto.package, file_descriptor))
786 self._file_descriptors[file_proto.name] = file_descriptor
788 # Add extensions to the pool
789 def AddExtensionForNested(message_type):
790 for nested in message_type.nested_types:
791 AddExtensionForNested(nested)
792 for extension in message_type.extensions:
793 self._AddExtensionDescriptor(extension)
795 file_desc = self._file_descriptors[file_proto.name]
796 for extension in file_desc.extensions_by_name.values():
797 self._AddExtensionDescriptor(extension)
798 for message_type in file_desc.message_types_by_name.values():
799 AddExtensionForNested(message_type)
801 return file_desc
803 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
804 scope=None, syntax=None):
805 """Adds the proto to the pool in the specified package.
807 Args:
808 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
809 package: The package the proto should be located in.
810 file_desc: The file containing this message.
811 scope: Dict mapping short and full symbols to message and enum types.
812 syntax: string indicating syntax of the file ("proto2" or "proto3")
814 Returns:
815 The added descriptor.
816 """
818 if package:
819 desc_name = '.'.join((package, desc_proto.name))
820 else:
821 desc_name = desc_proto.name
823 if file_desc is None:
824 file_name = None
825 else:
826 file_name = file_desc.name
828 if scope is None:
829 scope = {}
831 nested = [
832 self._ConvertMessageDescriptor(
833 nested, desc_name, file_desc, scope, syntax)
834 for nested in desc_proto.nested_type]
835 enums = [
836 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
837 scope, False)
838 for enum in desc_proto.enum_type]
839 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
840 for index, field in enumerate(desc_proto.field)]
841 extensions = [
842 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
843 is_extension=True)
844 for index, extension in enumerate(desc_proto.extension)]
845 oneofs = [
846 # pylint: disable=g-complex-comprehension
847 descriptor.OneofDescriptor(
848 desc.name,
849 '.'.join((desc_name, desc.name)),
850 index,
851 None,
852 [],
853 _OptionsOrNone(desc),
854 # pylint: disable=protected-access
855 create_key=descriptor._internal_create_key)
856 for index, desc in enumerate(desc_proto.oneof_decl)
857 ]
858 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
859 if extension_ranges:
860 is_extendable = True
861 else:
862 is_extendable = False
863 desc = descriptor.Descriptor(
864 name=desc_proto.name,
865 full_name=desc_name,
866 filename=file_name,
867 containing_type=None,
868 fields=fields,
869 oneofs=oneofs,
870 nested_types=nested,
871 enum_types=enums,
872 extensions=extensions,
873 options=_OptionsOrNone(desc_proto),
874 is_extendable=is_extendable,
875 extension_ranges=extension_ranges,
876 file=file_desc,
877 serialized_start=None,
878 serialized_end=None,
879 syntax=syntax,
880 is_map_entry=desc_proto.options.map_entry,
881 # pylint: disable=protected-access
882 create_key=descriptor._internal_create_key)
883 for nested in desc.nested_types:
884 nested.containing_type = desc
885 for enum in desc.enum_types:
886 enum.containing_type = desc
887 for field_index, field_desc in enumerate(desc_proto.field):
888 if field_desc.HasField('oneof_index'):
889 oneof_index = field_desc.oneof_index
890 oneofs[oneof_index].fields.append(fields[field_index])
891 fields[field_index].containing_oneof = oneofs[oneof_index]
893 scope[_PrefixWithDot(desc_name)] = desc
894 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
895 self._descriptors[desc_name] = desc
896 return desc
898 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
899 containing_type=None, scope=None, top_level=False):
900 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
902 Args:
903 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
904 package: Optional package name for the new message EnumDescriptor.
905 file_desc: The file containing the enum descriptor.
906 containing_type: The type containing this enum.
907 scope: Scope containing available types.
908 top_level: If True, the enum is a top level symbol. If False, the enum
909 is defined inside a message.
911 Returns:
912 The added descriptor
913 """
915 if package:
916 enum_name = '.'.join((package, enum_proto.name))
917 else:
918 enum_name = enum_proto.name
920 if file_desc is None:
921 file_name = None
922 else:
923 file_name = file_desc.name
925 values = [self._MakeEnumValueDescriptor(value, index)
926 for index, value in enumerate(enum_proto.value)]
927 desc = descriptor.EnumDescriptor(name=enum_proto.name,
928 full_name=enum_name,
929 filename=file_name,
930 file=file_desc,
931 values=values,
932 containing_type=containing_type,
933 options=_OptionsOrNone(enum_proto),
934 # pylint: disable=protected-access
935 create_key=descriptor._internal_create_key)
936 scope['.%s' % enum_name] = desc
937 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
938 self._enum_descriptors[enum_name] = desc
940 # Add top level enum values.
941 if top_level:
942 for value in values:
943 full_name = _NormalizeFullyQualifiedName(
944 '.'.join((package, value.name)))
945 self._CheckConflictRegister(value, full_name, file_name)
946 self._top_enum_values[full_name] = value
948 return desc
950 def _MakeFieldDescriptor(self, field_proto, message_name, index,
951 file_desc, is_extension=False):
952 """Creates a field descriptor from a FieldDescriptorProto.
954 For message and enum type fields, this method will do a look up
955 in the pool for the appropriate descriptor for that type. If it
956 is unavailable, it will fall back to the _source function to
957 create it. If this type is still unavailable, construction will
958 fail.
960 Args:
961 field_proto: The proto describing the field.
962 message_name: The name of the containing message.
963 index: Index of the field
964 file_desc: The file containing the field descriptor.
965 is_extension: Indication that this field is for an extension.
967 Returns:
968 An initialized FieldDescriptor object
969 """
971 if message_name:
972 full_name = '.'.join((message_name, field_proto.name))
973 else:
974 full_name = field_proto.name
976 if field_proto.json_name:
977 json_name = field_proto.json_name
978 else:
979 json_name = None
981 return descriptor.FieldDescriptor(
982 name=field_proto.name,
983 full_name=full_name,
984 index=index,
985 number=field_proto.number,
986 type=field_proto.type,
987 cpp_type=None,
988 message_type=None,
989 enum_type=None,
990 containing_type=None,
991 label=field_proto.label,
992 has_default_value=False,
993 default_value=None,
994 is_extension=is_extension,
995 extension_scope=None,
996 options=_OptionsOrNone(field_proto),
997 json_name=json_name,
998 file=file_desc,
999 # pylint: disable=protected-access
1000 create_key=descriptor._internal_create_key)
1002 def _SetAllFieldTypes(self, package, desc_proto, scope):
1003 """Sets all the descriptor's fields's types.
1005 This method also sets the containing types on any extensions.
1007 Args:
1008 package: The current package of desc_proto.
1009 desc_proto: The message descriptor to update.
1010 scope: Enclosing scope of available types.
1011 """
1013 package = _PrefixWithDot(package)
1015 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1017 if package == '.':
1018 nested_package = _PrefixWithDot(desc_proto.name)
1019 else:
1020 nested_package = '.'.join([package, desc_proto.name])
1022 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1023 self._SetFieldType(field_proto, field_desc, nested_package, scope)
1025 for extension_proto, extension_desc in (
1026 zip(desc_proto.extension, main_desc.extensions)):
1027 extension_desc.containing_type = self._GetTypeFromScope(
1028 nested_package, extension_proto.extendee, scope)
1029 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1031 for nested_type in desc_proto.nested_type:
1032 self._SetAllFieldTypes(nested_package, nested_type, scope)
1034 def _SetFieldType(self, field_proto, field_desc, package, scope):
1035 """Sets the field's type, cpp_type, message_type and enum_type.
1037 Args:
1038 field_proto: Data about the field in proto format.
1039 field_desc: The descriptor to modify.
1040 package: The package the field's container is in.
1041 scope: Enclosing scope of available types.
1042 """
1043 if field_proto.type_name:
1044 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1045 else:
1046 desc = None
1048 if not field_proto.HasField('type'):
1049 if isinstance(desc, descriptor.Descriptor):
1050 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1051 else:
1052 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1054 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1055 field_proto.type)
1057 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1058 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1059 field_desc.message_type = desc
1061 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1062 field_desc.enum_type = desc
1064 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1065 field_desc.has_default_value = False
1066 field_desc.default_value = []
1067 elif field_proto.HasField('default_value'):
1068 field_desc.has_default_value = True
1069 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1070 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1071 field_desc.default_value = float(field_proto.default_value)
1072 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1073 field_desc.default_value = field_proto.default_value
1074 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1075 field_desc.default_value = field_proto.default_value.lower() == 'true'
1076 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1077 field_desc.default_value = field_desc.enum_type.values_by_name[
1078 field_proto.default_value].number
1079 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1080 field_desc.default_value = text_encoding.CUnescape(
1081 field_proto.default_value)
1082 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1083 field_desc.default_value = None
1084 else:
1085 # All other types are of the "int" type.
1086 field_desc.default_value = int(field_proto.default_value)
1087 else:
1088 field_desc.has_default_value = False
1089 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1090 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1091 field_desc.default_value = 0.0
1092 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1093 field_desc.default_value = u''
1094 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1095 field_desc.default_value = False
1096 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1097 field_desc.default_value = field_desc.enum_type.values[0].number
1098 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1099 field_desc.default_value = b''
1100 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1101 field_desc.default_value = None
1102 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1103 field_desc.default_value = None
1104 else:
1105 # All other types are of the "int" type.
1106 field_desc.default_value = 0
1108 field_desc.type = field_proto.type
1110 def _MakeEnumValueDescriptor(self, value_proto, index):
1111 """Creates a enum value descriptor object from a enum value proto.
1113 Args:
1114 value_proto: The proto describing the enum value.
1115 index: The index of the enum value.
1117 Returns:
1118 An initialized EnumValueDescriptor object.
1119 """
1121 return descriptor.EnumValueDescriptor(
1122 name=value_proto.name,
1123 index=index,
1124 number=value_proto.number,
1125 options=_OptionsOrNone(value_proto),
1126 type=None,
1127 # pylint: disable=protected-access
1128 create_key=descriptor._internal_create_key)
1130 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1131 package, file_desc):
1132 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1134 Args:
1135 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1136 service_index: The index of the service in the File.
1137 scope: Dict mapping short and full symbols to message and enum types.
1138 package: Optional package name for the new message EnumDescriptor.
1139 file_desc: The file containing the service descriptor.
1141 Returns:
1142 The added descriptor.
1143 """
1145 if package:
1146 service_name = '.'.join((package, service_proto.name))
1147 else:
1148 service_name = service_proto.name
1150 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1151 scope, index)
1152 for index, method_proto in enumerate(service_proto.method)]
1153 desc = descriptor.ServiceDescriptor(
1154 name=service_proto.name,
1155 full_name=service_name,
1156 index=service_index,
1157 methods=methods,
1158 options=_OptionsOrNone(service_proto),
1159 file=file_desc,
1160 # pylint: disable=protected-access
1161 create_key=descriptor._internal_create_key)
1162 self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1163 self._service_descriptors[service_name] = desc
1164 return desc
1166 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1167 index):
1168 """Creates a method descriptor from a MethodDescriptorProto.
1170 Args:
1171 method_proto: The proto describing the method.
1172 service_name: The name of the containing service.
1173 package: Optional package name to look up for types.
1174 scope: Scope containing available types.
1175 index: Index of the method in the service.
1177 Returns:
1178 An initialized MethodDescriptor object.
1179 """
1180 full_name = '.'.join((service_name, method_proto.name))
1181 input_type = self._GetTypeFromScope(
1182 package, method_proto.input_type, scope)
1183 output_type = self._GetTypeFromScope(
1184 package, method_proto.output_type, scope)
1185 return descriptor.MethodDescriptor(
1186 name=method_proto.name,
1187 full_name=full_name,
1188 index=index,
1189 containing_service=None,
1190 input_type=input_type,
1191 output_type=output_type,
1192 client_streaming=method_proto.client_streaming,
1193 server_streaming=method_proto.server_streaming,
1194 options=_OptionsOrNone(method_proto),
1195 # pylint: disable=protected-access
1196 create_key=descriptor._internal_create_key)
1198 def _ExtractSymbols(self, descriptors):
1199 """Pulls out all the symbols from descriptor protos.
1201 Args:
1202 descriptors: The messages to extract descriptors from.
1203 Yields:
1204 A two element tuple of the type name and descriptor object.
1205 """
1207 for desc in descriptors:
1208 yield (_PrefixWithDot(desc.full_name), desc)
1209 for symbol in self._ExtractSymbols(desc.nested_types):
1210 yield symbol
1211 for enum in desc.enum_types:
1212 yield (_PrefixWithDot(enum.full_name), enum)
1214 def _GetDeps(self, dependencies, visited=None):
1215 """Recursively finds dependencies for file protos.
1217 Args:
1218 dependencies: The names of the files being depended on.
1219 visited: The names of files already found.
1221 Yields:
1222 Each direct and indirect dependency.
1223 """
1225 visited = visited or set()
1226 for dependency in dependencies:
1227 if dependency not in visited:
1228 visited.add(dependency)
1229 dep_desc = self.FindFileByName(dependency)
1230 yield dep_desc
1231 public_files = [d.name for d in dep_desc.public_dependencies]
1232 yield from self._GetDeps(public_files, visited)
1234 def _GetTypeFromScope(self, package, type_name, scope):
1235 """Finds a given type name in the current scope.
1237 Args:
1238 package: The package the proto should be located in.
1239 type_name: The name of the type to be found in the scope.
1240 scope: Dict mapping short and full symbols to message and enum types.
1242 Returns:
1243 The descriptor for the requested type.
1244 """
1245 if type_name not in scope:
1246 components = _PrefixWithDot(package).split('.')
1247 while components:
1248 possible_match = '.'.join(components + [type_name])
1249 if possible_match in scope:
1250 type_name = possible_match
1251 break
1252 else:
1253 components.pop(-1)
1254 return scope[type_name]
1257def _PrefixWithDot(name):
1258 return name if name.startswith('.') else '.%s' % name
1261if _USE_C_DESCRIPTORS:
1262 # TODO: This pool could be constructed from Python code, when we
1263 # support a flag like 'use_cpp_generated_pool=True'.
1264 # pylint: disable=protected-access
1265 _DEFAULT = descriptor._message.default_pool
1266else:
1267 _DEFAULT = DescriptorPool()
1270def Default():
1271 return _DEFAULT