Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 15%
527 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:40 +0000
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
8"""Code for decoding protocol buffer primitives.
10This code is very similar to encoder.py -- read the docs for that module first.
12A "decoder" is a function with the signature:
13 Decode(buffer, pos, end, message, field_dict)
14The arguments are:
15 buffer: The string containing the encoded message.
16 pos: The current position in the string.
17 end: The position in the string where the current message ends. May be
18 less than len(buffer) if we're reading a sub-message.
19 message: The message object into which we're parsing.
20 field_dict: message._fields (avoids a hashtable lookup).
21The decoder reads the field and stores it into field_dict, returning the new
22buffer position. A decoder for a repeated field may proactively decode all of
23the elements of that field, if they appear consecutively.
25Note that decoders may throw any of the following:
26 IndexError: Indicates a truncated message.
27 struct.error: Unpacking of a fixed-width field failed.
28 message.DecodeError: Other errors.
30Decoders are expected to raise an exception if they are called with pos > end.
31This allows callers to be lax about bounds checking: it's fineto read past
32"end" as long as you are sure that someone else will notice and throw an
33exception later on.
35Something up the call stack is expected to catch IndexError and struct.error
36and convert them to message.DecodeError.
38Decoders are constructed using decoder constructors with the signature:
39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
40The arguments are:
41 field_number: The field number of the field we want to decode.
42 is_repeated: Is the field a repeated field? (bool)
43 is_packed: Is the field a packed field? (bool)
44 key: The key to use when looking up the field within field_dict.
45 (This is actually the FieldDescriptor but nothing in this
46 file should depend on that.)
47 new_default: A function which takes a message object as a parameter and
48 returns a new instance of the default value for this field.
49 (This is called for repeated fields and sub-messages, when an
50 instance does not already exist.)
52As with encoders, we define a decoder constructor for every type of field.
53Then, for every field of every message class we construct an actual decoder.
54That decoder goes into a dict indexed by tag, so when we decode a message
55we repeatedly read a tag, look up the corresponding decoder, and invoke it.
56"""
58__author__ = 'kenton@google.com (Kenton Varda)'
60import math
61import struct
63from google.protobuf.internal import containers
64from google.protobuf.internal import encoder
65from google.protobuf.internal import wire_format
66from google.protobuf import message
69# This is not for optimization, but rather to avoid conflicts with local
70# variables named "message".
71_DecodeError = message.DecodeError
74def _VarintDecoder(mask, result_type):
75 """Return an encoder for a basic varint value (does not include tag).
77 Decoded values will be bitwise-anded with the given mask before being
78 returned, e.g. to limit them to 32 bits. The returned decoder does not
79 take the usual "end" parameter -- the caller is expected to do bounds checking
80 after the fact (often the caller can defer such checking until later). The
81 decoder returns a (value, new_pos) pair.
82 """
84 def DecodeVarint(buffer, pos):
85 result = 0
86 shift = 0
87 while 1:
88 b = buffer[pos]
89 result |= ((b & 0x7f) << shift)
90 pos += 1
91 if not (b & 0x80):
92 result &= mask
93 result = result_type(result)
94 return (result, pos)
95 shift += 7
96 if shift >= 64:
97 raise _DecodeError('Too many bytes when decoding varint.')
98 return DecodeVarint
101def _SignedVarintDecoder(bits, result_type):
102 """Like _VarintDecoder() but decodes signed values."""
104 signbit = 1 << (bits - 1)
105 mask = (1 << bits) - 1
107 def DecodeVarint(buffer, pos):
108 result = 0
109 shift = 0
110 while 1:
111 b = buffer[pos]
112 result |= ((b & 0x7f) << shift)
113 pos += 1
114 if not (b & 0x80):
115 result &= mask
116 result = (result ^ signbit) - signbit
117 result = result_type(result)
118 return (result, pos)
119 shift += 7
120 if shift >= 64:
121 raise _DecodeError('Too many bytes when decoding varint.')
122 return DecodeVarint
124# All 32-bit and 64-bit values are represented as int.
125_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
126_DecodeSignedVarint = _SignedVarintDecoder(64, int)
128# Use these versions for values which must be limited to 32 bits.
129_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
130_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
133def ReadTag(buffer, pos):
134 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
136 We return the raw bytes of the tag rather than decoding them. The raw
137 bytes can then be used to look up the proper decoder. This effectively allows
138 us to trade some work that would be done in pure-python (decoding a varint)
139 for work that is done in C (searching for a byte string in a hash table).
140 In a low-level language it would be much cheaper to decode the varint and
141 use that, but not in Python.
143 Args:
144 buffer: memoryview object of the encoded bytes
145 pos: int of the current position to start from
147 Returns:
148 Tuple[bytes, int] of the tag data and new position.
149 """
150 start = pos
151 while buffer[pos] & 0x80:
152 pos += 1
153 pos += 1
155 tag_bytes = buffer[start:pos].tobytes()
156 return tag_bytes, pos
159# --------------------------------------------------------------------
162def _SimpleDecoder(wire_type, decode_value):
163 """Return a constructor for a decoder for fields of a particular type.
165 Args:
166 wire_type: The field's wire type.
167 decode_value: A function which decodes an individual value, e.g.
168 _DecodeVarint()
169 """
171 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172 clear_if_default=False):
173 if is_packed:
174 local_DecodeVarint = _DecodeVarint
175 def DecodePackedField(buffer, pos, end, message, field_dict):
176 value = field_dict.get(key)
177 if value is None:
178 value = field_dict.setdefault(key, new_default(message))
179 (endpoint, pos) = local_DecodeVarint(buffer, pos)
180 endpoint += pos
181 if endpoint > end:
182 raise _DecodeError('Truncated message.')
183 while pos < endpoint:
184 (element, pos) = decode_value(buffer, pos)
185 value.append(element)
186 if pos > endpoint:
187 del value[-1] # Discard corrupt value.
188 raise _DecodeError('Packed element was truncated.')
189 return pos
190 return DecodePackedField
191 elif is_repeated:
192 tag_bytes = encoder.TagBytes(field_number, wire_type)
193 tag_len = len(tag_bytes)
194 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
195 value = field_dict.get(key)
196 if value is None:
197 value = field_dict.setdefault(key, new_default(message))
198 while 1:
199 (element, new_pos) = decode_value(buffer, pos)
200 value.append(element)
201 # Predict that the next tag is another copy of the same repeated
202 # field.
203 pos = new_pos + tag_len
204 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
205 # Prediction failed. Return.
206 if new_pos > end:
207 raise _DecodeError('Truncated message.')
208 return new_pos
209 return DecodeRepeatedField
210 else:
211 def DecodeField(buffer, pos, end, message, field_dict):
212 (new_value, pos) = decode_value(buffer, pos)
213 if pos > end:
214 raise _DecodeError('Truncated message.')
215 if clear_if_default and not new_value:
216 field_dict.pop(key, None)
217 else:
218 field_dict[key] = new_value
219 return pos
220 return DecodeField
222 return SpecificDecoder
225def _ModifiedDecoder(wire_type, decode_value, modify_value):
226 """Like SimpleDecoder but additionally invokes modify_value on every value
227 before storing it. Usually modify_value is ZigZagDecode.
228 """
230 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
231 # not enough to make a significant difference.
233 def InnerDecode(buffer, pos):
234 (result, new_pos) = decode_value(buffer, pos)
235 return (modify_value(result), new_pos)
236 return _SimpleDecoder(wire_type, InnerDecode)
239def _StructPackDecoder(wire_type, format):
240 """Return a constructor for a decoder for a fixed-width field.
242 Args:
243 wire_type: The field's wire type.
244 format: The format string to pass to struct.unpack().
245 """
247 value_size = struct.calcsize(format)
248 local_unpack = struct.unpack
250 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
251 # not enough to make a significant difference.
253 # Note that we expect someone up-stack to catch struct.error and convert
254 # it to _DecodeError -- this way we don't have to set up exception-
255 # handling blocks every time we parse one value.
257 def InnerDecode(buffer, pos):
258 new_pos = pos + value_size
259 result = local_unpack(format, buffer[pos:new_pos])[0]
260 return (result, new_pos)
261 return _SimpleDecoder(wire_type, InnerDecode)
264def _FloatDecoder():
265 """Returns a decoder for a float field.
267 This code works around a bug in struct.unpack for non-finite 32-bit
268 floating-point values.
269 """
271 local_unpack = struct.unpack
273 def InnerDecode(buffer, pos):
274 """Decode serialized float to a float and new position.
276 Args:
277 buffer: memoryview of the serialized bytes
278 pos: int, position in the memory view to start at.
280 Returns:
281 Tuple[float, int] of the deserialized float value and new position
282 in the serialized data.
283 """
284 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
285 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
286 new_pos = pos + 4
287 float_bytes = buffer[pos:new_pos].tobytes()
289 # If this value has all its exponent bits set, then it's non-finite.
290 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
291 # To avoid that, we parse it specially.
292 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
293 # If at least one significand bit is set...
294 if float_bytes[0:3] != b'\x00\x00\x80':
295 return (math.nan, new_pos)
296 # If sign bit is set...
297 if float_bytes[3:4] == b'\xFF':
298 return (-math.inf, new_pos)
299 return (math.inf, new_pos)
301 # Note that we expect someone up-stack to catch struct.error and convert
302 # it to _DecodeError -- this way we don't have to set up exception-
303 # handling blocks every time we parse one value.
304 result = local_unpack('<f', float_bytes)[0]
305 return (result, new_pos)
306 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
309def _DoubleDecoder():
310 """Returns a decoder for a double field.
312 This code works around a bug in struct.unpack for not-a-number.
313 """
315 local_unpack = struct.unpack
317 def InnerDecode(buffer, pos):
318 """Decode serialized double to a double and new position.
320 Args:
321 buffer: memoryview of the serialized bytes.
322 pos: int, position in the memory view to start at.
324 Returns:
325 Tuple[float, int] of the decoded double value and new position
326 in the serialized data.
327 """
328 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
329 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
330 new_pos = pos + 8
331 double_bytes = buffer[pos:new_pos].tobytes()
333 # If this value has all its exponent bits set and at least one significand
334 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
335 # as inf or -inf. To avoid that, we treat it specially.
336 if ((double_bytes[7:8] in b'\x7F\xFF')
337 and (double_bytes[6:7] >= b'\xF0')
338 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
339 return (math.nan, new_pos)
341 # Note that we expect someone up-stack to catch struct.error and convert
342 # it to _DecodeError -- this way we don't have to set up exception-
343 # handling blocks every time we parse one value.
344 result = local_unpack('<d', double_bytes)[0]
345 return (result, new_pos)
346 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
349def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
350 clear_if_default=False):
351 """Returns a decoder for enum field."""
352 enum_type = key.enum_type
353 if is_packed:
354 local_DecodeVarint = _DecodeVarint
355 def DecodePackedField(buffer, pos, end, message, field_dict):
356 """Decode serialized packed enum to its value and a new position.
358 Args:
359 buffer: memoryview of the serialized bytes.
360 pos: int, position in the memory view to start at.
361 end: int, end position of serialized data
362 message: Message object to store unknown fields in
363 field_dict: Map[Descriptor, Any] to store decoded values in.
365 Returns:
366 int, new position in serialized data.
367 """
368 value = field_dict.get(key)
369 if value is None:
370 value = field_dict.setdefault(key, new_default(message))
371 (endpoint, pos) = local_DecodeVarint(buffer, pos)
372 endpoint += pos
373 if endpoint > end:
374 raise _DecodeError('Truncated message.')
375 while pos < endpoint:
376 value_start_pos = pos
377 (element, pos) = _DecodeSignedVarint32(buffer, pos)
378 # pylint: disable=protected-access
379 if element in enum_type.values_by_number:
380 value.append(element)
381 else:
382 if not message._unknown_fields:
383 message._unknown_fields = []
384 tag_bytes = encoder.TagBytes(field_number,
385 wire_format.WIRETYPE_VARINT)
387 message._unknown_fields.append(
388 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
389 if message._unknown_field_set is None:
390 message._unknown_field_set = containers.UnknownFieldSet()
391 message._unknown_field_set._add(
392 field_number, wire_format.WIRETYPE_VARINT, element)
393 # pylint: enable=protected-access
394 if pos > endpoint:
395 if element in enum_type.values_by_number:
396 del value[-1] # Discard corrupt value.
397 else:
398 del message._unknown_fields[-1]
399 # pylint: disable=protected-access
400 del message._unknown_field_set._values[-1]
401 # pylint: enable=protected-access
402 raise _DecodeError('Packed element was truncated.')
403 return pos
404 return DecodePackedField
405 elif is_repeated:
406 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
407 tag_len = len(tag_bytes)
408 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
409 """Decode serialized repeated enum to its value and a new position.
411 Args:
412 buffer: memoryview of the serialized bytes.
413 pos: int, position in the memory view to start at.
414 end: int, end position of serialized data
415 message: Message object to store unknown fields in
416 field_dict: Map[Descriptor, Any] to store decoded values in.
418 Returns:
419 int, new position in serialized data.
420 """
421 value = field_dict.get(key)
422 if value is None:
423 value = field_dict.setdefault(key, new_default(message))
424 while 1:
425 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
426 # pylint: disable=protected-access
427 if element in enum_type.values_by_number:
428 value.append(element)
429 else:
430 if not message._unknown_fields:
431 message._unknown_fields = []
432 message._unknown_fields.append(
433 (tag_bytes, buffer[pos:new_pos].tobytes()))
434 if message._unknown_field_set is None:
435 message._unknown_field_set = containers.UnknownFieldSet()
436 message._unknown_field_set._add(
437 field_number, wire_format.WIRETYPE_VARINT, element)
438 # pylint: enable=protected-access
439 # Predict that the next tag is another copy of the same repeated
440 # field.
441 pos = new_pos + tag_len
442 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
443 # Prediction failed. Return.
444 if new_pos > end:
445 raise _DecodeError('Truncated message.')
446 return new_pos
447 return DecodeRepeatedField
448 else:
449 def DecodeField(buffer, pos, end, message, field_dict):
450 """Decode serialized repeated enum to its value and a new position.
452 Args:
453 buffer: memoryview of the serialized bytes.
454 pos: int, position in the memory view to start at.
455 end: int, end position of serialized data
456 message: Message object to store unknown fields in
457 field_dict: Map[Descriptor, Any] to store decoded values in.
459 Returns:
460 int, new position in serialized data.
461 """
462 value_start_pos = pos
463 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
464 if pos > end:
465 raise _DecodeError('Truncated message.')
466 if clear_if_default and not enum_value:
467 field_dict.pop(key, None)
468 return pos
469 # pylint: disable=protected-access
470 if enum_value in enum_type.values_by_number:
471 field_dict[key] = enum_value
472 else:
473 if not message._unknown_fields:
474 message._unknown_fields = []
475 tag_bytes = encoder.TagBytes(field_number,
476 wire_format.WIRETYPE_VARINT)
477 message._unknown_fields.append(
478 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
479 if message._unknown_field_set is None:
480 message._unknown_field_set = containers.UnknownFieldSet()
481 message._unknown_field_set._add(
482 field_number, wire_format.WIRETYPE_VARINT, enum_value)
483 # pylint: enable=protected-access
484 return pos
485 return DecodeField
488# --------------------------------------------------------------------
491Int32Decoder = _SimpleDecoder(
492 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
494Int64Decoder = _SimpleDecoder(
495 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
497UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
498UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
500SInt32Decoder = _ModifiedDecoder(
501 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
502SInt64Decoder = _ModifiedDecoder(
503 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
505# Note that Python conveniently guarantees that when using the '<' prefix on
506# formats, they will also have the same size across all platforms (as opposed
507# to without the prefix, where their sizes depend on the C compiler's basic
508# type sizes).
509Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
510Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
511SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
512SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
513FloatDecoder = _FloatDecoder()
514DoubleDecoder = _DoubleDecoder()
516BoolDecoder = _ModifiedDecoder(
517 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
520def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
521 clear_if_default=False):
522 """Returns a decoder for a string field."""
524 local_DecodeVarint = _DecodeVarint
526 def _ConvertToUnicode(memview):
527 """Convert byte to unicode."""
528 byte_str = memview.tobytes()
529 try:
530 value = str(byte_str, 'utf-8')
531 except UnicodeDecodeError as e:
532 # add more information to the error message and re-raise it.
533 e.reason = '%s in field: %s' % (e, key.full_name)
534 raise
536 return value
538 assert not is_packed
539 if is_repeated:
540 tag_bytes = encoder.TagBytes(field_number,
541 wire_format.WIRETYPE_LENGTH_DELIMITED)
542 tag_len = len(tag_bytes)
543 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
544 value = field_dict.get(key)
545 if value is None:
546 value = field_dict.setdefault(key, new_default(message))
547 while 1:
548 (size, pos) = local_DecodeVarint(buffer, pos)
549 new_pos = pos + size
550 if new_pos > end:
551 raise _DecodeError('Truncated string.')
552 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
553 # Predict that the next tag is another copy of the same repeated field.
554 pos = new_pos + tag_len
555 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
556 # Prediction failed. Return.
557 return new_pos
558 return DecodeRepeatedField
559 else:
560 def DecodeField(buffer, pos, end, message, field_dict):
561 (size, pos) = local_DecodeVarint(buffer, pos)
562 new_pos = pos + size
563 if new_pos > end:
564 raise _DecodeError('Truncated string.')
565 if clear_if_default and not size:
566 field_dict.pop(key, None)
567 else:
568 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
569 return new_pos
570 return DecodeField
573def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
574 clear_if_default=False):
575 """Returns a decoder for a bytes field."""
577 local_DecodeVarint = _DecodeVarint
579 assert not is_packed
580 if is_repeated:
581 tag_bytes = encoder.TagBytes(field_number,
582 wire_format.WIRETYPE_LENGTH_DELIMITED)
583 tag_len = len(tag_bytes)
584 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
585 value = field_dict.get(key)
586 if value is None:
587 value = field_dict.setdefault(key, new_default(message))
588 while 1:
589 (size, pos) = local_DecodeVarint(buffer, pos)
590 new_pos = pos + size
591 if new_pos > end:
592 raise _DecodeError('Truncated string.')
593 value.append(buffer[pos:new_pos].tobytes())
594 # Predict that the next tag is another copy of the same repeated field.
595 pos = new_pos + tag_len
596 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
597 # Prediction failed. Return.
598 return new_pos
599 return DecodeRepeatedField
600 else:
601 def DecodeField(buffer, pos, end, message, field_dict):
602 (size, pos) = local_DecodeVarint(buffer, pos)
603 new_pos = pos + size
604 if new_pos > end:
605 raise _DecodeError('Truncated string.')
606 if clear_if_default and not size:
607 field_dict.pop(key, None)
608 else:
609 field_dict[key] = buffer[pos:new_pos].tobytes()
610 return new_pos
611 return DecodeField
614def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
615 """Returns a decoder for a group field."""
617 end_tag_bytes = encoder.TagBytes(field_number,
618 wire_format.WIRETYPE_END_GROUP)
619 end_tag_len = len(end_tag_bytes)
621 assert not is_packed
622 if is_repeated:
623 tag_bytes = encoder.TagBytes(field_number,
624 wire_format.WIRETYPE_START_GROUP)
625 tag_len = len(tag_bytes)
626 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
627 value = field_dict.get(key)
628 if value is None:
629 value = field_dict.setdefault(key, new_default(message))
630 while 1:
631 value = field_dict.get(key)
632 if value is None:
633 value = field_dict.setdefault(key, new_default(message))
634 # Read sub-message.
635 pos = value.add()._InternalParse(buffer, pos, end)
636 # Read end tag.
637 new_pos = pos+end_tag_len
638 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
639 raise _DecodeError('Missing group end tag.')
640 # Predict that the next tag is another copy of the same repeated field.
641 pos = new_pos + tag_len
642 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
643 # Prediction failed. Return.
644 return new_pos
645 return DecodeRepeatedField
646 else:
647 def DecodeField(buffer, pos, end, message, field_dict):
648 value = field_dict.get(key)
649 if value is None:
650 value = field_dict.setdefault(key, new_default(message))
651 # Read sub-message.
652 pos = value._InternalParse(buffer, pos, end)
653 # Read end tag.
654 new_pos = pos+end_tag_len
655 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
656 raise _DecodeError('Missing group end tag.')
657 return new_pos
658 return DecodeField
661def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
662 """Returns a decoder for a message field."""
664 local_DecodeVarint = _DecodeVarint
666 assert not is_packed
667 if is_repeated:
668 tag_bytes = encoder.TagBytes(field_number,
669 wire_format.WIRETYPE_LENGTH_DELIMITED)
670 tag_len = len(tag_bytes)
671 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
672 value = field_dict.get(key)
673 if value is None:
674 value = field_dict.setdefault(key, new_default(message))
675 while 1:
676 # Read length.
677 (size, pos) = local_DecodeVarint(buffer, pos)
678 new_pos = pos + size
679 if new_pos > end:
680 raise _DecodeError('Truncated message.')
681 # Read sub-message.
682 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
683 # The only reason _InternalParse would return early is if it
684 # encountered an end-group tag.
685 raise _DecodeError('Unexpected end-group tag.')
686 # Predict that the next tag is another copy of the same repeated field.
687 pos = new_pos + tag_len
688 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
689 # Prediction failed. Return.
690 return new_pos
691 return DecodeRepeatedField
692 else:
693 def DecodeField(buffer, pos, end, message, field_dict):
694 value = field_dict.get(key)
695 if value is None:
696 value = field_dict.setdefault(key, new_default(message))
697 # Read length.
698 (size, pos) = local_DecodeVarint(buffer, pos)
699 new_pos = pos + size
700 if new_pos > end:
701 raise _DecodeError('Truncated message.')
702 # Read sub-message.
703 if value._InternalParse(buffer, pos, new_pos) != new_pos:
704 # The only reason _InternalParse would return early is if it encountered
705 # an end-group tag.
706 raise _DecodeError('Unexpected end-group tag.')
707 return new_pos
708 return DecodeField
711# --------------------------------------------------------------------
713MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
715def MessageSetItemDecoder(descriptor):
716 """Returns a decoder for a MessageSet item.
718 The parameter is the message Descriptor.
720 The message set message looks like this:
721 message MessageSet {
722 repeated group Item = 1 {
723 required int32 type_id = 2;
724 required string message = 3;
725 }
726 }
727 """
729 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
730 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
731 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
733 local_ReadTag = ReadTag
734 local_DecodeVarint = _DecodeVarint
735 local_SkipField = SkipField
737 def DecodeItem(buffer, pos, end, message, field_dict):
738 """Decode serialized message set to its value and new position.
740 Args:
741 buffer: memoryview of the serialized bytes.
742 pos: int, position in the memory view to start at.
743 end: int, end position of serialized data
744 message: Message object to store unknown fields in
745 field_dict: Map[Descriptor, Any] to store decoded values in.
747 Returns:
748 int, new position in serialized data.
749 """
750 message_set_item_start = pos
751 type_id = -1
752 message_start = -1
753 message_end = -1
755 # Technically, type_id and message can appear in any order, so we need
756 # a little loop here.
757 while 1:
758 (tag_bytes, pos) = local_ReadTag(buffer, pos)
759 if tag_bytes == type_id_tag_bytes:
760 (type_id, pos) = local_DecodeVarint(buffer, pos)
761 elif tag_bytes == message_tag_bytes:
762 (size, message_start) = local_DecodeVarint(buffer, pos)
763 pos = message_end = message_start + size
764 elif tag_bytes == item_end_tag_bytes:
765 break
766 else:
767 pos = SkipField(buffer, pos, end, tag_bytes)
768 if pos == -1:
769 raise _DecodeError('Missing group end tag.')
771 if pos > end:
772 raise _DecodeError('Truncated message.')
774 if type_id == -1:
775 raise _DecodeError('MessageSet item missing type_id.')
776 if message_start == -1:
777 raise _DecodeError('MessageSet item missing message.')
779 extension = message.Extensions._FindExtensionByNumber(type_id)
780 # pylint: disable=protected-access
781 if extension is not None:
782 value = field_dict.get(extension)
783 if value is None:
784 message_type = extension.message_type
785 if not hasattr(message_type, '_concrete_class'):
786 message_factory.GetMessageClass(message_type)
787 value = field_dict.setdefault(
788 extension, message_type._concrete_class())
789 if value._InternalParse(buffer, message_start,message_end) != message_end:
790 # The only reason _InternalParse would return early is if it encountered
791 # an end-group tag.
792 raise _DecodeError('Unexpected end-group tag.')
793 else:
794 if not message._unknown_fields:
795 message._unknown_fields = []
796 message._unknown_fields.append(
797 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
798 if message._unknown_field_set is None:
799 message._unknown_field_set = containers.UnknownFieldSet()
800 message._unknown_field_set._add(
801 type_id,
802 wire_format.WIRETYPE_LENGTH_DELIMITED,
803 buffer[message_start:message_end].tobytes())
804 # pylint: enable=protected-access
806 return pos
808 return DecodeItem
811def UnknownMessageSetItemDecoder():
812 """Returns a decoder for a Unknown MessageSet item."""
814 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
815 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
816 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
818 def DecodeUnknownItem(buffer):
819 pos = 0
820 end = len(buffer)
821 message_start = -1
822 message_end = -1
823 while 1:
824 (tag_bytes, pos) = ReadTag(buffer, pos)
825 if tag_bytes == type_id_tag_bytes:
826 (type_id, pos) = _DecodeVarint(buffer, pos)
827 elif tag_bytes == message_tag_bytes:
828 (size, message_start) = _DecodeVarint(buffer, pos)
829 pos = message_end = message_start + size
830 elif tag_bytes == item_end_tag_bytes:
831 break
832 else:
833 pos = SkipField(buffer, pos, end, tag_bytes)
834 if pos == -1:
835 raise _DecodeError('Missing group end tag.')
837 if pos > end:
838 raise _DecodeError('Truncated message.')
840 if type_id == -1:
841 raise _DecodeError('MessageSet item missing type_id.')
842 if message_start == -1:
843 raise _DecodeError('MessageSet item missing message.')
845 return (type_id, buffer[message_start:message_end].tobytes())
847 return DecodeUnknownItem
849# --------------------------------------------------------------------
851def MapDecoder(field_descriptor, new_default, is_message_map):
852 """Returns a decoder for a map field."""
854 key = field_descriptor
855 tag_bytes = encoder.TagBytes(field_descriptor.number,
856 wire_format.WIRETYPE_LENGTH_DELIMITED)
857 tag_len = len(tag_bytes)
858 local_DecodeVarint = _DecodeVarint
859 # Can't read _concrete_class yet; might not be initialized.
860 message_type = field_descriptor.message_type
862 def DecodeMap(buffer, pos, end, message, field_dict):
863 submsg = message_type._concrete_class()
864 value = field_dict.get(key)
865 if value is None:
866 value = field_dict.setdefault(key, new_default(message))
867 while 1:
868 # Read length.
869 (size, pos) = local_DecodeVarint(buffer, pos)
870 new_pos = pos + size
871 if new_pos > end:
872 raise _DecodeError('Truncated message.')
873 # Read sub-message.
874 submsg.Clear()
875 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
876 # The only reason _InternalParse would return early is if it
877 # encountered an end-group tag.
878 raise _DecodeError('Unexpected end-group tag.')
880 if is_message_map:
881 value[submsg.key].CopyFrom(submsg.value)
882 else:
883 value[submsg.key] = submsg.value
885 # Predict that the next tag is another copy of the same repeated field.
886 pos = new_pos + tag_len
887 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
888 # Prediction failed. Return.
889 return new_pos
891 return DecodeMap
893# --------------------------------------------------------------------
894# Optimization is not as heavy here because calls to SkipField() are rare,
895# except for handling end-group tags.
897def _SkipVarint(buffer, pos, end):
898 """Skip a varint value. Returns the new position."""
899 # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
900 # With this code, ord(b'') raises TypeError. Both are handled in
901 # python_message.py to generate a 'Truncated message' error.
902 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
903 pos += 1
904 pos += 1
905 if pos > end:
906 raise _DecodeError('Truncated message.')
907 return pos
909def _SkipFixed64(buffer, pos, end):
910 """Skip a fixed64 value. Returns the new position."""
912 pos += 8
913 if pos > end:
914 raise _DecodeError('Truncated message.')
915 return pos
918def _DecodeFixed64(buffer, pos):
919 """Decode a fixed64."""
920 new_pos = pos + 8
921 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
924def _SkipLengthDelimited(buffer, pos, end):
925 """Skip a length-delimited value. Returns the new position."""
927 (size, pos) = _DecodeVarint(buffer, pos)
928 pos += size
929 if pos > end:
930 raise _DecodeError('Truncated message.')
931 return pos
934def _SkipGroup(buffer, pos, end):
935 """Skip sub-group. Returns the new position."""
937 while 1:
938 (tag_bytes, pos) = ReadTag(buffer, pos)
939 new_pos = SkipField(buffer, pos, end, tag_bytes)
940 if new_pos == -1:
941 return pos
942 pos = new_pos
945def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
946 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
948 unknown_field_set = containers.UnknownFieldSet()
949 while end_pos is None or pos < end_pos:
950 (tag_bytes, pos) = ReadTag(buffer, pos)
951 (tag, _) = _DecodeVarint(tag_bytes, 0)
952 field_number, wire_type = wire_format.UnpackTag(tag)
953 if wire_type == wire_format.WIRETYPE_END_GROUP:
954 break
955 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
956 # pylint: disable=protected-access
957 unknown_field_set._add(field_number, wire_type, data)
959 return (unknown_field_set, pos)
962def _DecodeUnknownField(buffer, pos, wire_type):
963 """Decode a unknown field. Returns the UnknownField and new position."""
965 if wire_type == wire_format.WIRETYPE_VARINT:
966 (data, pos) = _DecodeVarint(buffer, pos)
967 elif wire_type == wire_format.WIRETYPE_FIXED64:
968 (data, pos) = _DecodeFixed64(buffer, pos)
969 elif wire_type == wire_format.WIRETYPE_FIXED32:
970 (data, pos) = _DecodeFixed32(buffer, pos)
971 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
972 (size, pos) = _DecodeVarint(buffer, pos)
973 data = buffer[pos:pos+size].tobytes()
974 pos += size
975 elif wire_type == wire_format.WIRETYPE_START_GROUP:
976 (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
977 elif wire_type == wire_format.WIRETYPE_END_GROUP:
978 return (0, -1)
979 else:
980 raise _DecodeError('Wrong wire type in tag.')
982 return (data, pos)
985def _EndGroup(buffer, pos, end):
986 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
988 return -1
991def _SkipFixed32(buffer, pos, end):
992 """Skip a fixed32 value. Returns the new position."""
994 pos += 4
995 if pos > end:
996 raise _DecodeError('Truncated message.')
997 return pos
1000def _DecodeFixed32(buffer, pos):
1001 """Decode a fixed32."""
1003 new_pos = pos + 4
1004 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1007def _RaiseInvalidWireType(buffer, pos, end):
1008 """Skip function for unknown wire types. Raises an exception."""
1010 raise _DecodeError('Tag had invalid wire type.')
1012def _FieldSkipper():
1013 """Constructs the SkipField function."""
1015 WIRETYPE_TO_SKIPPER = [
1016 _SkipVarint,
1017 _SkipFixed64,
1018 _SkipLengthDelimited,
1019 _SkipGroup,
1020 _EndGroup,
1021 _SkipFixed32,
1022 _RaiseInvalidWireType,
1023 _RaiseInvalidWireType,
1024 ]
1026 wiretype_mask = wire_format.TAG_TYPE_MASK
1028 def SkipField(buffer, pos, end, tag_bytes):
1029 """Skips a field with the specified tag.
1031 |pos| should point to the byte immediately after the tag.
1033 Returns:
1034 The new position (after the tag value), or -1 if the tag is an end-group
1035 tag (in which case the calling loop should break).
1036 """
1038 # The wire type is always in the first byte since varints are little-endian.
1039 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1040 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1042 return SkipField
1044SkipField = _FieldSkipper()