Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/protobuf/internal/decoder.py: 13%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

548 statements  

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Code for decoding protocol buffer primitives. 

9 

10This code is very similar to encoder.py -- read the docs for that module first. 

11 

12A "decoder" is a function with the signature: 

13 Decode(buffer, pos, end, message, field_dict) 

14The arguments are: 

15 buffer: The string containing the encoded message. 

16 pos: The current position in the string. 

17 end: The position in the string where the current message ends. May be 

18 less than len(buffer) if we're reading a sub-message. 

19 message: The message object into which we're parsing. 

20 field_dict: message._fields (avoids a hashtable lookup). 

21The decoder reads the field and stores it into field_dict, returning the new 

22buffer position. A decoder for a repeated field may proactively decode all of 

23the elements of that field, if they appear consecutively. 

24 

25Note that decoders may throw any of the following: 

26 IndexError: Indicates a truncated message. 

27 struct.error: Unpacking of a fixed-width field failed. 

28 message.DecodeError: Other errors. 

29 

30Decoders are expected to raise an exception if they are called with pos > end. 

31This allows callers to be lax about bounds checking: it's fineto read past 

32"end" as long as you are sure that someone else will notice and throw an 

33exception later on. 

34 

35Something up the call stack is expected to catch IndexError and struct.error 

36and convert them to message.DecodeError. 

37 

38Decoders are constructed using decoder constructors with the signature: 

39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

40The arguments are: 

41 field_number: The field number of the field we want to decode. 

42 is_repeated: Is the field a repeated field? (bool) 

43 is_packed: Is the field a packed field? (bool) 

44 key: The key to use when looking up the field within field_dict. 

45 (This is actually the FieldDescriptor but nothing in this 

46 file should depend on that.) 

47 new_default: A function which takes a message object as a parameter and 

48 returns a new instance of the default value for this field. 

49 (This is called for repeated fields and sub-messages, when an 

50 instance does not already exist.) 

51 

52As with encoders, we define a decoder constructor for every type of field. 

53Then, for every field of every message class we construct an actual decoder. 

54That decoder goes into a dict indexed by tag, so when we decode a message 

55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

56""" 

57 

58__author__ = 'kenton@google.com (Kenton Varda)' 

59 

60import math 

61import numbers 

62import struct 

63 

64from google.protobuf import message 

65from google.protobuf.internal import containers 

66from google.protobuf.internal import encoder 

67from google.protobuf.internal import wire_format 

68 

69 

70# This is not for optimization, but rather to avoid conflicts with local 

71# variables named "message". 

72_DecodeError = message.DecodeError 

73 

74 

75def IsDefaultScalarValue(value): 

76 """Returns whether or not a scalar value is the default value of its type. 

77 

78 Specifically, this should be used to determine presence of implicit-presence 

79 fields, where we disallow custom defaults. 

80 

81 Args: 

82 value: A scalar value to check. 

83 

84 Returns: 

85 True if the value is equivalent to a default value, False otherwise. 

86 """ 

87 if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0: 

88 # Special case for negative zero, where "truthiness" fails to give the right 

89 # answer. 

90 return False 

91 

92 # Normally, we can just use Python's boolean conversion. 

93 return not value 

94 

95 

96def _VarintDecoder(mask, result_type): 

97 """Return an encoder for a basic varint value (does not include tag). 

98 

99 Decoded values will be bitwise-anded with the given mask before being 

100 returned, e.g. to limit them to 32 bits. The returned decoder does not 

101 take the usual "end" parameter -- the caller is expected to do bounds checking 

102 after the fact (often the caller can defer such checking until later). The 

103 decoder returns a (value, new_pos) pair. 

