1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Code for decoding protocol buffer primitives.
9
10This code is very similar to encoder.py -- read the docs for that module first.
11
12A "decoder" is a function with the signature:
13 Decode(buffer, pos, end, message, field_dict)
14The arguments are:
15 buffer: The string containing the encoded message.
16 pos: The current position in the string.
17 end: The position in the string where the current message ends. May be
18 less than len(buffer) if we're reading a sub-message.
19 message: The message object into which we're parsing.
20 field_dict: message._fields (avoids a hashtable lookup).
21The decoder reads the field and stores it into field_dict, returning the new
22buffer position. A decoder for a repeated field may proactively decode all of
23the elements of that field, if they appear consecutively.
24
25Note that decoders may throw any of the following:
26 IndexError: Indicates a truncated message.
27 struct.error: Unpacking of a fixed-width field failed.
28 message.DecodeError: Other errors.
29
30Decoders are expected to raise an exception if they are called with pos > end.
31This allows callers to be lax about bounds checking: it's fineto read past
32"end" as long as you are sure that someone else will notice and throw an
33exception later on.
34
35Something up the call stack is expected to catch IndexError and struct.error
36and convert them to message.DecodeError.
37
38Decoders are constructed using decoder constructors with the signature:
39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
40The arguments are:
41 field_number: The field number of the field we want to decode.
42 is_repeated: Is the field a repeated field? (bool)
43 is_packed: Is the field a packed field? (bool)
44 key: The key to use when looking up the field within field_dict.
45 (This is actually the FieldDescriptor but nothing in this
46 file should depend on that.)
47 new_default: A function which takes a message object as a parameter and
48 returns a new instance of the default value for this field.
49 (This is called for repeated fields and sub-messages, when an
50 instance does not already exist.)
51
52As with encoders, we define a decoder constructor for every type of field.
53Then, for every field of every message class we construct an actual decoder.
54That decoder goes into a dict indexed by tag, so when we decode a message
55we repeatedly read a tag, look up the corresponding decoder, and invoke it.
56"""
57
58__author__ = 'kenton@google.com (Kenton Varda)'
59
60import math
61import numbers
62import struct
63
64from google.protobuf import message
65from google.protobuf.internal import containers
66from google.protobuf.internal import encoder
67from google.protobuf.internal import wire_format
68
69
70# This is not for optimization, but rather to avoid conflicts with local
71# variables named "message".
72_DecodeError = message.DecodeError
73
74
75def IsDefaultScalarValue(value):
76 """Returns whether or not a scalar value is the default value of its type.
77
78 Specifically, this should be used to determine presence of implicit-presence
79 fields, where we disallow custom defaults.
80
81 Args:
82 value: A scalar value to check.
83
84 Returns:
85 True if the value is equivalent to a default value, False otherwise.
86 """
87 if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0:
88 # Special case for negative zero, where "truthiness" fails to give the right
89 # answer.
90 return False
91
92 # Normally, we can just use Python's boolean conversion.
93 return not value
94
95
96def _VarintDecoder(mask, result_type):
97 """Return an encoder for a basic varint value (does not include tag).
98
99 Decoded values will be bitwise-anded with the given mask before being
100 returned, e.g. to limit them to 32 bits. The returned decoder does not
101 take the usual "end" parameter -- the caller is expected to do bounds checking
102 after the fact (often the caller can defer such checking until later). The
103 decoder returns a (value, new_pos) pair.
104 """
105
106 def DecodeVarint(buffer, pos: int=None):
107 result = 0
108 shift = 0
109 while 1:
110 if pos is None:
111 # Read from BytesIO
112 try:
113 b = buffer.read(1)[0]
114 except IndexError as e:
115 if shift == 0:
116 # End of BytesIO.
117 return None
118 else:
119 raise ValueError('Fail to read varint %s' % str(e))
120 else:
121 b = buffer[pos]
122 pos += 1
123 result |= ((b & 0x7f) << shift)
124 if not (b & 0x80):
125 result &= mask
126 result = result_type(result)
127 return result if pos is None else (result, pos)
128 shift += 7
129 if shift >= 64:
130 raise _DecodeError('Too many bytes when decoding varint.')
131
132 return DecodeVarint
133
134
135def _SignedVarintDecoder(bits, result_type):
136 """Like _VarintDecoder() but decodes signed values."""
137
138 signbit = 1 << (bits - 1)
139 mask = (1 << bits) - 1
140
141 def DecodeVarint(buffer, pos):
142 result = 0
143 shift = 0
144 while 1:
145 b = buffer[pos]
146 result |= ((b & 0x7f) << shift)
147 pos += 1
148 if not (b & 0x80):
149 result &= mask
150 result = (result ^ signbit) - signbit
151 result = result_type(result)
152 return (result, pos)
153 shift += 7
154 if shift >= 64:
155 raise _DecodeError('Too many bytes when decoding varint.')
156 return DecodeVarint
157
158# All 32-bit and 64-bit values are represented as int.
159_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
160_DecodeSignedVarint = _SignedVarintDecoder(64, int)
161
162# Use these versions for values which must be limited to 32 bits.
163_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
164_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
165
166
167def ReadTag(buffer, pos):
168 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
169
170 We return the raw bytes of the tag rather than decoding them. The raw
171 bytes can then be used to look up the proper decoder. This effectively allows
172 us to trade some work that would be done in pure-python (decoding a varint)
173 for work that is done in C (searching for a byte string in a hash table).
174 In a low-level language it would be much cheaper to decode the varint and
175 use that, but not in Python.
176
177 Args:
178 buffer: memoryview object of the encoded bytes
179 pos: int of the current position to start from
180
181 Returns:
182 Tuple[bytes, int] of the tag data and new position.
183 """
184 start = pos
185 while buffer[pos] & 0x80:
186 pos += 1
187 pos += 1
188
189 tag_bytes = buffer[start:pos].tobytes()
190 return tag_bytes, pos
191
192
193def DecodeTag(tag_bytes):
194 """Decode a tag from the bytes.
195
196 Args:
197 tag_bytes: the bytes of the tag
198
199 Returns:
200 Tuple[int, int] of the tag field number and wire type.
201 """
202 (tag, _) = _DecodeVarint(tag_bytes, 0)
203 return wire_format.UnpackTag(tag)
204
205
206# --------------------------------------------------------------------
207
208
209def _SimpleDecoder(wire_type, decode_value):
210 """Return a constructor for a decoder for fields of a particular type.
211
212 Args:
213 wire_type: The field's wire type.
214 decode_value: A function which decodes an individual value, e.g.
215 _DecodeVarint()
216 """
217
218 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
219 clear_if_default=False):
220 if is_packed:
221 local_DecodeVarint = _DecodeVarint
222 def DecodePackedField(
223 buffer, pos, end, message, field_dict, current_depth=0
224 ):
225 del current_depth # unused
226 value = field_dict.get(key)
227 if value is None:
228 value = field_dict.setdefault(key, new_default(message))
229 (endpoint, pos) = local_DecodeVarint(buffer, pos)
230 endpoint += pos
231 if endpoint > end:
232 raise _DecodeError('Truncated message.')
233 while pos < endpoint:
234 (element, pos) = decode_value(buffer, pos)
235 value.append(element)
236 if pos > endpoint:
237 del value[-1] # Discard corrupt value.
238 raise _DecodeError('Packed element was truncated.')
239 return pos
240
241 return DecodePackedField
242 elif is_repeated:
243 tag_bytes = encoder.TagBytes(field_number, wire_type)
244 tag_len = len(tag_bytes)
245 def DecodeRepeatedField(
246 buffer, pos, end, message, field_dict, current_depth=0
247 ):
248 del current_depth # unused
249 value = field_dict.get(key)
250 if value is None:
251 value = field_dict.setdefault(key, new_default(message))
252 while 1:
253 (element, new_pos) = decode_value(buffer, pos)
254 value.append(element)
255 # Predict that the next tag is another copy of the same repeated
256 # field.
257 pos = new_pos + tag_len
258 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
259 # Prediction failed. Return.
260 if new_pos > end:
261 raise _DecodeError('Truncated message.')
262 return new_pos
263
264 return DecodeRepeatedField
265 else:
266
267 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
268 del current_depth # unused
269 (new_value, pos) = decode_value(buffer, pos)
270 if pos > end:
271 raise _DecodeError('Truncated message.')
272 if clear_if_default and IsDefaultScalarValue(new_value):
273 field_dict.pop(key, None)
274 else:
275 field_dict[key] = new_value
276 return pos
277
278 return DecodeField
279
280 return SpecificDecoder
281
282
283def _ModifiedDecoder(wire_type, decode_value, modify_value):
284 """Like SimpleDecoder but additionally invokes modify_value on every value
285 before storing it. Usually modify_value is ZigZagDecode.
286 """
287
288 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
289 # not enough to make a significant difference.
290
291 def InnerDecode(buffer, pos):
292 (result, new_pos) = decode_value(buffer, pos)
293 return (modify_value(result), new_pos)
294 return _SimpleDecoder(wire_type, InnerDecode)
295
296
297def _StructPackDecoder(wire_type, format):
298 """Return a constructor for a decoder for a fixed-width field.
299
300 Args:
301 wire_type: The field's wire type.
302 format: The format string to pass to struct.unpack().
303 """
304
305 value_size = struct.calcsize(format)
306 local_unpack = struct.unpack
307
308 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
309 # not enough to make a significant difference.
310
311 # Note that we expect someone up-stack to catch struct.error and convert
312 # it to _DecodeError -- this way we don't have to set up exception-
313 # handling blocks every time we parse one value.
314
315 def InnerDecode(buffer, pos):
316 new_pos = pos + value_size
317 result = local_unpack(format, buffer[pos:new_pos])[0]
318 return (result, new_pos)
319 return _SimpleDecoder(wire_type, InnerDecode)
320
321
322def _FloatDecoder():
323 """Returns a decoder for a float field.
324
325 This code works around a bug in struct.unpack for non-finite 32-bit
326 floating-point values.
327 """
328
329 local_unpack = struct.unpack
330
331 def InnerDecode(buffer, pos):
332 """Decode serialized float to a float and new position.
333
334 Args:
335 buffer: memoryview of the serialized bytes
336 pos: int, position in the memory view to start at.
337
338 Returns:
339 Tuple[float, int] of the deserialized float value and new position
340 in the serialized data.
341 """
342 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
343 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
344 new_pos = pos + 4
345 float_bytes = buffer[pos:new_pos].tobytes()
346
347 # If this value has all its exponent bits set, then it's non-finite.
348 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
349 # To avoid that, we parse it specially.
350 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
351 # If at least one significand bit is set...
352 if float_bytes[0:3] != b'\x00\x00\x80':
353 return (math.nan, new_pos)
354 # If sign bit is set...
355 if float_bytes[3:4] == b'\xFF':
356 return (-math.inf, new_pos)
357 return (math.inf, new_pos)
358
359 # Note that we expect someone up-stack to catch struct.error and convert
360 # it to _DecodeError -- this way we don't have to set up exception-
361 # handling blocks every time we parse one value.
362 result = local_unpack('<f', float_bytes)[0]
363 return (result, new_pos)
364 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
365
366
367def _DoubleDecoder():
368 """Returns a decoder for a double field.
369
370 This code works around a bug in struct.unpack for not-a-number.
371 """
372
373 local_unpack = struct.unpack
374
375 def InnerDecode(buffer, pos):
376 """Decode serialized double to a double and new position.
377
378 Args:
379 buffer: memoryview of the serialized bytes.
380 pos: int, position in the memory view to start at.
381
382 Returns:
383 Tuple[float, int] of the decoded double value and new position
384 in the serialized data.
385 """
386 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
387 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
388 new_pos = pos + 8
389 double_bytes = buffer[pos:new_pos].tobytes()
390
391 # If this value has all its exponent bits set and at least one significand
392 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
393 # as inf or -inf. To avoid that, we treat it specially.
394 if ((double_bytes[7:8] in b'\x7F\xFF')
395 and (double_bytes[6:7] >= b'\xF0')
396 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
397 return (math.nan, new_pos)
398
399 # Note that we expect someone up-stack to catch struct.error and convert
400 # it to _DecodeError -- this way we don't have to set up exception-
401 # handling blocks every time we parse one value.
402 result = local_unpack('<d', double_bytes)[0]
403 return (result, new_pos)
404 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
405
406
407def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
408 clear_if_default=False):
409 """Returns a decoder for enum field."""
410 enum_type = key.enum_type
411 if is_packed:
412 local_DecodeVarint = _DecodeVarint
413 def DecodePackedField(
414 buffer, pos, end, message, field_dict, current_depth=0
415 ):
416 """Decode serialized packed enum to its value and a new position.
417
418 Args:
419 buffer: memoryview of the serialized bytes.
420 pos: int, position in the memory view to start at.
421 end: int, end position of serialized data
422 message: Message object to store unknown fields in
423 field_dict: Map[Descriptor, Any] to store decoded values in.
424
425 Returns:
426 int, new position in serialized data.
427 """
428 del current_depth # unused
429 value = field_dict.get(key)
430 if value is None:
431 value = field_dict.setdefault(key, new_default(message))
432 (endpoint, pos) = local_DecodeVarint(buffer, pos)
433 endpoint += pos
434 if endpoint > end:
435 raise _DecodeError('Truncated message.')
436 while pos < endpoint:
437 value_start_pos = pos
438 (element, pos) = _DecodeSignedVarint32(buffer, pos)
439 # pylint: disable=protected-access
440 if element in enum_type.values_by_number:
441 value.append(element)
442 else:
443 if not message._unknown_fields:
444 message._unknown_fields = []
445 tag_bytes = encoder.TagBytes(field_number,
446 wire_format.WIRETYPE_VARINT)
447
448 message._unknown_fields.append(
449 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
450 # pylint: enable=protected-access
451 if pos > endpoint:
452 if element in enum_type.values_by_number:
453 del value[-1] # Discard corrupt value.
454 else:
455 del message._unknown_fields[-1]
456 # pylint: enable=protected-access
457 raise _DecodeError('Packed element was truncated.')
458 return pos
459
460 return DecodePackedField
461 elif is_repeated:
462 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
463 tag_len = len(tag_bytes)
464 def DecodeRepeatedField(
465 buffer, pos, end, message, field_dict, current_depth=0
466 ):
467 """Decode serialized repeated enum to its value and a new position.
468
469 Args:
470 buffer: memoryview of the serialized bytes.
471 pos: int, position in the memory view to start at.
472 end: int, end position of serialized data
473 message: Message object to store unknown fields in
474 field_dict: Map[Descriptor, Any] to store decoded values in.
475
476 Returns:
477 int, new position in serialized data.
478 """
479 del current_depth # unused
480 value = field_dict.get(key)
481 if value is None:
482 value = field_dict.setdefault(key, new_default(message))
483 while 1:
484 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
485 # pylint: disable=protected-access
486 if element in enum_type.values_by_number:
487 value.append(element)
488 else:
489 if not message._unknown_fields:
490 message._unknown_fields = []
491 message._unknown_fields.append(
492 (tag_bytes, buffer[pos:new_pos].tobytes()))
493 # pylint: enable=protected-access
494 # Predict that the next tag is another copy of the same repeated
495 # field.
496 pos = new_pos + tag_len
497 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
498 # Prediction failed. Return.
499 if new_pos > end:
500 raise _DecodeError('Truncated message.')
501 return new_pos
502
503 return DecodeRepeatedField
504 else:
505
506 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
507 """Decode serialized repeated enum to its value and a new position.
508
509 Args:
510 buffer: memoryview of the serialized bytes.
511 pos: int, position in the memory view to start at.
512 end: int, end position of serialized data
513 message: Message object to store unknown fields in
514 field_dict: Map[Descriptor, Any] to store decoded values in.
515
516 Returns:
517 int, new position in serialized data.
518 """
519 del current_depth # unused
520 value_start_pos = pos
521 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
522 if pos > end:
523 raise _DecodeError('Truncated message.')
524 if clear_if_default and IsDefaultScalarValue(enum_value):
525 field_dict.pop(key, None)
526 return pos
527 # pylint: disable=protected-access
528 if enum_value in enum_type.values_by_number:
529 field_dict[key] = enum_value
530 else:
531 if not message._unknown_fields:
532 message._unknown_fields = []
533 tag_bytes = encoder.TagBytes(field_number,
534 wire_format.WIRETYPE_VARINT)
535 message._unknown_fields.append(
536 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
537 # pylint: enable=protected-access
538 return pos
539
540 return DecodeField
541
542
543# --------------------------------------------------------------------
544
545
546Int32Decoder = _SimpleDecoder(
547 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
548
549Int64Decoder = _SimpleDecoder(
550 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
551
552UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
553UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
554
555SInt32Decoder = _ModifiedDecoder(
556 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
557SInt64Decoder = _ModifiedDecoder(
558 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
559
560# Note that Python conveniently guarantees that when using the '<' prefix on
561# formats, they will also have the same size across all platforms (as opposed
562# to without the prefix, where their sizes depend on the C compiler's basic
563# type sizes).
564Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
565Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
566SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
567SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
568FloatDecoder = _FloatDecoder()
569DoubleDecoder = _DoubleDecoder()
570
571BoolDecoder = _ModifiedDecoder(
572 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
573
574
575def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
576 clear_if_default=False):
577 """Returns a decoder for a string field."""
578
579 local_DecodeVarint = _DecodeVarint
580
581 def _ConvertToUnicode(memview):
582 """Convert byte to unicode."""
583 byte_str = memview.tobytes()
584 try:
585 value = str(byte_str, 'utf-8')
586 except UnicodeDecodeError as e:
587 # add more information to the error message and re-raise it.
588 e.reason = '%s in field: %s' % (e, key.full_name)
589 raise
590
591 return value
592
593 assert not is_packed
594 if is_repeated:
595 tag_bytes = encoder.TagBytes(field_number,
596 wire_format.WIRETYPE_LENGTH_DELIMITED)
597 tag_len = len(tag_bytes)
598 def DecodeRepeatedField(
599 buffer, pos, end, message, field_dict, current_depth=0
600 ):
601 del current_depth # unused
602 value = field_dict.get(key)
603 if value is None:
604 value = field_dict.setdefault(key, new_default(message))
605 while 1:
606 (size, pos) = local_DecodeVarint(buffer, pos)
607 new_pos = pos + size
608 if new_pos > end:
609 raise _DecodeError('Truncated string.')
610 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
611 # Predict that the next tag is another copy of the same repeated field.
612 pos = new_pos + tag_len
613 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
614 # Prediction failed. Return.
615 return new_pos
616
617 return DecodeRepeatedField
618 else:
619
620 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
621 del current_depth # unused
622 (size, pos) = local_DecodeVarint(buffer, pos)
623 new_pos = pos + size
624 if new_pos > end:
625 raise _DecodeError('Truncated string.')
626 if clear_if_default and IsDefaultScalarValue(size):
627 field_dict.pop(key, None)
628 else:
629 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
630 return new_pos
631
632 return DecodeField
633
634
635def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
636 clear_if_default=False):
637 """Returns a decoder for a bytes field."""
638
639 local_DecodeVarint = _DecodeVarint
640
641 assert not is_packed
642 if is_repeated:
643 tag_bytes = encoder.TagBytes(field_number,
644 wire_format.WIRETYPE_LENGTH_DELIMITED)
645 tag_len = len(tag_bytes)
646 def DecodeRepeatedField(
647 buffer, pos, end, message, field_dict, current_depth=0
648 ):
649 del current_depth # unused
650 value = field_dict.get(key)
651 if value is None:
652 value = field_dict.setdefault(key, new_default(message))
653 while 1:
654 (size, pos) = local_DecodeVarint(buffer, pos)
655 new_pos = pos + size
656 if new_pos > end:
657 raise _DecodeError('Truncated string.')
658 value.append(buffer[pos:new_pos].tobytes())
659 # Predict that the next tag is another copy of the same repeated field.
660 pos = new_pos + tag_len
661 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
662 # Prediction failed. Return.
663 return new_pos
664
665 return DecodeRepeatedField
666 else:
667
668 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
669 del current_depth # unused
670 (size, pos) = local_DecodeVarint(buffer, pos)
671 new_pos = pos + size
672 if new_pos > end:
673 raise _DecodeError('Truncated string.')
674 if clear_if_default and IsDefaultScalarValue(size):
675 field_dict.pop(key, None)
676 else:
677 field_dict[key] = buffer[pos:new_pos].tobytes()
678 return new_pos
679
680 return DecodeField
681
682
683def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
684 """Returns a decoder for a group field."""
685
686 end_tag_bytes = encoder.TagBytes(field_number,
687 wire_format.WIRETYPE_END_GROUP)
688 end_tag_len = len(end_tag_bytes)
689
690 assert not is_packed
691 if is_repeated:
692 tag_bytes = encoder.TagBytes(field_number,
693 wire_format.WIRETYPE_START_GROUP)
694 tag_len = len(tag_bytes)
695 def DecodeRepeatedField(
696 buffer, pos, end, message, field_dict, current_depth=0
697 ):
698 value = field_dict.get(key)
699 if value is None:
700 value = field_dict.setdefault(key, new_default(message))
701 while 1:
702 value = field_dict.get(key)
703 if value is None:
704 value = field_dict.setdefault(key, new_default(message))
705 # Read sub-message.
706 current_depth += 1
707 if current_depth > _recursion_limit:
708 raise _DecodeError(
709 'Error parsing message: too many levels of nesting.'
710 )
711 pos = value.add()._InternalParse(buffer, pos, end, current_depth)
712 current_depth -= 1
713 # Read end tag.
714 new_pos = pos+end_tag_len
715 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
716 raise _DecodeError('Missing group end tag.')
717 # Predict that the next tag is another copy of the same repeated field.
718 pos = new_pos + tag_len
719 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
720 # Prediction failed. Return.
721 return new_pos
722
723 return DecodeRepeatedField
724 else:
725
726 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
727 value = field_dict.get(key)
728 if value is None:
729 value = field_dict.setdefault(key, new_default(message))
730 # Read sub-message.
731 current_depth += 1
732 if current_depth > _recursion_limit:
733 raise _DecodeError('Error parsing message: too many levels of nesting.')
734 pos = value._InternalParse(buffer, pos, end, current_depth)
735 current_depth -= 1
736 # Read end tag.
737 new_pos = pos+end_tag_len
738 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
739 raise _DecodeError('Missing group end tag.')
740 return new_pos
741
742 return DecodeField
743
744
745def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
746 """Returns a decoder for a message field."""
747
748 local_DecodeVarint = _DecodeVarint
749
750 assert not is_packed
751 if is_repeated:
752 tag_bytes = encoder.TagBytes(field_number,
753 wire_format.WIRETYPE_LENGTH_DELIMITED)
754 tag_len = len(tag_bytes)
755 def DecodeRepeatedField(
756 buffer, pos, end, message, field_dict, current_depth=0
757 ):
758 value = field_dict.get(key)
759 if value is None:
760 value = field_dict.setdefault(key, new_default(message))
761 while 1:
762 # Read length.
763 (size, pos) = local_DecodeVarint(buffer, pos)
764 new_pos = pos + size
765 if new_pos > end:
766 raise _DecodeError('Truncated message.')
767 # Read sub-message.
768 current_depth += 1
769 if current_depth > _recursion_limit:
770 raise _DecodeError(
771 'Error parsing message: too many levels of nesting.'
772 )
773 if (
774 value.add()._InternalParse(buffer, pos, new_pos, current_depth)
775 != new_pos
776 ):
777 # The only reason _InternalParse would return early is if it
778 # encountered an end-group tag.
779 raise _DecodeError('Unexpected end-group tag.')
780 current_depth -= 1
781 # Predict that the next tag is another copy of the same repeated field.
782 pos = new_pos + tag_len
783 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
784 # Prediction failed. Return.
785 return new_pos
786
787 return DecodeRepeatedField
788 else:
789
790 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
791 value = field_dict.get(key)
792 if value is None:
793 value = field_dict.setdefault(key, new_default(message))
794 # Read length.
795 (size, pos) = local_DecodeVarint(buffer, pos)
796 new_pos = pos + size
797 if new_pos > end:
798 raise _DecodeError('Truncated message.')
799 # Read sub-message.
800 current_depth += 1
801 if current_depth > _recursion_limit:
802 raise _DecodeError('Error parsing message: too many levels of nesting.')
803 if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
804 # The only reason _InternalParse would return early is if it encountered
805 # an end-group tag.
806 raise _DecodeError('Unexpected end-group tag.')
807 current_depth -= 1
808 return new_pos
809
810 return DecodeField
811
812
813# --------------------------------------------------------------------
814
815MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
816
817def MessageSetItemDecoder(descriptor):
818 """Returns a decoder for a MessageSet item.
819
820 The parameter is the message Descriptor.
821
822 The message set message looks like this:
823 message MessageSet {
824 repeated group Item = 1 {
825 required int32 type_id = 2;
826 required string message = 3;
827 }
828 }
829 """
830
831 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
832 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
833 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
834
835 local_ReadTag = ReadTag
836 local_DecodeVarint = _DecodeVarint
837
838 def DecodeItem(buffer, pos, end, message, field_dict):
839 """Decode serialized message set to its value and new position.
840
841 Args:
842 buffer: memoryview of the serialized bytes.
843 pos: int, position in the memory view to start at.
844 end: int, end position of serialized data
845 message: Message object to store unknown fields in
846 field_dict: Map[Descriptor, Any] to store decoded values in.
847
848 Returns:
849 int, new position in serialized data.
850 """
851 message_set_item_start = pos
852 type_id = -1
853 message_start = -1
854 message_end = -1
855
856 # Technically, type_id and message can appear in any order, so we need
857 # a little loop here.
858 while 1:
859 (tag_bytes, pos) = local_ReadTag(buffer, pos)
860 if tag_bytes == type_id_tag_bytes:
861 (type_id, pos) = local_DecodeVarint(buffer, pos)
862 elif tag_bytes == message_tag_bytes:
863 (size, message_start) = local_DecodeVarint(buffer, pos)
864 pos = message_end = message_start + size
865 elif tag_bytes == item_end_tag_bytes:
866 break
867 else:
868 field_number, wire_type = DecodeTag(tag_bytes)
869 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
870 if pos == -1:
871 raise _DecodeError('Unexpected end-group tag.')
872
873 if pos > end:
874 raise _DecodeError('Truncated message.')
875
876 if type_id == -1:
877 raise _DecodeError('MessageSet item missing type_id.')
878 if message_start == -1:
879 raise _DecodeError('MessageSet item missing message.')
880
881 extension = message.Extensions._FindExtensionByNumber(type_id)
882 # pylint: disable=protected-access
883 if extension is not None:
884 value = field_dict.get(extension)
885 if value is None:
886 message_type = extension.message_type
887 if not hasattr(message_type, '_concrete_class'):
888 message_factory.GetMessageClass(message_type)
889 value = field_dict.setdefault(
890 extension, message_type._concrete_class())
891 if value._InternalParse(buffer, message_start,message_end) != message_end:
892 # The only reason _InternalParse would return early is if it encountered
893 # an end-group tag.
894 raise _DecodeError('Unexpected end-group tag.')
895 else:
896 if not message._unknown_fields:
897 message._unknown_fields = []
898 message._unknown_fields.append(
899 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
900 # pylint: enable=protected-access
901
902 return pos
903
904 return DecodeItem
905
906
907def UnknownMessageSetItemDecoder():
908 """Returns a decoder for a Unknown MessageSet item."""
909
910 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
911 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
912 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
913
914 def DecodeUnknownItem(buffer):
915 pos = 0
916 end = len(buffer)
917 message_start = -1
918 message_end = -1
919 while 1:
920 (tag_bytes, pos) = ReadTag(buffer, pos)
921 if tag_bytes == type_id_tag_bytes:
922 (type_id, pos) = _DecodeVarint(buffer, pos)
923 elif tag_bytes == message_tag_bytes:
924 (size, message_start) = _DecodeVarint(buffer, pos)
925 pos = message_end = message_start + size
926 elif tag_bytes == item_end_tag_bytes:
927 break
928 else:
929 field_number, wire_type = DecodeTag(tag_bytes)
930 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
931 if pos == -1:
932 raise _DecodeError('Unexpected end-group tag.')
933
934 if pos > end:
935 raise _DecodeError('Truncated message.')
936
937 if type_id == -1:
938 raise _DecodeError('MessageSet item missing type_id.')
939 if message_start == -1:
940 raise _DecodeError('MessageSet item missing message.')
941
942 return (type_id, buffer[message_start:message_end].tobytes())
943
944 return DecodeUnknownItem
945
946# --------------------------------------------------------------------
947
948def MapDecoder(field_descriptor, new_default, is_message_map):
949 """Returns a decoder for a map field."""
950
951 key = field_descriptor
952 tag_bytes = encoder.TagBytes(field_descriptor.number,
953 wire_format.WIRETYPE_LENGTH_DELIMITED)
954 tag_len = len(tag_bytes)
955 local_DecodeVarint = _DecodeVarint
956 # Can't read _concrete_class yet; might not be initialized.
957 message_type = field_descriptor.message_type
958
959 def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
960 del current_depth # Unused.
961 submsg = message_type._concrete_class()
962 value = field_dict.get(key)
963 if value is None:
964 value = field_dict.setdefault(key, new_default(message))
965 while 1:
966 # Read length.
967 (size, pos) = local_DecodeVarint(buffer, pos)
968 new_pos = pos + size
969 if new_pos > end:
970 raise _DecodeError('Truncated message.')
971 # Read sub-message.
972 submsg.Clear()
973 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
974 # The only reason _InternalParse would return early is if it
975 # encountered an end-group tag.
976 raise _DecodeError('Unexpected end-group tag.')
977
978 if is_message_map:
979 value[submsg.key].CopyFrom(submsg.value)
980 else:
981 value[submsg.key] = submsg.value
982
983 # Predict that the next tag is another copy of the same repeated field.
984 pos = new_pos + tag_len
985 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
986 # Prediction failed. Return.
987 return new_pos
988
989 return DecodeMap
990
991
992def _DecodeFixed64(buffer, pos):
993 """Decode a fixed64."""
994 new_pos = pos + 8
995 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
996
997
998def _DecodeFixed32(buffer, pos):
999 """Decode a fixed32."""
1000
1001 new_pos = pos + 4
1002 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1003DEFAULT_RECURSION_LIMIT = 100
1004_recursion_limit = DEFAULT_RECURSION_LIMIT
1005
1006
1007def SetRecursionLimit(new_limit):
1008 global _recursion_limit
1009 _recursion_limit = new_limit
1010
1011
1012def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
1013 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
1014
1015 unknown_field_set = containers.UnknownFieldSet()
1016 while end_pos is None or pos < end_pos:
1017 (tag_bytes, pos) = ReadTag(buffer, pos)
1018 (tag, _) = _DecodeVarint(tag_bytes, 0)
1019 field_number, wire_type = wire_format.UnpackTag(tag)
1020 if wire_type == wire_format.WIRETYPE_END_GROUP:
1021 break
1022 (data, pos) = _DecodeUnknownField(
1023 buffer, pos, end_pos, field_number, wire_type, current_depth
1024 )
1025 # pylint: disable=protected-access
1026 unknown_field_set._add(field_number, wire_type, data)
1027
1028 return (unknown_field_set, pos)
1029
1030
1031def _DecodeUnknownField(
1032 buffer, pos, end_pos, field_number, wire_type, current_depth=0
1033):
1034 """Decode a unknown field. Returns the UnknownField and new position."""
1035
1036 if wire_type == wire_format.WIRETYPE_VARINT:
1037 (data, pos) = _DecodeVarint(buffer, pos)
1038 elif wire_type == wire_format.WIRETYPE_FIXED64:
1039 (data, pos) = _DecodeFixed64(buffer, pos)
1040 elif wire_type == wire_format.WIRETYPE_FIXED32:
1041 (data, pos) = _DecodeFixed32(buffer, pos)
1042 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
1043 (size, pos) = _DecodeVarint(buffer, pos)
1044 data = buffer[pos:pos+size].tobytes()
1045 pos += size
1046 elif wire_type == wire_format.WIRETYPE_START_GROUP:
1047 end_tag_bytes = encoder.TagBytes(
1048 field_number, wire_format.WIRETYPE_END_GROUP
1049 )
1050 current_depth += 1
1051 if current_depth >= _recursion_limit:
1052 raise _DecodeError('Error parsing message: too many levels of nesting.')
1053 data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth)
1054 current_depth -= 1
1055 # Check end tag.
1056 if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
1057 raise _DecodeError('Missing group end tag.')
1058 elif wire_type == wire_format.WIRETYPE_END_GROUP:
1059 return (0, -1)
1060 else:
1061 raise _DecodeError('Wrong wire type in tag.')
1062
1063 if pos > end_pos:
1064 raise _DecodeError('Truncated message.')
1065
1066 return (data, pos)