1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Contains container classes to represent different protocol buffer types.
9
10This file defines container classes which represent categories of protocol
11buffer field types which need extra maintenance. Currently these categories
12are:
13
14- Repeated scalar fields - These are all repeated fields which aren't
15 composite (e.g. they are of simple types like int32, string, etc).
16- Repeated composite fields - Repeated fields which are composite. This
17 includes groups and nested messages.
18"""
19
20import collections.abc
21import copy
22import pickle
23from typing import (
24 Any,
25 Iterable,
26 Iterator,
27 List,
28 MutableMapping,
29 MutableSequence,
30 NoReturn,
31 Optional,
32 Sequence,
33 TypeVar,
34 Union,
35 overload,
36)
37
38
39_T = TypeVar('_T')
40_K = TypeVar('_K')
41_V = TypeVar('_V')
42
43
44class BaseContainer(Sequence[_T]):
45 """Base container class."""
46
47 # Minimizes memory usage and disallows assignment to other attributes.
48 __slots__ = ['_message_listener', '_values']
49
50 def __init__(self, message_listener: Any) -> None:
51 """
52 Args:
53 message_listener: A MessageListener implementation.
54 The RepeatedScalarFieldContainer will call this object's
55 Modified() method when it is modified.
56 """
57 self._message_listener = message_listener
58 self._values = []
59
60 @overload
61 def __getitem__(self, key: int) -> _T:
62 ...
63
64 @overload
65 def __getitem__(self, key: slice) -> List[_T]:
66 ...
67
68 def __getitem__(self, key):
69 """Retrieves item by the specified key."""
70 return self._values[key]
71
72 def __len__(self) -> int:
73 """Returns the number of elements in the container."""
74 return len(self._values)
75
76 def __ne__(self, other: Any) -> bool:
77 """Checks if another instance isn't equal to this one."""
78 # The concrete classes should define __eq__.
79 return not self == other
80
81 __hash__ = None
82
83 def __repr__(self) -> str:
84 return repr(self._values)
85
86 def sort(self, *args, **kwargs) -> None:
87 # Continue to support the old sort_function keyword argument.
88 # This is expected to be a rare occurrence, so use LBYL to avoid
89 # the overhead of actually catching KeyError.
90 if 'sort_function' in kwargs:
91 kwargs['cmp'] = kwargs.pop('sort_function')
92 self._values.sort(*args, **kwargs)
93
94 def reverse(self) -> None:
95 self._values.reverse()
96
97
98# TODO: Remove this. BaseContainer does *not* conform to
99# MutableSequence, only its subclasses do.
100collections.abc.MutableSequence.register(BaseContainer)
101
102
103class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
104 """Simple, type-checked, list-like container for holding repeated scalars."""
105
106 # Disallows assignment to other attributes.
107 __slots__ = ['_type_checker']
108
109 def __init__(
110 self,
111 message_listener: Any,
112 type_checker: Any,
113 ) -> None:
114 """Args:
115
116 message_listener: A MessageListener implementation. The
117 RepeatedScalarFieldContainer will call this object's Modified() method
118 when it is modified.
119 type_checker: A type_checkers.ValueChecker instance to run on elements
120 inserted into this container.
121 """
122 super().__init__(message_listener)
123 self._type_checker = type_checker
124
125 def append(self, value: _T) -> None:
126 """Appends an item to the list. Similar to list.append()."""
127 self._values.append(self._type_checker.CheckValue(value))
128 if not self._message_listener.dirty:
129 self._message_listener.Modified()
130
131 def insert(self, key: int, value: _T) -> None:
132 """Inserts the item at the specified position. Similar to list.insert()."""
133 self._values.insert(key, self._type_checker.CheckValue(value))
134 if not self._message_listener.dirty:
135 self._message_listener.Modified()
136
137 def extend(self, elem_seq: Iterable[_T]) -> None:
138 """Extends by appending the given iterable. Similar to list.extend()."""
139 elem_seq_iter = iter(elem_seq)
140 new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
141 if new_values:
142 self._values.extend(new_values)
143 self._message_listener.Modified()
144
145 def MergeFrom(
146 self,
147 other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
148 ) -> None:
149 """Appends the contents of another repeated field of the same type to this
150 one. We do not check the types of the individual fields.
151 """
152 self._values.extend(other)
153 self._message_listener.Modified()
154
155 def remove(self, elem: _T):
156 """Removes an item from the list. Similar to list.remove()."""
157 self._values.remove(elem)
158 self._message_listener.Modified()
159
160 def pop(self, key: Optional[int] = -1) -> _T:
161 """Removes and returns an item at a given index. Similar to list.pop()."""
162 value = self._values[key]
163 self.__delitem__(key)
164 return value
165
166 @overload
167 def __setitem__(self, key: int, value: _T) -> None:
168 ...
169
170 @overload
171 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
172 ...
173
174 def __setitem__(self, key, value) -> None:
175 """Sets the item on the specified position."""
176 if isinstance(key, slice):
177 if key.step is not None:
178 raise ValueError('Extended slices not supported')
179 self._values[key] = map(self._type_checker.CheckValue, value)
180 self._message_listener.Modified()
181 else:
182 self._values[key] = self._type_checker.CheckValue(value)
183 self._message_listener.Modified()
184
185 def __delitem__(self, key: Union[int, slice]) -> None:
186 """Deletes the item at the specified position."""
187 del self._values[key]
188 self._message_listener.Modified()
189
190 def __eq__(self, other: Any) -> bool:
191 """Compares the current instance with another one."""
192 if self is other:
193 return True
194 # Special case for the same type which should be common and fast.
195 if isinstance(other, self.__class__):
196 return other._values == self._values
197 # We are presumably comparing against some other sequence type.
198 return other == self._values
199
200 def __deepcopy__(
201 self,
202 unused_memo: Any = None,
203 ) -> 'RepeatedScalarFieldContainer[_T]':
204 clone = RepeatedScalarFieldContainer(
205 copy.deepcopy(self._message_listener), self._type_checker)
206 clone.MergeFrom(self)
207 return clone
208
209 def __reduce__(self, **kwargs) -> NoReturn:
210 raise pickle.PickleError(
211 "Can't pickle repeated scalar fields, convert to list first")
212
213
214# TODO: Constrain T to be a subtype of Message.
215class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
216 """Simple, list-like container for holding repeated composite fields."""
217
218 # Disallows assignment to other attributes.
219 __slots__ = ['_message_descriptor']
220
221 def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
222 """
223 Note that we pass in a descriptor instead of the generated directly,
224 since at the time we construct a _RepeatedCompositeFieldContainer we
225 haven't yet necessarily initialized the type that will be contained in the
226 container.
227
228 Args:
229 message_listener: A MessageListener implementation.
230 The RepeatedCompositeFieldContainer will call this object's
231 Modified() method when it is modified.
232 message_descriptor: A Descriptor instance describing the protocol type
233 that should be present in this container. We'll use the
234 _concrete_class field of this descriptor when the client calls add().
235 """
236 super().__init__(message_listener)
237 self._message_descriptor = message_descriptor
238
239 def add(self, **kwargs: Any) -> _T:
240 """Adds a new element at the end of the list and returns it. Keyword
241 arguments may be used to initialize the element.
242 """
243 new_element = self._message_descriptor._concrete_class(**kwargs)
244 new_element._SetListener(self._message_listener)
245 self._values.append(new_element)
246 if not self._message_listener.dirty:
247 self._message_listener.Modified()
248 return new_element
249
250 def append(self, value: _T) -> None:
251 """Appends one element by copying the message."""
252 new_element = self._message_descriptor._concrete_class()
253 new_element._SetListener(self._message_listener)
254 new_element.CopyFrom(value)
255 self._values.append(new_element)
256 if not self._message_listener.dirty:
257 self._message_listener.Modified()
258
259 def insert(self, key: int, value: _T) -> None:
260 """Inserts the item at the specified position by copying."""
261 new_element = self._message_descriptor._concrete_class()
262 new_element._SetListener(self._message_listener)
263 new_element.CopyFrom(value)
264 self._values.insert(key, new_element)
265 if not self._message_listener.dirty:
266 self._message_listener.Modified()
267
268 def extend(self, elem_seq: Iterable[_T]) -> None:
269 """Extends by appending the given sequence of elements of the same type
270
271 as this one, copying each individual message.
272 """
273 message_class = self._message_descriptor._concrete_class
274 listener = self._message_listener
275 values = self._values
276 for message in elem_seq:
277 new_element = message_class()
278 new_element._SetListener(listener)
279 new_element.MergeFrom(message)
280 values.append(new_element)
281 listener.Modified()
282
283 def MergeFrom(
284 self,
285 other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
286 ) -> None:
287 """Appends the contents of another repeated field of the same type to this
288 one, copying each individual message.
289 """
290 self.extend(other)
291
292 def remove(self, elem: _T) -> None:
293 """Removes an item from the list. Similar to list.remove()."""
294 self._values.remove(elem)
295 self._message_listener.Modified()
296
297 def pop(self, key: Optional[int] = -1) -> _T:
298 """Removes and returns an item at a given index. Similar to list.pop()."""
299 value = self._values[key]
300 self.__delitem__(key)
301 return value
302
303 @overload
304 def __setitem__(self, key: int, value: _T) -> None:
305 ...
306
307 @overload
308 def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
309 ...
310
311 def __setitem__(self, key, value):
312 # This method is implemented to make RepeatedCompositeFieldContainer
313 # structurally compatible with typing.MutableSequence. It is
314 # otherwise unsupported and will always raise an error.
315 raise TypeError(
316 f'{self.__class__.__name__} object does not support item assignment')
317
318 def __delitem__(self, key: Union[int, slice]) -> None:
319 """Deletes the item at the specified position."""
320 del self._values[key]
321 self._message_listener.Modified()
322
323 def __eq__(self, other: Any) -> bool:
324 """Compares the current instance with another one."""
325 if self is other:
326 return True
327 if not isinstance(other, self.__class__):
328 raise TypeError('Can only compare repeated composite fields against '
329 'other repeated composite fields.')
330 return self._values == other._values
331
332
333class ScalarMap(MutableMapping[_K, _V]):
334 """Simple, type-checked, dict-like container for holding repeated scalars."""
335
336 # Disallows assignment to other attributes.
337 __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
338 '_entry_descriptor']
339
340 def __init__(
341 self,
342 message_listener: Any,
343 key_checker: Any,
344 value_checker: Any,
345 entry_descriptor: Any,
346 ) -> None:
347 """
348 Args:
349 message_listener: A MessageListener implementation.
350 The ScalarMap will call this object's Modified() method when it
351 is modified.
352 key_checker: A type_checkers.ValueChecker instance to run on keys
353 inserted into this container.
354 value_checker: A type_checkers.ValueChecker instance to run on values
355 inserted into this container.
356 entry_descriptor: The MessageDescriptor of a map entry: key and value.
357 """
358 self._message_listener = message_listener
359 self._key_checker = key_checker
360 self._value_checker = value_checker
361 self._entry_descriptor = entry_descriptor
362 self._values = {}
363
364 def __getitem__(self, key: _K) -> _V:
365 try:
366 return self._values[key]
367 except KeyError:
368 key = self._key_checker.CheckValue(key)
369 val = self._value_checker.DefaultValue()
370 self._values[key] = val
371 return val
372
373 def __contains__(self, item: _K) -> bool:
374 # We check the key's type to match the strong-typing flavor of the API.
375 # Also this makes it easier to match the behavior of the C++ implementation.
376 self._key_checker.CheckValue(item)
377 return item in self._values
378
379 @overload
380 def get(self, key: _K) -> Optional[_V]:
381 ...
382
383 @overload
384 def get(self, key: _K, default: _T) -> Union[_V, _T]:
385 ...
386
387 # We need to override this explicitly, because our defaultdict-like behavior
388 # will make the default implementation (from our base class) always insert
389 # the key.
390 def get(self, key, default=None):
391 if key in self:
392 return self[key]
393 else:
394 return default
395
396 def __setitem__(self, key: _K, value: _V) -> _T:
397 checked_key = self._key_checker.CheckValue(key)
398 checked_value = self._value_checker.CheckValue(value)
399 self._values[checked_key] = checked_value
400 self._message_listener.Modified()
401
402 def __delitem__(self, key: _K) -> None:
403 del self._values[key]
404 self._message_listener.Modified()
405
406 def __len__(self) -> int:
407 return len(self._values)
408
409 def __iter__(self) -> Iterator[_K]:
410 return iter(self._values)
411
412 def __repr__(self) -> str:
413 return repr(self._values)
414
415 def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
416 self._values.update(other._values)
417 self._message_listener.Modified()
418
419 def InvalidateIterators(self) -> None:
420 # It appears that the only way to reliably invalidate iterators to
421 # self._values is to ensure that its size changes.
422 original = self._values
423 self._values = original.copy()
424 original[None] = None
425
426 # This is defined in the abstract base, but we can do it much more cheaply.
427 def clear(self) -> None:
428 self._values.clear()
429 self._message_listener.Modified()
430
431 def GetEntryClass(self) -> Any:
432 return self._entry_descriptor._concrete_class
433
434
435class MessageMap(MutableMapping[_K, _V]):
436 """Simple, type-checked, dict-like container for with submessage values."""
437
438 # Disallows assignment to other attributes.
439 __slots__ = ['_key_checker', '_values', '_message_listener',
440 '_message_descriptor', '_entry_descriptor']
441
442 def __init__(
443 self,
444 message_listener: Any,
445 message_descriptor: Any,
446 key_checker: Any,
447 entry_descriptor: Any,
448 ) -> None:
449 """
450 Args:
451 message_listener: A MessageListener implementation.
452 The ScalarMap will call this object's Modified() method when it
453 is modified.
454 key_checker: A type_checkers.ValueChecker instance to run on keys
455 inserted into this container.
456 value_checker: A type_checkers.ValueChecker instance to run on values
457 inserted into this container.
458 entry_descriptor: The MessageDescriptor of a map entry: key and value.
459 """
460 self._message_listener = message_listener
461 self._message_descriptor = message_descriptor
462 self._key_checker = key_checker
463 self._entry_descriptor = entry_descriptor
464 self._values = {}
465
466 def __getitem__(self, key: _K) -> _V:
467 key = self._key_checker.CheckValue(key)
468 try:
469 return self._values[key]
470 except KeyError:
471 new_element = self._message_descriptor._concrete_class()
472 new_element._SetListener(self._message_listener)
473 self._values[key] = new_element
474 self._message_listener.Modified()
475 return new_element
476
477 def get_or_create(self, key: _K) -> _V:
478 """get_or_create() is an alias for getitem (ie. map[key]).
479
480 Args:
481 key: The key to get or create in the map.
482
483 This is useful in cases where you want to be explicit that the call is
484 mutating the map. This can avoid lint errors for statements like this
485 that otherwise would appear to be pointless statements:
486
487 msg.my_map[key]
488 """
489 return self[key]
490
491 @overload
492 def get(self, key: _K) -> Optional[_V]:
493 ...
494
495 @overload
496 def get(self, key: _K, default: _T) -> Union[_V, _T]:
497 ...
498
499 # We need to override this explicitly, because our defaultdict-like behavior
500 # will make the default implementation (from our base class) always insert
501 # the key.
502 def get(self, key, default=None):
503 if key in self:
504 return self[key]
505 else:
506 return default
507
508 def __contains__(self, item: _K) -> bool:
509 item = self._key_checker.CheckValue(item)
510 return item in self._values
511
512 def __setitem__(self, key: _K, value: _V) -> NoReturn:
513 raise ValueError('May not set values directly, call my_map[key].foo = 5')
514
515 def __delitem__(self, key: _K) -> None:
516 key = self._key_checker.CheckValue(key)
517 del self._values[key]
518 self._message_listener.Modified()
519
520 def __len__(self) -> int:
521 return len(self._values)
522
523 def __iter__(self) -> Iterator[_K]:
524 return iter(self._values)
525
526 def __repr__(self) -> str:
527 return repr(self._values)
528
529 def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
530 # pylint: disable=protected-access
531 for key in other._values:
532 # According to documentation: "When parsing from the wire or when merging,
533 # if there are duplicate map keys the last key seen is used".
534 if key in self:
535 del self[key]
536 self[key].CopyFrom(other[key])
537 # self._message_listener.Modified() not required here, because
538 # mutations to submessages already propagate.
539
540 def InvalidateIterators(self) -> None:
541 # It appears that the only way to reliably invalidate iterators to
542 # self._values is to ensure that its size changes.
543 original = self._values
544 self._values = original.copy()
545 original[None] = None
546
547 # This is defined in the abstract base, but we can do it much more cheaply.
548 def clear(self) -> None:
549 self._values.clear()
550 self._message_listener.Modified()
551
552 def GetEntryClass(self) -> Any:
553 return self._entry_descriptor._concrete_class
554
555
556class _UnknownField:
557 """A parsed unknown field."""
558
559 # Disallows assignment to other attributes.
560 __slots__ = ['_field_number', '_wire_type', '_data']
561
562 def __init__(self, field_number, wire_type, data):
563 self._field_number = field_number
564 self._wire_type = wire_type
565 self._data = data
566 return
567
568 def __lt__(self, other):
569 # pylint: disable=protected-access
570 return self._field_number < other._field_number
571
572 def __eq__(self, other):
573 if self is other:
574 return True
575 # pylint: disable=protected-access
576 return (self._field_number == other._field_number and
577 self._wire_type == other._wire_type and
578 self._data == other._data)
579
580
581class UnknownFieldRef: # pylint: disable=missing-class-docstring
582
583 def __init__(self, parent, index):
584 self._parent = parent
585 self._index = index
586
587 def _check_valid(self):
588 if not self._parent:
589 raise ValueError('UnknownField does not exist. '
590 'The parent message might be cleared.')
591 if self._index >= len(self._parent):
592 raise ValueError('UnknownField does not exist. '
593 'The parent message might be cleared.')
594
595 @property
596 def field_number(self):
597 self._check_valid()
598 # pylint: disable=protected-access
599 return self._parent._internal_get(self._index)._field_number
600
601 @property
602 def wire_type(self):
603 self._check_valid()
604 # pylint: disable=protected-access
605 return self._parent._internal_get(self._index)._wire_type
606
607 @property
608 def data(self):
609 self._check_valid()
610 # pylint: disable=protected-access
611 return self._parent._internal_get(self._index)._data
612
613
614class UnknownFieldSet:
615 """UnknownField container"""
616
617 # Disallows assignment to other attributes.
618 __slots__ = ['_values']
619
620 def __init__(self):
621 self._values = []
622
623 def __getitem__(self, index):
624 if self._values is None:
625 raise ValueError('UnknownFields does not exist. '
626 'The parent message might be cleared.')
627 size = len(self._values)
628 if index < 0:
629 index += size
630 if index < 0 or index >= size:
631 raise IndexError('index %d out of range'.index)
632
633 return UnknownFieldRef(self, index)
634
635 def _internal_get(self, index):
636 return self._values[index]
637
638 def __len__(self):
639 if self._values is None:
640 raise ValueError('UnknownFields does not exist. '
641 'The parent message might be cleared.')
642 return len(self._values)
643
644 def _add(self, field_number, wire_type, data):
645 unknown_field = _UnknownField(field_number, wire_type, data)
646 self._values.append(unknown_field)
647 return unknown_field
648
649 def __iter__(self):
650 for i in range(len(self)):
651 yield UnknownFieldRef(self, i)
652
653 def _extend(self, other):
654 if other is None:
655 return
656 # pylint: disable=protected-access
657 self._values.extend(other._values)
658
659 def __eq__(self, other):
660 if self is other:
661 return True
662 # Sort unknown fields because their order shouldn't
663 # affect equality test.
664 values = list(self._values)
665 if other is None:
666 return not values
667 values.sort()
668 # pylint: disable=protected-access
669 other_values = sorted(other._values)
670 return values == other_values
671
672 def _clear(self):
673 for value in self._values:
674 # pylint: disable=protected-access
675 if isinstance(value._data, UnknownFieldSet):
676 value._data._clear() # pylint: disable=protected-access
677 self._values = None