1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Code for decoding protocol buffer primitives.
9
10This code is very similar to encoder.py -- read the docs for that module first.
11
12A "decoder" is a function with the signature:
13 Decode(buffer, pos, end, message, field_dict)
14The arguments are:
15 buffer: The string containing the encoded message.
16 pos: The current position in the string.
17 end: The position in the string where the current message ends. May be
18 less than len(buffer) if we're reading a sub-message.
19 message: The message object into which we're parsing.
20 field_dict: message._fields (avoids a hashtable lookup).
21The decoder reads the field and stores it into field_dict, returning the new
22buffer position. A decoder for a repeated field may proactively decode all of
23the elements of that field, if they appear consecutively.
24
25Note that decoders may throw any of the following:
26 IndexError: Indicates a truncated message.
27 struct.error: Unpacking of a fixed-width field failed.
28 message.DecodeError: Other errors.
29
30Decoders are expected to raise an exception if they are called with pos > end.
31This allows callers to be lax about bounds checking: it's fineto read past
32"end" as long as you are sure that someone else will notice and throw an
33exception later on.
34
35Something up the call stack is expected to catch IndexError and struct.error
36and convert them to message.DecodeError.
37
38Decoders are constructed using decoder constructors with the signature:
39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
40The arguments are:
41 field_number: The field number of the field we want to decode.
42 is_repeated: Is the field a repeated field? (bool)
43 is_packed: Is the field a packed field? (bool)
44 key: The key to use when looking up the field within field_dict.
45 (This is actually the FieldDescriptor but nothing in this
46 file should depend on that.)
47 new_default: A function which takes a message object as a parameter and
48 returns a new instance of the default value for this field.
49 (This is called for repeated fields and sub-messages, when an
50 instance does not already exist.)
51
52As with encoders, we define a decoder constructor for every type of field.
53Then, for every field of every message class we construct an actual decoder.
54That decoder goes into a dict indexed by tag, so when we decode a message
55we repeatedly read a tag, look up the corresponding decoder, and invoke it.
56"""
57
58__author__ = 'kenton@google.com (Kenton Varda)'
59
60import math
61import struct
62
63from google.protobuf.internal import containers
64from google.protobuf.internal import encoder
65from google.protobuf.internal import wire_format
66from google.protobuf import message
67
68
69# This is not for optimization, but rather to avoid conflicts with local
70# variables named "message".
71_DecodeError = message.DecodeError
72
73
74def _VarintDecoder(mask, result_type):
75 """Return an encoder for a basic varint value (does not include tag).
76
77 Decoded values will be bitwise-anded with the given mask before being
78 returned, e.g. to limit them to 32 bits. The returned decoder does not
79 take the usual "end" parameter -- the caller is expected to do bounds checking
80 after the fact (often the caller can defer such checking until later). The
81 decoder returns a (value, new_pos) pair.
82 """
83
84 def DecodeVarint(buffer, pos):
85 result = 0
86 shift = 0
87 while 1:
88 b = buffer[pos]
89 result |= ((b & 0x7f) << shift)
90 pos += 1
91 if not (b & 0x80):
92 result &= mask
93 result = result_type(result)
94 return (result, pos)
95 shift += 7
96 if shift >= 64:
97 raise _DecodeError('Too many bytes when decoding varint.')
98 return DecodeVarint
99
100
101def _SignedVarintDecoder(bits, result_type):
102 """Like _VarintDecoder() but decodes signed values."""
103
104 signbit = 1 << (bits - 1)
105 mask = (1 << bits) - 1
106
107 def DecodeVarint(buffer, pos):
108 result = 0
109 shift = 0
110 while 1:
111 b = buffer[pos]
112 result |= ((b & 0x7f) << shift)
113 pos += 1
114 if not (b & 0x80):
115 result &= mask
116 result = (result ^ signbit) - signbit
117 result = result_type(result)
118 return (result, pos)
119 shift += 7
120 if shift >= 64:
121 raise _DecodeError('Too many bytes when decoding varint.')
122 return DecodeVarint
123
124# All 32-bit and 64-bit values are represented as int.
125_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
126_DecodeSignedVarint = _SignedVarintDecoder(64, int)
127
128# Use these versions for values which must be limited to 32 bits.
129_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
130_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
131
132
133def ReadTag(buffer, pos):
134 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
135
136 We return the raw bytes of the tag rather than decoding them. The raw
137 bytes can then be used to look up the proper decoder. This effectively allows
138 us to trade some work that would be done in pure-python (decoding a varint)
139 for work that is done in C (searching for a byte string in a hash table).
140 In a low-level language it would be much cheaper to decode the varint and
141 use that, but not in Python.
142
143 Args:
144 buffer: memoryview object of the encoded bytes
145 pos: int of the current position to start from
146
147 Returns:
148 Tuple[bytes, int] of the tag data and new position.
149 """
150 start = pos
151 while buffer[pos] & 0x80:
152 pos += 1
153 pos += 1
154
155 tag_bytes = buffer[start:pos].tobytes()
156 return tag_bytes, pos
157
158
159# --------------------------------------------------------------------
160
161
162def _SimpleDecoder(wire_type, decode_value):
163 """Return a constructor for a decoder for fields of a particular type.
164
165 Args:
166 wire_type: The field's wire type.
167 decode_value: A function which decodes an individual value, e.g.
168 _DecodeVarint()
169 """
170
171 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172 clear_if_default=False):
173 if is_packed:
174 local_DecodeVarint = _DecodeVarint
175 def DecodePackedField(buffer, pos, end, message, field_dict):
176 value = field_dict.get(key)
177 if value is None:
178 value = field_dict.setdefault(key, new_default(message))
179 (endpoint, pos) = local_DecodeVarint(buffer, pos)
180 endpoint += pos
181 if endpoint > end:
182 raise _DecodeError('Truncated message.')
183 while pos < endpoint:
184 (element, pos) = decode_value(buffer, pos)
185 value.append(element)
186 if pos > endpoint:
187 del value[-1] # Discard corrupt value.
188 raise _DecodeError('Packed element was truncated.')
189 return pos
190 return DecodePackedField
191 elif is_repeated:
192 tag_bytes = encoder.TagBytes(field_number, wire_type)
193 tag_len = len(tag_bytes)
194 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
195 value = field_dict.get(key)
196 if value is None:
197 value = field_dict.setdefault(key, new_default(message))
198 while 1:
199 (element, new_pos) = decode_value(buffer, pos)
200 value.append(element)
201 # Predict that the next tag is another copy of the same repeated
202 # field.
203 pos = new_pos + tag_len
204 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
205 # Prediction failed. Return.
206 if new_pos > end:
207 raise _DecodeError('Truncated message.')
208 return new_pos
209 return DecodeRepeatedField
210 else:
211 def DecodeField(buffer, pos, end, message, field_dict):
212 (new_value, pos) = decode_value(buffer, pos)
213 if pos > end:
214 raise _DecodeError('Truncated message.')
215 if clear_if_default and not new_value:
216 field_dict.pop(key, None)
217 else:
218 field_dict[key] = new_value
219 return pos
220 return DecodeField
221
222 return SpecificDecoder
223
224
225def _ModifiedDecoder(wire_type, decode_value, modify_value):
226 """Like SimpleDecoder but additionally invokes modify_value on every value
227 before storing it. Usually modify_value is ZigZagDecode.
228 """
229
230 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
231 # not enough to make a significant difference.
232
233 def InnerDecode(buffer, pos):
234 (result, new_pos) = decode_value(buffer, pos)
235 return (modify_value(result), new_pos)
236 return _SimpleDecoder(wire_type, InnerDecode)
237
238
239def _StructPackDecoder(wire_type, format):
240 """Return a constructor for a decoder for a fixed-width field.
241
242 Args:
243 wire_type: The field's wire type.
244 format: The format string to pass to struct.unpack().
245 """
246
247 value_size = struct.calcsize(format)
248 local_unpack = struct.unpack
249
250 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
251 # not enough to make a significant difference.
252
253 # Note that we expect someone up-stack to catch struct.error and convert
254 # it to _DecodeError -- this way we don't have to set up exception-
255 # handling blocks every time we parse one value.
256
257 def InnerDecode(buffer, pos):
258 new_pos = pos + value_size
259 result = local_unpack(format, buffer[pos:new_pos])[0]
260 return (result, new_pos)
261 return _SimpleDecoder(wire_type, InnerDecode)
262
263
264def _FloatDecoder():
265 """Returns a decoder for a float field.
266
267 This code works around a bug in struct.unpack for non-finite 32-bit
268 floating-point values.
269 """
270
271 local_unpack = struct.unpack
272
273 def InnerDecode(buffer, pos):
274 """Decode serialized float to a float and new position.
275
276 Args:
277 buffer: memoryview of the serialized bytes
278 pos: int, position in the memory view to start at.
279
280 Returns:
281 Tuple[float, int] of the deserialized float value and new position
282 in the serialized data.
283 """
284 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
285 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
286 new_pos = pos + 4
287 float_bytes = buffer[pos:new_pos].tobytes()
288
289 # If this value has all its exponent bits set, then it's non-finite.
290 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
291 # To avoid that, we parse it specially.
292 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
293 # If at least one significand bit is set...
294 if float_bytes[0:3] != b'\x00\x00\x80':
295 return (math.nan, new_pos)
296 # If sign bit is set...
297 if float_bytes[3:4] == b'\xFF':
298 return (-math.inf, new_pos)
299 return (math.inf, new_pos)
300
301 # Note that we expect someone up-stack to catch struct.error and convert
302 # it to _DecodeError -- this way we don't have to set up exception-
303 # handling blocks every time we parse one value.
304 result = local_unpack('<f', float_bytes)[0]
305 return (result, new_pos)
306 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
307
308
309def _DoubleDecoder():
310 """Returns a decoder for a double field.
311
312 This code works around a bug in struct.unpack for not-a-number.
313 """
314
315 local_unpack = struct.unpack
316
317 def InnerDecode(buffer, pos):
318 """Decode serialized double to a double and new position.
319
320 Args:
321 buffer: memoryview of the serialized bytes.
322 pos: int, position in the memory view to start at.
323
324 Returns:
325 Tuple[float, int] of the decoded double value and new position
326 in the serialized data.
327 """
328 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
329 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
330 new_pos = pos + 8
331 double_bytes = buffer[pos:new_pos].tobytes()
332
333 # If this value has all its exponent bits set and at least one significand
334 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
335 # as inf or -inf. To avoid that, we treat it specially.
336 if ((double_bytes[7:8] in b'\x7F\xFF')
337 and (double_bytes[6:7] >= b'\xF0')
338 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
339 return (math.nan, new_pos)
340
341 # Note that we expect someone up-stack to catch struct.error and convert
342 # it to _DecodeError -- this way we don't have to set up exception-
343 # handling blocks every time we parse one value.
344 result = local_unpack('<d', double_bytes)[0]
345 return (result, new_pos)
346 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
347
348
349def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
350 clear_if_default=False):
351 """Returns a decoder for enum field."""
352 enum_type = key.enum_type
353 if is_packed:
354 local_DecodeVarint = _DecodeVarint
355 def DecodePackedField(buffer, pos, end, message, field_dict):
356 """Decode serialized packed enum to its value and a new position.
357
358 Args:
359 buffer: memoryview of the serialized bytes.
360 pos: int, position in the memory view to start at.
361 end: int, end position of serialized data
362 message: Message object to store unknown fields in
363 field_dict: Map[Descriptor, Any] to store decoded values in.
364
365 Returns:
366 int, new position in serialized data.
367 """
368 value = field_dict.get(key)
369 if value is None:
370 value = field_dict.setdefault(key, new_default(message))
371 (endpoint, pos) = local_DecodeVarint(buffer, pos)
372 endpoint += pos
373 if endpoint > end:
374 raise _DecodeError('Truncated message.')
375 while pos < endpoint:
376 value_start_pos = pos
377 (element, pos) = _DecodeSignedVarint32(buffer, pos)
378 # pylint: disable=protected-access
379 if element in enum_type.values_by_number:
380 value.append(element)
381 else:
382 if not message._unknown_fields:
383 message._unknown_fields = []
384 tag_bytes = encoder.TagBytes(field_number,
385 wire_format.WIRETYPE_VARINT)
386
387 message._unknown_fields.append(
388 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
389 # pylint: enable=protected-access
390 if pos > endpoint:
391 if element in enum_type.values_by_number:
392 del value[-1] # Discard corrupt value.
393 else:
394 del message._unknown_fields[-1]
395 # pylint: enable=protected-access
396 raise _DecodeError('Packed element was truncated.')
397 return pos
398 return DecodePackedField
399 elif is_repeated:
400 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
401 tag_len = len(tag_bytes)
402 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
403 """Decode serialized repeated enum to its value and a new position.
404
405 Args:
406 buffer: memoryview of the serialized bytes.
407 pos: int, position in the memory view to start at.
408 end: int, end position of serialized data
409 message: Message object to store unknown fields in
410 field_dict: Map[Descriptor, Any] to store decoded values in.
411
412 Returns:
413 int, new position in serialized data.
414 """
415 value = field_dict.get(key)
416 if value is None:
417 value = field_dict.setdefault(key, new_default(message))
418 while 1:
419 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
420 # pylint: disable=protected-access
421 if element in enum_type.values_by_number:
422 value.append(element)
423 else:
424 if not message._unknown_fields:
425 message._unknown_fields = []
426 message._unknown_fields.append(
427 (tag_bytes, buffer[pos:new_pos].tobytes()))
428 # pylint: enable=protected-access
429 # Predict that the next tag is another copy of the same repeated
430 # field.
431 pos = new_pos + tag_len
432 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
433 # Prediction failed. Return.
434 if new_pos > end:
435 raise _DecodeError('Truncated message.')
436 return new_pos
437 return DecodeRepeatedField
438 else:
439 def DecodeField(buffer, pos, end, message, field_dict):
440 """Decode serialized repeated enum to its value and a new position.
441
442 Args:
443 buffer: memoryview of the serialized bytes.
444 pos: int, position in the memory view to start at.
445 end: int, end position of serialized data
446 message: Message object to store unknown fields in
447 field_dict: Map[Descriptor, Any] to store decoded values in.
448
449 Returns:
450 int, new position in serialized data.
451 """
452 value_start_pos = pos
453 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
454 if pos > end:
455 raise _DecodeError('Truncated message.')
456 if clear_if_default and not enum_value:
457 field_dict.pop(key, None)
458 return pos
459 # pylint: disable=protected-access
460 if enum_value in enum_type.values_by_number:
461 field_dict[key] = enum_value
462 else:
463 if not message._unknown_fields:
464 message._unknown_fields = []
465 tag_bytes = encoder.TagBytes(field_number,
466 wire_format.WIRETYPE_VARINT)
467 message._unknown_fields.append(
468 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
469 # pylint: enable=protected-access
470 return pos
471 return DecodeField
472
473
474# --------------------------------------------------------------------
475
476
477Int32Decoder = _SimpleDecoder(
478 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
479
480Int64Decoder = _SimpleDecoder(
481 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
482
483UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
484UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
485
486SInt32Decoder = _ModifiedDecoder(
487 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
488SInt64Decoder = _ModifiedDecoder(
489 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
490
491# Note that Python conveniently guarantees that when using the '<' prefix on
492# formats, they will also have the same size across all platforms (as opposed
493# to without the prefix, where their sizes depend on the C compiler's basic
494# type sizes).
495Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
496Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
497SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
498SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
499FloatDecoder = _FloatDecoder()
500DoubleDecoder = _DoubleDecoder()
501
502BoolDecoder = _ModifiedDecoder(
503 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
504
505
506def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
507 clear_if_default=False):
508 """Returns a decoder for a string field."""
509
510 local_DecodeVarint = _DecodeVarint
511
512 def _ConvertToUnicode(memview):
513 """Convert byte to unicode."""
514 byte_str = memview.tobytes()
515 try:
516 value = str(byte_str, 'utf-8')
517 except UnicodeDecodeError as e:
518 # add more information to the error message and re-raise it.
519 e.reason = '%s in field: %s' % (e, key.full_name)
520 raise
521
522 return value
523
524 assert not is_packed
525 if is_repeated:
526 tag_bytes = encoder.TagBytes(field_number,
527 wire_format.WIRETYPE_LENGTH_DELIMITED)
528 tag_len = len(tag_bytes)
529 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
530 value = field_dict.get(key)
531 if value is None:
532 value = field_dict.setdefault(key, new_default(message))
533 while 1:
534 (size, pos) = local_DecodeVarint(buffer, pos)
535 new_pos = pos + size
536 if new_pos > end:
537 raise _DecodeError('Truncated string.')
538 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
539 # Predict that the next tag is another copy of the same repeated field.
540 pos = new_pos + tag_len
541 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
542 # Prediction failed. Return.
543 return new_pos
544 return DecodeRepeatedField
545 else:
546 def DecodeField(buffer, pos, end, message, field_dict):
547 (size, pos) = local_DecodeVarint(buffer, pos)
548 new_pos = pos + size
549 if new_pos > end:
550 raise _DecodeError('Truncated string.')
551 if clear_if_default and not size:
552 field_dict.pop(key, None)
553 else:
554 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
555 return new_pos
556 return DecodeField
557
558
559def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
560 clear_if_default=False):
561 """Returns a decoder for a bytes field."""
562
563 local_DecodeVarint = _DecodeVarint
564
565 assert not is_packed
566 if is_repeated:
567 tag_bytes = encoder.TagBytes(field_number,
568 wire_format.WIRETYPE_LENGTH_DELIMITED)
569 tag_len = len(tag_bytes)
570 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
571 value = field_dict.get(key)
572 if value is None:
573 value = field_dict.setdefault(key, new_default(message))
574 while 1:
575 (size, pos) = local_DecodeVarint(buffer, pos)
576 new_pos = pos + size
577 if new_pos > end:
578 raise _DecodeError('Truncated string.')
579 value.append(buffer[pos:new_pos].tobytes())
580 # Predict that the next tag is another copy of the same repeated field.
581 pos = new_pos + tag_len
582 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
583 # Prediction failed. Return.
584 return new_pos
585 return DecodeRepeatedField
586 else:
587 def DecodeField(buffer, pos, end, message, field_dict):
588 (size, pos) = local_DecodeVarint(buffer, pos)
589 new_pos = pos + size
590 if new_pos > end:
591 raise _DecodeError('Truncated string.')
592 if clear_if_default and not size:
593 field_dict.pop(key, None)
594 else:
595 field_dict[key] = buffer[pos:new_pos].tobytes()
596 return new_pos
597 return DecodeField
598
599
600def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
601 """Returns a decoder for a group field."""
602
603 end_tag_bytes = encoder.TagBytes(field_number,
604 wire_format.WIRETYPE_END_GROUP)
605 end_tag_len = len(end_tag_bytes)
606
607 assert not is_packed
608 if is_repeated:
609 tag_bytes = encoder.TagBytes(field_number,
610 wire_format.WIRETYPE_START_GROUP)
611 tag_len = len(tag_bytes)
612 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
613 value = field_dict.get(key)
614 if value is None:
615 value = field_dict.setdefault(key, new_default(message))
616 while 1:
617 value = field_dict.get(key)
618 if value is None:
619 value = field_dict.setdefault(key, new_default(message))
620 # Read sub-message.
621 pos = value.add()._InternalParse(buffer, pos, end)
622 # Read end tag.
623 new_pos = pos+end_tag_len
624 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
625 raise _DecodeError('Missing group end tag.')
626 # Predict that the next tag is another copy of the same repeated field.
627 pos = new_pos + tag_len
628 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
629 # Prediction failed. Return.
630 return new_pos
631 return DecodeRepeatedField
632 else:
633 def DecodeField(buffer, pos, end, message, field_dict):
634 value = field_dict.get(key)
635 if value is None:
636 value = field_dict.setdefault(key, new_default(message))
637 # Read sub-message.
638 pos = value._InternalParse(buffer, pos, end)
639 # Read end tag.
640 new_pos = pos+end_tag_len
641 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
642 raise _DecodeError('Missing group end tag.')
643 return new_pos
644 return DecodeField
645
646
647def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
648 """Returns a decoder for a message field."""
649
650 local_DecodeVarint = _DecodeVarint
651
652 assert not is_packed
653 if is_repeated:
654 tag_bytes = encoder.TagBytes(field_number,
655 wire_format.WIRETYPE_LENGTH_DELIMITED)
656 tag_len = len(tag_bytes)
657 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
658 value = field_dict.get(key)
659 if value is None:
660 value = field_dict.setdefault(key, new_default(message))
661 while 1:
662 # Read length.
663 (size, pos) = local_DecodeVarint(buffer, pos)
664 new_pos = pos + size
665 if new_pos > end:
666 raise _DecodeError('Truncated message.')
667 # Read sub-message.
668 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
669 # The only reason _InternalParse would return early is if it
670 # encountered an end-group tag.
671 raise _DecodeError('Unexpected end-group tag.')
672 # Predict that the next tag is another copy of the same repeated field.
673 pos = new_pos + tag_len
674 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
675 # Prediction failed. Return.
676 return new_pos
677 return DecodeRepeatedField
678 else:
679 def DecodeField(buffer, pos, end, message, field_dict):
680 value = field_dict.get(key)
681 if value is None:
682 value = field_dict.setdefault(key, new_default(message))
683 # Read length.
684 (size, pos) = local_DecodeVarint(buffer, pos)
685 new_pos = pos + size
686 if new_pos > end:
687 raise _DecodeError('Truncated message.')
688 # Read sub-message.
689 if value._InternalParse(buffer, pos, new_pos) != new_pos:
690 # The only reason _InternalParse would return early is if it encountered
691 # an end-group tag.
692 raise _DecodeError('Unexpected end-group tag.')
693 return new_pos
694 return DecodeField
695
696
697# --------------------------------------------------------------------
698
699MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
700
701def MessageSetItemDecoder(descriptor):
702 """Returns a decoder for a MessageSet item.
703
704 The parameter is the message Descriptor.
705
706 The message set message looks like this:
707 message MessageSet {
708 repeated group Item = 1 {
709 required int32 type_id = 2;
710 required string message = 3;
711 }
712 }
713 """
714
715 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
716 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
717 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
718
719 local_ReadTag = ReadTag
720 local_DecodeVarint = _DecodeVarint
721 local_SkipField = SkipField
722
723 def DecodeItem(buffer, pos, end, message, field_dict):
724 """Decode serialized message set to its value and new position.
725
726 Args:
727 buffer: memoryview of the serialized bytes.
728 pos: int, position in the memory view to start at.
729 end: int, end position of serialized data
730 message: Message object to store unknown fields in
731 field_dict: Map[Descriptor, Any] to store decoded values in.
732
733 Returns:
734 int, new position in serialized data.
735 """
736 message_set_item_start = pos
737 type_id = -1
738 message_start = -1
739 message_end = -1
740
741 # Technically, type_id and message can appear in any order, so we need
742 # a little loop here.
743 while 1:
744 (tag_bytes, pos) = local_ReadTag(buffer, pos)
745 if tag_bytes == type_id_tag_bytes:
746 (type_id, pos) = local_DecodeVarint(buffer, pos)
747 elif tag_bytes == message_tag_bytes:
748 (size, message_start) = local_DecodeVarint(buffer, pos)
749 pos = message_end = message_start + size
750 elif tag_bytes == item_end_tag_bytes:
751 break
752 else:
753 pos = SkipField(buffer, pos, end, tag_bytes)
754 if pos == -1:
755 raise _DecodeError('Missing group end tag.')
756
757 if pos > end:
758 raise _DecodeError('Truncated message.')
759
760 if type_id == -1:
761 raise _DecodeError('MessageSet item missing type_id.')
762 if message_start == -1:
763 raise _DecodeError('MessageSet item missing message.')
764
765 extension = message.Extensions._FindExtensionByNumber(type_id)
766 # pylint: disable=protected-access
767 if extension is not None:
768 value = field_dict.get(extension)
769 if value is None:
770 message_type = extension.message_type
771 if not hasattr(message_type, '_concrete_class'):
772 message_factory.GetMessageClass(message_type)
773 value = field_dict.setdefault(
774 extension, message_type._concrete_class())
775 if value._InternalParse(buffer, message_start,message_end) != message_end:
776 # The only reason _InternalParse would return early is if it encountered
777 # an end-group tag.
778 raise _DecodeError('Unexpected end-group tag.')
779 else:
780 if not message._unknown_fields:
781 message._unknown_fields = []
782 message._unknown_fields.append(
783 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
784 # pylint: enable=protected-access
785
786 return pos
787
788 return DecodeItem
789
790
791def UnknownMessageSetItemDecoder():
792 """Returns a decoder for a Unknown MessageSet item."""
793
794 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
795 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
796 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
797
798 def DecodeUnknownItem(buffer):
799 pos = 0
800 end = len(buffer)
801 message_start = -1
802 message_end = -1
803 while 1:
804 (tag_bytes, pos) = ReadTag(buffer, pos)
805 if tag_bytes == type_id_tag_bytes:
806 (type_id, pos) = _DecodeVarint(buffer, pos)
807 elif tag_bytes == message_tag_bytes:
808 (size, message_start) = _DecodeVarint(buffer, pos)
809 pos = message_end = message_start + size
810 elif tag_bytes == item_end_tag_bytes:
811 break
812 else:
813 pos = SkipField(buffer, pos, end, tag_bytes)
814 if pos == -1:
815 raise _DecodeError('Missing group end tag.')
816
817 if pos > end:
818 raise _DecodeError('Truncated message.')
819
820 if type_id == -1:
821 raise _DecodeError('MessageSet item missing type_id.')
822 if message_start == -1:
823 raise _DecodeError('MessageSet item missing message.')
824
825 return (type_id, buffer[message_start:message_end].tobytes())
826
827 return DecodeUnknownItem
828
829# --------------------------------------------------------------------
830
831def MapDecoder(field_descriptor, new_default, is_message_map):
832 """Returns a decoder for a map field."""
833
834 key = field_descriptor
835 tag_bytes = encoder.TagBytes(field_descriptor.number,
836 wire_format.WIRETYPE_LENGTH_DELIMITED)
837 tag_len = len(tag_bytes)
838 local_DecodeVarint = _DecodeVarint
839 # Can't read _concrete_class yet; might not be initialized.
840 message_type = field_descriptor.message_type
841
842 def DecodeMap(buffer, pos, end, message, field_dict):
843 submsg = message_type._concrete_class()
844 value = field_dict.get(key)
845 if value is None:
846 value = field_dict.setdefault(key, new_default(message))
847 while 1:
848 # Read length.
849 (size, pos) = local_DecodeVarint(buffer, pos)
850 new_pos = pos + size
851 if new_pos > end:
852 raise _DecodeError('Truncated message.')
853 # Read sub-message.
854 submsg.Clear()
855 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
856 # The only reason _InternalParse would return early is if it
857 # encountered an end-group tag.
858 raise _DecodeError('Unexpected end-group tag.')
859
860 if is_message_map:
861 value[submsg.key].CopyFrom(submsg.value)
862 else:
863 value[submsg.key] = submsg.value
864
865 # Predict that the next tag is another copy of the same repeated field.
866 pos = new_pos + tag_len
867 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
868 # Prediction failed. Return.
869 return new_pos
870
871 return DecodeMap
872
873# --------------------------------------------------------------------
874# Optimization is not as heavy here because calls to SkipField() are rare,
875# except for handling end-group tags.
876
877def _SkipVarint(buffer, pos, end):
878 """Skip a varint value. Returns the new position."""
879 # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
880 # With this code, ord(b'') raises TypeError. Both are handled in
881 # python_message.py to generate a 'Truncated message' error.
882 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
883 pos += 1
884 pos += 1
885 if pos > end:
886 raise _DecodeError('Truncated message.')
887 return pos
888
889def _SkipFixed64(buffer, pos, end):
890 """Skip a fixed64 value. Returns the new position."""
891
892 pos += 8
893 if pos > end:
894 raise _DecodeError('Truncated message.')
895 return pos
896
897
898def _DecodeFixed64(buffer, pos):
899 """Decode a fixed64."""
900 new_pos = pos + 8
901 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
902
903
904def _SkipLengthDelimited(buffer, pos, end):
905 """Skip a length-delimited value. Returns the new position."""
906
907 (size, pos) = _DecodeVarint(buffer, pos)
908 pos += size
909 if pos > end:
910 raise _DecodeError('Truncated message.')
911 return pos
912
913
914def _SkipGroup(buffer, pos, end):
915 """Skip sub-group. Returns the new position."""
916
917 while 1:
918 (tag_bytes, pos) = ReadTag(buffer, pos)
919 new_pos = SkipField(buffer, pos, end, tag_bytes)
920 if new_pos == -1:
921 return pos
922 pos = new_pos
923
924
925def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
926 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
927
928 unknown_field_set = containers.UnknownFieldSet()
929 while end_pos is None or pos < end_pos:
930 (tag_bytes, pos) = ReadTag(buffer, pos)
931 (tag, _) = _DecodeVarint(tag_bytes, 0)
932 field_number, wire_type = wire_format.UnpackTag(tag)
933 if wire_type == wire_format.WIRETYPE_END_GROUP:
934 break
935 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
936 # pylint: disable=protected-access
937 unknown_field_set._add(field_number, wire_type, data)
938
939 return (unknown_field_set, pos)
940
941
942def _DecodeUnknownField(buffer, pos, wire_type):
943 """Decode a unknown field. Returns the UnknownField and new position."""
944
945 if wire_type == wire_format.WIRETYPE_VARINT:
946 (data, pos) = _DecodeVarint(buffer, pos)
947 elif wire_type == wire_format.WIRETYPE_FIXED64:
948 (data, pos) = _DecodeFixed64(buffer, pos)
949 elif wire_type == wire_format.WIRETYPE_FIXED32:
950 (data, pos) = _DecodeFixed32(buffer, pos)
951 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
952 (size, pos) = _DecodeVarint(buffer, pos)
953 data = buffer[pos:pos+size].tobytes()
954 pos += size
955 elif wire_type == wire_format.WIRETYPE_START_GROUP:
956 (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
957 elif wire_type == wire_format.WIRETYPE_END_GROUP:
958 return (0, -1)
959 else:
960 raise _DecodeError('Wrong wire type in tag.')
961
962 return (data, pos)
963
964
965def _EndGroup(buffer, pos, end):
966 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
967
968 return -1
969
970
971def _SkipFixed32(buffer, pos, end):
972 """Skip a fixed32 value. Returns the new position."""
973
974 pos += 4
975 if pos > end:
976 raise _DecodeError('Truncated message.')
977 return pos
978
979
980def _DecodeFixed32(buffer, pos):
981 """Decode a fixed32."""
982
983 new_pos = pos + 4
984 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
985
986
987def _RaiseInvalidWireType(buffer, pos, end):
988 """Skip function for unknown wire types. Raises an exception."""
989
990 raise _DecodeError('Tag had invalid wire type.')
991
992def _FieldSkipper():
993 """Constructs the SkipField function."""
994
995 WIRETYPE_TO_SKIPPER = [
996 _SkipVarint,
997 _SkipFixed64,
998 _SkipLengthDelimited,
999 _SkipGroup,
1000 _EndGroup,
1001 _SkipFixed32,
1002 _RaiseInvalidWireType,
1003 _RaiseInvalidWireType,
1004 ]
1005
1006 wiretype_mask = wire_format.TAG_TYPE_MASK
1007
1008 def SkipField(buffer, pos, end, tag_bytes):
1009 """Skips a field with the specified tag.
1010
1011 |pos| should point to the byte immediately after the tag.
1012
1013 Returns:
1014 The new position (after the tag value), or -1 if the tag is an end-group
1015 tag (in which case the calling loop should break).
1016 """
1017
1018 # The wire type is always in the first byte since varints are little-endian.
1019 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1020 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1021
1022 return SkipField
1023
1024SkipField = _FieldSkipper()