1# Protocol Buffers - Google's data interchange format 
    2# Copyright 2008 Google Inc.  All rights reserved. 
    3# 
    4# Use of this source code is governed by a BSD-style 
    5# license that can be found in the LICENSE file or at 
    6# https://developers.google.com/open-source/licenses/bsd 
    7 
    8"""Code for decoding protocol buffer primitives. 
    9 
    10This code is very similar to encoder.py -- read the docs for that module first. 
    11 
    12A "decoder" is a function with the signature: 
    13  Decode(buffer, pos, end, message, field_dict) 
    14The arguments are: 
    15  buffer:     The string containing the encoded message. 
    16  pos:        The current position in the string. 
    17  end:        The position in the string where the current message ends.  May be 
    18              less than len(buffer) if we're reading a sub-message. 
    19  message:    The message object into which we're parsing. 
    20  field_dict: message._fields (avoids a hashtable lookup). 
    21The decoder reads the field and stores it into field_dict, returning the new 
    22buffer position.  A decoder for a repeated field may proactively decode all of 
    23the elements of that field, if they appear consecutively. 
    24 
    25Note that decoders may throw any of the following: 
    26  IndexError:  Indicates a truncated message. 
    27  struct.error:  Unpacking of a fixed-width field failed. 
    28  message.DecodeError:  Other errors. 
    29 
    30Decoders are expected to raise an exception if they are called with pos > end. 
    31This allows callers to be lax about bounds checking:  it's fineto read past 
    32"end" as long as you are sure that someone else will notice and throw an 
    33exception later on. 
    34 
    35Something up the call stack is expected to catch IndexError and struct.error 
    36and convert them to message.DecodeError. 
    37 
    38Decoders are constructed using decoder constructors with the signature: 
    39  MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 
    40The arguments are: 
    41  field_number:  The field number of the field we want to decode. 
    42  is_repeated:   Is the field a repeated field? (bool) 
    43  is_packed:     Is the field a packed field? (bool) 
    44  key:           The key to use when looking up the field within field_dict. 
    45                 (This is actually the FieldDescriptor but nothing in this 
    46                 file should depend on that.) 
    47  new_default:   A function which takes a message object as a parameter and 
    48                 returns a new instance of the default value for this field. 
    49                 (This is called for repeated fields and sub-messages, when an 
    50                 instance does not already exist.) 
    51 
    52As with encoders, we define a decoder constructor for every type of field. 
    53Then, for every field of every message class we construct an actual decoder. 
    54That decoder goes into a dict indexed by tag, so when we decode a message 
    55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 
    56""" 
    57 
    58__author__ = 'kenton@google.com (Kenton Varda)' 
    59 
    60import math 
    61import numbers 
    62import struct 
    63 
    64from google.protobuf import message 
    65from google.protobuf.internal import containers 
    66from google.protobuf.internal import encoder 
    67from google.protobuf.internal import wire_format 
    68 
    69 
    70# This is not for optimization, but rather to avoid conflicts with local 
    71# variables named "message". 
    72_DecodeError = message.DecodeError 
    73 
    74 
    75def IsDefaultScalarValue(value): 
    76  """Returns whether or not a scalar value is the default value of its type. 
    77 
    78  Specifically, this should be used to determine presence of implicit-presence 
    79  fields, where we disallow custom defaults. 
    80 
    81  Args: 
    82    value: A scalar value to check. 
    83 
    84  Returns: 
    85    True if the value is equivalent to a default value, False otherwise. 
    86  """ 
    87  if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0: 
    88    # Special case for negative zero, where "truthiness" fails to give the right 
    89    # answer. 
    90    return False 
    91 
    92  # Normally, we can just use Python's boolean conversion. 
    93  return not value 
    94 
    95 
    96def _VarintDecoder(mask, result_type): 
    97  """Return an encoder for a basic varint value (does not include tag). 
    98 
    99  Decoded values will be bitwise-anded with the given mask before being 
    100  returned, e.g. to limit them to 32 bits.  The returned decoder does not 
    101  take the usual "end" parameter -- the caller is expected to do bounds checking 
    102  after the fact (often the caller can defer such checking until later).  The 
    103  decoder returns a (value, new_pos) pair. 
    104  """ 
    105 
    106  def DecodeVarint(buffer, pos: int=None): 
    107    result = 0 
    108    shift = 0 
    109    while 1: 
    110      if pos is None: 
    111        # Read from BytesIO 
    112        try: 
    113          b = buffer.read(1)[0] 
    114        except IndexError as e: 
    115          if shift == 0: 
    116            # End of BytesIO. 
    117            return None 
    118          else: 
    119            raise ValueError('Fail to read varint %s' % str(e)) 
    120      else: 
    121        b = buffer[pos] 
    122        pos += 1 
    123      result |= ((b & 0x7f) << shift) 
    124      if not (b & 0x80): 
    125        result &= mask 
    126        result = result_type(result) 
    127        return result if pos is None else (result, pos) 
    128      shift += 7 
    129      if shift >= 64: 
    130        raise _DecodeError('Too many bytes when decoding varint.') 
    131 
    132  return DecodeVarint 
    133 
    134 
    135def _SignedVarintDecoder(bits, result_type): 
    136  """Like _VarintDecoder() but decodes signed values.""" 
    137 
    138  signbit = 1 << (bits - 1) 
    139  mask = (1 << bits) - 1 
    140 
    141  def DecodeVarint(buffer, pos): 
    142    result = 0 
    143    shift = 0 
    144    while 1: 
    145      b = buffer[pos] 
    146      result |= ((b & 0x7f) << shift) 
    147      pos += 1 
    148      if not (b & 0x80): 
    149        result &= mask 
    150        result = (result ^ signbit) - signbit 
    151        result = result_type(result) 
    152        return (result, pos) 
    153      shift += 7 
    154      if shift >= 64: 
    155        raise _DecodeError('Too many bytes when decoding varint.') 
    156  return DecodeVarint 
    157 
    158# All 32-bit and 64-bit values are represented as int. 
    159_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 
    160_DecodeSignedVarint = _SignedVarintDecoder(64, int) 
    161 
    162# Use these versions for values which must be limited to 32 bits. 
    163_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 
    164_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 
    165 
    166 
    167def ReadTag(buffer, pos): 
    168  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 
    169 
    170  We return the raw bytes of the tag rather than decoding them.  The raw 
    171  bytes can then be used to look up the proper decoder.  This effectively allows 
    172  us to trade some work that would be done in pure-python (decoding a varint) 
    173  for work that is done in C (searching for a byte string in a hash table). 
    174  In a low-level language it would be much cheaper to decode the varint and 
    175  use that, but not in Python. 
    176 
    177  Args: 
    178    buffer: memoryview object of the encoded bytes 
    179    pos: int of the current position to start from 
    180 
    181  Returns: 
    182    Tuple[bytes, int] of the tag data and new position. 
    183  """ 
    184  start = pos 
    185  while buffer[pos] & 0x80: 
    186    pos += 1 
    187  pos += 1 
    188 
    189  tag_bytes = buffer[start:pos].tobytes() 
    190  return tag_bytes, pos 
    191 
    192 
    193def DecodeTag(tag_bytes): 
    194  """Decode a tag from the bytes. 
    195 
    196  Args: 
    197    tag_bytes: the bytes of the tag 
    198 
    199  Returns: 
    200    Tuple[int, int] of the tag field number and wire type. 
    201  """ 
    202  (tag, _) = _DecodeVarint(tag_bytes, 0) 
    203  return wire_format.UnpackTag(tag) 
    204 
    205 
    206# -------------------------------------------------------------------- 
    207 
    208 
    209def _SimpleDecoder(wire_type, decode_value): 
    210  """Return a constructor for a decoder for fields of a particular type. 
    211 
    212  Args: 
    213      wire_type:  The field's wire type. 
    214      decode_value:  A function which decodes an individual value, e.g. 
    215        _DecodeVarint() 
    216  """ 
    217 
    218  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 
    219                      clear_if_default=False): 
    220    if is_packed: 
    221      local_DecodeVarint = _DecodeVarint 
    222      def DecodePackedField( 
    223          buffer, pos, end, message, field_dict, current_depth=0 
    224      ): 
    225        del current_depth  # unused 
    226        value = field_dict.get(key) 
    227        if value is None: 
    228          value = field_dict.setdefault(key, new_default(message)) 
    229        (endpoint, pos) = local_DecodeVarint(buffer, pos) 
    230        endpoint += pos 
    231        if endpoint > end: 
    232          raise _DecodeError('Truncated message.') 
    233        while pos < endpoint: 
    234          (element, pos) = decode_value(buffer, pos) 
    235          value.append(element) 
    236        if pos > endpoint: 
    237          del value[-1]   # Discard corrupt value. 
    238          raise _DecodeError('Packed element was truncated.') 
    239        return pos 
    240 
    241      return DecodePackedField 
    242    elif is_repeated: 
    243      tag_bytes = encoder.TagBytes(field_number, wire_type) 
    244      tag_len = len(tag_bytes) 
    245      def DecodeRepeatedField( 
    246          buffer, pos, end, message, field_dict, current_depth=0 
    247      ): 
    248        del current_depth  # unused 
    249        value = field_dict.get(key) 
    250        if value is None: 
    251          value = field_dict.setdefault(key, new_default(message)) 
    252        while 1: 
    253          (element, new_pos) = decode_value(buffer, pos) 
    254          value.append(element) 
    255          # Predict that the next tag is another copy of the same repeated 
    256          # field. 
    257          pos = new_pos + tag_len 
    258          if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 
    259            # Prediction failed.  Return. 
    260            if new_pos > end: 
    261              raise _DecodeError('Truncated message.') 
    262            return new_pos 
    263 
    264      return DecodeRepeatedField 
    265    else: 
    266 
    267      def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    268        del current_depth  # unused 
    269        (new_value, pos) = decode_value(buffer, pos) 
    270        if pos > end: 
    271          raise _DecodeError('Truncated message.') 
    272        if clear_if_default and IsDefaultScalarValue(new_value): 
    273          field_dict.pop(key, None) 
    274        else: 
    275          field_dict[key] = new_value 
    276        return pos 
    277 
    278      return DecodeField 
    279 
    280  return SpecificDecoder 
    281 
    282 
    283def _ModifiedDecoder(wire_type, decode_value, modify_value): 
    284  """Like SimpleDecoder but additionally invokes modify_value on every value 
    285  before storing it.  Usually modify_value is ZigZagDecode. 
    286  """ 
    287 
    288  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 
    289  # not enough to make a significant difference. 
    290 
    291  def InnerDecode(buffer, pos): 
    292    (result, new_pos) = decode_value(buffer, pos) 
    293    return (modify_value(result), new_pos) 
    294  return _SimpleDecoder(wire_type, InnerDecode) 
    295 
    296 
    297def _StructPackDecoder(wire_type, format): 
    298  """Return a constructor for a decoder for a fixed-width field. 
    299 
    300  Args: 
    301      wire_type:  The field's wire type. 
    302      format:  The format string to pass to struct.unpack(). 
    303  """ 
    304 
    305  value_size = struct.calcsize(format) 
    306  local_unpack = struct.unpack 
    307 
    308  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 
    309  # not enough to make a significant difference. 
    310 
    311  # Note that we expect someone up-stack to catch struct.error and convert 
    312  # it to _DecodeError -- this way we don't have to set up exception- 
    313  # handling blocks every time we parse one value. 
    314 
    315  def InnerDecode(buffer, pos): 
    316    new_pos = pos + value_size 
    317    result = local_unpack(format, buffer[pos:new_pos])[0] 
    318    return (result, new_pos) 
    319  return _SimpleDecoder(wire_type, InnerDecode) 
    320 
    321 
    322def _FloatDecoder(): 
    323  """Returns a decoder for a float field. 
    324 
    325  This code works around a bug in struct.unpack for non-finite 32-bit 
    326  floating-point values. 
    327  """ 
    328 
    329  local_unpack = struct.unpack 
    330 
    331  def InnerDecode(buffer, pos): 
    332    """Decode serialized float to a float and new position. 
    333 
    334    Args: 
    335      buffer: memoryview of the serialized bytes 
    336      pos: int, position in the memory view to start at. 
    337 
    338    Returns: 
    339      Tuple[float, int] of the deserialized float value and new position 
    340      in the serialized data. 
    341    """ 
    342    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign 
    343    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 
    344    new_pos = pos + 4 
    345    float_bytes = buffer[pos:new_pos].tobytes() 
    346 
    347    # If this value has all its exponent bits set, then it's non-finite. 
    348    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 
    349    # To avoid that, we parse it specially. 
    350    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 
    351      # If at least one significand bit is set... 
    352      if float_bytes[0:3] != b'\x00\x00\x80': 
    353        return (math.nan, new_pos) 
    354      # If sign bit is set... 
    355      if float_bytes[3:4] == b'\xFF': 
    356        return (-math.inf, new_pos) 
    357      return (math.inf, new_pos) 
    358 
    359    # Note that we expect someone up-stack to catch struct.error and convert 
    360    # it to _DecodeError -- this way we don't have to set up exception- 
    361    # handling blocks every time we parse one value. 
    362    result = local_unpack('<f', float_bytes)[0] 
    363    return (result, new_pos) 
    364  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 
    365 
    366 
    367def _DoubleDecoder(): 
    368  """Returns a decoder for a double field. 
    369 
    370  This code works around a bug in struct.unpack for not-a-number. 
    371  """ 
    372 
    373  local_unpack = struct.unpack 
    374 
    375  def InnerDecode(buffer, pos): 
    376    """Decode serialized double to a double and new position. 
    377 
    378    Args: 
    379      buffer: memoryview of the serialized bytes. 
    380      pos: int, position in the memory view to start at. 
    381 
    382    Returns: 
    383      Tuple[float, int] of the decoded double value and new position 
    384      in the serialized data. 
    385    """ 
    386    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign 
    387    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 
    388    new_pos = pos + 8 
    389    double_bytes = buffer[pos:new_pos].tobytes() 
    390 
    391    # If this value has all its exponent bits set and at least one significand 
    392    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it 
    393    # as inf or -inf.  To avoid that, we treat it specially. 
    394    if ((double_bytes[7:8] in b'\x7F\xFF') 
    395        and (double_bytes[6:7] >= b'\xF0') 
    396        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 
    397      return (math.nan, new_pos) 
    398 
    399    # Note that we expect someone up-stack to catch struct.error and convert 
    400    # it to _DecodeError -- this way we don't have to set up exception- 
    401    # handling blocks every time we parse one value. 
    402    result = local_unpack('<d', double_bytes)[0] 
    403    return (result, new_pos) 
    404  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 
    405 
    406 
    407def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 
    408                clear_if_default=False): 
    409  """Returns a decoder for enum field.""" 
    410  enum_type = key.enum_type 
    411  if is_packed: 
    412    local_DecodeVarint = _DecodeVarint 
    413    def DecodePackedField( 
    414        buffer, pos, end, message, field_dict, current_depth=0 
    415    ): 
    416      """Decode serialized packed enum to its value and a new position. 
    417 
    418      Args: 
    419        buffer: memoryview of the serialized bytes. 
    420        pos: int, position in the memory view to start at. 
    421        end: int, end position of serialized data 
    422        message: Message object to store unknown fields in 
    423        field_dict: Map[Descriptor, Any] to store decoded values in. 
    424 
    425      Returns: 
    426        int, new position in serialized data. 
    427      """ 
    428      del current_depth  # unused 
    429      value = field_dict.get(key) 
    430      if value is None: 
    431        value = field_dict.setdefault(key, new_default(message)) 
    432      (endpoint, pos) = local_DecodeVarint(buffer, pos) 
    433      endpoint += pos 
    434      if endpoint > end: 
    435        raise _DecodeError('Truncated message.') 
    436      while pos < endpoint: 
    437        value_start_pos = pos 
    438        (element, pos) = _DecodeSignedVarint32(buffer, pos) 
    439        # pylint: disable=protected-access 
    440        if element in enum_type.values_by_number: 
    441          value.append(element) 
    442        else: 
    443          if not message._unknown_fields: 
    444            message._unknown_fields = [] 
    445          tag_bytes = encoder.TagBytes(field_number, 
    446                                       wire_format.WIRETYPE_VARINT) 
    447 
    448          message._unknown_fields.append( 
    449              (tag_bytes, buffer[value_start_pos:pos].tobytes())) 
    450          # pylint: enable=protected-access 
    451      if pos > endpoint: 
    452        if element in enum_type.values_by_number: 
    453          del value[-1]   # Discard corrupt value. 
    454        else: 
    455          del message._unknown_fields[-1] 
    456          # pylint: enable=protected-access 
    457        raise _DecodeError('Packed element was truncated.') 
    458      return pos 
    459 
    460    return DecodePackedField 
    461  elif is_repeated: 
    462    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 
    463    tag_len = len(tag_bytes) 
    464    def DecodeRepeatedField( 
    465        buffer, pos, end, message, field_dict, current_depth=0 
    466    ): 
    467      """Decode serialized repeated enum to its value and a new position. 
    468 
    469      Args: 
    470        buffer: memoryview of the serialized bytes. 
    471        pos: int, position in the memory view to start at. 
    472        end: int, end position of serialized data 
    473        message: Message object to store unknown fields in 
    474        field_dict: Map[Descriptor, Any] to store decoded values in. 
    475 
    476      Returns: 
    477        int, new position in serialized data. 
    478      """ 
    479      del current_depth  # unused 
    480      value = field_dict.get(key) 
    481      if value is None: 
    482        value = field_dict.setdefault(key, new_default(message)) 
    483      while 1: 
    484        (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 
    485        # pylint: disable=protected-access 
    486        if element in enum_type.values_by_number: 
    487          value.append(element) 
    488        else: 
    489          if not message._unknown_fields: 
    490            message._unknown_fields = [] 
    491          message._unknown_fields.append( 
    492              (tag_bytes, buffer[pos:new_pos].tobytes())) 
    493        # pylint: enable=protected-access 
    494        # Predict that the next tag is another copy of the same repeated 
    495        # field. 
    496        pos = new_pos + tag_len 
    497        if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 
    498          # Prediction failed.  Return. 
    499          if new_pos > end: 
    500            raise _DecodeError('Truncated message.') 
    501          return new_pos 
    502 
    503    return DecodeRepeatedField 
    504  else: 
    505 
    506    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    507      """Decode serialized repeated enum to its value and a new position. 
    508 
    509      Args: 
    510        buffer: memoryview of the serialized bytes. 
    511        pos: int, position in the memory view to start at. 
    512        end: int, end position of serialized data 
    513        message: Message object to store unknown fields in 
    514        field_dict: Map[Descriptor, Any] to store decoded values in. 
    515 
    516      Returns: 
    517        int, new position in serialized data. 
    518      """ 
    519      del current_depth  # unused 
    520      value_start_pos = pos 
    521      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 
    522      if pos > end: 
    523        raise _DecodeError('Truncated message.') 
    524      if clear_if_default and IsDefaultScalarValue(enum_value): 
    525        field_dict.pop(key, None) 
    526        return pos 
    527      # pylint: disable=protected-access 
    528      if enum_value in enum_type.values_by_number: 
    529        field_dict[key] = enum_value 
    530      else: 
    531        if not message._unknown_fields: 
    532          message._unknown_fields = [] 
    533        tag_bytes = encoder.TagBytes(field_number, 
    534                                     wire_format.WIRETYPE_VARINT) 
    535        message._unknown_fields.append( 
    536            (tag_bytes, buffer[value_start_pos:pos].tobytes())) 
    537        # pylint: enable=protected-access 
    538      return pos 
    539 
    540    return DecodeField 
    541 
    542 
    543# -------------------------------------------------------------------- 
    544 
    545 
    546Int32Decoder = _SimpleDecoder( 
    547    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 
    548 
    549Int64Decoder = _SimpleDecoder( 
    550    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 
    551 
    552UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 
    553UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 
    554 
    555SInt32Decoder = _ModifiedDecoder( 
    556    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 
    557SInt64Decoder = _ModifiedDecoder( 
    558    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 
    559 
    560# Note that Python conveniently guarantees that when using the '<' prefix on 
    561# formats, they will also have the same size across all platforms (as opposed 
    562# to without the prefix, where their sizes depend on the C compiler's basic 
    563# type sizes). 
    564Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 
    565Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 
    566SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 
    567SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 
    568FloatDecoder = _FloatDecoder() 
    569DoubleDecoder = _DoubleDecoder() 
    570 
    571BoolDecoder = _ModifiedDecoder( 
    572    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 
    573 
    574 
    575def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 
    576                  clear_if_default=False): 
    577  """Returns a decoder for a string field.""" 
    578 
    579  local_DecodeVarint = _DecodeVarint 
    580 
    581  def _ConvertToUnicode(memview): 
    582    """Convert byte to unicode.""" 
    583    byte_str = memview.tobytes() 
    584    try: 
    585      value = str(byte_str, 'utf-8') 
    586    except UnicodeDecodeError as e: 
    587      # add more information to the error message and re-raise it. 
    588      e.reason = '%s in field: %s' % (e, key.full_name) 
    589      raise 
    590 
    591    return value 
    592 
    593  assert not is_packed 
    594  if is_repeated: 
    595    tag_bytes = encoder.TagBytes(field_number, 
    596                                 wire_format.WIRETYPE_LENGTH_DELIMITED) 
    597    tag_len = len(tag_bytes) 
    598    def DecodeRepeatedField( 
    599        buffer, pos, end, message, field_dict, current_depth=0 
    600    ): 
    601      del current_depth  # unused 
    602      value = field_dict.get(key) 
    603      if value is None: 
    604        value = field_dict.setdefault(key, new_default(message)) 
    605      while 1: 
    606        (size, pos) = local_DecodeVarint(buffer, pos) 
    607        new_pos = pos + size 
    608        if new_pos > end: 
    609          raise _DecodeError('Truncated string.') 
    610        value.append(_ConvertToUnicode(buffer[pos:new_pos])) 
    611        # Predict that the next tag is another copy of the same repeated field. 
    612        pos = new_pos + tag_len 
    613        if buffer[new_pos:pos] != tag_bytes or new_pos == end: 
    614          # Prediction failed.  Return. 
    615          return new_pos 
    616 
    617    return DecodeRepeatedField 
    618  else: 
    619 
    620    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    621      del current_depth  # unused 
    622      (size, pos) = local_DecodeVarint(buffer, pos) 
    623      new_pos = pos + size 
    624      if new_pos > end: 
    625        raise _DecodeError('Truncated string.') 
    626      if clear_if_default and IsDefaultScalarValue(size): 
    627        field_dict.pop(key, None) 
    628      else: 
    629        field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 
    630      return new_pos 
    631 
    632    return DecodeField 
    633 
    634 
    635def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 
    636                 clear_if_default=False): 
    637  """Returns a decoder for a bytes field.""" 
    638 
    639  local_DecodeVarint = _DecodeVarint 
    640 
    641  assert not is_packed 
    642  if is_repeated: 
    643    tag_bytes = encoder.TagBytes(field_number, 
    644                                 wire_format.WIRETYPE_LENGTH_DELIMITED) 
    645    tag_len = len(tag_bytes) 
    646    def DecodeRepeatedField( 
    647        buffer, pos, end, message, field_dict, current_depth=0 
    648    ): 
    649      del current_depth  # unused 
    650      value = field_dict.get(key) 
    651      if value is None: 
    652        value = field_dict.setdefault(key, new_default(message)) 
    653      while 1: 
    654        (size, pos) = local_DecodeVarint(buffer, pos) 
    655        new_pos = pos + size 
    656        if new_pos > end: 
    657          raise _DecodeError('Truncated string.') 
    658        value.append(buffer[pos:new_pos].tobytes()) 
    659        # Predict that the next tag is another copy of the same repeated field. 
    660        pos = new_pos + tag_len 
    661        if buffer[new_pos:pos] != tag_bytes or new_pos == end: 
    662          # Prediction failed.  Return. 
    663          return new_pos 
    664 
    665    return DecodeRepeatedField 
    666  else: 
    667 
    668    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    669      del current_depth  # unused 
    670      (size, pos) = local_DecodeVarint(buffer, pos) 
    671      new_pos = pos + size 
    672      if new_pos > end: 
    673        raise _DecodeError('Truncated string.') 
    674      if clear_if_default and IsDefaultScalarValue(size): 
    675        field_dict.pop(key, None) 
    676      else: 
    677        field_dict[key] = buffer[pos:new_pos].tobytes() 
    678      return new_pos 
    679 
    680    return DecodeField 
    681 
    682 
    683def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 
    684  """Returns a decoder for a group field.""" 
    685 
    686  end_tag_bytes = encoder.TagBytes(field_number, 
    687                                   wire_format.WIRETYPE_END_GROUP) 
    688  end_tag_len = len(end_tag_bytes) 
    689 
    690  assert not is_packed 
    691  if is_repeated: 
    692    tag_bytes = encoder.TagBytes(field_number, 
    693                                 wire_format.WIRETYPE_START_GROUP) 
    694    tag_len = len(tag_bytes) 
    695    def DecodeRepeatedField( 
    696        buffer, pos, end, message, field_dict, current_depth=0 
    697    ): 
    698      value = field_dict.get(key) 
    699      if value is None: 
    700        value = field_dict.setdefault(key, new_default(message)) 
    701      while 1: 
    702        value = field_dict.get(key) 
    703        if value is None: 
    704          value = field_dict.setdefault(key, new_default(message)) 
    705        # Read sub-message. 
    706        current_depth += 1 
    707        if current_depth > _recursion_limit: 
    708          raise _DecodeError( 
    709              'Error parsing message: too many levels of nesting.' 
    710          ) 
    711        pos = value.add()._InternalParse(buffer, pos, end, current_depth) 
    712        current_depth -= 1 
    713        # Read end tag. 
    714        new_pos = pos+end_tag_len 
    715        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 
    716          raise _DecodeError('Missing group end tag.') 
    717        # Predict that the next tag is another copy of the same repeated field. 
    718        pos = new_pos + tag_len 
    719        if buffer[new_pos:pos] != tag_bytes or new_pos == end: 
    720          # Prediction failed.  Return. 
    721          return new_pos 
    722 
    723    return DecodeRepeatedField 
    724  else: 
    725 
    726    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    727      value = field_dict.get(key) 
    728      if value is None: 
    729        value = field_dict.setdefault(key, new_default(message)) 
    730      # Read sub-message. 
    731      current_depth += 1 
    732      if current_depth > _recursion_limit: 
    733        raise _DecodeError('Error parsing message: too many levels of nesting.') 
    734      pos = value._InternalParse(buffer, pos, end, current_depth) 
    735      current_depth -= 1 
    736      # Read end tag. 
    737      new_pos = pos+end_tag_len 
    738      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 
    739        raise _DecodeError('Missing group end tag.') 
    740      return new_pos 
    741 
    742    return DecodeField 
    743 
    744 
    745def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 
    746  """Returns a decoder for a message field.""" 
    747 
    748  local_DecodeVarint = _DecodeVarint 
    749 
    750  assert not is_packed 
    751  if is_repeated: 
    752    tag_bytes = encoder.TagBytes(field_number, 
    753                                 wire_format.WIRETYPE_LENGTH_DELIMITED) 
    754    tag_len = len(tag_bytes) 
    755    def DecodeRepeatedField( 
    756        buffer, pos, end, message, field_dict, current_depth=0 
    757    ): 
    758      value = field_dict.get(key) 
    759      if value is None: 
    760        value = field_dict.setdefault(key, new_default(message)) 
    761      while 1: 
    762        # Read length. 
    763        (size, pos) = local_DecodeVarint(buffer, pos) 
    764        new_pos = pos + size 
    765        if new_pos > end: 
    766          raise _DecodeError('Truncated message.') 
    767        # Read sub-message. 
    768        current_depth += 1 
    769        if current_depth > _recursion_limit: 
    770          raise _DecodeError( 
    771              'Error parsing message: too many levels of nesting.' 
    772          ) 
    773        if ( 
    774            value.add()._InternalParse(buffer, pos, new_pos, current_depth) 
    775            != new_pos 
    776        ): 
    777          # The only reason _InternalParse would return early is if it 
    778          # encountered an end-group tag. 
    779          raise _DecodeError('Unexpected end-group tag.') 
    780        current_depth -= 1 
    781        # Predict that the next tag is another copy of the same repeated field. 
    782        pos = new_pos + tag_len 
    783        if buffer[new_pos:pos] != tag_bytes or new_pos == end: 
    784          # Prediction failed.  Return. 
    785          return new_pos 
    786 
    787    return DecodeRepeatedField 
    788  else: 
    789 
    790    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 
    791      value = field_dict.get(key) 
    792      if value is None: 
    793        value = field_dict.setdefault(key, new_default(message)) 
    794      # Read length. 
    795      (size, pos) = local_DecodeVarint(buffer, pos) 
    796      new_pos = pos + size 
    797      if new_pos > end: 
    798        raise _DecodeError('Truncated message.') 
    799      # Read sub-message. 
    800      current_depth += 1 
    801      if current_depth > _recursion_limit: 
    802        raise _DecodeError('Error parsing message: too many levels of nesting.') 
    803      if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: 
    804        # The only reason _InternalParse would return early is if it encountered 
    805        # an end-group tag. 
    806        raise _DecodeError('Unexpected end-group tag.') 
    807      current_depth -= 1 
    808      return new_pos 
    809 
    810    return DecodeField 
    811 
    812 
    813# -------------------------------------------------------------------- 
    814 
    815MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 
    816 
    817def MessageSetItemDecoder(descriptor): 
    818  """Returns a decoder for a MessageSet item. 
    819 
    820  The parameter is the message Descriptor. 
    821 
    822  The message set message looks like this: 
    823    message MessageSet { 
    824      repeated group Item = 1 { 
    825        required int32 type_id = 2; 
    826        required string message = 3; 
    827      } 
    828    } 
    829  """ 
    830 
    831  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 
    832  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 
    833  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 
    834 
    835  local_ReadTag = ReadTag 
    836  local_DecodeVarint = _DecodeVarint 
    837 
    838  def DecodeItem(buffer, pos, end, message, field_dict): 
    839    """Decode serialized message set to its value and new position. 
    840 
    841    Args: 
    842      buffer: memoryview of the serialized bytes. 
    843      pos: int, position in the memory view to start at. 
    844      end: int, end position of serialized data 
    845      message: Message object to store unknown fields in 
    846      field_dict: Map[Descriptor, Any] to store decoded values in. 
    847 
    848    Returns: 
    849      int, new position in serialized data. 
    850    """ 
    851    message_set_item_start = pos 
    852    type_id = -1 
    853    message_start = -1 
    854    message_end = -1 
    855 
    856    # Technically, type_id and message can appear in any order, so we need 
    857    # a little loop here. 
    858    while 1: 
    859      (tag_bytes, pos) = local_ReadTag(buffer, pos) 
    860      if tag_bytes == type_id_tag_bytes: 
    861        (type_id, pos) = local_DecodeVarint(buffer, pos) 
    862      elif tag_bytes == message_tag_bytes: 
    863        (size, message_start) = local_DecodeVarint(buffer, pos) 
    864        pos = message_end = message_start + size 
    865      elif tag_bytes == item_end_tag_bytes: 
    866        break 
    867      else: 
    868        field_number, wire_type = DecodeTag(tag_bytes) 
    869        _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 
    870        if pos == -1: 
    871          raise _DecodeError('Unexpected end-group tag.') 
    872 
    873    if pos > end: 
    874      raise _DecodeError('Truncated message.') 
    875 
    876    if type_id == -1: 
    877      raise _DecodeError('MessageSet item missing type_id.') 
    878    if message_start == -1: 
    879      raise _DecodeError('MessageSet item missing message.') 
    880 
    881    extension = message.Extensions._FindExtensionByNumber(type_id) 
    882    # pylint: disable=protected-access 
    883    if extension is not None: 
    884      value = field_dict.get(extension) 
    885      if value is None: 
    886        message_type = extension.message_type 
    887        if not hasattr(message_type, '_concrete_class'): 
    888          message_factory.GetMessageClass(message_type) 
    889        value = field_dict.setdefault( 
    890            extension, message_type._concrete_class()) 
    891      if value._InternalParse(buffer, message_start,message_end) != message_end: 
    892        # The only reason _InternalParse would return early is if it encountered 
    893        # an end-group tag. 
    894        raise _DecodeError('Unexpected end-group tag.') 
    895    else: 
    896      if not message._unknown_fields: 
    897        message._unknown_fields = [] 
    898      message._unknown_fields.append( 
    899          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 
    900      # pylint: enable=protected-access 
    901 
    902    return pos 
    903 
    904  return DecodeItem 
    905 
    906 
    907def UnknownMessageSetItemDecoder(): 
    908  """Returns a decoder for a Unknown MessageSet item.""" 
    909 
    910  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 
    911  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 
    912  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 
    913 
    914  def DecodeUnknownItem(buffer): 
    915    pos = 0 
    916    end = len(buffer) 
    917    message_start = -1 
    918    message_end = -1 
    919    while 1: 
    920      (tag_bytes, pos) = ReadTag(buffer, pos) 
    921      if tag_bytes == type_id_tag_bytes: 
    922        (type_id, pos) = _DecodeVarint(buffer, pos) 
    923      elif tag_bytes == message_tag_bytes: 
    924        (size, message_start) = _DecodeVarint(buffer, pos) 
    925        pos = message_end = message_start + size 
    926      elif tag_bytes == item_end_tag_bytes: 
    927        break 
    928      else: 
    929        field_number, wire_type = DecodeTag(tag_bytes) 
    930        _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 
    931        if pos == -1: 
    932          raise _DecodeError('Unexpected end-group tag.') 
    933 
    934    if pos > end: 
    935      raise _DecodeError('Truncated message.') 
    936 
    937    if type_id == -1: 
    938      raise _DecodeError('MessageSet item missing type_id.') 
    939    if message_start == -1: 
    940      raise _DecodeError('MessageSet item missing message.') 
    941 
    942    return (type_id, buffer[message_start:message_end].tobytes()) 
    943 
    944  return DecodeUnknownItem 
    945 
    946# -------------------------------------------------------------------- 
    947 
    948def MapDecoder(field_descriptor, new_default, is_message_map): 
    949  """Returns a decoder for a map field.""" 
    950 
    951  key = field_descriptor 
    952  tag_bytes = encoder.TagBytes(field_descriptor.number, 
    953                               wire_format.WIRETYPE_LENGTH_DELIMITED) 
    954  tag_len = len(tag_bytes) 
    955  local_DecodeVarint = _DecodeVarint 
    956  # Can't read _concrete_class yet; might not be initialized. 
    957  message_type = field_descriptor.message_type 
    958 
    959  def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0): 
    960    del current_depth  # Unused. 
    961    submsg = message_type._concrete_class() 
    962    value = field_dict.get(key) 
    963    if value is None: 
    964      value = field_dict.setdefault(key, new_default(message)) 
    965    while 1: 
    966      # Read length. 
    967      (size, pos) = local_DecodeVarint(buffer, pos) 
    968      new_pos = pos + size 
    969      if new_pos > end: 
    970        raise _DecodeError('Truncated message.') 
    971      # Read sub-message. 
    972      submsg.Clear() 
    973      if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 
    974        # The only reason _InternalParse would return early is if it 
    975        # encountered an end-group tag. 
    976        raise _DecodeError('Unexpected end-group tag.') 
    977 
    978      if is_message_map: 
    979        value[submsg.key].CopyFrom(submsg.value) 
    980      else: 
    981        value[submsg.key] = submsg.value 
    982 
    983      # Predict that the next tag is another copy of the same repeated field. 
    984      pos = new_pos + tag_len 
    985      if buffer[new_pos:pos] != tag_bytes or new_pos == end: 
    986        # Prediction failed.  Return. 
    987        return new_pos 
    988 
    989  return DecodeMap 
    990 
    991 
    992def _DecodeFixed64(buffer, pos): 
    993  """Decode a fixed64.""" 
    994  new_pos = pos + 8 
    995  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 
    996 
    997 
    998def _DecodeFixed32(buffer, pos): 
    999  """Decode a fixed32.""" 
    1000 
    1001  new_pos = pos + 4 
    1002  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 
    1003DEFAULT_RECURSION_LIMIT = 100 
    1004_recursion_limit = DEFAULT_RECURSION_LIMIT 
    1005 
    1006 
    1007def SetRecursionLimit(new_limit): 
    1008  global _recursion_limit 
    1009  _recursion_limit = new_limit 
    1010 
    1011 
    1012def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0): 
    1013  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position.""" 
    1014 
    1015  unknown_field_set = containers.UnknownFieldSet() 
    1016  while end_pos is None or pos < end_pos: 
    1017    (tag_bytes, pos) = ReadTag(buffer, pos) 
    1018    (tag, _) = _DecodeVarint(tag_bytes, 0) 
    1019    field_number, wire_type = wire_format.UnpackTag(tag) 
    1020    if wire_type == wire_format.WIRETYPE_END_GROUP: 
    1021      break 
    1022    (data, pos) = _DecodeUnknownField( 
    1023        buffer, pos, end_pos, field_number, wire_type, current_depth 
    1024    ) 
    1025    # pylint: disable=protected-access 
    1026    unknown_field_set._add(field_number, wire_type, data) 
    1027 
    1028  return (unknown_field_set, pos) 
    1029 
    1030 
    1031def _DecodeUnknownField( 
    1032    buffer, pos, end_pos, field_number, wire_type, current_depth=0 
    1033): 
    1034  """Decode a unknown field.  Returns the UnknownField and new position.""" 
    1035 
    1036  if wire_type == wire_format.WIRETYPE_VARINT: 
    1037    (data, pos) = _DecodeVarint(buffer, pos) 
    1038  elif wire_type == wire_format.WIRETYPE_FIXED64: 
    1039    (data, pos) = _DecodeFixed64(buffer, pos) 
    1040  elif wire_type == wire_format.WIRETYPE_FIXED32: 
    1041    (data, pos) = _DecodeFixed32(buffer, pos) 
    1042  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 
    1043    (size, pos) = _DecodeVarint(buffer, pos) 
    1044    data = buffer[pos:pos+size].tobytes() 
    1045    pos += size 
    1046  elif wire_type == wire_format.WIRETYPE_START_GROUP: 
    1047    end_tag_bytes = encoder.TagBytes( 
    1048        field_number, wire_format.WIRETYPE_END_GROUP 
    1049    ) 
    1050    current_depth += 1 
    1051    if current_depth >= _recursion_limit: 
    1052      raise _DecodeError('Error parsing message: too many levels of nesting.') 
    1053    data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth) 
    1054    current_depth -= 1 
    1055    # Check end tag. 
    1056    if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes: 
    1057      raise _DecodeError('Missing group end tag.') 
    1058  elif wire_type == wire_format.WIRETYPE_END_GROUP: 
    1059    return (0, -1) 
    1060  else: 
    1061    raise _DecodeError('Wrong wire type in tag.') 
    1062 
    1063  if pos > end_pos: 
    1064    raise _DecodeError('Truncated message.') 
    1065 
    1066  return (data, pos)