Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 15%
527 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:30 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 07:30 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31"""Code for decoding protocol buffer primitives.
33This code is very similar to encoder.py -- read the docs for that module first.
35A "decoder" is a function with the signature:
36 Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38 buffer: The string containing the encoded message.
39 pos: The current position in the string.
40 end: The position in the string where the current message ends. May be
41 less than len(buffer) if we're reading a sub-message.
42 message: The message object into which we're parsing.
43 field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position. A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
48Note that decoders may throw any of the following:
49 IndexError: Indicates a truncated message.
50 struct.error: Unpacking of a fixed-width field failed.
51 message.DecodeError: Other errors.
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking: it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
61Decoders are constructed using decoder constructors with the signature:
62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64 field_number: The field number of the field we want to decode.
65 is_repeated: Is the field a repeated field? (bool)
66 is_packed: Is the field a packed field? (bool)
67 key: The key to use when looking up the field within field_dict.
68 (This is actually the FieldDescriptor but nothing in this
69 file should depend on that.)
70 new_default: A function which takes a message object as a parameter and
71 returns a new instance of the default value for this field.
72 (This is called for repeated fields and sub-messages, when an
73 instance does not already exist.)
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
81__author__ = 'kenton@google.com (Kenton Varda)'
83import math
84import struct
86from google.protobuf.internal import containers
87from google.protobuf.internal import encoder
88from google.protobuf.internal import wire_format
89from google.protobuf import message
92# This is not for optimization, but rather to avoid conflicts with local
93# variables named "message".
94_DecodeError = message.DecodeError
97def _VarintDecoder(mask, result_type):
98 """Return an encoder for a basic varint value (does not include tag).
100 Decoded values will be bitwise-anded with the given mask before being
101 returned, e.g. to limit them to 32 bits. The returned decoder does not
102 take the usual "end" parameter -- the caller is expected to do bounds checking
103 after the fact (often the caller can defer such checking until later). The
104 decoder returns a (value, new_pos) pair.
105 """
107 def DecodeVarint(buffer, pos):
108 result = 0
109 shift = 0
110 while 1:
111 b = buffer[pos]
112 result |= ((b & 0x7f) << shift)
113 pos += 1
114 if not (b & 0x80):
115 result &= mask
116 result = result_type(result)
117 return (result, pos)
118 shift += 7
119 if shift >= 64:
120 raise _DecodeError('Too many bytes when decoding varint.')
121 return DecodeVarint
124def _SignedVarintDecoder(bits, result_type):
125 """Like _VarintDecoder() but decodes signed values."""
127 signbit = 1 << (bits - 1)
128 mask = (1 << bits) - 1
130 def DecodeVarint(buffer, pos):
131 result = 0
132 shift = 0
133 while 1:
134 b = buffer[pos]
135 result |= ((b & 0x7f) << shift)
136 pos += 1
137 if not (b & 0x80):
138 result &= mask
139 result = (result ^ signbit) - signbit
140 result = result_type(result)
141 return (result, pos)
142 shift += 7
143 if shift >= 64:
144 raise _DecodeError('Too many bytes when decoding varint.')
145 return DecodeVarint
147# All 32-bit and 64-bit values are represented as int.
148_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
149_DecodeSignedVarint = _SignedVarintDecoder(64, int)
151# Use these versions for values which must be limited to 32 bits.
152_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
153_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
156def ReadTag(buffer, pos):
157 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
159 We return the raw bytes of the tag rather than decoding them. The raw
160 bytes can then be used to look up the proper decoder. This effectively allows
161 us to trade some work that would be done in pure-python (decoding a varint)
162 for work that is done in C (searching for a byte string in a hash table).
163 In a low-level language it would be much cheaper to decode the varint and
164 use that, but not in Python.
166 Args:
167 buffer: memoryview object of the encoded bytes
168 pos: int of the current position to start from
170 Returns:
171 Tuple[bytes, int] of the tag data and new position.
172 """
173 start = pos
174 while buffer[pos] & 0x80:
175 pos += 1
176 pos += 1
178 tag_bytes = buffer[start:pos].tobytes()
179 return tag_bytes, pos
182# --------------------------------------------------------------------
185def _SimpleDecoder(wire_type, decode_value):
186 """Return a constructor for a decoder for fields of a particular type.
188 Args:
189 wire_type: The field's wire type.
190 decode_value: A function which decodes an individual value, e.g.
191 _DecodeVarint()
192 """
194 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
195 clear_if_default=False):
196 if is_packed:
197 local_DecodeVarint = _DecodeVarint
198 def DecodePackedField(buffer, pos, end, message, field_dict):
199 value = field_dict.get(key)
200 if value is None:
201 value = field_dict.setdefault(key, new_default(message))
202 (endpoint, pos) = local_DecodeVarint(buffer, pos)
203 endpoint += pos
204 if endpoint > end:
205 raise _DecodeError('Truncated message.')
206 while pos < endpoint:
207 (element, pos) = decode_value(buffer, pos)
208 value.append(element)
209 if pos > endpoint:
210 del value[-1] # Discard corrupt value.
211 raise _DecodeError('Packed element was truncated.')
212 return pos
213 return DecodePackedField
214 elif is_repeated:
215 tag_bytes = encoder.TagBytes(field_number, wire_type)
216 tag_len = len(tag_bytes)
217 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218 value = field_dict.get(key)
219 if value is None:
220 value = field_dict.setdefault(key, new_default(message))
221 while 1:
222 (element, new_pos) = decode_value(buffer, pos)
223 value.append(element)
224 # Predict that the next tag is another copy of the same repeated
225 # field.
226 pos = new_pos + tag_len
227 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
228 # Prediction failed. Return.
229 if new_pos > end:
230 raise _DecodeError('Truncated message.')
231 return new_pos
232 return DecodeRepeatedField
233 else:
234 def DecodeField(buffer, pos, end, message, field_dict):
235 (new_value, pos) = decode_value(buffer, pos)
236 if pos > end:
237 raise _DecodeError('Truncated message.')
238 if clear_if_default and not new_value:
239 field_dict.pop(key, None)
240 else:
241 field_dict[key] = new_value
242 return pos
243 return DecodeField
245 return SpecificDecoder
248def _ModifiedDecoder(wire_type, decode_value, modify_value):
249 """Like SimpleDecoder but additionally invokes modify_value on every value
250 before storing it. Usually modify_value is ZigZagDecode.
251 """
253 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
254 # not enough to make a significant difference.
256 def InnerDecode(buffer, pos):
257 (result, new_pos) = decode_value(buffer, pos)
258 return (modify_value(result), new_pos)
259 return _SimpleDecoder(wire_type, InnerDecode)
262def _StructPackDecoder(wire_type, format):
263 """Return a constructor for a decoder for a fixed-width field.
265 Args:
266 wire_type: The field's wire type.
267 format: The format string to pass to struct.unpack().
268 """
270 value_size = struct.calcsize(format)
271 local_unpack = struct.unpack
273 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
274 # not enough to make a significant difference.
276 # Note that we expect someone up-stack to catch struct.error and convert
277 # it to _DecodeError -- this way we don't have to set up exception-
278 # handling blocks every time we parse one value.
280 def InnerDecode(buffer, pos):
281 new_pos = pos + value_size
282 result = local_unpack(format, buffer[pos:new_pos])[0]
283 return (result, new_pos)
284 return _SimpleDecoder(wire_type, InnerDecode)
287def _FloatDecoder():
288 """Returns a decoder for a float field.
290 This code works around a bug in struct.unpack for non-finite 32-bit
291 floating-point values.
292 """
294 local_unpack = struct.unpack
296 def InnerDecode(buffer, pos):
297 """Decode serialized float to a float and new position.
299 Args:
300 buffer: memoryview of the serialized bytes
301 pos: int, position in the memory view to start at.
303 Returns:
304 Tuple[float, int] of the deserialized float value and new position
305 in the serialized data.
306 """
307 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
308 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
309 new_pos = pos + 4
310 float_bytes = buffer[pos:new_pos].tobytes()
312 # If this value has all its exponent bits set, then it's non-finite.
313 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
314 # To avoid that, we parse it specially.
315 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
316 # If at least one significand bit is set...
317 if float_bytes[0:3] != b'\x00\x00\x80':
318 return (math.nan, new_pos)
319 # If sign bit is set...
320 if float_bytes[3:4] == b'\xFF':
321 return (-math.inf, new_pos)
322 return (math.inf, new_pos)
324 # Note that we expect someone up-stack to catch struct.error and convert
325 # it to _DecodeError -- this way we don't have to set up exception-
326 # handling blocks every time we parse one value.
327 result = local_unpack('<f', float_bytes)[0]
328 return (result, new_pos)
329 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
332def _DoubleDecoder():
333 """Returns a decoder for a double field.
335 This code works around a bug in struct.unpack for not-a-number.
336 """
338 local_unpack = struct.unpack
340 def InnerDecode(buffer, pos):
341 """Decode serialized double to a double and new position.
343 Args:
344 buffer: memoryview of the serialized bytes.
345 pos: int, position in the memory view to start at.
347 Returns:
348 Tuple[float, int] of the decoded double value and new position
349 in the serialized data.
350 """
351 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
352 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
353 new_pos = pos + 8
354 double_bytes = buffer[pos:new_pos].tobytes()
356 # If this value has all its exponent bits set and at least one significand
357 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
358 # as inf or -inf. To avoid that, we treat it specially.
359 if ((double_bytes[7:8] in b'\x7F\xFF')
360 and (double_bytes[6:7] >= b'\xF0')
361 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
362 return (math.nan, new_pos)
364 # Note that we expect someone up-stack to catch struct.error and convert
365 # it to _DecodeError -- this way we don't have to set up exception-
366 # handling blocks every time we parse one value.
367 result = local_unpack('<d', double_bytes)[0]
368 return (result, new_pos)
369 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
372def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
373 clear_if_default=False):
374 """Returns a decoder for enum field."""
375 enum_type = key.enum_type
376 if is_packed:
377 local_DecodeVarint = _DecodeVarint
378 def DecodePackedField(buffer, pos, end, message, field_dict):
379 """Decode serialized packed enum to its value and a new position.
381 Args:
382 buffer: memoryview of the serialized bytes.
383 pos: int, position in the memory view to start at.
384 end: int, end position of serialized data
385 message: Message object to store unknown fields in
386 field_dict: Map[Descriptor, Any] to store decoded values in.
388 Returns:
389 int, new position in serialized data.
390 """
391 value = field_dict.get(key)
392 if value is None:
393 value = field_dict.setdefault(key, new_default(message))
394 (endpoint, pos) = local_DecodeVarint(buffer, pos)
395 endpoint += pos
396 if endpoint > end:
397 raise _DecodeError('Truncated message.')
398 while pos < endpoint:
399 value_start_pos = pos
400 (element, pos) = _DecodeSignedVarint32(buffer, pos)
401 # pylint: disable=protected-access
402 if element in enum_type.values_by_number:
403 value.append(element)
404 else:
405 if not message._unknown_fields:
406 message._unknown_fields = []
407 tag_bytes = encoder.TagBytes(field_number,
408 wire_format.WIRETYPE_VARINT)
410 message._unknown_fields.append(
411 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
412 if message._unknown_field_set is None:
413 message._unknown_field_set = containers.UnknownFieldSet()
414 message._unknown_field_set._add(
415 field_number, wire_format.WIRETYPE_VARINT, element)
416 # pylint: enable=protected-access
417 if pos > endpoint:
418 if element in enum_type.values_by_number:
419 del value[-1] # Discard corrupt value.
420 else:
421 del message._unknown_fields[-1]
422 # pylint: disable=protected-access
423 del message._unknown_field_set._values[-1]
424 # pylint: enable=protected-access
425 raise _DecodeError('Packed element was truncated.')
426 return pos
427 return DecodePackedField
428 elif is_repeated:
429 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
430 tag_len = len(tag_bytes)
431 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
432 """Decode serialized repeated enum to its value and a new position.
434 Args:
435 buffer: memoryview of the serialized bytes.
436 pos: int, position in the memory view to start at.
437 end: int, end position of serialized data
438 message: Message object to store unknown fields in
439 field_dict: Map[Descriptor, Any] to store decoded values in.
441 Returns:
442 int, new position in serialized data.
443 """
444 value = field_dict.get(key)
445 if value is None:
446 value = field_dict.setdefault(key, new_default(message))
447 while 1:
448 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
449 # pylint: disable=protected-access
450 if element in enum_type.values_by_number:
451 value.append(element)
452 else:
453 if not message._unknown_fields:
454 message._unknown_fields = []
455 message._unknown_fields.append(
456 (tag_bytes, buffer[pos:new_pos].tobytes()))
457 if message._unknown_field_set is None:
458 message._unknown_field_set = containers.UnknownFieldSet()
459 message._unknown_field_set._add(
460 field_number, wire_format.WIRETYPE_VARINT, element)
461 # pylint: enable=protected-access
462 # Predict that the next tag is another copy of the same repeated
463 # field.
464 pos = new_pos + tag_len
465 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
466 # Prediction failed. Return.
467 if new_pos > end:
468 raise _DecodeError('Truncated message.')
469 return new_pos
470 return DecodeRepeatedField
471 else:
472 def DecodeField(buffer, pos, end, message, field_dict):
473 """Decode serialized repeated enum to its value and a new position.
475 Args:
476 buffer: memoryview of the serialized bytes.
477 pos: int, position in the memory view to start at.
478 end: int, end position of serialized data
479 message: Message object to store unknown fields in
480 field_dict: Map[Descriptor, Any] to store decoded values in.
482 Returns:
483 int, new position in serialized data.
484 """
485 value_start_pos = pos
486 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
487 if pos > end:
488 raise _DecodeError('Truncated message.')
489 if clear_if_default and not enum_value:
490 field_dict.pop(key, None)
491 return pos
492 # pylint: disable=protected-access
493 if enum_value in enum_type.values_by_number:
494 field_dict[key] = enum_value
495 else:
496 if not message._unknown_fields:
497 message._unknown_fields = []
498 tag_bytes = encoder.TagBytes(field_number,
499 wire_format.WIRETYPE_VARINT)
500 message._unknown_fields.append(
501 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
502 if message._unknown_field_set is None:
503 message._unknown_field_set = containers.UnknownFieldSet()
504 message._unknown_field_set._add(
505 field_number, wire_format.WIRETYPE_VARINT, enum_value)
506 # pylint: enable=protected-access
507 return pos
508 return DecodeField
511# --------------------------------------------------------------------
514Int32Decoder = _SimpleDecoder(
515 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
517Int64Decoder = _SimpleDecoder(
518 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
520UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
521UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
523SInt32Decoder = _ModifiedDecoder(
524 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
525SInt64Decoder = _ModifiedDecoder(
526 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
528# Note that Python conveniently guarantees that when using the '<' prefix on
529# formats, they will also have the same size across all platforms (as opposed
530# to without the prefix, where their sizes depend on the C compiler's basic
531# type sizes).
532Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
533Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
534SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
535SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
536FloatDecoder = _FloatDecoder()
537DoubleDecoder = _DoubleDecoder()
539BoolDecoder = _ModifiedDecoder(
540 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
543def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
544 clear_if_default=False):
545 """Returns a decoder for a string field."""
547 local_DecodeVarint = _DecodeVarint
549 def _ConvertToUnicode(memview):
550 """Convert byte to unicode."""
551 byte_str = memview.tobytes()
552 try:
553 value = str(byte_str, 'utf-8')
554 except UnicodeDecodeError as e:
555 # add more information to the error message and re-raise it.
556 e.reason = '%s in field: %s' % (e, key.full_name)
557 raise
559 return value
561 assert not is_packed
562 if is_repeated:
563 tag_bytes = encoder.TagBytes(field_number,
564 wire_format.WIRETYPE_LENGTH_DELIMITED)
565 tag_len = len(tag_bytes)
566 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
567 value = field_dict.get(key)
568 if value is None:
569 value = field_dict.setdefault(key, new_default(message))
570 while 1:
571 (size, pos) = local_DecodeVarint(buffer, pos)
572 new_pos = pos + size
573 if new_pos > end:
574 raise _DecodeError('Truncated string.')
575 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
576 # Predict that the next tag is another copy of the same repeated field.
577 pos = new_pos + tag_len
578 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
579 # Prediction failed. Return.
580 return new_pos
581 return DecodeRepeatedField
582 else:
583 def DecodeField(buffer, pos, end, message, field_dict):
584 (size, pos) = local_DecodeVarint(buffer, pos)
585 new_pos = pos + size
586 if new_pos > end:
587 raise _DecodeError('Truncated string.')
588 if clear_if_default and not size:
589 field_dict.pop(key, None)
590 else:
591 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
592 return new_pos
593 return DecodeField
596def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
597 clear_if_default=False):
598 """Returns a decoder for a bytes field."""
600 local_DecodeVarint = _DecodeVarint
602 assert not is_packed
603 if is_repeated:
604 tag_bytes = encoder.TagBytes(field_number,
605 wire_format.WIRETYPE_LENGTH_DELIMITED)
606 tag_len = len(tag_bytes)
607 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
608 value = field_dict.get(key)
609 if value is None:
610 value = field_dict.setdefault(key, new_default(message))
611 while 1:
612 (size, pos) = local_DecodeVarint(buffer, pos)
613 new_pos = pos + size
614 if new_pos > end:
615 raise _DecodeError('Truncated string.')
616 value.append(buffer[pos:new_pos].tobytes())
617 # Predict that the next tag is another copy of the same repeated field.
618 pos = new_pos + tag_len
619 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
620 # Prediction failed. Return.
621 return new_pos
622 return DecodeRepeatedField
623 else:
624 def DecodeField(buffer, pos, end, message, field_dict):
625 (size, pos) = local_DecodeVarint(buffer, pos)
626 new_pos = pos + size
627 if new_pos > end:
628 raise _DecodeError('Truncated string.')
629 if clear_if_default and not size:
630 field_dict.pop(key, None)
631 else:
632 field_dict[key] = buffer[pos:new_pos].tobytes()
633 return new_pos
634 return DecodeField
637def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
638 """Returns a decoder for a group field."""
640 end_tag_bytes = encoder.TagBytes(field_number,
641 wire_format.WIRETYPE_END_GROUP)
642 end_tag_len = len(end_tag_bytes)
644 assert not is_packed
645 if is_repeated:
646 tag_bytes = encoder.TagBytes(field_number,
647 wire_format.WIRETYPE_START_GROUP)
648 tag_len = len(tag_bytes)
649 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
650 value = field_dict.get(key)
651 if value is None:
652 value = field_dict.setdefault(key, new_default(message))
653 while 1:
654 value = field_dict.get(key)
655 if value is None:
656 value = field_dict.setdefault(key, new_default(message))
657 # Read sub-message.
658 pos = value.add()._InternalParse(buffer, pos, end)
659 # Read end tag.
660 new_pos = pos+end_tag_len
661 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
662 raise _DecodeError('Missing group end tag.')
663 # Predict that the next tag is another copy of the same repeated field.
664 pos = new_pos + tag_len
665 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
666 # Prediction failed. Return.
667 return new_pos
668 return DecodeRepeatedField
669 else:
670 def DecodeField(buffer, pos, end, message, field_dict):
671 value = field_dict.get(key)
672 if value is None:
673 value = field_dict.setdefault(key, new_default(message))
674 # Read sub-message.
675 pos = value._InternalParse(buffer, pos, end)
676 # Read end tag.
677 new_pos = pos+end_tag_len
678 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
679 raise _DecodeError('Missing group end tag.')
680 return new_pos
681 return DecodeField
684def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
685 """Returns a decoder for a message field."""
687 local_DecodeVarint = _DecodeVarint
689 assert not is_packed
690 if is_repeated:
691 tag_bytes = encoder.TagBytes(field_number,
692 wire_format.WIRETYPE_LENGTH_DELIMITED)
693 tag_len = len(tag_bytes)
694 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
695 value = field_dict.get(key)
696 if value is None:
697 value = field_dict.setdefault(key, new_default(message))
698 while 1:
699 # Read length.
700 (size, pos) = local_DecodeVarint(buffer, pos)
701 new_pos = pos + size
702 if new_pos > end:
703 raise _DecodeError('Truncated message.')
704 # Read sub-message.
705 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
706 # The only reason _InternalParse would return early is if it
707 # encountered an end-group tag.
708 raise _DecodeError('Unexpected end-group tag.')
709 # Predict that the next tag is another copy of the same repeated field.
710 pos = new_pos + tag_len
711 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
712 # Prediction failed. Return.
713 return new_pos
714 return DecodeRepeatedField
715 else:
716 def DecodeField(buffer, pos, end, message, field_dict):
717 value = field_dict.get(key)
718 if value is None:
719 value = field_dict.setdefault(key, new_default(message))
720 # Read length.
721 (size, pos) = local_DecodeVarint(buffer, pos)
722 new_pos = pos + size
723 if new_pos > end:
724 raise _DecodeError('Truncated message.')
725 # Read sub-message.
726 if value._InternalParse(buffer, pos, new_pos) != new_pos:
727 # The only reason _InternalParse would return early is if it encountered
728 # an end-group tag.
729 raise _DecodeError('Unexpected end-group tag.')
730 return new_pos
731 return DecodeField
734# --------------------------------------------------------------------
736MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
738def MessageSetItemDecoder(descriptor):
739 """Returns a decoder for a MessageSet item.
741 The parameter is the message Descriptor.
743 The message set message looks like this:
744 message MessageSet {
745 repeated group Item = 1 {
746 required int32 type_id = 2;
747 required string message = 3;
748 }
749 }
750 """
752 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
753 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
754 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
756 local_ReadTag = ReadTag
757 local_DecodeVarint = _DecodeVarint
758 local_SkipField = SkipField
760 def DecodeItem(buffer, pos, end, message, field_dict):
761 """Decode serialized message set to its value and new position.
763 Args:
764 buffer: memoryview of the serialized bytes.
765 pos: int, position in the memory view to start at.
766 end: int, end position of serialized data
767 message: Message object to store unknown fields in
768 field_dict: Map[Descriptor, Any] to store decoded values in.
770 Returns:
771 int, new position in serialized data.
772 """
773 message_set_item_start = pos
774 type_id = -1
775 message_start = -1
776 message_end = -1
778 # Technically, type_id and message can appear in any order, so we need
779 # a little loop here.
780 while 1:
781 (tag_bytes, pos) = local_ReadTag(buffer, pos)
782 if tag_bytes == type_id_tag_bytes:
783 (type_id, pos) = local_DecodeVarint(buffer, pos)
784 elif tag_bytes == message_tag_bytes:
785 (size, message_start) = local_DecodeVarint(buffer, pos)
786 pos = message_end = message_start + size
787 elif tag_bytes == item_end_tag_bytes:
788 break
789 else:
790 pos = SkipField(buffer, pos, end, tag_bytes)
791 if pos == -1:
792 raise _DecodeError('Missing group end tag.')
794 if pos > end:
795 raise _DecodeError('Truncated message.')
797 if type_id == -1:
798 raise _DecodeError('MessageSet item missing type_id.')
799 if message_start == -1:
800 raise _DecodeError('MessageSet item missing message.')
802 extension = message.Extensions._FindExtensionByNumber(type_id)
803 # pylint: disable=protected-access
804 if extension is not None:
805 value = field_dict.get(extension)
806 if value is None:
807 message_type = extension.message_type
808 if not hasattr(message_type, '_concrete_class'):
809 message_factory.GetMessageClass(message_type)
810 value = field_dict.setdefault(
811 extension, message_type._concrete_class())
812 if value._InternalParse(buffer, message_start,message_end) != message_end:
813 # The only reason _InternalParse would return early is if it encountered
814 # an end-group tag.
815 raise _DecodeError('Unexpected end-group tag.')
816 else:
817 if not message._unknown_fields:
818 message._unknown_fields = []
819 message._unknown_fields.append(
820 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
821 if message._unknown_field_set is None:
822 message._unknown_field_set = containers.UnknownFieldSet()
823 message._unknown_field_set._add(
824 type_id,
825 wire_format.WIRETYPE_LENGTH_DELIMITED,
826 buffer[message_start:message_end].tobytes())
827 # pylint: enable=protected-access
829 return pos
831 return DecodeItem
834def UnknownMessageSetItemDecoder():
835 """Returns a decoder for a Unknown MessageSet item."""
837 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
838 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
839 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
841 def DecodeUnknownItem(buffer):
842 pos = 0
843 end = len(buffer)
844 message_start = -1
845 message_end = -1
846 while 1:
847 (tag_bytes, pos) = ReadTag(buffer, pos)
848 if tag_bytes == type_id_tag_bytes:
849 (type_id, pos) = _DecodeVarint(buffer, pos)
850 elif tag_bytes == message_tag_bytes:
851 (size, message_start) = _DecodeVarint(buffer, pos)
852 pos = message_end = message_start + size
853 elif tag_bytes == item_end_tag_bytes:
854 break
855 else:
856 pos = SkipField(buffer, pos, end, tag_bytes)
857 if pos == -1:
858 raise _DecodeError('Missing group end tag.')
860 if pos > end:
861 raise _DecodeError('Truncated message.')
863 if type_id == -1:
864 raise _DecodeError('MessageSet item missing type_id.')
865 if message_start == -1:
866 raise _DecodeError('MessageSet item missing message.')
868 return (type_id, buffer[message_start:message_end].tobytes())
870 return DecodeUnknownItem
872# --------------------------------------------------------------------
874def MapDecoder(field_descriptor, new_default, is_message_map):
875 """Returns a decoder for a map field."""
877 key = field_descriptor
878 tag_bytes = encoder.TagBytes(field_descriptor.number,
879 wire_format.WIRETYPE_LENGTH_DELIMITED)
880 tag_len = len(tag_bytes)
881 local_DecodeVarint = _DecodeVarint
882 # Can't read _concrete_class yet; might not be initialized.
883 message_type = field_descriptor.message_type
885 def DecodeMap(buffer, pos, end, message, field_dict):
886 submsg = message_type._concrete_class()
887 value = field_dict.get(key)
888 if value is None:
889 value = field_dict.setdefault(key, new_default(message))
890 while 1:
891 # Read length.
892 (size, pos) = local_DecodeVarint(buffer, pos)
893 new_pos = pos + size
894 if new_pos > end:
895 raise _DecodeError('Truncated message.')
896 # Read sub-message.
897 submsg.Clear()
898 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
899 # The only reason _InternalParse would return early is if it
900 # encountered an end-group tag.
901 raise _DecodeError('Unexpected end-group tag.')
903 if is_message_map:
904 value[submsg.key].CopyFrom(submsg.value)
905 else:
906 value[submsg.key] = submsg.value
908 # Predict that the next tag is another copy of the same repeated field.
909 pos = new_pos + tag_len
910 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
911 # Prediction failed. Return.
912 return new_pos
914 return DecodeMap
916# --------------------------------------------------------------------
917# Optimization is not as heavy here because calls to SkipField() are rare,
918# except for handling end-group tags.
920def _SkipVarint(buffer, pos, end):
921 """Skip a varint value. Returns the new position."""
922 # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
923 # With this code, ord(b'') raises TypeError. Both are handled in
924 # python_message.py to generate a 'Truncated message' error.
925 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
926 pos += 1
927 pos += 1
928 if pos > end:
929 raise _DecodeError('Truncated message.')
930 return pos
932def _SkipFixed64(buffer, pos, end):
933 """Skip a fixed64 value. Returns the new position."""
935 pos += 8
936 if pos > end:
937 raise _DecodeError('Truncated message.')
938 return pos
941def _DecodeFixed64(buffer, pos):
942 """Decode a fixed64."""
943 new_pos = pos + 8
944 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
947def _SkipLengthDelimited(buffer, pos, end):
948 """Skip a length-delimited value. Returns the new position."""
950 (size, pos) = _DecodeVarint(buffer, pos)
951 pos += size
952 if pos > end:
953 raise _DecodeError('Truncated message.')
954 return pos
957def _SkipGroup(buffer, pos, end):
958 """Skip sub-group. Returns the new position."""
960 while 1:
961 (tag_bytes, pos) = ReadTag(buffer, pos)
962 new_pos = SkipField(buffer, pos, end, tag_bytes)
963 if new_pos == -1:
964 return pos
965 pos = new_pos
968def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
969 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
971 unknown_field_set = containers.UnknownFieldSet()
972 while end_pos is None or pos < end_pos:
973 (tag_bytes, pos) = ReadTag(buffer, pos)
974 (tag, _) = _DecodeVarint(tag_bytes, 0)
975 field_number, wire_type = wire_format.UnpackTag(tag)
976 if wire_type == wire_format.WIRETYPE_END_GROUP:
977 break
978 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
979 # pylint: disable=protected-access
980 unknown_field_set._add(field_number, wire_type, data)
982 return (unknown_field_set, pos)
985def _DecodeUnknownField(buffer, pos, wire_type):
986 """Decode a unknown field. Returns the UnknownField and new position."""
988 if wire_type == wire_format.WIRETYPE_VARINT:
989 (data, pos) = _DecodeVarint(buffer, pos)
990 elif wire_type == wire_format.WIRETYPE_FIXED64:
991 (data, pos) = _DecodeFixed64(buffer, pos)
992 elif wire_type == wire_format.WIRETYPE_FIXED32:
993 (data, pos) = _DecodeFixed32(buffer, pos)
994 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
995 (size, pos) = _DecodeVarint(buffer, pos)
996 data = buffer[pos:pos+size].tobytes()
997 pos += size
998 elif wire_type == wire_format.WIRETYPE_START_GROUP:
999 (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
1000 elif wire_type == wire_format.WIRETYPE_END_GROUP:
1001 return (0, -1)
1002 else:
1003 raise _DecodeError('Wrong wire type in tag.')
1005 return (data, pos)
1008def _EndGroup(buffer, pos, end):
1009 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
1011 return -1
1014def _SkipFixed32(buffer, pos, end):
1015 """Skip a fixed32 value. Returns the new position."""
1017 pos += 4
1018 if pos > end:
1019 raise _DecodeError('Truncated message.')
1020 return pos
1023def _DecodeFixed32(buffer, pos):
1024 """Decode a fixed32."""
1026 new_pos = pos + 4
1027 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1030def _RaiseInvalidWireType(buffer, pos, end):
1031 """Skip function for unknown wire types. Raises an exception."""
1033 raise _DecodeError('Tag had invalid wire type.')
1035def _FieldSkipper():
1036 """Constructs the SkipField function."""
1038 WIRETYPE_TO_SKIPPER = [
1039 _SkipVarint,
1040 _SkipFixed64,
1041 _SkipLengthDelimited,
1042 _SkipGroup,
1043 _EndGroup,
1044 _SkipFixed32,
1045 _RaiseInvalidWireType,
1046 _RaiseInvalidWireType,
1047 ]
1049 wiretype_mask = wire_format.TAG_TYPE_MASK
1051 def SkipField(buffer, pos, end, tag_bytes):
1052 """Skips a field with the specified tag.
1054 |pos| should point to the byte immediately after the tag.
1056 Returns:
1057 The new position (after the tag value), or -1 if the tag is an end-group
1058 tag (in which case the calling loop should break).
1059 """
1061 # The wire type is always in the first byte since varints are little-endian.
1062 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1063 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1065 return SkipField
1067SkipField = _FieldSkipper()