1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides a factory class for generating dynamic messages.
9
10The easiest way to use this class is if you have access to the FileDescriptor
11protos containing the messages you want to create you can just do the following:
12
13message_classes = message_factory.GetMessages(iterable_of_file_descriptors)
14my_proto_instance = message_classes['some.proto.package.MessageName']()
15"""
16
17__author__ = 'matthewtoia@google.com (Matt Toia)'
18
19import warnings
20
21from google.protobuf import descriptor_pool
22from google.protobuf import message
23from google.protobuf.internal import api_implementation
24
25if api_implementation.Type() == 'python':
26 from google.protobuf.internal import python_message as message_impl
27else:
28 from google.protobuf.pyext import cpp_message as message_impl # pylint: disable=g-import-not-at-top
29
30
31# The type of all Message classes.
32_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
33
34
35def GetMessageClass(descriptor):
36 """Obtains a proto2 message class based on the passed in descriptor.
37
38 Passing a descriptor with a fully qualified name matching a previous
39 invocation will cause the same class to be returned.
40
41 Args:
42 descriptor: The descriptor to build from.
43
44 Returns:
45 A class describing the passed in descriptor.
46 """
47 concrete_class = getattr(descriptor, '_concrete_class', None)
48 if concrete_class:
49 return concrete_class
50 return _InternalCreateMessageClass(descriptor)
51
52
53def GetMessageClassesForFiles(files, pool):
54 """Gets all the messages from specified files.
55
56 This will find and resolve dependencies, failing if the descriptor
57 pool cannot satisfy them.
58
59 Args:
60 files: The file names to extract messages from.
61 pool: The descriptor pool to find the files including the dependent files.
62
63 Returns:
64 A dictionary mapping proto names to the message classes.
65 """
66 result = {}
67 for file_name in files:
68 file_desc = pool.FindFileByName(file_name)
69 for desc in file_desc.message_types_by_name.values():
70 result[desc.full_name] = GetMessageClass(desc)
71
72 # While the extension FieldDescriptors are created by the descriptor pool,
73 # the python classes created in the factory need them to be registered
74 # explicitly, which is done below.
75 #
76 # The call to RegisterExtension will specifically check if the
77 # extension was already registered on the object and either
78 # ignore the registration if the original was the same, or raise
79 # an error if they were different.
80
81 for extension in file_desc.extensions_by_name.values():
82 _ = GetMessageClass(extension.containing_type)
83 if api_implementation.Type() != 'python':
84 # TODO: Remove this check here. Duplicate extension
85 # register check should be in descriptor_pool.
86 if extension is not pool.FindExtensionByNumber(
87 extension.containing_type, extension.number
88 ):
89 raise ValueError('Double registration of Extensions')
90 # Recursively load protos for extension field, in order to be able to
91 # fully represent the extension. This matches the behavior for regular
92 # fields too.
93 if extension.message_type:
94 GetMessageClass(extension.message_type)
95 return result
96
97
98def _InternalCreateMessageClass(descriptor):
99 """Builds a proto2 message class based on the passed in descriptor.
100
101 Args:
102 descriptor: The descriptor to build from.
103
104 Returns:
105 A class describing the passed in descriptor.
106 """
107 descriptor_name = descriptor.name
108 result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
109 descriptor_name,
110 (message.Message,),
111 {
112 'DESCRIPTOR': descriptor,
113 # If module not set, it wrongly points to message_factory module.
114 '__module__': None,
115 },
116 )
117 for field in descriptor.fields:
118 if field.message_type:
119 GetMessageClass(field.message_type)
120
121 for extension in result_class.DESCRIPTOR.extensions:
122 extended_class = GetMessageClass(extension.containing_type)
123 if api_implementation.Type() != 'python':
124 # TODO: Remove this check here. Duplicate extension
125 # register check should be in descriptor_pool.
126 pool = extension.containing_type.file.pool
127 if extension is not pool.FindExtensionByNumber(
128 extension.containing_type, extension.number
129 ):
130 raise ValueError('Double registration of Extensions')
131 if extension.message_type:
132 GetMessageClass(extension.message_type)
133 return result_class
134
135
136# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
137# method above instead.
138class MessageFactory(object):
139 """Factory for creating Proto2 messages from descriptors in a pool."""
140
141 def __init__(self, pool=None):
142 """Initializes a new factory."""
143 self.pool = pool or descriptor_pool.DescriptorPool()
144
145 def GetPrototype(self, descriptor):
146 """Obtains a proto2 message class based on the passed in descriptor.
147
148 Passing a descriptor with a fully qualified name matching a previous
149 invocation will cause the same class to be returned.
150
151 Args:
152 descriptor: The descriptor to build from.
153
154 Returns:
155 A class describing the passed in descriptor.
156 """
157 warnings.warn(
158 'MessageFactory class is deprecated. Please use '
159 'GetMessageClass() instead of MessageFactory.GetPrototype. '
160 'MessageFactory class will be removed after 2024.',
161 stacklevel=2,
162 )
163 return GetMessageClass(descriptor)
164
165 def CreatePrototype(self, descriptor):
166 """Builds a proto2 message class based on the passed in descriptor.
167
168 Don't call this function directly, it always creates a new class. Call
169 GetMessageClass() instead.
170
171 Args:
172 descriptor: The descriptor to build from.
173
174 Returns:
175 A class describing the passed in descriptor.
176 """
177 warnings.warn(
178 'Directly call CreatePrototype is wrong. Please use '
179 'GetMessageClass() method instead. Directly use '
180 'CreatePrototype will raise error after July 2023.',
181 stacklevel=2,
182 )
183 return _InternalCreateMessageClass(descriptor)
184
185 def GetMessages(self, files):
186 """Gets all the messages from a specified file.
187
188 This will find and resolve dependencies, failing if the descriptor
189 pool cannot satisfy them.
190
191 Args:
192 files: The file names to extract messages from.
193
194 Returns:
195 A dictionary mapping proto names to the message classes. This will include
196 any dependent messages as well as any messages defined in the same file as
197 a specified message.
198 """
199 warnings.warn(
200 'MessageFactory class is deprecated. Please use '
201 'GetMessageClassesForFiles() instead of '
202 'MessageFactory.GetMessages(). MessageFactory class '
203 'will be removed after 2024.',
204 stacklevel=2,
205 )
206 return GetMessageClassesForFiles(files, self.pool)
207
208
209def GetMessages(file_protos, pool=None):
210 """Builds a dictionary of all the messages available in a set of files.
211
212 Args:
213 file_protos: Iterable of FileDescriptorProto to build messages out of.
214 pool: The descriptor pool to add the file protos.
215
216 Returns:
217 A dictionary mapping proto names to the message classes. This will include
218 any dependent messages as well as any messages defined in the same file as
219 a specified message.
220 """
221 # The cpp implementation of the protocol buffer library requires to add the
222 # message in topological order of the dependency graph.
223 des_pool = pool or descriptor_pool.DescriptorPool()
224 file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
225
226 def _AddFile(file_proto):
227 for dependency in file_proto.dependency:
228 if dependency in file_by_name:
229 # Remove from elements to be visited, in order to cut cycles.
230 _AddFile(file_by_name.pop(dependency))
231 des_pool.Add(file_proto)
232
233 while file_by_name:
234 _AddFile(file_by_name.popitem()[1])
235 return GetMessageClassesForFiles(
236 [file_proto.name for file_proto in file_protos], des_pool
237 )