1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Contains container classes to represent different protocol buffer types.
9
10This file defines container classes which represent categories of protocol
11buffer field types which need extra maintenance. Currently these categories
12are:
13
14- Repeated scalar fields - These are all repeated fields which aren't
15 composite (e.g. they are of simple types like int32, string, etc).
16- Repeated composite fields - Repeated fields which are composite. This
17 includes groups and nested messages.
18"""
19
20import collections.abc
21import copy
22import pickle
23from typing import (
24 Any,
25 Iterable,
26 Iterator,
27 List,
28 MutableMapping,
29 MutableSequence,
30 NoReturn,
31 Optional,
32 Sequence,
33 TypeVar,
34 Union,
35 overload,
36)
37
38
39_T = TypeVar('_T')
40_K = TypeVar('_K')
41_V = TypeVar('_V')
42
43from google.protobuf.descriptor import FieldDescriptor
44
45class BaseContainer(Sequence[_T]):
46 """Base container class."""
47
48 # Minimizes memory usage and disallows assignment to other attributes.
49 __slots__ = ['_message_listener', '_values']
50
51 def __init__(self, message_listener: Any) -> None:
52 """
53 Args:
54 message_listener: A MessageListener implementation.
55 The RepeatedScalarFieldContainer will call this object's
56 Modified() method when it is modified.
57 """
58 self._message_listener = message_listener
59 self._values = []
60
61 @overload
62 def __getitem__(self, key: int) -> _T:
63 ...
64
65 @overload
66 def __getitem__(self, key: slice) -> List[_T]:
67 ...
68
69 def __getitem__(self, key):
70 """Retrieves item by the specified key."""
71 return self._values[key]
72
73 def __len__(self) -> int:
74 """Returns the number of elements in the container."""
75 return len(self._values)
76
77 def __ne__(self, other: Any) -> bool:
78 """Checks if another instance isn't equal to this one."""
79 # The concrete classes should define __eq__.
80 return not self == other
81
82 __hash__ = None
83
84 def __repr__(self) -> str:
85 return repr(self._values)
86
87 def sort(self, *args, **kwargs) -> None:
88 # Continue to support the old sort_function keyword argument.
89 # This is expected to be a rare occurrence, so use LBYL to avoid
90 # the overhead of actually catching KeyError.
91 if 'sort_function' in kwargs:
92 kwargs['cmp'] = kwargs.pop('sort_function')
93 self._values.sort(*args, **kwargs)
94
95 def reverse(self) -> None:
96 self._values.reverse()
97
98
99# TODO: Remove this. BaseContainer does *not* conform to
100# MutableSequence, only its subclasses do.
101collections.abc.MutableSequence.register(BaseContainer)
102
103
104class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
105 """Simple, type-checked, list-like container for holding repeated scalars."""
106
107 # Disallows assignment to other attributes.
108 __slots__ = ['_type_checker', '_field']
109
110 def __init__(
111 self,
112 message_listener: Any,
113 type_checker: Any,
114 field: Any = None,
115 ) -> None:
116 """Args:
117
118 message_listener: A MessageListener implementation. The
119 RepeatedScalarFieldContainer will call this object's Modified() method
120 when it is modified.
121 type_checker: A type_checkers.ValueChecker instance to run on elements
122 inserted into this container.
123 """
124 super().__init__(message_listener)
125 self._type_checker = type_checker
126 self._field = field
127
128 def append(self, value: _T) -> None:
129 """Appends an item to the list. Similar to list.append()."""
130 self._values.append(self._type_checker.CheckValue(value))
131 if not self._message_listener.dirty:
132 self._message_listener.Modified()
133
134 def insert(self, key: int, value: _T) -> None:
135 """Inserts the item at the specified position. Similar to list.insert()."""
136 self._values.insert(key, self._type_checker.CheckValue(value))
137 if not self._message_listener.dirty:
138 self._message_listener.Modified()
139
140 def extend(self, elem_seq: Iterable[_T]) -> None:
141 """Extends by appending the given iterable. Similar to list.extend()."""
142 elem_seq_iter = iter(elem_seq)
143 new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
144 if new_values:
145 self._values.extend(new_values)
146 self._message_listener.Modified()
147
148 def MergeFrom(
149 self,
150 other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
151 ) -> None:
152 """Appends the contents of another repeated field of the same type to this
153 one. We do not check the types of the individual fields.
154 """
155 self._values.extend(other)
156 self._message_listener.Modified()
157
158 def remove(self, elem: _T):
159 """Removes an item from the list. Similar to list.remove()."""
160 self._values.remove(elem)
161 self._message_listener.Modified()
162
163 def pop(self, key: Optional[int] = -1) -> _T:
164 """Removes and returns an item at a given index. Similar to list.pop()."""
165 value = self._values[key]
166 self.__delitem__(key)
167 return value
168
169 @overload
170 def __setitem__(self, key: int, value: _T) -> None:
171 ...
172
173 @overload
174 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
175 ...
176
177 def __setitem__(self, key, value) -> None:
178 """Sets the item on the specified position."""
179 if isinstance(key, slice):
180 if key.step is not None:
181 raise ValueError('Extended slices not supported')
182 self._values[key] = map(self._type_checker.CheckValue, value)
183 self._message_listener.Modified()
184 else:
185 self._values[key] = self._type_checker.CheckValue(value)
186 self._message_listener.Modified()
187
188 def __delitem__(self, key: Union[int, slice]) -> None:
189 """Deletes the item at the specified position."""
190 del self._values[key]
191 self._message_listener.Modified()
192
193 def __eq__(self, other: Any) -> bool:
194 """Compares the current instance with another one."""
195 if self is other:
196 return True
197 # Special case for the same type which should be common and fast.
198 if isinstance(other, self.__class__):
199 return other._values == self._values
200 # We are presumably comparing against some other sequence type.
201 return other == self._values
202
203 def __deepcopy__(
204 self,
205 unused_memo: Any = None,
206 ) -> 'RepeatedScalarFieldContainer[_T]':
207 clone = RepeatedScalarFieldContainer(
208 copy.deepcopy(self._message_listener), self._type_checker, self._field
209 )
210 clone.MergeFrom(self)
211 return clone
212
213 def __reduce__(self, **kwargs) -> NoReturn:
214 raise pickle.PickleError(
215 "Can't pickle repeated scalar fields, convert to list first")
216
217 def __array__(self, dtype=None, copy=None):
218 import numpy as np
219
220 if dtype is None:
221 cpp_type = self._field.cpp_type
222 if cpp_type == FieldDescriptor.CPPTYPE_INT32:
223 dtype = np.int32
224 elif cpp_type == FieldDescriptor.CPPTYPE_INT64:
225 dtype = np.int64
226 elif cpp_type == FieldDescriptor.CPPTYPE_UINT32:
227 dtype = np.uint32
228 elif cpp_type == FieldDescriptor.CPPTYPE_UINT64:
229 dtype = np.uint64
230 elif cpp_type == FieldDescriptor.CPPTYPE_DOUBLE:
231 dtype = np.float64
232 elif cpp_type == FieldDescriptor.CPPTYPE_FLOAT:
233 dtype = np.float32
234 elif cpp_type == FieldDescriptor.CPPTYPE_BOOL:
235 dtype = np.bool
236 elif cpp_type == FieldDescriptor.CPPTYPE_ENUM:
237 dtype = np.int32
238 elif self._field.type == FieldDescriptor.TYPE_BYTES:
239 dtype = 'S'
240 elif self._field.type == FieldDescriptor.TYPE_STRING:
241 dtype = str
242 else:
243 raise SystemError(
244 'Code should never reach here: message type detected in'
245 ' RepeatedScalarFieldContainer'
246 )
247 return np.array(self._values, dtype=dtype, copy=True)
248
249
250# TODO: Constrain T to be a subtype of Message.
251class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
252 """Simple, list-like container for holding repeated composite fields."""
253
254 # Disallows assignment to other attributes.
255 __slots__ = ['_message_descriptor']
256
257 def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
258 """
259 Note that we pass in a descriptor instead of the generated directly,
260 since at the time we construct a _RepeatedCompositeFieldContainer we
261 haven't yet necessarily initialized the type that will be contained in the
262 container.
263
264 Args:
265 message_listener: A MessageListener implementation.
266 The RepeatedCompositeFieldContainer will call this object's
267 Modified() method when it is modified.
268 message_descriptor: A Descriptor instance describing the protocol type
269 that should be present in this container. We'll use the
270 _concrete_class field of this descriptor when the client calls add().
271 """
272 super().__init__(message_listener)
273 self._message_descriptor = message_descriptor
274
275 def add(self, **kwargs: Any) -> _T:
276 """Adds a new element at the end of the list and returns it. Keyword
277 arguments may be used to initialize the element.
278 """
279 new_element = self._message_descriptor._concrete_class(**kwargs)
280 new_element._SetListener(self._message_listener)
281 self._values.append(new_element)
282 if not self._message_listener.dirty:
283 self._message_listener.Modified()
284 return new_element
285
286 def append(self, value: _T) -> None:
287 """Appends one element by copying the message."""
288 new_element = self._message_descriptor._concrete_class()
289 new_element._SetListener(self._message_listener)
290 new_element.CopyFrom(value)
291 self._values.append(new_element)
292 if not self._message_listener.dirty:
293 self._message_listener.Modified()
294
295 def insert(self, key: int, value: _T) -> None:
296 """Inserts the item at the specified position by copying."""
297 new_element = self._message_descriptor._concrete_class()
298 new_element._SetListener(self._message_listener)
299 new_element.CopyFrom(value)
300 self._values.insert(key, new_element)
301 if not self._message_listener.dirty:
302 self._message_listener.Modified()
303
304 def extend(self, elem_seq: Iterable[_T]) -> None:
305 """Extends by appending the given sequence of elements of the same type
306
307 as this one, copying each individual message.
308 """
309 message_class = self._message_descriptor._concrete_class
310 listener = self._message_listener
311 values = self._values
312 for message in elem_seq:
313 new_element = message_class()
314 new_element._SetListener(listener)
315 new_element.MergeFrom(message)
316 values.append(new_element)
317 listener.Modified()
318
319 def MergeFrom(
320 self,
321 other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
322 ) -> None:
323 """Appends the contents of another repeated field of the same type to this
324 one, copying each individual message.
325 """
326 self.extend(other)
327
328 def remove(self, elem: _T) -> None:
329 """Removes an item from the list. Similar to list.remove()."""
330 self._values.remove(elem)
331 self._message_listener.Modified()
332
333 def pop(self, key: Optional[int] = -1) -> _T:
334 """Removes and returns an item at a given index. Similar to list.pop()."""
335 value = self._values[key]
336 self.__delitem__(key)
337 return value
338
339 @overload
340 def __setitem__(self, key: int, value: _T) -> None:
341 ...
342
343 @overload
344 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
345 ...
346
347 def __setitem__(self, key, value):
348 # This method is implemented to make RepeatedCompositeFieldContainer
349 # structurally compatible with typing.MutableSequence. It is
350 # otherwise unsupported and will always raise an error.
351 raise TypeError(
352 f'{self.__class__.__name__} object does not support item assignment')
353
354 def __delitem__(self, key: Union[int, slice]) -> None:
355 """Deletes the item at the specified position."""
356 del self._values[key]
357 self._message_listener.Modified()
358
359 def __eq__(self, other: Any) -> bool:
360 """Compares the current instance with another one."""
361 if self is other:
362 return True
363 if not isinstance(other, self.__class__):
364 raise TypeError('Can only compare repeated composite fields against '
365 'other repeated composite fields.')
366 return self._values == other._values
367
368
369class ScalarMap(MutableMapping[_K, _V]):
370 """Simple, type-checked, dict-like container for holding repeated scalars."""
371
372 # Disallows assignment to other attributes.
373 __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
374 '_entry_descriptor']
375
376 def __init__(
377 self,
378 message_listener: Any,
379 key_checker: Any,
380 value_checker: Any,
381 entry_descriptor: Any,
382 ) -> None:
383 """
384 Args:
385 message_listener: A MessageListener implementation.
386 The ScalarMap will call this object's Modified() method when it
387 is modified.
388 key_checker: A type_checkers.ValueChecker instance to run on keys
389 inserted into this container.
390 value_checker: A type_checkers.ValueChecker instance to run on values
391 inserted into this container.
392 entry_descriptor: The MessageDescriptor of a map entry: key and value.
393 """
394 self._message_listener = message_listener
395 self._key_checker = key_checker
396 self._value_checker = value_checker
397 self._entry_descriptor = entry_descriptor
398 self._values = {}
399
400 def __getitem__(self, key: _K) -> _V:
401 try:
402 return self._values[key]
403 except KeyError:
404 key = self._key_checker.CheckValue(key)
405 val = self._value_checker.DefaultValue()
406 self._values[key] = val
407 return val
408
409 def __contains__(self, item: _K) -> bool:
410 # We check the key's type to match the strong-typing flavor of the API.
411 # Also this makes it easier to match the behavior of the C++ implementation.
412 self._key_checker.CheckValue(item)
413 return item in self._values
414
415 @overload
416 def get(self, key: _K) -> Optional[_V]:
417 ...
418
419 @overload
420 def get(self, key: _K, default: _T) -> Union[_V, _T]:
421 ...
422
423 # We need to override this explicitly, because our defaultdict-like behavior
424 # will make the default implementation (from our base class) always insert
425 # the key.
426 def get(self, key, default=None):
427 if key in self:
428 return self[key]
429 else:
430 return default
431
432 def __setitem__(self, key: _K, value: _V) -> _T:
433 checked_key = self._key_checker.CheckValue(key)
434 checked_value = self._value_checker.CheckValue(value)
435 self._values[checked_key] = checked_value
436 self._message_listener.Modified()
437
438 def __delitem__(self, key: _K) -> None:
439 del self._values[key]
440 self._message_listener.Modified()
441
442 def __len__(self) -> int:
443 return len(self._values)
444
445 def __iter__(self) -> Iterator[_K]:
446 return iter(self._values)
447
448 def __repr__(self) -> str:
449 return repr(self._values)
450
451 def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
452 if value == None:
453 raise ValueError('The value for scalar map setdefault must be set.')
454 if key not in self._values:
455 self.__setitem__(key, value)
456 return self[key]
457
458 def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
459 self._values.update(other._values)
460 self._message_listener.Modified()
461
462 def InvalidateIterators(self) -> None:
463 # It appears that the only way to reliably invalidate iterators to
464 # self._values is to ensure that its size changes.
465 original = self._values
466 self._values = original.copy()
467 original[None] = None
468
469 # This is defined in the abstract base, but we can do it much more cheaply.
470 def clear(self) -> None:
471 self._values.clear()
472 self._message_listener.Modified()
473
474 def GetEntryClass(self) -> Any:
475 return self._entry_descriptor._concrete_class
476
477
478class MessageMap(MutableMapping[_K, _V]):
479 """Simple, type-checked, dict-like container for with submessage values."""
480
481 # Disallows assignment to other attributes.
482 __slots__ = ['_key_checker', '_values', '_message_listener',
483 '_message_descriptor', '_entry_descriptor']
484
485 def __init__(
486 self,
487 message_listener: Any,
488 message_descriptor: Any,
489 key_checker: Any,
490 entry_descriptor: Any,
491 ) -> None:
492 """
493 Args:
494 message_listener: A MessageListener implementation.
495 The ScalarMap will call this object's Modified() method when it
496 is modified.
497 key_checker: A type_checkers.ValueChecker instance to run on keys
498 inserted into this container.
499 value_checker: A type_checkers.ValueChecker instance to run on values
500 inserted into this container.
501 entry_descriptor: The MessageDescriptor of a map entry: key and value.
502 """
503 self._message_listener = message_listener
504 self._message_descriptor = message_descriptor
505 self._key_checker = key_checker
506 self._entry_descriptor = entry_descriptor
507 self._values = {}
508
509 def __getitem__(self, key: _K) -> _V:
510 key = self._key_checker.CheckValue(key)
511 try:
512 return self._values[key]
513 except KeyError:
514 new_element = self._message_descriptor._concrete_class()
515 new_element._SetListener(self._message_listener)
516 self._values[key] = new_element
517 self._message_listener.Modified()
518 return new_element
519
520 def get_or_create(self, key: _K) -> _V:
521 """get_or_create() is an alias for getitem (ie. map[key]).
522
523 Args:
524 key: The key to get or create in the map.
525
526 This is useful in cases where you want to be explicit that the call is
527 mutating the map. This can avoid lint errors for statements like this
528 that otherwise would appear to be pointless statements:
529
530 msg.my_map[key]
531 """
532 return self[key]
533
534 @overload
535 def get(self, key: _K) -> Optional[_V]:
536 ...
537
538 @overload
539 def get(self, key: _K, default: _T) -> Union[_V, _T]:
540 ...
541
542 # We need to override this explicitly, because our defaultdict-like behavior
543 # will make the default implementation (from our base class) always insert
544 # the key.
545 def get(self, key, default=None):
546 if key in self:
547 return self[key]
548 else:
549 return default
550
551 def __contains__(self, item: _K) -> bool:
552 item = self._key_checker.CheckValue(item)
553 return item in self._values
554
555 def __setitem__(self, key: _K, value: _V) -> NoReturn:
556 raise ValueError('May not set values directly, call my_map[key].foo = 5')
557
558 def __delitem__(self, key: _K) -> None:
559 key = self._key_checker.CheckValue(key)
560 del self._values[key]
561 self._message_listener.Modified()
562
563 def __len__(self) -> int:
564 return len(self._values)
565
566 def __iter__(self) -> Iterator[_K]:
567 return iter(self._values)
568
569 def __repr__(self) -> str:
570 return repr(self._values)
571
572 def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
573 raise NotImplementedError(
574 'Set message map value directly is not supported, call'
575 ' my_map[key].foo = 5'
576 )
577
578 def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
579 # pylint: disable=protected-access
580 for key in other._values:
581 # According to documentation: "When parsing from the wire or when merging,
582 # if there are duplicate map keys the last key seen is used".
583 if key in self:
584 del self[key]
585 self[key].CopyFrom(other[key])
586 # self._message_listener.Modified() not required here, because
587 # mutations to submessages already propagate.
588
589 def InvalidateIterators(self) -> None:
590 # It appears that the only way to reliably invalidate iterators to
591 # self._values is to ensure that its size changes.
592 original = self._values
593 self._values = original.copy()
594 original[None] = None
595
596 # This is defined in the abstract base, but we can do it much more cheaply.
597 def clear(self) -> None:
598 self._values.clear()
599 self._message_listener.Modified()
600
601 def GetEntryClass(self) -> Any:
602 return self._entry_descriptor._concrete_class
603
604
605class _UnknownField:
606 """A parsed unknown field."""
607
608 # Disallows assignment to other attributes.
609 __slots__ = ['_field_number', '_wire_type', '_data']
610
611 def __init__(self, field_number, wire_type, data):
612 self._field_number = field_number
613 self._wire_type = wire_type
614 self._data = data
615 return
616
617 def __lt__(self, other):
618 # pylint: disable=protected-access
619 return self._field_number < other._field_number
620
621 def __eq__(self, other):
622 if self is other:
623 return True
624 # pylint: disable=protected-access
625 return (self._field_number == other._field_number and
626 self._wire_type == other._wire_type and
627 self._data == other._data)
628
629
630class UnknownFieldRef: # pylint: disable=missing-class-docstring
631
632 def __init__(self, parent, index):
633 self._parent = parent
634 self._index = index
635
636 def _check_valid(self):
637 if not self._parent:
638 raise ValueError('UnknownField does not exist. '
639 'The parent message might be cleared.')
640 if self._index >= len(self._parent):
641 raise ValueError('UnknownField does not exist. '
642 'The parent message might be cleared.')
643
644 @property
645 def field_number(self):
646 self._check_valid()
647 # pylint: disable=protected-access
648 return self._parent._internal_get(self._index)._field_number
649
650 @property
651 def wire_type(self):
652 self._check_valid()
653 # pylint: disable=protected-access
654 return self._parent._internal_get(self._index)._wire_type
655
656 @property
657 def data(self):
658 self._check_valid()
659 # pylint: disable=protected-access
660 return self._parent._internal_get(self._index)._data
661
662
663class UnknownFieldSet:
664 """UnknownField container"""
665
666 # Disallows assignment to other attributes.
667 __slots__ = ['_values']
668
669 def __init__(self):
670 self._values = []
671
672 def __getitem__(self, index):
673 if self._values is None:
674 raise ValueError('UnknownFields does not exist. '
675 'The parent message might be cleared.')
676 size = len(self._values)
677 if index < 0:
678 index += size
679 if index < 0 or index >= size:
680 raise IndexError('index %d out of range'.index)
681
682 return UnknownFieldRef(self, index)
683
684 def _internal_get(self, index):
685 return self._values[index]
686
687 def __len__(self):
688 if self._values is None:
689 raise ValueError('UnknownFields does not exist. '
690 'The parent message might be cleared.')
691 return len(self._values)
692
693 def _add(self, field_number, wire_type, data):
694 unknown_field = _UnknownField(field_number, wire_type, data)
695 self._values.append(unknown_field)
696 return unknown_field
697
698 def __iter__(self):
699 for i in range(len(self)):
700 yield UnknownFieldRef(self, i)
701
702 def _extend(self, other):
703 if other is None:
704 return
705 # pylint: disable=protected-access
706 self._values.extend(other._values)
707
708 def __eq__(self, other):
709 if self is other:
710 return True
711 # Sort unknown fields because their order shouldn't
712 # affect equality test.
713 values = list(self._values)
714 if other is None:
715 return not values
716 values.sort()
717 # pylint: disable=protected-access
718 other_values = sorted(other._values)
719 return values == other_values
720
721 def _clear(self):
722 for value in self._values:
723 # pylint: disable=protected-access
724 if isinstance(value._data, UnknownFieldSet):
725 value._data._clear() # pylint: disable=protected-access
726 self._values = None