1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import inspect
17import logging
18
19from google.protobuf import descriptor_pb2
20from google.protobuf import descriptor_pool
21from google.protobuf import message
22from google.protobuf import reflection
23
24from proto.marshal.rules.message import MessageRule
25
26log = logging.getLogger("_FileInfo")
27
28
29class _FileInfo(
30 collections.namedtuple(
31 "_FileInfo",
32 ["descriptor", "messages", "enums", "name", "nested", "nested_enum"],
33 )
34):
35 registry = {} # Mapping[str, '_FileInfo']
36
37 @classmethod
38 def maybe_add_descriptor(cls, filename, package):
39 descriptor = cls.registry.get(filename)
40 if not descriptor:
41 descriptor = cls.registry[filename] = cls(
42 descriptor=descriptor_pb2.FileDescriptorProto(
43 name=filename,
44 package=package,
45 syntax="proto3",
46 ),
47 enums=collections.OrderedDict(),
48 messages=collections.OrderedDict(),
49 name=filename,
50 nested={},
51 nested_enum={},
52 )
53
54 return descriptor
55
56 @staticmethod
57 def proto_file_name(name):
58 return "{0}.proto".format(name.replace(".", "/"))
59
60 def _get_manifest(self, new_class):
61 module = inspect.getmodule(new_class)
62 if hasattr(module, "__protobuf__"):
63 return frozenset(module.__protobuf__.manifest)
64
65 return frozenset()
66
67 def _get_remaining_manifest(self, new_class):
68 return self._get_manifest(new_class) - {new_class.__name__}
69
70 def _calculate_salt(self, new_class, fallback):
71 manifest = self._get_manifest(new_class)
72 if manifest and new_class.__name__ not in manifest:
73 log.warning(
74 "proto-plus module {module} has a declared manifest but {class_name} is not in it".format(
75 module=inspect.getmodule(new_class).__name__,
76 class_name=new_class.__name__,
77 )
78 )
79
80 return "" if new_class.__name__ in manifest else (fallback or "").lower()
81
82 def generate_file_pb(self, new_class, fallback_salt=""):
83 """Generate the descriptors for all protos in the file.
84
85 This method takes the file descriptor attached to the parent
86 message and generates the immutable descriptors for all of the
87 messages in the file descriptor. (This must be done in one fell
88 swoop for immutability and to resolve proto cross-referencing.)
89
90 This is run automatically when the last proto in the file is
91 generated, as determined by the module's __all__ tuple.
92 """
93 pool = descriptor_pool.Default()
94
95 # Salt the filename in the descriptor.
96 # This allows re-use of the filename by other proto messages if
97 # needed (e.g. if __all__ is not used).
98 salt = self._calculate_salt(new_class, fallback_salt)
99 self.descriptor.name = "{name}.proto".format(
100 name="_".join([self.descriptor.name[:-6], salt]).rstrip("_"),
101 )
102
103 # Add the file descriptor.
104 pool.Add(self.descriptor)
105
106 # Adding the file descriptor to the pool created a descriptor for
107 # each message; go back through our wrapper messages and associate
108 # them with the internal protobuf version.
109 for full_name, proto_plus_message in self.messages.items():
110 # Get the descriptor from the pool, and create the protobuf
111 # message based on it.
112 descriptor = pool.FindMessageTypeByName(full_name)
113 pb_message = reflection.GeneratedProtocolMessageType(
114 descriptor.name,
115 (message.Message,),
116 {"DESCRIPTOR": descriptor, "__module__": None},
117 )
118
119 # Register the message with the marshal so it is wrapped
120 # appropriately.
121 #
122 # We do this here (rather than at class creation) because it
123 # is not until this point that we have an actual protobuf
124 # message subclass, which is what we need to use.
125 proto_plus_message._meta._pb = pb_message
126 proto_plus_message._meta.marshal.register(
127 pb_message, MessageRule(pb_message, proto_plus_message)
128 )
129
130 # Iterate over any fields on the message and, if their type
131 # is a message still referenced as a string, resolve the reference.
132 for field in proto_plus_message._meta.fields.values():
133 if field.message and isinstance(field.message, str):
134 field.message = self.messages[field.message]
135 elif field.enum and isinstance(field.enum, str):
136 field.enum = self.enums[field.enum]
137
138 # Same thing for enums
139 for full_name, proto_plus_enum in self.enums.items():
140 descriptor = pool.FindEnumTypeByName(full_name)
141 proto_plus_enum._meta.pb = descriptor
142
143 # We no longer need to track this file's info; remove it from
144 # the module's registry and from this object.
145 self.registry.pop(self.name)
146
147 def ready(self, new_class):
148 """Return True if a file descriptor may added, False otherwise.
149
150 This determine if all the messages that we plan to create have been
151 created, as best as we are able.
152
153 Since messages depend on one another, we create descriptor protos
154 (which reference each other using strings) and wait until we have
155 built everything that is going to be in the module, and then
156 use the descriptor protos to instantiate the actual descriptors in
157 one fell swoop.
158
159 Args:
160 new_class (~.MessageMeta): The new class currently undergoing
161 creation.
162 """
163 # If there are any nested descriptors that have not been assigned to
164 # the descriptors that should contain them, then we are not ready.
165 if len(self.nested) or len(self.nested_enum):
166 return False
167
168 # If there are any unresolved fields (fields with a composite message
169 # declared as a string), ensure that the corresponding message is
170 # declared.
171 for field in self.unresolved_fields:
172 if (field.message and field.message not in self.messages) or (
173 field.enum and field.enum not in self.enums
174 ):
175 return False
176
177 # If the module in which this class is defined provides a
178 # __protobuf__ property, it may have a manifest.
179 #
180 # Do not generate the file descriptor until every member of the
181 # manifest has been populated.
182 module = inspect.getmodule(new_class)
183 manifest = self._get_remaining_manifest(new_class)
184
185 # We are ready if all members have been populated.
186 return all(hasattr(module, i) for i in manifest)
187
188 @property
189 def unresolved_fields(self):
190 """Return fields with referencing message types as strings."""
191 for proto_plus_message in self.messages.values():
192 for field in proto_plus_message._meta.fields.values():
193 if (field.message and isinstance(field.message, str)) or (
194 field.enum and isinstance(field.enum, str)
195 ):
196 yield field