1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`protobuf`."""
16
17import collections
18import collections.abc
19import copy
20import inspect
21
22from google.protobuf import field_mask_pb2
23from google.protobuf import message
24from google.protobuf import wrappers_pb2
25
26
27_SENTINEL = object()
28_WRAPPER_TYPES = (
29 wrappers_pb2.BoolValue,
30 wrappers_pb2.BytesValue,
31 wrappers_pb2.DoubleValue,
32 wrappers_pb2.FloatValue,
33 wrappers_pb2.Int32Value,
34 wrappers_pb2.Int64Value,
35 wrappers_pb2.StringValue,
36 wrappers_pb2.UInt32Value,
37 wrappers_pb2.UInt64Value,
38)
39
40
41def from_any_pb(pb_type, any_pb):
42 """Converts an ``Any`` protobuf to the specified message type.
43
44 Args:
45 pb_type (type): the type of the message that any_pb stores an instance
46 of.
47 any_pb (google.protobuf.any_pb2.Any): the object to be converted.
48
49 Returns:
50 pb_type: An instance of the pb_type message.
51
52 Raises:
53 TypeError: if the message could not be converted.
54 """
55 msg = pb_type()
56
57 # Unwrap proto-plus wrapped messages.
58 if callable(getattr(pb_type, "pb", None)):
59 msg_pb = pb_type.pb(msg)
60 else:
61 msg_pb = msg
62
63 # Unpack the Any object and populate the protobuf message instance.
64 if not any_pb.Unpack(msg_pb):
65 raise TypeError(
66 f"Could not convert `{any_pb.TypeName()}` with underlying type `google.protobuf.any_pb2.Any` to `{msg_pb.DESCRIPTOR.full_name}`"
67 )
68
69 # Done; return the message.
70 return msg
71
72
73def check_oneof(**kwargs):
74 """Raise ValueError if more than one keyword argument is not ``None``.
75
76 Args:
77 kwargs (dict): The keyword arguments sent to the function.
78
79 Raises:
80 ValueError: If more than one entry in ``kwargs`` is not ``None``.
81 """
82 # Sanity check: If no keyword arguments were sent, this is fine.
83 if not kwargs:
84 return
85
86 not_nones = [val for val in kwargs.values() if val is not None]
87 if len(not_nones) > 1:
88 raise ValueError(
89 "Only one of {fields} should be set.".format(
90 fields=", ".join(sorted(kwargs.keys()))
91 )
92 )
93
94
95def get_messages(module):
96 """Discovers all protobuf Message classes in a given import module.
97
98 Args:
99 module (module): A Python module; :func:`dir` will be run against this
100 module to find Message subclasses.
101
102 Returns:
103 dict[str, google.protobuf.message.Message]: A dictionary with the
104 Message class names as keys, and the Message subclasses themselves
105 as values.
106 """
107 answer = collections.OrderedDict()
108 for name in dir(module):
109 candidate = getattr(module, name)
110 if inspect.isclass(candidate) and issubclass(candidate, message.Message):
111 answer[name] = candidate
112 return answer
113
114
115def _resolve_subkeys(key, separator="."):
116 """Resolve a potentially nested key.
117
118 If the key contains the ``separator`` (e.g. ``.``) then the key will be
119 split on the first instance of the subkey::
120
121 >>> _resolve_subkeys('a.b.c')
122 ('a', 'b.c')
123 >>> _resolve_subkeys('d|e|f', separator='|')
124 ('d', 'e|f')
125
126 If not, the subkey will be :data:`None`::
127
128 >>> _resolve_subkeys('foo')
129 ('foo', None)
130
131 Args:
132 key (str): A string that may or may not contain the separator.
133 separator (str): The namespace separator. Defaults to `.`.
134
135 Returns:
136 Tuple[str, str]: The key and subkey(s).
137 """
138 parts = key.split(separator, 1)
139
140 if len(parts) > 1:
141 return parts
142 else:
143 return parts[0], None
144
145
146def get(msg_or_dict, key, default=_SENTINEL):
147 """Retrieve a key's value from a protobuf Message or dictionary.
148
149 Args:
150 mdg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
151 object.
152 key (str): The key to retrieve from the object.
153 default (Any): If the key is not present on the object, and a default
154 is set, returns that default instead. A type-appropriate falsy
155 default is generally recommended, as protobuf messages almost
156 always have default values for unset values and it is not always
157 possible to tell the difference between a falsy value and an
158 unset one. If no default is set then :class:`KeyError` will be
159 raised if the key is not present in the object.
160
161 Returns:
162 Any: The return value from the underlying Message or dict.
163
164 Raises:
165 KeyError: If the key is not found. Note that, for unset values,
166 messages and dictionaries may not have consistent behavior.
167 TypeError: If ``msg_or_dict`` is not a Message or Mapping.
168 """
169 # We may need to get a nested key. Resolve this.
170 key, subkey = _resolve_subkeys(key)
171
172 # Attempt to get the value from the two types of objects we know about.
173 # If we get something else, complain.
174 if isinstance(msg_or_dict, message.Message):
175 answer = getattr(msg_or_dict, key, default)
176 elif isinstance(msg_or_dict, collections.abc.Mapping):
177 answer = msg_or_dict.get(key, default)
178 else:
179 raise TypeError(
180 "get() expected a dict or protobuf message, got {!r}.".format(
181 type(msg_or_dict)
182 )
183 )
184
185 # If the object we got back is our sentinel, raise KeyError; this is
186 # a "not found" case.
187 if answer is _SENTINEL:
188 raise KeyError(key)
189
190 # If a subkey exists, call this method recursively against the answer.
191 if subkey is not None and answer is not default:
192 return get(answer, subkey, default=default)
193
194 return answer
195
196
197def _set_field_on_message(msg, key, value):
198 """Set helper for protobuf Messages."""
199 # Attempt to set the value on the types of objects we know how to deal
200 # with.
201 if isinstance(value, (collections.abc.MutableSequence, tuple)):
202 # Clear the existing repeated protobuf message of any elements
203 # currently inside it.
204 while getattr(msg, key):
205 getattr(msg, key).pop()
206
207 # Write our new elements to the repeated field.
208 for item in value:
209 if isinstance(item, collections.abc.Mapping):
210 getattr(msg, key).add(**item)
211 else:
212 # protobuf's RepeatedCompositeContainer doesn't support
213 # append.
214 getattr(msg, key).extend([item])
215 elif isinstance(value, collections.abc.Mapping):
216 # Assign the dictionary values to the protobuf message.
217 for item_key, item_value in value.items():
218 set(getattr(msg, key), item_key, item_value)
219 elif isinstance(value, message.Message):
220 getattr(msg, key).CopyFrom(value)
221 else:
222 setattr(msg, key, value)
223
224
225def set(msg_or_dict, key, value):
226 """Set a key's value on a protobuf Message or dictionary.
227
228 Args:
229 msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
230 object.
231 key (str): The key to set.
232 value (Any): The value to set.
233
234 Raises:
235 TypeError: If ``msg_or_dict`` is not a Message or dictionary.
236 """
237 # Sanity check: Is our target object valid?
238 if not isinstance(msg_or_dict, (collections.abc.MutableMapping, message.Message)):
239 raise TypeError(
240 "set() expected a dict or protobuf message, got {!r}.".format(
241 type(msg_or_dict)
242 )
243 )
244
245 # We may be setting a nested key. Resolve this.
246 basekey, subkey = _resolve_subkeys(key)
247
248 # If a subkey exists, then get that object and call this method
249 # recursively against it using the subkey.
250 if subkey is not None:
251 if isinstance(msg_or_dict, collections.abc.MutableMapping):
252 msg_or_dict.setdefault(basekey, {})
253 set(get(msg_or_dict, basekey), subkey, value)
254 return
255
256 if isinstance(msg_or_dict, collections.abc.MutableMapping):
257 msg_or_dict[key] = value
258 else:
259 _set_field_on_message(msg_or_dict, key, value)
260
261
262def setdefault(msg_or_dict, key, value):
263 """Set the key on a protobuf Message or dictionary to a given value if the
264 current value is falsy.
265
266 Because protobuf Messages do not distinguish between unset values and
267 falsy ones particularly well (by design), this method treats any falsy
268 value (e.g. 0, empty list) as a target to be overwritten, on both Messages
269 and dictionaries.
270
271 Args:
272 msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
273 object.
274 key (str): The key on the object in question.
275 value (Any): The value to set.
276
277 Raises:
278 TypeError: If ``msg_or_dict`` is not a Message or dictionary.
279 """
280 if not get(msg_or_dict, key, default=None):
281 set(msg_or_dict, key, value)
282
283
284def field_mask(original, modified):
285 """Create a field mask by comparing two messages.
286
287 Args:
288 original (~google.protobuf.message.Message): the original message.
289 If set to None, this field will be interpreted as an empty
290 message.
291 modified (~google.protobuf.message.Message): the modified message.
292 If set to None, this field will be interpreted as an empty
293 message.
294
295 Returns:
296 google.protobuf.field_mask_pb2.FieldMask: field mask that contains
297 the list of field names that have different values between the two
298 messages. If the messages are equivalent, then the field mask is empty.
299
300 Raises:
301 ValueError: If the ``original`` or ``modified`` are not the same type.
302 """
303 if original is None and modified is None:
304 return field_mask_pb2.FieldMask()
305
306 if original is None and modified is not None:
307 original = copy.deepcopy(modified)
308 original.Clear()
309
310 if modified is None and original is not None:
311 modified = copy.deepcopy(original)
312 modified.Clear()
313
314 if not isinstance(original, type(modified)):
315 raise ValueError(
316 "expected that both original and modified should be of the "
317 'same type, received "{!r}" and "{!r}".'.format(
318 type(original), type(modified)
319 )
320 )
321
322 return field_mask_pb2.FieldMask(paths=_field_mask_helper(original, modified))
323
324
325def _field_mask_helper(original, modified, current=""):
326 answer = []
327
328 for name in original.DESCRIPTOR.fields_by_name:
329 field_path = _get_path(current, name)
330
331 original_val = getattr(original, name)
332 modified_val = getattr(modified, name)
333
334 if _is_message(original_val) or _is_message(modified_val):
335 if original_val != modified_val:
336 # Wrapper types do not need to include the .value part of the
337 # path.
338 if _is_wrapper(original_val) or _is_wrapper(modified_val):
339 answer.append(field_path)
340 elif not modified_val.ListFields():
341 answer.append(field_path)
342 else:
343 answer.extend(
344 _field_mask_helper(original_val, modified_val, field_path)
345 )
346 else:
347 if original_val != modified_val:
348 answer.append(field_path)
349
350 return answer
351
352
353def _get_path(current, name):
354 # gapic-generator-python appends underscores to field names
355 # that collide with python keywords.
356 # `_` is stripped away as it is not possible to
357 # natively define a field with a trailing underscore in protobuf.
358 # APIs will reject field masks if fields have trailing underscores.
359 # See https://github.com/googleapis/python-api-core/issues/227
360 name = name.rstrip("_")
361 if not current:
362 return name
363 return "%s.%s" % (current, name)
364
365
366def _is_message(value):
367 return isinstance(value, message.Message)
368
369
370def _is_wrapper(value):
371 return type(value) in _WRAPPER_TYPES