1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides a factory class for generating dynamic messages.
9
10The easiest way to use this class is if you have access to the FileDescriptor
11protos containing the messages you want to create you can just do the following:
12
13message_classes = message_factory.GetMessages(iterable_of_file_descriptors)
14my_proto_instance = message_classes['some.proto.package.MessageName']()
15"""
16
17__author__ = 'matthewtoia@google.com (Matt Toia)'
18
19import warnings
20
21from google.protobuf.internal import api_implementation
22from google.protobuf import descriptor_pool
23from google.protobuf import message
24
25if api_implementation.Type() == 'python':
26 from google.protobuf.internal import python_message as message_impl
27else:
28 from google.protobuf.pyext import cpp_message as message_impl # pylint: disable=g-import-not-at-top
29
30
31# The type of all Message classes.
32_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
33
34
35def GetMessageClass(descriptor):
36 """Obtains a proto2 message class based on the passed in descriptor.
37
38 Passing a descriptor with a fully qualified name matching a previous
39 invocation will cause the same class to be returned.
40
41 Args:
42 descriptor: The descriptor to build from.
43
44 Returns:
45 A class describing the passed in descriptor.
46 """
47 concrete_class = getattr(descriptor, '_concrete_class', None)
48 if concrete_class:
49 return concrete_class
50 return _InternalCreateMessageClass(descriptor)
51
52
53def GetMessageClassesForFiles(files, pool):
54 """Gets all the messages from specified files.
55
56 This will find and resolve dependencies, failing if the descriptor
57 pool cannot satisfy them.
58
59 Args:
60 files: The file names to extract messages from.
61 pool: The descriptor pool to find the files including the dependent
62 files.
63
64 Returns:
65 A dictionary mapping proto names to the message classes.
66 """
67 result = {}
68 for file_name in files:
69 file_desc = pool.FindFileByName(file_name)
70 for desc in file_desc.message_types_by_name.values():
71 result[desc.full_name] = GetMessageClass(desc)
72
73 # While the extension FieldDescriptors are created by the descriptor pool,
74 # the python classes created in the factory need them to be registered
75 # explicitly, which is done below.
76 #
77 # The call to RegisterExtension will specifically check if the
78 # extension was already registered on the object and either
79 # ignore the registration if the original was the same, or raise
80 # an error if they were different.
81
82 for extension in file_desc.extensions_by_name.values():
83 extended_class = GetMessageClass(extension.containing_type)
84 if api_implementation.Type() != 'python':
85 # TODO: Remove this check here. Duplicate extension
86 # register check should be in descriptor_pool.
87 if extension is not pool.FindExtensionByNumber(
88 extension.containing_type, extension.number
89 ):
90 raise ValueError('Double registration of Extensions')
91 # Recursively load protos for extension field, in order to be able to
92 # fully represent the extension. This matches the behavior for regular
93 # fields too.
94 if extension.message_type:
95 GetMessageClass(extension.message_type)
96 return result
97
98
99def _InternalCreateMessageClass(descriptor):
100 """Builds a proto2 message class based on the passed in descriptor.
101
102 Args:
103 descriptor: The descriptor to build from.
104
105 Returns:
106 A class describing the passed in descriptor.
107 """
108 descriptor_name = descriptor.name
109 result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
110 descriptor_name,
111 (message.Message,),
112 {
113 'DESCRIPTOR': descriptor,
114 # If module not set, it wrongly points to message_factory module.
115 '__module__': None,
116 })
117 for field in descriptor.fields:
118 if field.message_type:
119 GetMessageClass(field.message_type)
120 for extension in result_class.DESCRIPTOR.extensions:
121 extended_class = GetMessageClass(extension.containing_type)
122 if api_implementation.Type() != 'python':
123 # TODO: Remove this check here. Duplicate extension
124 # register check should be in descriptor_pool.
125 pool = extension.containing_type.file.pool
126 if extension is not pool.FindExtensionByNumber(
127 extension.containing_type, extension.number
128 ):
129 raise ValueError('Double registration of Extensions')
130 if extension.message_type:
131 GetMessageClass(extension.message_type)
132 return result_class
133
134
135# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
136# method above instead.
137class MessageFactory(object):
138 """Factory for creating Proto2 messages from descriptors in a pool."""
139
140 def __init__(self, pool=None):
141 """Initializes a new factory."""
142 self.pool = pool or descriptor_pool.DescriptorPool()
143
144 def GetPrototype(self, descriptor):
145 """Obtains a proto2 message class based on the passed in descriptor.
146
147 Passing a descriptor with a fully qualified name matching a previous
148 invocation will cause the same class to be returned.
149
150 Args:
151 descriptor: The descriptor to build from.
152
153 Returns:
154 A class describing the passed in descriptor.
155 """
156 warnings.warn(
157 'MessageFactory class is deprecated. Please use '
158 'GetMessageClass() instead of MessageFactory.GetPrototype. '
159 'MessageFactory class will be removed after 2024.',
160 stacklevel=2,
161 )
162 return GetMessageClass(descriptor)
163
164 def CreatePrototype(self, descriptor):
165 """Builds a proto2 message class based on the passed in descriptor.
166
167 Don't call this function directly, it always creates a new class. Call
168 GetMessageClass() instead.
169
170 Args:
171 descriptor: The descriptor to build from.
172
173 Returns:
174 A class describing the passed in descriptor.
175 """
176 warnings.warn(
177 'Directly call CreatePrototype is wrong. Please use '
178 'GetMessageClass() method instead. Directly use '
179 'CreatePrototype will raise error after July 2023.',
180 stacklevel=2,
181 )
182 return _InternalCreateMessageClass(descriptor)
183
184 def GetMessages(self, files):
185 """Gets all the messages from a specified file.
186
187 This will find and resolve dependencies, failing if the descriptor
188 pool cannot satisfy them.
189
190 Args:
191 files: The file names to extract messages from.
192
193 Returns:
194 A dictionary mapping proto names to the message classes. This will include
195 any dependent messages as well as any messages defined in the same file as
196 a specified message.
197 """
198 warnings.warn(
199 'MessageFactory class is deprecated. Please use '
200 'GetMessageClassesForFiles() instead of '
201 'MessageFactory.GetMessages(). MessageFactory class '
202 'will be removed after 2024.',
203 stacklevel=2,
204 )
205 return GetMessageClassesForFiles(files, self.pool)
206
207
208def GetMessages(file_protos, pool=None):
209 """Builds a dictionary of all the messages available in a set of files.
210
211 Args:
212 file_protos: Iterable of FileDescriptorProto to build messages out of.
213 pool: The descriptor pool to add the file protos.
214
215 Returns:
216 A dictionary mapping proto names to the message classes. This will include
217 any dependent messages as well as any messages defined in the same file as
218 a specified message.
219 """
220 # The cpp implementation of the protocol buffer library requires to add the
221 # message in topological order of the dependency graph.
222 des_pool = pool or descriptor_pool.DescriptorPool()
223 file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
224 def _AddFile(file_proto):
225 for dependency in file_proto.dependency:
226 if dependency in file_by_name:
227 # Remove from elements to be visited, in order to cut cycles.
228 _AddFile(file_by_name.pop(dependency))
229 des_pool.Add(file_proto)
230 while file_by_name:
231 _AddFile(file_by_name.popitem()[1])
232 return GetMessageClassesForFiles(
233 [file_proto.name for file_proto in file_protos], des_pool)