104 """ 

105 

106 def DecodeVarint(buffer, pos: int=None): 

107 result = 0 

108 shift = 0 

109 while 1: 

110 if pos is None: 

111 # Read from BytesIO 

112 try: 

113 b = buffer.read(1)[0] 

114 except IndexError as e: 

115 if shift == 0: 

116 # End of BytesIO. 

117 return None 

118 else: 

119 raise ValueError('Fail to read varint %s' % str(e)) 

120 else: 

121 b = buffer[pos] 

122 pos += 1 

123 result |= ((b & 0x7f) << shift) 

124 if not (b & 0x80): 

125 result &= mask 

126 result = result_type(result) 

127 return result if pos is None else (result, pos) 

128 shift += 7 

129 if shift >= 64: 

130 raise _DecodeError('Too many bytes when decoding varint.') 

131 

132 return DecodeVarint 

133 

134 

135def _SignedVarintDecoder(bits, result_type): 

136 """Like _VarintDecoder() but decodes signed values.""" 

137 

138 signbit = 1 << (bits - 1) 

139 mask = (1 << bits) - 1 

140 

141 def DecodeVarint(buffer, pos): 

142 result = 0 

143 shift = 0 

144 while 1: 

145 b = buffer[pos] 

146 result |= ((b & 0x7f) << shift) 

147 pos += 1 

148 if not (b & 0x80): 

149 result &= mask 

150 result = (result ^ signbit) - signbit 

151 result = result_type(result) 

152 return (result, pos) 

153 shift += 7 

154 if shift >= 64: 

155 raise _DecodeError('Too many bytes when decoding varint.') 

156 return DecodeVarint 

157 

158# All 32-bit and 64-bit values are represented as int. 

159_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

160_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

161 

162# Use these versions for values which must be limited to 32 bits. 

163_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

164_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

165 

166 

167def ReadTag(buffer, pos): 

168 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

169 

170 We return the raw bytes of the tag rather than decoding them. The raw 

171 bytes can then be used to look up the proper decoder. This effectively allows 

172 us to trade some work that would be done in pure-python (decoding a varint) 

173 for work that is done in C (searching for a byte string in a hash table). 

174 In a low-level language it would be much cheaper to decode the varint and 

175 use that, but not in Python. 

176 

177 Args: 

178 buffer: memoryview object of the encoded bytes 

179 pos: int of the current position to start from 

180 

181 Returns: 

182 Tuple[bytes, int] of the tag data and new position. 

183 """ 

184 start = pos 

185 while buffer[pos] & 0x80: 

186 pos += 1 

187 pos += 1 

188 

189 tag_bytes = buffer[start:pos].tobytes() 

190 return tag_bytes, pos 

191 

192 

193def DecodeTag(tag_bytes): 

194 """Decode a tag from the bytes. 

195 

196 Args: 

197 tag_bytes: the bytes of the tag 

198 

199 Returns: 

200 Tuple[int, int] of the tag field number and wire type. 

201 """ 

202 (tag, _) = _DecodeVarint(tag_bytes, 0) 

203 return wire_format.UnpackTag(tag) 

204 

205 

206# -------------------------------------------------------------------- 

207 

208 

209def _SimpleDecoder(wire_type, decode_value): 

210 """Return a constructor for a decoder for fields of a particular type. 

211 

212 Args: 

213 wire_type: The field's wire type. 

214 decode_value: A function which decodes an individual value, e.g. 

215 _DecodeVarint() 

216 """ 

217 

218 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

219 clear_if_default=False): 

220 if is_packed: 

221 local_DecodeVarint = _DecodeVarint 

222 def DecodePackedField( 

223 buffer, pos, end, message, field_dict, current_depth=0 

224 ): 

225 del current_depth # unused 

226 value = field_dict.get(key) 

227 if value is None: 

228 value = field_dict.setdefault(key, new_default(message)) 

229 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

230 endpoint += pos 

231 if endpoint > end: 

232 raise _DecodeError('Truncated message.') 

233 while pos < endpoint: 

234 (element, pos) = decode_value(buffer, pos) 

235 value.append(element) 

236 if pos > endpoint: 

237 del value[-1] # Discard corrupt value. 

238 raise _DecodeError('Packed element was truncated.') 

239 return pos 

240 

241 return DecodePackedField 

242 elif is_repeated: 

243 tag_bytes = encoder.TagBytes(field_number, wire_type) 

244 tag_len = len(tag_bytes) 

245 def DecodeRepeatedField( 

246 buffer, pos, end, message, field_dict, current_depth=0 

247 ): 

248 del current_depth # unused 

249 value = field_dict.get(key) 

250 if value is None: 

251 value = field_dict.setdefault(key, new_default(message)) 

252 while 1: 

253 (element, new_pos) = decode_value(buffer, pos) 

254 value.append(element) 

255 # Predict that the next tag is another copy of the same repeated 

256 # field. 

257 pos = new_pos + tag_len 

258 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

259 # Prediction failed. Return. 

260 if new_pos > end: 

261 raise _DecodeError('Truncated message.') 

262 return new_pos 

263 

264 return DecodeRepeatedField 

265 else: 

266 

267 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

268 del current_depth # unused 

269 (new_value, pos) = decode_value(buffer, pos) 

270 if pos > end: 

271 raise _DecodeError('Truncated message.') 

272 if clear_if_default and IsDefaultScalarValue(new_value): 

273 field_dict.pop(key, None) 

274 else: 

275 field_dict[key] = new_value 

