1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides a container for DescriptorProtos."""
9
10__author__ = 'matthewtoia@google.com (Matt Toia)'
11
12from typing import Dict, Iterator, Optional
13import warnings
14
15
16class Error(Exception):
17 pass
18
19
20class DescriptorDatabaseConflictingDefinitionError(Error):
21 """Raised when a proto is added with the same name & different descriptor."""
22
23
24class DescriptorDatabase(object):
25 """A container accepting FileDescriptorProtos and maps DescriptorProtos."""
26
27 def __init__(self) -> None:
28 self._file_desc_protos_by_file: Dict[
29 str, 'descriptor_pb2.FileDescriptorProto'
30 ] = {}
31 self._file_desc_protos_by_symbol: Dict[
32 str, 'descriptor_pb2.FileDescriptorProto'
33 ] = {}
34
35 def Add(self, file_desc_proto: 'descriptor_pb2.FileDescriptorProto') -> None:
36 """Adds the FileDescriptorProto and its types to this database.
37
38 Args:
39 file_desc_proto: The FileDescriptorProto to add.
40 Raises:
41 DescriptorDatabaseConflictingDefinitionError: if an attempt is made to
42 add a proto with the same name but different definition than an
43 existing proto in the database.
44 """
45 proto_name = file_desc_proto.name
46 if proto_name not in self._file_desc_protos_by_file:
47 self._file_desc_protos_by_file[proto_name] = file_desc_proto
48 elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
49 raise DescriptorDatabaseConflictingDefinitionError(
50 '%s already added, but with different descriptor.' % proto_name)
51 else:
52 return
53
54 # Add all the top-level descriptors to the index.
55 package = file_desc_proto.package
56 for message in file_desc_proto.message_type:
57 for name in _ExtractSymbols(message, package):
58 self._AddSymbol(name, file_desc_proto)
59 for enum in file_desc_proto.enum_type:
60 self._AddSymbol(
61 ('.'.join((package, enum.name)) if package else enum.name),
62 file_desc_proto,
63 )
64 for enum_value in enum.value:
65 self._file_desc_protos_by_symbol[
66 '.'.join((package, enum_value.name)) if package else enum_value.name
67 ] = file_desc_proto
68 for extension in file_desc_proto.extension:
69 self._AddSymbol(
70 ('.'.join((package, extension.name)) if package else extension.name),
71 file_desc_proto,
72 )
73 for service in file_desc_proto.service:
74 self._AddSymbol(
75 ('.'.join((package, service.name)) if package else service.name),
76 file_desc_proto,
77 )
78
79 def FindFileByName(self, name: str) -> 'descriptor_pb2.FileDescriptorProto':
80 """Finds the file descriptor proto by file name.
81
82 Typically the file name is a relative path ending to a .proto file. The
83 proto with the given name will have to have been added to this database
84 using the Add method or else an error will be raised.
85
86 Args:
87 name: The file name to find.
88
89 Returns:
90 The file descriptor proto matching the name.
91
92 Raises:
93 KeyError if no file by the given name was added.
94 """
95
96 return self._file_desc_protos_by_file[name]
97
98 def FindFileContainingSymbol(
99 self, symbol: str
100 ) -> 'descriptor_pb2.FileDescriptorProto':
101 """Finds the file descriptor proto containing the specified symbol.
102
103 The symbol should be a fully qualified name including the file descriptor's
104 package and any containing messages. Some examples:
105
106 'some.package.name.Message'
107 'some.package.name.Message.NestedEnum'
108 'some.package.name.Message.some_field'
109
110 The file descriptor proto containing the specified symbol must be added to
111 this database using the Add method or else an error will be raised.
112
113 Args:
114 symbol: The fully qualified symbol name.
115
116 Returns:
117 The file descriptor proto containing the symbol.
118
119 Raises:
120 KeyError if no file contains the specified symbol.
121 """
122 if symbol.count('.') == 1 and symbol[0] == '.':
123 symbol = symbol.lstrip('.')
124 warnings.warn(
125 'Please remove the leading "." when '
126 'FindFileContainingSymbol, this will turn to error '
127 'in 2026 Jan.',
128 RuntimeWarning,
129 )
130 try:
131 return self._file_desc_protos_by_symbol[symbol]
132 except KeyError:
133 # Fields, enum values, and nested extensions are not in
134 # _file_desc_protos_by_symbol. Try to find the top level
135 # descriptor. Non-existent nested symbol under a valid top level
136 # descriptor can also be found. The behavior is the same with
137 # protobuf C++.
138 top_level, _, _ = symbol.rpartition('.')
139 try:
140 return self._file_desc_protos_by_symbol[top_level]
141 except KeyError:
142 # Raise the original symbol as a KeyError for better diagnostics.
143 raise KeyError(symbol)
144
145 def FindFileContainingExtension(
146 self, extendee_name: str, extension_number: int # pylint: disable=unused-argument
147 ) -> Optional['descriptor_pb2.FileDescriptorProto']:
148 # TODO: implement this API.
149 return None
150
151 def FindAllExtensionNumbers(self, extendee_name: str) -> list[int]: # pylint: disable=unused-argument
152 # TODO: implement this API.
153 return []
154
155 def _AddSymbol(
156 self, name: str, file_desc_proto: 'descriptor_pb2.FileDescriptorProto'
157 ) -> None:
158 if name in self._file_desc_protos_by_symbol:
159 warn_msg = ('Conflict register for file "' + file_desc_proto.name +
160 '": ' + name +
161 ' is already defined in file "' +
162 self._file_desc_protos_by_symbol[name].name + '"')
163 warnings.warn(warn_msg, RuntimeWarning)
164 self._file_desc_protos_by_symbol[name] = file_desc_proto
165
166
167def _ExtractSymbols(
168 desc_proto: 'descriptor_pb2.DescriptorProto', package: str
169) -> Iterator[str]:
170 """Pulls out all the symbols from a descriptor proto.
171
172 Args:
173 desc_proto: The proto to extract symbols from.
174 package: The package containing the descriptor type.
175
176 Yields:
177 The fully qualified name found in the descriptor.
178 """
179 message_name = package + '.' + desc_proto.name if package else desc_proto.name
180 yield message_name
181 for nested_type in desc_proto.nested_type:
182 for symbol in _ExtractSymbols(nested_type, message_name):
183 yield symbol
184 for enum_type in desc_proto.enum_type:
185 yield '.'.join((message_name, enum_type.name))