1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Contains container classes to represent different protocol buffer types.
9
10This file defines container classes which represent categories of protocol
11buffer field types which need extra maintenance. Currently these categories
12are:
13
14- Repeated scalar fields - These are all repeated fields which aren't
15 composite (e.g. they are of simple types like int32, string, etc).
16- Repeated composite fields - Repeated fields which are composite. This
17 includes groups and nested messages.
18"""
19
20import collections.abc
21import copy
22import pickle
23from typing import (
24 Any,
25 Iterable,
26 Iterator,
27 List,
28 MutableMapping,
29 MutableSequence,
30 NoReturn,
31 Optional,
32 Sequence,
33 TypeVar,
34 Union,
35 overload,
36)
37
38
39_T = TypeVar('_T')
40_K = TypeVar('_K')
41_V = TypeVar('_V')
42
43
44class BaseContainer(Sequence[_T]):
45 """Base container class."""
46
47 # Minimizes memory usage and disallows assignment to other attributes.
48 __slots__ = ['_message_listener', '_values']
49
50 def __init__(self, message_listener: Any) -> None:
51 """
52 Args:
53 message_listener: A MessageListener implementation.
54 The RepeatedScalarFieldContainer will call this object's
55 Modified() method when it is modified.
56 """
57 self._message_listener = message_listener
58 self._values = []
59
60 @overload
61 def __getitem__(self, key: int) -> _T:
62 ...
63
64 @overload
65 def __getitem__(self, key: slice) -> List[_T]:
66 ...
67
68 def __getitem__(self, key):
69 """Retrieves item by the specified key."""
70 return self._values[key]
71
72 def __len__(self) -> int:
73 """Returns the number of elements in the container."""
74 return len(self._values)
75
76 def __ne__(self, other: Any) -> bool:
77 """Checks if another instance isn't equal to this one."""
78 # The concrete classes should define __eq__.
79 return not self == other
80
81 __hash__ = None
82
83 def __repr__(self) -> str:
84 return repr(self._values)
85
86 def sort(self, *args, **kwargs) -> None:
87 # Continue to support the old sort_function keyword argument.
88 # This is expected to be a rare occurrence, so use LBYL to avoid
89 # the overhead of actually catching KeyError.
90 if 'sort_function' in kwargs:
91 kwargs['cmp'] = kwargs.pop('sort_function')
92 self._values.sort(*args, **kwargs)
93
94 def reverse(self) -> None:
95 self._values.reverse()
96
97
98# TODO: Remove this. BaseContainer does *not* conform to
99# MutableSequence, only its subclasses do.
100collections.abc.MutableSequence.register(BaseContainer)
101
102
103class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
104 """Simple, type-checked, list-like container for holding repeated scalars."""
105
106 # Disallows assignment to other attributes.
107 __slots__ = ['_type_checker']
108
109 def __init__(
110 self,
111 message_listener: Any,
112 type_checker: Any,
113 ) -> None:
114 """Args:
115
116 message_listener: A MessageListener implementation. The
117 RepeatedScalarFieldContainer will call this object's Modified() method
118 when it is modified.
119 type_checker: A type_checkers.ValueChecker instance to run on elements
120 inserted into this container.
121 """
122 super().__init__(message_listener)
123 self._type_checker = type_checker
124
125 def append(self, value: _T) -> None:
126 """Appends an item to the list. Similar to list.append()."""
127 self._values.append(self._type_checker.CheckValue(value))
128 if not self._message_listener.dirty:
129 self._message_listener.Modified()
130
131 def insert(self, key: int, value: _T) -> None:
132 """Inserts the item at the specified position. Similar to list.insert()."""
133 self._values.insert(key, self._type_checker.CheckValue(value))
134 if not self._message_listener.dirty:
135 self._message_listener.Modified()
136
137 def extend(self, elem_seq: Iterable[_T]) -> None:
138 """Extends by appending the given iterable. Similar to list.extend()."""
139 elem_seq_iter = iter(elem_seq)
140 new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
141 if new_values:
142 self._values.extend(new_values)
143 self._message_listener.Modified()
144
145 def MergeFrom(
146 self,
147 other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
148 ) -> None:
149 """Appends the contents of another repeated field of the same type to this
150 one. We do not check the types of the individual fields.
151 """
152 self._values.extend(other)
153 self._message_listener.Modified()
154
155 def remove(self, elem: _T):
156 """Removes an item from the list. Similar to list.remove()."""
157 self._values.remove(elem)
158 self._message_listener.Modified()
159
160 def pop(self, key: Optional[int] = -1) -> _T:
161 """Removes and returns an item at a given index. Similar to list.pop()."""
162 value = self._values[key]
163 self.__delitem__(key)
164 return value
165
166 @overload
167 def __setitem__(self, key: int, value: _T) -> None:
168 ...
169
170 @overload
171 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
172 ...
173
174 def __setitem__(self, key, value) -> None:
175 """Sets the item on the specified position."""
176 if isinstance(key, slice):
177 if key.step is not None:
178 raise ValueError('Extended slices not supported')
179 self._values[key] = map(self._type_checker.CheckValue, value)
180 self._message_listener.Modified()
181 else:
182 self._values[key] = self._type_checker.CheckValue(value)
183 self._message_listener.Modified()
184
185 def __delitem__(self, key: Union[int, slice]) -> None:
186 """Deletes the item at the specified position."""
187 del self._values[key]
188 self._message_listener.Modified()
189
190 def __eq__(self, other: Any) -> bool:
191 """Compares the current instance with another one."""
192 if self is other:
193 return True
194 # Special case for the same type which should be common and fast.
195 if isinstance(other, self.__class__):
196 return other._values == self._values
197 # We are presumably comparing against some other sequence type.
198 return other == self._values
199
200 def __deepcopy__(
201 self,
202 unused_memo: Any = None,
203 ) -> 'RepeatedScalarFieldContainer[_T]':
204 clone = RepeatedScalarFieldContainer(
205 copy.deepcopy(self._message_listener), self._type_checker)
206 clone.MergeFrom(self)
207 return clone
208
209 def __reduce__(self, **kwargs) -> NoReturn:
210 raise pickle.PickleError(
211 "Can't pickle repeated scalar fields, convert to list first")
212
213
214# TODO: Constrain T to be a subtype of Message.
215class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
216 """Simple, list-like container for holding repeated composite fields."""
217
218 # Disallows assignment to other attributes.
219 __slots__ = ['_message_descriptor']
220
221 def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
222 """
223 Note that we pass in a descriptor instead of the generated directly,
224 since at the time we construct a _RepeatedCompositeFieldContainer we
225 haven't yet necessarily initialized the type that will be contained in the
226 container.
227
228 Args:
229 message_listener: A MessageListener implementation.
230 The RepeatedCompositeFieldContainer will call this object's
231 Modified() method when it is modified.
232 message_descriptor: A Descriptor instance describing the protocol type
233 that should be present in this container. We'll use the
234 _concrete_class field of this descriptor when the client calls add().
235 """
236 super().__init__(message_listener)
237 self._message_descriptor = message_descriptor
238
239 def add(self, **kwargs: Any) -> _T:
240 """Adds a new element at the end of the list and returns it. Keyword
241 arguments may be used to initialize the element.
242 """
243 new_element = self._message_descriptor._concrete_class(**kwargs)
244 new_element._SetListener(self._message_listener)
245 self._values.append(new_element)
246 if not self._message_listener.dirty:
247 self._message_listener.Modified()
248 return new_element
249
250 def append(self, value: _T) -> None:
251 """Appends one element by copying the message."""
252 new_element = self._message_descriptor._concrete_class()
253 new_element._SetListener(self._message_listener)
254 new_element.CopyFrom(value)
255 self._values.append(new_element)
256 if not self._message_listener.dirty:
257 self._message_listener.Modified()
258
259 def insert(self, key: int, value: _T) -> None:
260 """Inserts the item at the specified position by copying."""
261 new_element = self._message_descriptor._concrete_class()
262 new_element._SetListener(self._message_listener)
263 new_element.CopyFrom(value)
264 self._values.insert(key, new_element)
265 if not self._message_listener.dirty:
266 self._message_listener.Modified()
267
268 def extend(self, elem_seq: Iterable[_T]) -> None:
269 """Extends by appending the given sequence of elements of the same type
270
271 as this one, copying each individual message.
272 """
273 message_class = self._message_descriptor._concrete_class
274 listener = self._message_listener
275 values = self._values
276 for message in elem_seq:
277 new_element = message_class()
278 new_element._SetListener(listener)
279 new_element.MergeFrom(message)
280 values.append(new_element)
281 listener.Modified()
282
283 def MergeFrom(
284 self,
285 other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
286 ) -> None:
287 """Appends the contents of another repeated field of the same type to this
288 one, copying each individual message.
289 """
290 self.extend(other)
291
292 def remove(self, elem: _T) -> None:
293 """Removes an item from the list. Similar to list.remove()."""
294 self._values.remove(elem)
295 self._message_listener.Modified()
296
297 def pop(self, key: Optional[int] = -1) -> _T:
298 """Removes and returns an item at a given index. Similar to list.pop()."""
299 value = self._values[key]
300 self.__delitem__(key)
301 return value
302
303 @overload
304 def __setitem__(self, key: int, value: _T) -> None:
305 ...
306
307 @overload
308 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
309 ...
310
311 def __setitem__(self, key, value):
312 # This method is implemented to make RepeatedCompositeFieldContainer
313 # structurally compatible with typing.MutableSequence. It is
314 # otherwise unsupported and will always raise an error.
315 raise TypeError(
316 f'{self.__class__.__name__} object does not support item assignment')
317
318 def __delitem__(self, key: Union[int, slice]) -> None:
319 """Deletes the item at the specified position."""
320 del self._values[key]
321 self._message_listener.Modified()
322
323 def __eq__(self, other: Any) -> bool:
324 """Compares the current instance with another one."""
325 if self is other:
326 return True
327 if not isinstance(other, self.__class__):
328 raise TypeError('Can only compare repeated composite fields against '
329 'other repeated composite fields.')
330 return self._values == other._values
331
332
333class ScalarMap(MutableMapping[_K, _V]):
334 """Simple, type-checked, dict-like container for holding repeated scalars."""
335
336 # Disallows assignment to other attributes.
337 __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
338 '_entry_descriptor']
339
340 def __init__(
341 self,
342 message_listener: Any,
343 key_checker: Any,
344 value_checker: Any,
345 entry_descriptor: Any,
346 ) -> None:
347 """
348 Args:
349 message_listener: A MessageListener implementation.
350 The ScalarMap will call this object's Modified() method when it
351 is modified.
352 key_checker: A type_checkers.ValueChecker instance to run on keys
353 inserted into this container.
354 value_checker: A type_checkers.ValueChecker instance to run on values
355 inserted into this container.
356 entry_descriptor: The MessageDescriptor of a map entry: key and value.
357 """
358 self._message_listener = message_listener
359 self._key_checker = key_checker
360 self._value_checker = value_checker
361 self._entry_descriptor = entry_descriptor
362 self._values = {}
363
364 def __getitem__(self, key: _K) -> _V:
365 try:
366 return self._values[key]
367 except KeyError:
368 key = self._key_checker.CheckValue(key)
369 val = self._value_checker.DefaultValue()
370 self._values[key] = val
371 return val
372
373 def __contains__(self, item: _K) -> bool:
374 # We check the key's type to match the strong-typing flavor of the API.
375 # Also this makes it easier to match the behavior of the C++ implementation.
376 self._key_checker.CheckValue(item)
377 return item in self._values
378
379 @overload
380 def get(self, key: _K) -> Optional[_V]:
381 ...
382
383 @overload
384 def get(self, key: _K, default: _T) -> Union[_V, _T]:
385 ...
386
387 # We need to override this explicitly, because our defaultdict-like behavior
388 # will make the default implementation (from our base class) always insert
389 # the key.
390 def get(self, key, default=None):
391 if key in self:
392 return self[key]
393 else:
394 return default
395
396 def __setitem__(self, key: _K, value: _V) -> _T:
397 checked_key = self._key_checker.CheckValue(key)
398 checked_value = self._value_checker.CheckValue(value)
399 self._values[checked_key] = checked_value
400 self._message_listener.Modified()
401
402 def __delitem__(self, key: _K) -> None:
403 del self._values[key]
404 self._message_listener.Modified()
405
406 def __len__(self) -> int:
407 return len(self._values)
408
409 def __iter__(self) -> Iterator[_K]:
410 return iter(self._values)
411
412 def __repr__(self) -> str:
413 return repr(self._values)
414
415 def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
416 if value == None:
417 raise ValueError('The value for scalar map setdefault must be set.')
418 if key not in self._values:
419 self.__setitem__(key, value)
420 return self[key]
421
422 def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
423 self._values.update(other._values)
424 self._message_listener.Modified()
425
426 def InvalidateIterators(self) -> None:
427 # It appears that the only way to reliably invalidate iterators to
428 # self._values is to ensure that its size changes.
429 original = self._values
430 self._values = original.copy()
431 original[None] = None
432
433 # This is defined in the abstract base, but we can do it much more cheaply.
434 def clear(self) -> None:
435 self._values.clear()
436 self._message_listener.Modified()
437
438 def GetEntryClass(self) -> Any:
439 return self._entry_descriptor._concrete_class
440
441
442class MessageMap(MutableMapping[_K, _V]):
443 """Simple, type-checked, dict-like container for with submessage values."""
444
445 # Disallows assignment to other attributes.
446 __slots__ = ['_key_checker', '_values', '_message_listener',
447 '_message_descriptor', '_entry_descriptor']
448
449 def __init__(
450 self,
451 message_listener: Any,
452 message_descriptor: Any,
453 key_checker: Any,
454 entry_descriptor: Any,
455 ) -> None:
456 """
457 Args:
458 message_listener: A MessageListener implementation.
459 The ScalarMap will call this object's Modified() method when it
460 is modified.
461 key_checker: A type_checkers.ValueChecker instance to run on keys
462 inserted into this container.
463 value_checker: A type_checkers.ValueChecker instance to run on values
464 inserted into this container.
465 entry_descriptor: The MessageDescriptor of a map entry: key and value.
466 """
467 self._message_listener = message_listener
468 self._message_descriptor = message_descriptor
469 self._key_checker = key_checker
470 self._entry_descriptor = entry_descriptor
471 self._values = {}
472
473 def __getitem__(self, key: _K) -> _V:
474 key = self._key_checker.CheckValue(key)
475 try:
476 return self._values[key]
477 except KeyError:
478 new_element = self._message_descriptor._concrete_class()
479 new_element._SetListener(self._message_listener)
480 self._values[key] = new_element
481 self._message_listener.Modified()
482 return new_element
483
484 def get_or_create(self, key: _K) -> _V:
485 """get_or_create() is an alias for getitem (ie. map[key]).
486
487 Args:
488 key: The key to get or create in the map.
489
490 This is useful in cases where you want to be explicit that the call is
491 mutating the map. This can avoid lint errors for statements like this
492 that otherwise would appear to be pointless statements:
493
494 msg.my_map[key]
495 """
496 return self[key]
497
498 @overload
499 def get(self, key: _K) -> Optional[_V]:
500 ...
501
502 @overload
503 def get(self, key: _K, default: _T) -> Union[_V, _T]:
504 ...
505
506 # We need to override this explicitly, because our defaultdict-like behavior
507 # will make the default implementation (from our base class) always insert
508 # the key.
509 def get(self, key, default=None):
510 if key in self:
511 return self[key]
512 else:
513 return default
514
515 def __contains__(self, item: _K) -> bool:
516 item = self._key_checker.CheckValue(item)
517 return item in self._values
518
519 def __setitem__(self, key: _K, value: _V) -> NoReturn:
520 raise ValueError('May not set values directly, call my_map[key].foo = 5')
521
522 def __delitem__(self, key: _K) -> None:
523 key = self._key_checker.CheckValue(key)
524 del self._values[key]
525 self._message_listener.Modified()
526
527 def __len__(self) -> int:
528 return len(self._values)
529
530 def __iter__(self) -> Iterator[_K]:
531 return iter(self._values)
532
533 def __repr__(self) -> str:
534 return repr(self._values)
535
536 def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
537 raise NotImplementedError(
538 'Set message map value directly is not supported, call'
539 ' my_map[key].foo = 5'
540 )
541
542 def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
543 # pylint: disable=protected-access
544 for key in other._values:
545 # According to documentation: "When parsing from the wire or when merging,
546 # if there are duplicate map keys the last key seen is used".
547 if key in self:
548 del self[key]
549 self[key].CopyFrom(other[key])
550 # self._message_listener.Modified() not required here, because
551 # mutations to submessages already propagate.
552
553 def InvalidateIterators(self) -> None:
554 # It appears that the only way to reliably invalidate iterators to
555 # self._values is to ensure that its size changes.
556 original = self._values
557 self._values = original.copy()
558 original[None] = None
559
560 # This is defined in the abstract base, but we can do it much more cheaply.
561 def clear(self) -> None:
562 self._values.clear()
563 self._message_listener.Modified()
564
565 def GetEntryClass(self) -> Any:
566 return self._entry_descriptor._concrete_class
567
568
569class _UnknownField:
570 """A parsed unknown field."""
571
572 # Disallows assignment to other attributes.
573 __slots__ = ['_field_number', '_wire_type', '_data']
574
575 def __init__(self, field_number, wire_type, data):
576 self._field_number = field_number
577 self._wire_type = wire_type
578 self._data = data
579 return
580
581 def __lt__(self, other):
582 # pylint: disable=protected-access
583 return self._field_number < other._field_number
584
585 def __eq__(self, other):
586 if self is other:
587 return True
588 # pylint: disable=protected-access
589 return (self._field_number == other._field_number and
590 self._wire_type == other._wire_type and
591 self._data == other._data)
592
593
594class UnknownFieldRef: # pylint: disable=missing-class-docstring
595
596 def __init__(self, parent, index):
597 self._parent = parent
598 self._index = index
599
600 def _check_valid(self):
601 if not self._parent:
602 raise ValueError('UnknownField does not exist. '
603 'The parent message might be cleared.')
604 if self._index >= len(self._parent):
605 raise ValueError('UnknownField does not exist. '
606 'The parent message might be cleared.')
607
608 @property
609 def field_number(self):
610 self._check_valid()
611 # pylint: disable=protected-access
612 return self._parent._internal_get(self._index)._field_number
613
614 @property
615 def wire_type(self):
616 self._check_valid()
617 # pylint: disable=protected-access
618 return self._parent._internal_get(self._index)._wire_type
619
620 @property
621 def data(self):
622 self._check_valid()
623 # pylint: disable=protected-access
624 return self._parent._internal_get(self._index)._data
625
626
627class UnknownFieldSet:
628 """UnknownField container"""
629
630 # Disallows assignment to other attributes.
631 __slots__ = ['_values']
632
633 def __init__(self):
634 self._values = []
635
636 def __getitem__(self, index):
637 if self._values is None:
638 raise ValueError('UnknownFields does not exist. '
639 'The parent message might be cleared.')
640 size = len(self._values)
641 if index < 0:
642 index += size
643 if index < 0 or index >= size:
644 raise IndexError('index %d out of range'.index)
645
646 return UnknownFieldRef(self, index)
647
648 def _internal_get(self, index):
649 return self._values[index]
650
651 def __len__(self):
652 if self._values is None:
653 raise ValueError('UnknownFields does not exist. '
654 'The parent message might be cleared.')
655 return len(self._values)
656
657 def _add(self, field_number, wire_type, data):
658 unknown_field = _UnknownField(field_number, wire_type, data)
659 self._values.append(unknown_field)
660 return unknown_field
661
662 def __iter__(self):
663 for i in range(len(self)):
664 yield UnknownFieldRef(self, i)
665
666 def _extend(self, other):
667 if other is None:
668 return
669 # pylint: disable=protected-access
670 self._values.extend(other._values)
671
672 def __eq__(self, other):
673 if self is other:
674 return True
675 # Sort unknown fields because their order shouldn't
676 # affect equality test.
677 values = list(self._values)
678 if other is None:
679 return not values
680 values.sort()
681 # pylint: disable=protected-access
682 other_values = sorted(other._values)
683 return values == other_values
684
685 def _clear(self):
686 for value in self._values:
687 # pylint: disable=protected-access
688 if isinstance(value._data, UnknownFieldSet):
689 value._data._clear() # pylint: disable=protected-access
690 self._values = None