276 return pos 

277 

278 return DecodeField 

279 

280 return SpecificDecoder 

281 

282 

283def _ModifiedDecoder(wire_type, decode_value, modify_value): 

284 """Like SimpleDecoder but additionally invokes modify_value on every value 

285 before storing it. Usually modify_value is ZigZagDecode. 

286 """ 

287 

288 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

289 # not enough to make a significant difference. 

290 

291 def InnerDecode(buffer, pos): 

292 (result, new_pos) = decode_value(buffer, pos) 

293 return (modify_value(result), new_pos) 

294 return _SimpleDecoder(wire_type, InnerDecode) 

295 

296 

297def _StructPackDecoder(wire_type, format): 

298 """Return a constructor for a decoder for a fixed-width field. 

299 

300 Args: 

301 wire_type: The field's wire type. 

302 format: The format string to pass to struct.unpack(). 

303 """ 

304 

305 value_size = struct.calcsize(format) 

306 local_unpack = struct.unpack 

307 

308 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

309 # not enough to make a significant difference. 

310 

311 # Note that we expect someone up-stack to catch struct.error and convert 

312 # it to _DecodeError -- this way we don't have to set up exception- 

313 # handling blocks every time we parse one value. 

314 

315 def InnerDecode(buffer, pos): 

316 new_pos = pos + value_size 

317 result = local_unpack(format, buffer[pos:new_pos])[0] 

318 return (result, new_pos) 

319 return _SimpleDecoder(wire_type, InnerDecode) 

320 

321 

322def _FloatDecoder(): 

323 """Returns a decoder for a float field. 

324 

325 This code works around a bug in struct.unpack for non-finite 32-bit 

326 floating-point values. 

327 """ 

328 

329 local_unpack = struct.unpack 

330 

331 def InnerDecode(buffer, pos): 

332 """Decode serialized float to a float and new position. 

333 

334 Args: 

335 buffer: memoryview of the serialized bytes 

336 pos: int, position in the memory view to start at. 

337 

338 Returns: 

339 Tuple[float, int] of the deserialized float value and new position 

340 in the serialized data. 

341 """ 

342 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

343 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

344 new_pos = pos + 4 

345 float_bytes = buffer[pos:new_pos].tobytes() 

346 

347 # If this value has all its exponent bits set, then it's non-finite. 

348 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

349 # To avoid that, we parse it specially. 

350 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

351 # If at least one significand bit is set... 

352 if float_bytes[0:3] != b'\x00\x00\x80': 

353 return (math.nan, new_pos) 

354 # If sign bit is set... 

355 if float_bytes[3:4] == b'\xFF': 

356 return (-math.inf, new_pos) 

357 return (math.inf, new_pos) 

358 

359 # Note that we expect someone up-stack to catch struct.error and convert 

360 # it to _DecodeError -- this way we don't have to set up exception- 

361 # handling blocks every time we parse one value. 

362 result = local_unpack('<f', float_bytes)[0] 

363 return (result, new_pos) 

364 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

365 

366 

367def _DoubleDecoder(): 

368 """Returns a decoder for a double field. 

369 

370 This code works around a bug in struct.unpack for not-a-number. 

371 """ 

372 

373 local_unpack = struct.unpack 

374 

375 def InnerDecode(buffer, pos): 

376 """Decode serialized double to a double and new position. 

377 

378 Args: 

379 buffer: memoryview of the serialized bytes. 

380 pos: int, position in the memory view to start at. 

381 

382 Returns: 

383 Tuple[float, int] of the decoded double value and new position 

384 in the serialized data. 

385 """ 

386 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

387 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

388 new_pos = pos + 8 

389 double_bytes = buffer[pos:new_pos].tobytes() 

390 

391 # If this value has all its exponent bits set and at least one significand 

392 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

393 # as inf or -inf. To avoid that, we treat it specially. 

394 if ((double_bytes[7:8] in b'\x7F\xFF') 

395 and (double_bytes[6:7] >= b'\xF0') 

396 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

397 return (math.nan, new_pos) 

398 

399 # Note that we expect someone up-stack to catch struct.error and convert 

400 # it to _DecodeError -- this way we don't have to set up exception- 

401 # handling blocks every time we parse one value. 

402 result = local_unpack('<d', double_bytes)[0] 

403 return (result, new_pos) 

404 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

405 

406 

407def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

408 clear_if_default=False): 

409 """Returns a decoder for enum field.""" 

410 enum_type = key.enum_type 

411 if is_packed: 

412 local_DecodeVarint = _DecodeVarint 

