1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Code for decoding protocol buffer primitives.
9
10This code is very similar to encoder.py -- read the docs for that module first.
11
12A "decoder" is a function with the signature:
13 Decode(buffer, pos, end, message, field_dict)
14The arguments are:
15 buffer: The string containing the encoded message.
16 pos: The current position in the string.
17 end: The position in the string where the current message ends. May be
18 less than len(buffer) if we're reading a sub-message.
19 message: The message object into which we're parsing.
20 field_dict: message._fields (avoids a hashtable lookup).
21The decoder reads the field and stores it into field_dict, returning the new
22buffer position. A decoder for a repeated field may proactively decode all of
23the elements of that field, if they appear consecutively.
24
25Note that decoders may throw any of the following:
26 IndexError: Indicates a truncated message.
27 struct.error: Unpacking of a fixed-width field failed.
28 message.DecodeError: Other errors.
29
30Decoders are expected to raise an exception if they are called with pos > end.
31This allows callers to be lax about bounds checking: it's fineto read past
32"end" as long as you are sure that someone else will notice and throw an
33exception later on.
34
35Something up the call stack is expected to catch IndexError and struct.error
36and convert them to message.DecodeError.
37
38Decoders are constructed using decoder constructors with the signature:
39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
40The arguments are:
41 field_number: The field number of the field we want to decode.
42 is_repeated: Is the field a repeated field? (bool)
43 is_packed: Is the field a packed field? (bool)
44 key: The key to use when looking up the field within field_dict.
45 (This is actually the FieldDescriptor but nothing in this
46 file should depend on that.)
47 new_default: A function which takes a message object as a parameter and
48 returns a new instance of the default value for this field.
49 (This is called for repeated fields and sub-messages, when an
50 instance does not already exist.)
51
52As with encoders, we define a decoder constructor for every type of field.
53Then, for every field of every message class we construct an actual decoder.
54That decoder goes into a dict indexed by tag, so when we decode a message
55we repeatedly read a tag, look up the corresponding decoder, and invoke it.
56"""
57
58__author__ = 'kenton@google.com (Kenton Varda)'
59
60import math
61import struct
62
63from google.protobuf import message
64from google.protobuf.internal import containers
65from google.protobuf.internal import encoder
66from google.protobuf.internal import wire_format
67
68
69# This is not for optimization, but rather to avoid conflicts with local
70# variables named "message".
71_DecodeError = message.DecodeError
72
73
74def _VarintDecoder(mask, result_type):
75 """Return an encoder for a basic varint value (does not include tag).
76
77 Decoded values will be bitwise-anded with the given mask before being
78 returned, e.g. to limit them to 32 bits. The returned decoder does not
79 take the usual "end" parameter -- the caller is expected to do bounds checking
80 after the fact (often the caller can defer such checking until later). The
81 decoder returns a (value, new_pos) pair.
82 """
83
84 def DecodeVarint(buffer, pos: int=None):
85 result = 0
86 shift = 0
87 while 1:
88 if pos is None:
89 # Read from BytesIO
90 try:
91 b = buffer.read(1)[0]
92 except IndexError as e:
93 if shift == 0:
94 # End of BytesIO.
95 return None
96 else:
97 raise ValueError('Fail to read varint %s' % str(e))
98 else:
99 b = buffer[pos]
100 pos += 1
101 result |= ((b & 0x7f) << shift)
102 if not (b & 0x80):
103 result &= mask
104 result = result_type(result)
105 return result if pos is None else (result, pos)
106 shift += 7
107 if shift >= 64:
108 raise _DecodeError('Too many bytes when decoding varint.')
109
110 return DecodeVarint
111
112
113def _SignedVarintDecoder(bits, result_type):
114 """Like _VarintDecoder() but decodes signed values."""
115
116 signbit = 1 << (bits - 1)
117 mask = (1 << bits) - 1
118
119 def DecodeVarint(buffer, pos):
120 result = 0
121 shift = 0
122 while 1:
123 b = buffer[pos]
124 result |= ((b & 0x7f) << shift)
125 pos += 1
126 if not (b & 0x80):
127 result &= mask
128 result = (result ^ signbit) - signbit
129 result = result_type(result)
130 return (result, pos)
131 shift += 7
132 if shift >= 64:
133 raise _DecodeError('Too many bytes when decoding varint.')
134 return DecodeVarint
135
136# All 32-bit and 64-bit values are represented as int.
137_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
138_DecodeSignedVarint = _SignedVarintDecoder(64, int)
139
140# Use these versions for values which must be limited to 32 bits.
141_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
142_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
143
144
145def ReadTag(buffer, pos):
146 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
147
148 We return the raw bytes of the tag rather than decoding them. The raw
149 bytes can then be used to look up the proper decoder. This effectively allows
150 us to trade some work that would be done in pure-python (decoding a varint)
151 for work that is done in C (searching for a byte string in a hash table).
152 In a low-level language it would be much cheaper to decode the varint and
153 use that, but not in Python.
154
155 Args:
156 buffer: memoryview object of the encoded bytes
157 pos: int of the current position to start from
158
159 Returns:
160 Tuple[bytes, int] of the tag data and new position.
161 """
162 start = pos
163 while buffer[pos] & 0x80:
164 pos += 1
165 pos += 1
166
167 tag_bytes = buffer[start:pos].tobytes()
168 return tag_bytes, pos
169
170
171# --------------------------------------------------------------------
172
173
174def _SimpleDecoder(wire_type, decode_value):
175 """Return a constructor for a decoder for fields of a particular type.
176
177 Args:
178 wire_type: The field's wire type.
179 decode_value: A function which decodes an individual value, e.g.
180 _DecodeVarint()
181 """
182
183 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
184 clear_if_default=False):
185 if is_packed:
186 local_DecodeVarint = _DecodeVarint
187 def DecodePackedField(buffer, pos, end, message, field_dict):
188 value = field_dict.get(key)
189 if value is None:
190 value = field_dict.setdefault(key, new_default(message))
191 (endpoint, pos) = local_DecodeVarint(buffer, pos)
192 endpoint += pos
193 if endpoint > end:
194 raise _DecodeError('Truncated message.')
195 while pos < endpoint:
196 (element, pos) = decode_value(buffer, pos)
197 value.append(element)
198 if pos > endpoint:
199 del value[-1] # Discard corrupt value.
200 raise _DecodeError('Packed element was truncated.')
201 return pos
202 return DecodePackedField
203 elif is_repeated:
204 tag_bytes = encoder.TagBytes(field_number, wire_type)
205 tag_len = len(tag_bytes)
206 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
207 value = field_dict.get(key)
208 if value is None:
209 value = field_dict.setdefault(key, new_default(message))
210 while 1:
211 (element, new_pos) = decode_value(buffer, pos)
212 value.append(element)
213 # Predict that the next tag is another copy of the same repeated
214 # field.
215 pos = new_pos + tag_len
216 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
217 # Prediction failed. Return.
218 if new_pos > end:
219 raise _DecodeError('Truncated message.')
220 return new_pos
221 return DecodeRepeatedField
222 else:
223 def DecodeField(buffer, pos, end, message, field_dict):
224 (new_value, pos) = decode_value(buffer, pos)
225 if pos > end:
226 raise _DecodeError('Truncated message.')
227 if clear_if_default and not new_value:
228 field_dict.pop(key, None)
229 else:
230 field_dict[key] = new_value
231 return pos
232 return DecodeField
233
234 return SpecificDecoder
235
236
237def _ModifiedDecoder(wire_type, decode_value, modify_value):
238 """Like SimpleDecoder but additionally invokes modify_value on every value
239 before storing it. Usually modify_value is ZigZagDecode.
240 """
241
242 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
243 # not enough to make a significant difference.
244
245 def InnerDecode(buffer, pos):
246 (result, new_pos) = decode_value(buffer, pos)
247 return (modify_value(result), new_pos)
248 return _SimpleDecoder(wire_type, InnerDecode)
249
250
251def _StructPackDecoder(wire_type, format):
252 """Return a constructor for a decoder for a fixed-width field.
253
254 Args:
255 wire_type: The field's wire type.
256 format: The format string to pass to struct.unpack().
257 """
258
259 value_size = struct.calcsize(format)
260 local_unpack = struct.unpack
261
262 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
263 # not enough to make a significant difference.
264
265 # Note that we expect someone up-stack to catch struct.error and convert
266 # it to _DecodeError -- this way we don't have to set up exception-
267 # handling blocks every time we parse one value.
268
269 def InnerDecode(buffer, pos):
270 new_pos = pos + value_size
271 result = local_unpack(format, buffer[pos:new_pos])[0]
272 return (result, new_pos)
273 return _SimpleDecoder(wire_type, InnerDecode)
274
275
276def _FloatDecoder():
277 """Returns a decoder for a float field.
278
279 This code works around a bug in struct.unpack for non-finite 32-bit
280 floating-point values.
281 """
282
283 local_unpack = struct.unpack
284
285 def InnerDecode(buffer, pos):
286 """Decode serialized float to a float and new position.
287
288 Args:
289 buffer: memoryview of the serialized bytes
290 pos: int, position in the memory view to start at.
291
292 Returns:
293 Tuple[float, int] of the deserialized float value and new position
294 in the serialized data.
295 """
296 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
297 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
298 new_pos = pos + 4
299 float_bytes = buffer[pos:new_pos].tobytes()
300
301 # If this value has all its exponent bits set, then it's non-finite.
302 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
303 # To avoid that, we parse it specially.
304 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
305 # If at least one significand bit is set...
306 if float_bytes[0:3] != b'\x00\x00\x80':
307 return (math.nan, new_pos)
308 # If sign bit is set...
309 if float_bytes[3:4] == b'\xFF':
310 return (-math.inf, new_pos)
311 return (math.inf, new_pos)
312
313 # Note that we expect someone up-stack to catch struct.error and convert
314 # it to _DecodeError -- this way we don't have to set up exception-
315 # handling blocks every time we parse one value.
316 result = local_unpack('<f', float_bytes)[0]
317 return (result, new_pos)
318 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
319
320
321def _DoubleDecoder():
322 """Returns a decoder for a double field.
323
324 This code works around a bug in struct.unpack for not-a-number.
325 """
326
327 local_unpack = struct.unpack
328
329 def InnerDecode(buffer, pos):
330 """Decode serialized double to a double and new position.
331
332 Args:
333 buffer: memoryview of the serialized bytes.
334 pos: int, position in the memory view to start at.
335
336 Returns:
337 Tuple[float, int] of the decoded double value and new position
338 in the serialized data.
339 """
340 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
341 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
342 new_pos = pos + 8
343 double_bytes = buffer[pos:new_pos].tobytes()
344
345 # If this value has all its exponent bits set and at least one significand
346 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
347 # as inf or -inf. To avoid that, we treat it specially.
348 if ((double_bytes[7:8] in b'\x7F\xFF')
349 and (double_bytes[6:7] >= b'\xF0')
350 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
351 return (math.nan, new_pos)
352
353 # Note that we expect someone up-stack to catch struct.error and convert
354 # it to _DecodeError -- this way we don't have to set up exception-
355 # handling blocks every time we parse one value.
356 result = local_unpack('<d', double_bytes)[0]
357 return (result, new_pos)
358 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
359
360
361def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
362 clear_if_default=False):
363 """Returns a decoder for enum field."""
364 enum_type = key.enum_type
365 if is_packed:
366 local_DecodeVarint = _DecodeVarint
367 def DecodePackedField(buffer, pos, end, message, field_dict):
368 """Decode serialized packed enum to its value and a new position.
369
370 Args:
371 buffer: memoryview of the serialized bytes.
372 pos: int, position in the memory view to start at.
373 end: int, end position of serialized data
374 message: Message object to store unknown fields in
375 field_dict: Map[Descriptor, Any] to store decoded values in.
376
377 Returns:
378 int, new position in serialized data.
379 """
380 value = field_dict.get(key)
381 if value is None:
382 value = field_dict.setdefault(key, new_default(message))
383 (endpoint, pos) = local_DecodeVarint(buffer, pos)
384 endpoint += pos
385 if endpoint > end:
386 raise _DecodeError('Truncated message.')
387 while pos < endpoint:
388 value_start_pos = pos
389 (element, pos) = _DecodeSignedVarint32(buffer, pos)
390 # pylint: disable=protected-access
391 if element in enum_type.values_by_number:
392 value.append(element)
393 else:
394 if not message._unknown_fields:
395 message._unknown_fields = []
396 tag_bytes = encoder.TagBytes(field_number,
397 wire_format.WIRETYPE_VARINT)
398
399 message._unknown_fields.append(
400 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
401 # pylint: enable=protected-access
402 if pos > endpoint:
403 if element in enum_type.values_by_number:
404 del value[-1] # Discard corrupt value.
405 else:
406 del message._unknown_fields[-1]
407 # pylint: enable=protected-access
408 raise _DecodeError('Packed element was truncated.')
409 return pos
410 return DecodePackedField
411 elif is_repeated:
412 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
413 tag_len = len(tag_bytes)
414 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
415 """Decode serialized repeated enum to its value and a new position.
416
417 Args:
418 buffer: memoryview of the serialized bytes.
419 pos: int, position in the memory view to start at.
420 end: int, end position of serialized data
421 message: Message object to store unknown fields in
422 field_dict: Map[Descriptor, Any] to store decoded values in.
423
424 Returns:
425 int, new position in serialized data.
426 """
427 value = field_dict.get(key)
428 if value is None:
429 value = field_dict.setdefault(key, new_default(message))
430 while 1:
431 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
432 # pylint: disable=protected-access
433 if element in enum_type.values_by_number:
434 value.append(element)
435 else:
436 if not message._unknown_fields:
437 message._unknown_fields = []
438 message._unknown_fields.append(
439 (tag_bytes, buffer[pos:new_pos].tobytes()))
440 # pylint: enable=protected-access
441 # Predict that the next tag is another copy of the same repeated
442 # field.
443 pos = new_pos + tag_len
444 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
445 # Prediction failed. Return.
446 if new_pos > end:
447 raise _DecodeError('Truncated message.')
448 return new_pos
449 return DecodeRepeatedField
450 else:
451 def DecodeField(buffer, pos, end, message, field_dict):
452 """Decode serialized repeated enum to its value and a new position.
453
454 Args:
455 buffer: memoryview of the serialized bytes.
456 pos: int, position in the memory view to start at.
457 end: int, end position of serialized data
458 message: Message object to store unknown fields in
459 field_dict: Map[Descriptor, Any] to store decoded values in.
460
461 Returns:
462 int, new position in serialized data.
463 """
464 value_start_pos = pos
465 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
466 if pos > end:
467 raise _DecodeError('Truncated message.')
468 if clear_if_default and not enum_value:
469 field_dict.pop(key, None)
470 return pos
471 # pylint: disable=protected-access
472 if enum_value in enum_type.values_by_number:
473 field_dict[key] = enum_value
474 else:
475 if not message._unknown_fields:
476 message._unknown_fields = []
477 tag_bytes = encoder.TagBytes(field_number,
478 wire_format.WIRETYPE_VARINT)
479 message._unknown_fields.append(
480 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
481 # pylint: enable=protected-access
482 return pos
483 return DecodeField
484
485
486# --------------------------------------------------------------------
487
488
489Int32Decoder = _SimpleDecoder(
490 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
491
492Int64Decoder = _SimpleDecoder(
493 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
494
495UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
496UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
497
498SInt32Decoder = _ModifiedDecoder(
499 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
500SInt64Decoder = _ModifiedDecoder(
501 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
502
503# Note that Python conveniently guarantees that when using the '<' prefix on
504# formats, they will also have the same size across all platforms (as opposed
505# to without the prefix, where their sizes depend on the C compiler's basic
506# type sizes).
507Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
508Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
509SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
510SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
511FloatDecoder = _FloatDecoder()
512DoubleDecoder = _DoubleDecoder()
513
514BoolDecoder = _ModifiedDecoder(
515 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
516
517
518def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
519 clear_if_default=False):
520 """Returns a decoder for a string field."""
521
522 local_DecodeVarint = _DecodeVarint
523
524 def _ConvertToUnicode(memview):
525 """Convert byte to unicode."""
526 byte_str = memview.tobytes()
527 try:
528 value = str(byte_str, 'utf-8')
529 except UnicodeDecodeError as e:
530 # add more information to the error message and re-raise it.
531 e.reason = '%s in field: %s' % (e, key.full_name)
532 raise
533
534 return value
535
536 assert not is_packed
537 if is_repeated:
538 tag_bytes = encoder.TagBytes(field_number,
539 wire_format.WIRETYPE_LENGTH_DELIMITED)
540 tag_len = len(tag_bytes)
541 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
542 value = field_dict.get(key)
543 if value is None:
544 value = field_dict.setdefault(key, new_default(message))
545 while 1:
546 (size, pos) = local_DecodeVarint(buffer, pos)
547 new_pos = pos + size
548 if new_pos > end:
549 raise _DecodeError('Truncated string.')
550 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
551 # Predict that the next tag is another copy of the same repeated field.
552 pos = new_pos + tag_len
553 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
554 # Prediction failed. Return.
555 return new_pos
556 return DecodeRepeatedField
557 else:
558 def DecodeField(buffer, pos, end, message, field_dict):
559 (size, pos) = local_DecodeVarint(buffer, pos)
560 new_pos = pos + size
561 if new_pos > end:
562 raise _DecodeError('Truncated string.')
563 if clear_if_default and not size:
564 field_dict.pop(key, None)
565 else:
566 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
567 return new_pos
568 return DecodeField
569
570
571def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
572 clear_if_default=False):
573 """Returns a decoder for a bytes field."""
574
575 local_DecodeVarint = _DecodeVarint
576
577 assert not is_packed
578 if is_repeated:
579 tag_bytes = encoder.TagBytes(field_number,
580 wire_format.WIRETYPE_LENGTH_DELIMITED)
581 tag_len = len(tag_bytes)
582 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
583 value = field_dict.get(key)
584 if value is None:
585 value = field_dict.setdefault(key, new_default(message))
586 while 1:
587 (size, pos) = local_DecodeVarint(buffer, pos)
588 new_pos = pos + size
589 if new_pos > end:
590 raise _DecodeError('Truncated string.')
591 value.append(buffer[pos:new_pos].tobytes())
592 # Predict that the next tag is another copy of the same repeated field.
593 pos = new_pos + tag_len
594 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
595 # Prediction failed. Return.
596 return new_pos
597 return DecodeRepeatedField
598 else:
599 def DecodeField(buffer, pos, end, message, field_dict):
600 (size, pos) = local_DecodeVarint(buffer, pos)
601 new_pos = pos + size
602 if new_pos > end:
603 raise _DecodeError('Truncated string.')
604 if clear_if_default and not size:
605 field_dict.pop(key, None)
606 else:
607 field_dict[key] = buffer[pos:new_pos].tobytes()
608 return new_pos
609 return DecodeField
610
611
612def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
613 """Returns a decoder for a group field."""
614
615 end_tag_bytes = encoder.TagBytes(field_number,
616 wire_format.WIRETYPE_END_GROUP)
617 end_tag_len = len(end_tag_bytes)
618
619 assert not is_packed
620 if is_repeated:
621 tag_bytes = encoder.TagBytes(field_number,
622 wire_format.WIRETYPE_START_GROUP)
623 tag_len = len(tag_bytes)
624 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
625 value = field_dict.get(key)
626 if value is None:
627 value = field_dict.setdefault(key, new_default(message))
628 while 1:
629 value = field_dict.get(key)
630 if value is None:
631 value = field_dict.setdefault(key, new_default(message))
632 # Read sub-message.
633 pos = value.add()._InternalParse(buffer, pos, end)
634 # Read end tag.
635 new_pos = pos+end_tag_len
636 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
637 raise _DecodeError('Missing group end tag.')
638 # Predict that the next tag is another copy of the same repeated field.
639 pos = new_pos + tag_len
640 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
641 # Prediction failed. Return.
642 return new_pos
643 return DecodeRepeatedField
644 else:
645 def DecodeField(buffer, pos, end, message, field_dict):
646 value = field_dict.get(key)
647 if value is None:
648 value = field_dict.setdefault(key, new_default(message))
649 # Read sub-message.
650 pos = value._InternalParse(buffer, pos, end)
651 # Read end tag.
652 new_pos = pos+end_tag_len
653 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
654 raise _DecodeError('Missing group end tag.')
655 return new_pos
656 return DecodeField
657
658
659def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
660 """Returns a decoder for a message field."""
661
662 local_DecodeVarint = _DecodeVarint
663
664 assert not is_packed
665 if is_repeated:
666 tag_bytes = encoder.TagBytes(field_number,
667 wire_format.WIRETYPE_LENGTH_DELIMITED)
668 tag_len = len(tag_bytes)
669 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
670 value = field_dict.get(key)
671 if value is None:
672 value = field_dict.setdefault(key, new_default(message))
673 while 1:
674 # Read length.
675 (size, pos) = local_DecodeVarint(buffer, pos)
676 new_pos = pos + size
677 if new_pos > end:
678 raise _DecodeError('Truncated message.')
679 # Read sub-message.
680 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
681 # The only reason _InternalParse would return early is if it
682 # encountered an end-group tag.
683 raise _DecodeError('Unexpected end-group tag.')
684 # Predict that the next tag is another copy of the same repeated field.
685 pos = new_pos + tag_len
686 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
687 # Prediction failed. Return.
688 return new_pos
689 return DecodeRepeatedField
690 else:
691 def DecodeField(buffer, pos, end, message, field_dict):
692 value = field_dict.get(key)
693 if value is None:
694 value = field_dict.setdefault(key, new_default(message))
695 # Read length.
696 (size, pos) = local_DecodeVarint(buffer, pos)
697 new_pos = pos + size
698 if new_pos > end:
699 raise _DecodeError('Truncated message.')
700 # Read sub-message.
701 if value._InternalParse(buffer, pos, new_pos) != new_pos:
702 # The only reason _InternalParse would return early is if it encountered
703 # an end-group tag.
704 raise _DecodeError('Unexpected end-group tag.')
705 return new_pos
706 return DecodeField
707
708
709# --------------------------------------------------------------------
710
711MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
712
713def MessageSetItemDecoder(descriptor):
714 """Returns a decoder for a MessageSet item.
715
716 The parameter is the message Descriptor.
717
718 The message set message looks like this:
719 message MessageSet {
720 repeated group Item = 1 {
721 required int32 type_id = 2;
722 required string message = 3;
723 }
724 }
725 """
726
727 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
728 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
729 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
730
731 local_ReadTag = ReadTag
732 local_DecodeVarint = _DecodeVarint
733 local_SkipField = SkipField
734
735 def DecodeItem(buffer, pos, end, message, field_dict):
736 """Decode serialized message set to its value and new position.
737
738 Args:
739 buffer: memoryview of the serialized bytes.
740 pos: int, position in the memory view to start at.
741 end: int, end position of serialized data
742 message: Message object to store unknown fields in
743 field_dict: Map[Descriptor, Any] to store decoded values in.
744
745 Returns:
746 int, new position in serialized data.
747 """
748 message_set_item_start = pos
749 type_id = -1
750 message_start = -1
751 message_end = -1
752
753 # Technically, type_id and message can appear in any order, so we need
754 # a little loop here.
755 while 1:
756 (tag_bytes, pos) = local_ReadTag(buffer, pos)
757 if tag_bytes == type_id_tag_bytes:
758 (type_id, pos) = local_DecodeVarint(buffer, pos)
759 elif tag_bytes == message_tag_bytes:
760 (size, message_start) = local_DecodeVarint(buffer, pos)
761 pos = message_end = message_start + size
762 elif tag_bytes == item_end_tag_bytes:
763 break
764 else:
765 pos = SkipField(buffer, pos, end, tag_bytes)
766 if pos == -1:
767 raise _DecodeError('Missing group end tag.')
768
769 if pos > end:
770 raise _DecodeError('Truncated message.')
771
772 if type_id == -1:
773 raise _DecodeError('MessageSet item missing type_id.')
774 if message_start == -1:
775 raise _DecodeError('MessageSet item missing message.')
776
777 extension = message.Extensions._FindExtensionByNumber(type_id)
778 # pylint: disable=protected-access
779 if extension is not None:
780 value = field_dict.get(extension)
781 if value is None:
782 message_type = extension.message_type
783 if not hasattr(message_type, '_concrete_class'):
784 message_factory.GetMessageClass(message_type)
785 value = field_dict.setdefault(
786 extension, message_type._concrete_class())
787 if value._InternalParse(buffer, message_start,message_end) != message_end:
788 # The only reason _InternalParse would return early is if it encountered
789 # an end-group tag.
790 raise _DecodeError('Unexpected end-group tag.')
791 else:
792 if not message._unknown_fields:
793 message._unknown_fields = []
794 message._unknown_fields.append(
795 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
796 # pylint: enable=protected-access
797
798 return pos
799
800 return DecodeItem
801
802
803def UnknownMessageSetItemDecoder():
804 """Returns a decoder for a Unknown MessageSet item."""
805
806 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
807 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
808 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
809
810 def DecodeUnknownItem(buffer):
811 pos = 0
812 end = len(buffer)
813 message_start = -1
814 message_end = -1
815 while 1:
816 (tag_bytes, pos) = ReadTag(buffer, pos)
817 if tag_bytes == type_id_tag_bytes:
818 (type_id, pos) = _DecodeVarint(buffer, pos)
819 elif tag_bytes == message_tag_bytes:
820 (size, message_start) = _DecodeVarint(buffer, pos)
821 pos = message_end = message_start + size
822 elif tag_bytes == item_end_tag_bytes:
823 break
824 else:
825 pos = SkipField(buffer, pos, end, tag_bytes)
826 if pos == -1:
827 raise _DecodeError('Missing group end tag.')
828
829 if pos > end:
830 raise _DecodeError('Truncated message.')
831
832 if type_id == -1:
833 raise _DecodeError('MessageSet item missing type_id.')
834 if message_start == -1:
835 raise _DecodeError('MessageSet item missing message.')
836
837 return (type_id, buffer[message_start:message_end].tobytes())
838
839 return DecodeUnknownItem
840
841# --------------------------------------------------------------------
842
843def MapDecoder(field_descriptor, new_default, is_message_map):
844 """Returns a decoder for a map field."""
845
846 key = field_descriptor
847 tag_bytes = encoder.TagBytes(field_descriptor.number,
848 wire_format.WIRETYPE_LENGTH_DELIMITED)
849 tag_len = len(tag_bytes)
850 local_DecodeVarint = _DecodeVarint
851 # Can't read _concrete_class yet; might not be initialized.
852 message_type = field_descriptor.message_type
853
854 def DecodeMap(buffer, pos, end, message, field_dict):
855 submsg = message_type._concrete_class()
856 value = field_dict.get(key)
857 if value is None:
858 value = field_dict.setdefault(key, new_default(message))
859 while 1:
860 # Read length.
861 (size, pos) = local_DecodeVarint(buffer, pos)
862 new_pos = pos + size
863 if new_pos > end:
864 raise _DecodeError('Truncated message.')
865 # Read sub-message.
866 submsg.Clear()
867 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
868 # The only reason _InternalParse would return early is if it
869 # encountered an end-group tag.
870 raise _DecodeError('Unexpected end-group tag.')
871
872 if is_message_map:
873 value[submsg.key].CopyFrom(submsg.value)
874 else:
875 value[submsg.key] = submsg.value
876
877 # Predict that the next tag is another copy of the same repeated field.
878 pos = new_pos + tag_len
879 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
880 # Prediction failed. Return.
881 return new_pos
882
883 return DecodeMap
884
885# --------------------------------------------------------------------
886# Optimization is not as heavy here because calls to SkipField() are rare,
887# except for handling end-group tags.
888
889def _SkipVarint(buffer, pos, end):
890 """Skip a varint value. Returns the new position."""
891 # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
892 # With this code, ord(b'') raises TypeError. Both are handled in
893 # python_message.py to generate a 'Truncated message' error.
894 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
895 pos += 1
896 pos += 1
897 if pos > end:
898 raise _DecodeError('Truncated message.')
899 return pos
900
901def _SkipFixed64(buffer, pos, end):
902 """Skip a fixed64 value. Returns the new position."""
903
904 pos += 8
905 if pos > end:
906 raise _DecodeError('Truncated message.')
907 return pos
908
909
910def _DecodeFixed64(buffer, pos):
911 """Decode a fixed64."""
912 new_pos = pos + 8
913 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
914
915
916def _SkipLengthDelimited(buffer, pos, end):
917 """Skip a length-delimited value. Returns the new position."""
918
919 (size, pos) = _DecodeVarint(buffer, pos)
920 pos += size
921 if pos > end:
922 raise _DecodeError('Truncated message.')
923 return pos
924
925
926def _SkipGroup(buffer, pos, end):
927 """Skip sub-group. Returns the new position."""
928
929 while 1:
930 (tag_bytes, pos) = ReadTag(buffer, pos)
931 new_pos = SkipField(buffer, pos, end, tag_bytes)
932 if new_pos == -1:
933 return pos
934 pos = new_pos
935
936
937def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
938 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
939
940 unknown_field_set = containers.UnknownFieldSet()
941 while end_pos is None or pos < end_pos:
942 (tag_bytes, pos) = ReadTag(buffer, pos)
943 (tag, _) = _DecodeVarint(tag_bytes, 0)
944 field_number, wire_type = wire_format.UnpackTag(tag)
945 if wire_type == wire_format.WIRETYPE_END_GROUP:
946 break
947 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
948 # pylint: disable=protected-access
949 unknown_field_set._add(field_number, wire_type, data)
950
951 return (unknown_field_set, pos)
952
953
954def _DecodeUnknownField(buffer, pos, wire_type):
955 """Decode a unknown field. Returns the UnknownField and new position."""
956
957 if wire_type == wire_format.WIRETYPE_VARINT:
958 (data, pos) = _DecodeVarint(buffer, pos)
959 elif wire_type == wire_format.WIRETYPE_FIXED64:
960 (data, pos) = _DecodeFixed64(buffer, pos)
961 elif wire_type == wire_format.WIRETYPE_FIXED32:
962 (data, pos) = _DecodeFixed32(buffer, pos)
963 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
964 (size, pos) = _DecodeVarint(buffer, pos)
965 data = buffer[pos:pos+size].tobytes()
966 pos += size
967 elif wire_type == wire_format.WIRETYPE_START_GROUP:
968 (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
969 elif wire_type == wire_format.WIRETYPE_END_GROUP:
970 return (0, -1)
971 else:
972 raise _DecodeError('Wrong wire type in tag.')
973
974 return (data, pos)
975
976
977def _EndGroup(buffer, pos, end):
978 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
979
980 return -1
981
982
983def _SkipFixed32(buffer, pos, end):
984 """Skip a fixed32 value. Returns the new position."""
985
986 pos += 4
987 if pos > end:
988 raise _DecodeError('Truncated message.')
989 return pos
990
991
992def _DecodeFixed32(buffer, pos):
993 """Decode a fixed32."""
994
995 new_pos = pos + 4
996 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
997
998
999def _RaiseInvalidWireType(buffer, pos, end):
1000 """Skip function for unknown wire types. Raises an exception."""
1001
1002 raise _DecodeError('Tag had invalid wire type.')
1003
1004def _FieldSkipper():
1005 """Constructs the SkipField function."""
1006
1007 WIRETYPE_TO_SKIPPER = [
1008 _SkipVarint,
1009 _SkipFixed64,
1010 _SkipLengthDelimited,
1011 _SkipGroup,
1012 _EndGroup,
1013 _SkipFixed32,
1014 _RaiseInvalidWireType,
1015 _RaiseInvalidWireType,
1016 ]
1017
1018 wiretype_mask = wire_format.TAG_TYPE_MASK
1019
1020 def SkipField(buffer, pos, end, tag_bytes):
1021 """Skips a field with the specified tag.
1022
1023 |pos| should point to the byte immediately after the tag.
1024
1025 Returns:
1026 The new position (after the tag value), or -1 if the tag is an end-group
1027 tag (in which case the calling loop should break).
1028 """
1029
1030 # The wire type is always in the first byte since varints are little-endian.
1031 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1032 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1033
1034 return SkipField
1035
1036SkipField = _FieldSkipper()