413 def DecodePackedField( 

414 buffer, pos, end, message, field_dict, current_depth=0 

415 ): 

416 """Decode serialized packed enum to its value and a new position. 

417 

418 Args: 

419 buffer: memoryview of the serialized bytes. 

420 pos: int, position in the memory view to start at. 

421 end: int, end position of serialized data 

422 message: Message object to store unknown fields in 

423 field_dict: Map[Descriptor, Any] to store decoded values in. 

424 

425 Returns: 

426 int, new position in serialized data. 

427 """ 

428 del current_depth # unused 

429 value = field_dict.get(key) 

430 if value is None: 

431 value = field_dict.setdefault(key, new_default(message)) 

432 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

433 endpoint += pos 

434 if endpoint > end: 

435 raise _DecodeError('Truncated message.') 

436 while pos < endpoint: 

437 value_start_pos = pos 

438 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

439 # pylint: disable=protected-access 

440 if element in enum_type.values_by_number: 

441 value.append(element) 

442 else: 

443 if not message._unknown_fields: 

444 message._unknown_fields = [] 

445 tag_bytes = encoder.TagBytes(field_number, 

446 wire_format.WIRETYPE_VARINT) 

447 

448 message._unknown_fields.append( 

449 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

450 # pylint: enable=protected-access 

451 if pos > endpoint: 

452 if element in enum_type.values_by_number: 

453 del value[-1] # Discard corrupt value. 

454 else: 

455 del message._unknown_fields[-1] 

456 # pylint: enable=protected-access 

457 raise _DecodeError('Packed element was truncated.') 

458 return pos 

459 

460 return DecodePackedField 

461 elif is_repeated: 

462 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

463 tag_len = len(tag_bytes) 

464 def DecodeRepeatedField( 

465 buffer, pos, end, message, field_dict, current_depth=0 

466 ): 

467 """Decode serialized repeated enum to its value and a new position. 

468 

469 Args: 

470 buffer: memoryview of the serialized bytes. 

471 pos: int, position in the memory view to start at. 

472 end: int, end position of serialized data 

473 message: Message object to store unknown fields in 

474 field_dict: Map[Descriptor, Any] to store decoded values in. 

475 

476 Returns: 

477 int, new position in serialized data. 

478 """ 

479 del current_depth # unused 

480 value = field_dict.get(key) 

481 if value is None: 

482 value = field_dict.setdefault(key, new_default(message)) 

483 while 1: 

484 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

485 # pylint: disable=protected-access 

486 if element in enum_type.values_by_number: 

487 value.append(element) 

488 else: 

489 if not message._unknown_fields: 

490 message._unknown_fields = [] 

491 message._unknown_fields.append( 

492 (tag_bytes, buffer[pos:new_pos].tobytes())) 

493 # pylint: enable=protected-access 

494 # Predict that the next tag is another copy of the same repeated 

495 # field. 

496 pos = new_pos + tag_len 

497 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

498 # Prediction failed. Return. 

499 if new_pos > end: 

500 raise _DecodeError('Truncated message.') 

501 return new_pos 

502 

503 return DecodeRepeatedField 

504 else: 

505 

506 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

507 """Decode serialized repeated enum to its value and a new position. 

508 

509 Args: 

510 buffer: memoryview of the serialized bytes. 

511 pos: int, position in the memory view to start at. 

512 end: int, end position of serialized data 

513 message: Message object to store unknown fields in 

514 field_dict: Map[Descriptor, Any] to store decoded values in. 

515 

516 Returns: 

517 int, new position in serialized data. 

518 """ 

519 del current_depth # unused 

520 value_start_pos = pos 

521 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

522 if pos > end: 

523 raise _DecodeError('Truncated message.') 

524 if clear_if_default and IsDefaultScalarValue(enum_value): 

525 field_dict.pop(key, None) 

526 return pos 

527 # pylint: disable=protected-access 

528 if enum_value in enum_type.values_by_number: 

529 field_dict[key] = enum_value 

530 else: 

531 if not message._unknown_fields: 

532 message._unknown_fields = [] 

533 tag_bytes = encoder.TagBytes(field_number, 

534 wire_format.WIRETYPE_VARINT) 

535 message._unknown_fields.append( 

536 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

537 # pylint: enable=protected-access 

538 return pos 

539 

540 return DecodeField 

541 

542 

543# -------------------------------------------------------------------- 

544 

545 

546Int32Decoder = _SimpleDecoder( 

547 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

548 

549Int64Decoder = _SimpleDecoder( 

550 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

551 

552UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

553UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

554 

555SInt32Decoder = _ModifiedDecoder( 

556 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

557SInt64Decoder = _ModifiedDecoder( 

558 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

559 

560# Note that Python conveniently guarantees that when using the '<' prefix on 

561# formats, they will also have the same size across all platforms (as opposed 

562# to without the prefix, where their sizes depend on the C compiler's basic 

563# type sizes). 

564Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

565Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

566SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

567SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

568FloatDecoder = _FloatDecoder() 

569DoubleDecoder = _DoubleDecoder() 

570 

571BoolDecoder = _ModifiedDecoder( 

572 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

573 

574 

575def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

576 clear_if_default=False): 

577 """Returns a decoder for a string field.""" 

578 

579 local_DecodeVarint = _DecodeVarint 

580 

581 def _ConvertToUnicode(memview): 

582 """Convert byte to unicode.""" 

583 byte_str = memview.tobytes() 

584 try: 

585 value = str(byte_str, 'utf-8') 

586 except UnicodeDecodeError as e: 

587 # add more information to the error message and re-raise it. 

588 e.reason = '%s in field: %s' % (e, key.full_name) 

589 raise 

590 

591 return value 

592 

593 assert not is_packed 

594 if is_repeated: 

595 tag_bytes = encoder.TagBytes(field_number, 

596 wire_format.WIRETYPE_LENGTH_DELIMITED) 

597 tag_len = len(tag_bytes) 

598 def DecodeRepeatedField( 

599 buffer, pos, end, message, field_dict, current_depth=0 

600 ): 

601 del current_depth # unused 

602 value = field_dict.get(key) 

603 if value is None: 

604 value = field_dict.setdefault(key, new_default(message)) 

605 while 1: 

606 (size, pos) = local_DecodeVarint(buffer, pos) 

607 new_pos = pos + size 

608 if new_pos > end: 

609 raise _DecodeError('Truncated string.') 

610 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

611 # Predict that the next tag is another copy of the same repeated field. 

612 pos = new_pos + tag_len 

613 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

614 # Prediction failed. Return. 

615 return new_pos 

616 

617 return DecodeRepeatedField 

618 else: 

619 

620 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

621 del current_depth # unused 

622 (size, pos) = local_DecodeVarint(buffer, pos) 

623 new_pos = pos + size 

624 if new_pos > end: 

625 raise _DecodeError('Truncated string.') 

626 if clear_if_default and IsDefaultScalarValue(size): 

627 field_dict.pop(key, None) 

628 else: 

629 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

630 return new_pos 

631 

632 return DecodeField 

633 

634 

635def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

636 clear_if_default=False): 

637 """Returns a decoder for a bytes field.""" 

638 

639 local_DecodeVarint = _DecodeVarint 

640 

641 assert not is_packed 

642 if is_repeated: 

643 tag_bytes = encoder.TagBytes(field_number, 

644 wire_format.WIRETYPE_LENGTH_DELIMITED) 

645 tag_len = len(tag_bytes) 

646 def DecodeRepeatedField( 

647 buffer, pos, end, message, field_dict, current_depth=0 

648 ): 

649 del current_depth # unused 

650 value = field_dict.get(key) 

651 if value is None: 

652 value = field_dict.setdefault(key, new_default(message)) 

653 while 1: 

654 (size, pos) = local_DecodeVarint(buffer, pos) 

655 new_pos = pos + size 

656 if new_pos > end: 

657 raise _DecodeError('Truncated string.') 

658 value.append(buffer[pos:new_pos].tobytes()) 

659 # Predict that the next tag is another copy of the same repeated field. 

660 pos = new_pos + tag_len 

661 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

662 # Prediction failed. Return. 

663 return new_pos 

664 

665 return DecodeRepeatedField 

666 else: 

667 

668 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

669 del current_depth # unused 

670 (size, pos) = local_DecodeVarint(buffer, pos) 

671 new_pos = pos + size 

672 if new_pos > end: 

673 raise _DecodeError('Truncated string.') 

674 if clear_if_default and IsDefaultScalarValue(size): 

675 field_dict.pop(key, None) 

676 else: 

677 field_dict[key] = buffer[pos:new_pos].tobytes() 

678 return new_pos 

679 

680 return DecodeField 

681 

682 

683def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

684 """Returns a decoder for a group field.""" 

685 

686 end_tag_bytes = encoder.TagBytes(field_number, 

687 wire_format.WIRETYPE_END_GROUP) 

688 end_tag_len = len(end_tag_bytes) 

689 

690 assert not is_packed 

691 if is_repeated: 

692 tag_bytes = encoder.TagBytes(field_number, 

693 wire_format.WIRETYPE_START_GROUP) 

694 tag_len = len(tag_bytes) 

695 def DecodeRepeatedField( 

696 buffer, pos, end, message, field_dict, current_depth=0 

697 ): 

698 value = field_dict.get(key) 

699 if value is None: 

700 value = field_dict.setdefault(key, new_default(message)) 

701 while 1: 

702 value = field_dict.get(key) 

703 if value is None: 

704 value = field_dict.setdefault(key, new_default(message)) 

705 # Read sub-message. 

706 current_depth += 1 

707 if current_depth > _recursion_limit: 

708 raise _DecodeError( 

709 'Error parsing message: too many levels of nesting.' 

710 ) 

711 pos = value.add()._InternalParse(buffer, pos, end, current_depth) 

712 current_depth -= 1 

713 # Read end tag. 

714 new_pos = pos+end_tag_len 

715 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

716 raise _DecodeError('Missing group end tag.') 

717 # Predict that the next tag is another copy of the same repeated field. 

718 pos = new_pos + tag_len 

719 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

720 # Prediction failed. Return. 

721 return new_pos 

722 

723 return DecodeRepeatedField 

724 else: 

725 

726 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

727 value = field_dict.get(key) 

728 if value is None: 

729 value = field_dict.setdefault(key, new_default(message)) 

730 # Read sub-message. 

731 current_depth += 1 

732 if current_depth > _recursion_limit: 

733 raise _DecodeError('Error parsing message: too many levels of nesting.') 

734 pos = value._InternalParse(buffer, pos, end, current_depth) 

735 current_depth -= 1 

736 # Read end tag. 

737 new_pos = pos+end_tag_len 

738 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

739 raise _DecodeError('Missing group end tag.') 

740 return new_pos 

741 

742 return DecodeField 

743 

744 

745def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

746 """Returns a decoder for a message field.""" 

747 

748 local_DecodeVarint = _DecodeVarint 

749 

750 assert not is_packed 

751 if is_repeated: 

752 tag_bytes = encoder.TagBytes(field_number, 

753 wire_format.WIRETYPE_LENGTH_DELIMITED) 

754 tag_len = len(tag_bytes) 

755 def DecodeRepeatedField( 

756 buffer, pos, end, message, field_dict, current_depth=0 

757 ): 

758 value = field_dict.get(key) 

759 if value is None: 

760 value = field_dict.setdefault(key, new_default(message)) 

761 while 1: 

762 # Read length. 

763 (size, pos) = local_DecodeVarint(buffer, pos) 

764 new_pos = pos + size 

765 if new_pos > end: 

766 raise _DecodeError('Truncated message.') 

767 # Read sub-message. 

768 current_depth += 1 

769 if current_depth > _recursion_limit: 

770 raise _DecodeError( 

771 'Error parsing message: too many levels of nesting.' 

772 ) 

773 if ( 

774 value.add()._InternalParse(buffer, pos, new_pos, current_depth) 

775 != new_pos 

776 ): 

777 # The only reason _InternalParse would return early is if it 

778 # encountered an end-group tag. 

779 raise _DecodeError('Unexpected end-group tag.') 

780 current_depth -= 1 

781 # Predict that the next tag is another copy of the same repeated field. 

782 pos = new_pos + tag_len 

783 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

784 # Prediction failed. Return. 

785 return new_pos 

786 

787 return DecodeRepeatedField 

788 else: 

789 

790 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

791 value = field_dict.get(key) 

792 if value is None: 

793 value = field_dict.setdefault(key, new_default(message)) 

794 # Read length. 

795 (size, pos) = local_DecodeVarint(buffer, pos) 

796 new_pos = pos + size 

797 if new_pos > end: 

798 raise _DecodeError('Truncated message.') 

799 # Read sub-message. 

800 current_depth += 1 

801 if current_depth > _recursion_limit: 

802 raise _DecodeError('Error parsing message: too many levels of nesting.') 

803 if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: 

804 # The only reason _InternalParse would return early is if it encountered 

805 # an end-group tag. 

806 raise _DecodeError('Unexpected end-group tag.') 

807 current_depth -= 1 

808 return new_pos 

809 

810 return DecodeField 

811 

812 

813# -------------------------------------------------------------------- 

814 

815MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

816 

817def MessageSetItemDecoder(descriptor): 

818 """Returns a decoder for a MessageSet item. 

819 

820 The parameter is the message Descriptor. 

821 

822 The message set message looks like this: 

823 message MessageSet { 

824 repeated group Item = 1 { 

825 required int32 type_id = 2; 

826 required string message = 3; 

827 } 

828 } 

829 """ 

830 

831 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

832 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

833 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

834 

835 local_ReadTag = ReadTag 

836 local_DecodeVarint = _DecodeVarint 

837 

838 def DecodeItem(buffer, pos, end, message, field_dict, current_depth=0): 

839 """Decode serialized message set to its value and new position. 

840 

841 Args: 

842 buffer: memoryview of the serialized bytes. 

843 pos: int, position in the memory view to start at. 

844 end: int, end position of serialized data 

845 message: Message object to store unknown fields in 

846 field_dict: Map[Descriptor, Any] to store decoded values in. 

847 

848 Returns: 

849 int, new position in serialized data. 

850 """ 

851 message_set_item_start = pos 

852 type_id = -1 

853 message_start = -1 

854 message_end = -1 

855 

856 # Technically, type_id and message can appear in any order, so we need 

857 # a little loop here. 

858 while 1: 

859 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

860 if tag_bytes == type_id_tag_bytes: 

861 (type_id, pos) = local_DecodeVarint(buffer, pos) 

862 elif tag_bytes == message_tag_bytes: 

863 (size, message_start) = local_DecodeVarint(buffer, pos) 

864 pos = message_end = message_start + size 

865 elif tag_bytes == item_end_tag_bytes: 

866 break 

867 else: 

868 field_number, wire_type = DecodeTag(tag_bytes) 

869 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 

870 if pos == -1: 

871 raise _DecodeError('Unexpected end-group tag.') 

872 

873 if pos > end: 

874 raise _DecodeError('Truncated message.') 

875 

876 if type_id == -1: 

877 raise _DecodeError('MessageSet item missing type_id.') 

878 if message_start == -1: 

879 raise _DecodeError('MessageSet item missing message.') 

880 

881 extension = message.Extensions._FindExtensionByNumber(type_id) 

882 # pylint: disable=protected-access 

883 if extension is not None: 

884 value = field_dict.get(extension) 

885 if value is None: 

886 message_type = extension.message_type 

887 if not hasattr(message_type, '_concrete_class'): 

888 message_factory.GetMessageClass(message_type) 

889 value = field_dict.setdefault( 

890 extension, message_type._concrete_class()) 

891 current_depth += 1 

892 if current_depth > _recursion_limit: 

893 raise _DecodeError('Error parsing message: too many levels of nesting.') 

894 if ( 

895 value._InternalParse( 

896 buffer, message_start, message_end, current_depth 

897 ) 

898 != message_end 

899 ): 

900 # The only reason _InternalParse would return early is if it encountered 

901 # an end-group tag. 

902 raise _DecodeError('Unexpected end-group tag.') 

903 current_depth -= 1 

904 else: 

905 if not message._unknown_fields: 

906 message._unknown_fields = [] 

907 message._unknown_fields.append( 

908 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

909 # pylint: enable=protected-access 

910 

911 return pos 

912 

913 return DecodeItem 

914 

915 

916def UnknownMessageSetItemDecoder(): 

917 """Returns a decoder for a Unknown MessageSet item.""" 

918 

919 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

920 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

921 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

922 

923 def DecodeUnknownItem(buffer): 

924 pos = 0 

925 end = len(buffer) 

926 message_start = -1 

927 message_end = -1 

928 while 1: 

929 (tag_bytes, pos) = ReadTag(buffer, pos) 

930 if tag_bytes == type_id_tag_bytes: 

931 (type_id, pos) = _DecodeVarint(buffer, pos) 

932 elif tag_bytes == message_tag_bytes: 

933 (size, message_start) = _DecodeVarint(buffer, pos) 

934 pos = message_end = message_start + size 

935 elif tag_bytes == item_end_tag_bytes: 

936 break 

937 else: 

938 field_number, wire_type = DecodeTag(tag_bytes) 

939 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 

940 if pos == -1: 

941 raise _DecodeError('Unexpected end-group tag.') 

942 

943 if pos > end: 

944 raise _DecodeError('Truncated message.') 

945 

946 if type_id == -1: 

947 raise _DecodeError('MessageSet item missing type_id.') 

948 if message_start == -1: 

949 raise _DecodeError('MessageSet item missing message.') 

950 

951 return (type_id, buffer[message_start:message_end].tobytes()) 

952 

953 return DecodeUnknownItem 

954 

955# -------------------------------------------------------------------- 

956 

957def MapDecoder(field_descriptor, new_default, is_message_map): 

958 """Returns a decoder for a map field.""" 

959 

960 key = field_descriptor 

961 tag_bytes = encoder.TagBytes(field_descriptor.number, 

962 wire_format.WIRETYPE_LENGTH_DELIMITED) 

963 tag_len = len(tag_bytes) 

964 local_DecodeVarint = _DecodeVarint 

965 # Can't read _concrete_class yet; might not be initialized. 

966 message_type = field_descriptor.message_type 

967 

968 def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0): 

969 submsg = message_type._concrete_class() 

970 value = field_dict.get(key) 

971 if value is None: 

972 value = field_dict.setdefault(key, new_default(message)) 

973 while 1: 

974 # Read length. 

975 (size, pos) = local_DecodeVarint(buffer, pos) 

976 new_pos = pos + size 

977 if new_pos > end: 

978 raise _DecodeError('Truncated message.') 

979 # Read sub-message. 

980 submsg.Clear() 

981 current_depth += 1 

982 if current_depth > _recursion_limit: 

983 raise _DecodeError('Error parsing message: too many levels of nesting.') 

984 if submsg._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: 

985 # The only reason _InternalParse would return early is if it 

986 # encountered an end-group tag. 

987 raise _DecodeError('Unexpected end-group tag.') 

988 current_depth -= 1 

989 

990 if is_message_map: 

991 value[submsg.key].CopyFrom(submsg.value) 

992 else: 

993 value[submsg.key] = submsg.value 

994 

995 # Predict that the next tag is another copy of the same repeated field. 

996 pos = new_pos + tag_len 

997 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

998 # Prediction failed. Return. 

999 return new_pos 

1000 

1001 return DecodeMap 

1002 

1003 

1004def _DecodeFixed64(buffer, pos): 

1005 """Decode a fixed64.""" 

1006 new_pos = pos + 8 

1007 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

1008 

1009 

1010def _DecodeFixed32(buffer, pos): 

1011 """Decode a fixed32.""" 

1012 

1013 new_pos = pos + 4 

1014 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

1015DEFAULT_RECURSION_LIMIT = 100 

1016_recursion_limit = DEFAULT_RECURSION_LIMIT 

1017 

1018 

1019def SetRecursionLimit(new_limit): 

1020 global _recursion_limit 

1021 _recursion_limit = new_limit 

1022 

1023 

1024def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0): 

1025 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

1026 

1027 unknown_field_set = containers.UnknownFieldSet() 

1028 while end_pos is None or pos < end_pos: 

1029 (tag_bytes, pos) = ReadTag(buffer, pos) 

1030 (tag, _) = _DecodeVarint(tag_bytes, 0) 

1031 field_number, wire_type = wire_format.UnpackTag(tag) 

1032 if wire_type == wire_format.WIRETYPE_END_GROUP: 

1033 break 

1034 (data, pos) = _DecodeUnknownField( 

1035 buffer, pos, end_pos, field_number, wire_type, current_depth 

1036 ) 

1037 # pylint: disable=protected-access 

1038 unknown_field_set._add(field_number, wire_type, data) 

1039 

1040 return (unknown_field_set, pos) 

1041 

1042 

1043def _DecodeUnknownField( 

1044 buffer, pos, end_pos, field_number, wire_type, current_depth=0 

1045): 

1046 """Decode a unknown field. Returns the UnknownField and new position.""" 

1047 

1048 if wire_type == wire_format.WIRETYPE_VARINT: 

1049 (data, pos) = _DecodeVarint(buffer, pos) 

1050 elif wire_type == wire_format.WIRETYPE_FIXED64: 

1051 (data, pos) = _DecodeFixed64(buffer, pos) 

1052 elif wire_type == wire_format.WIRETYPE_FIXED32: 

1053 (data, pos) = _DecodeFixed32(buffer, pos) 

1054 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

1055 (size, pos) = _DecodeVarint(buffer, pos) 

1056 data = buffer[pos:pos+size].tobytes() 

1057 pos += size 

1058 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

1059 end_tag_bytes = encoder.TagBytes( 

1060 field_number, wire_format.WIRETYPE_END_GROUP 

1061 ) 

1062 current_depth += 1 

1063 if current_depth >= _recursion_limit: 

1064 raise _DecodeError('Error parsing message: too many levels of nesting.') 

1065 data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth) 

1066 current_depth -= 1 

1067 # Check end tag. 

1068 if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes: 

1069 raise _DecodeError('Missing group end tag.') 

1070 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

1071 return (0, -1) 

1072 else: 

1073 raise _DecodeError('Wrong wire type in tag.') 

1074 

1075 if pos > end_pos: 

1076 raise _DecodeError('Truncated message.') 

1077 

1078 return (data, pos)