Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/protobuf/internal/decoder.py: 13%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

541 statements  

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Code for decoding protocol buffer primitives. 

9 

10This code is very similar to encoder.py -- read the docs for that module first. 

11 

12A "decoder" is a function with the signature: 

13 Decode(buffer, pos, end, message, field_dict) 

14The arguments are: 

15 buffer: The string containing the encoded message. 

16 pos: The current position in the string. 

17 end: The position in the string where the current message ends. May be 

18 less than len(buffer) if we're reading a sub-message. 

19 message: The message object into which we're parsing. 

20 field_dict: message._fields (avoids a hashtable lookup). 

21The decoder reads the field and stores it into field_dict, returning the new 

22buffer position. A decoder for a repeated field may proactively decode all of 

23the elements of that field, if they appear consecutively. 

24 

25Note that decoders may throw any of the following: 

26 IndexError: Indicates a truncated message. 

27 struct.error: Unpacking of a fixed-width field failed. 

28 message.DecodeError: Other errors. 

29 

30Decoders are expected to raise an exception if they are called with pos > end. 

31This allows callers to be lax about bounds checking: it's fineto read past 

32"end" as long as you are sure that someone else will notice and throw an 

33exception later on. 

34 

35Something up the call stack is expected to catch IndexError and struct.error 

36and convert them to message.DecodeError. 

37 

38Decoders are constructed using decoder constructors with the signature: 

39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

40The arguments are: 

41 field_number: The field number of the field we want to decode. 

42 is_repeated: Is the field a repeated field? (bool) 

43 is_packed: Is the field a packed field? (bool) 

44 key: The key to use when looking up the field within field_dict. 

45 (This is actually the FieldDescriptor but nothing in this 

46 file should depend on that.) 

47 new_default: A function which takes a message object as a parameter and 

48 returns a new instance of the default value for this field. 

49 (This is called for repeated fields and sub-messages, when an 

50 instance does not already exist.) 

51 

52As with encoders, we define a decoder constructor for every type of field. 

53Then, for every field of every message class we construct an actual decoder. 

54That decoder goes into a dict indexed by tag, so when we decode a message 

55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

56""" 

57 

58__author__ = 'kenton@google.com (Kenton Varda)' 

59 

60import math 

61import numbers 

62import struct 

63 

64from google.protobuf import message 

65from google.protobuf.internal import containers 

66from google.protobuf.internal import encoder 

67from google.protobuf.internal import wire_format 

68 

69 

70# This is not for optimization, but rather to avoid conflicts with local 

71# variables named "message". 

72_DecodeError = message.DecodeError 

73 

74 

75def IsDefaultScalarValue(value): 

76 """Returns whether or not a scalar value is the default value of its type. 

77 

78 Specifically, this should be used to determine presence of implicit-presence 

79 fields, where we disallow custom defaults. 

80 

81 Args: 

82 value: A scalar value to check. 

83 

84 Returns: 

85 True if the value is equivalent to a default value, False otherwise. 

86 """ 

87 if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0: 

88 # Special case for negative zero, where "truthiness" fails to give the right 

89 # answer. 

90 return False 

91 

92 # Normally, we can just use Python's boolean conversion. 

93 return not value 

94 

95 

96def _VarintDecoder(mask, result_type): 

97 """Return an encoder for a basic varint value (does not include tag). 

98 

99 Decoded values will be bitwise-anded with the given mask before being 

100 returned, e.g. to limit them to 32 bits. The returned decoder does not 

101 take the usual "end" parameter -- the caller is expected to do bounds checking 

102 after the fact (often the caller can defer such checking until later). The 

103 decoder returns a (value, new_pos) pair. 

104 """ 

105 

106 def DecodeVarint(buffer, pos: int=None): 

107 result = 0 

108 shift = 0 

109 while 1: 

110 if pos is None: 

111 # Read from BytesIO 

112 try: 

113 b = buffer.read(1)[0] 

114 except IndexError as e: 

115 if shift == 0: 

116 # End of BytesIO. 

117 return None 

118 else: 

119 raise ValueError('Fail to read varint %s' % str(e)) 

120 else: 

121 b = buffer[pos] 

122 pos += 1 

123 result |= ((b & 0x7f) << shift) 

124 if not (b & 0x80): 

125 result &= mask 

126 result = result_type(result) 

127 return result if pos is None else (result, pos) 

128 shift += 7 

129 if shift >= 64: 

130 raise _DecodeError('Too many bytes when decoding varint.') 

131 

132 return DecodeVarint 

133 

134 

135def _SignedVarintDecoder(bits, result_type): 

136 """Like _VarintDecoder() but decodes signed values.""" 

137 

138 signbit = 1 << (bits - 1) 

139 mask = (1 << bits) - 1 

140 

141 def DecodeVarint(buffer, pos): 

142 result = 0 

143 shift = 0 

144 while 1: 

145 b = buffer[pos] 

146 result |= ((b & 0x7f) << shift) 

147 pos += 1 

148 if not (b & 0x80): 

149 result &= mask 

150 result = (result ^ signbit) - signbit 

151 result = result_type(result) 

152 return (result, pos) 

153 shift += 7 

154 if shift >= 64: 

155 raise _DecodeError('Too many bytes when decoding varint.') 

156 return DecodeVarint 

157 

158# All 32-bit and 64-bit values are represented as int. 

159_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

160_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

161 

162# Use these versions for values which must be limited to 32 bits. 

163_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

164_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

165 

166 

167def ReadTag(buffer, pos): 

168 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

169 

170 We return the raw bytes of the tag rather than decoding them. The raw 

171 bytes can then be used to look up the proper decoder. This effectively allows 

172 us to trade some work that would be done in pure-python (decoding a varint) 

173 for work that is done in C (searching for a byte string in a hash table). 

174 In a low-level language it would be much cheaper to decode the varint and 

175 use that, but not in Python. 

176 

177 Args: 

178 buffer: memoryview object of the encoded bytes 

179 pos: int of the current position to start from 

180 

181 Returns: 

182 Tuple[bytes, int] of the tag data and new position. 

183 """ 

184 start = pos 

185 while buffer[pos] & 0x80: 

186 pos += 1 

187 pos += 1 

188 

189 tag_bytes = buffer[start:pos].tobytes() 

190 return tag_bytes, pos 

191 

192 

193def DecodeTag(tag_bytes): 

194 """Decode a tag from the bytes. 

195 

196 Args: 

197 tag_bytes: the bytes of the tag 

198 

199 Returns: 

200 Tuple[int, int] of the tag field number and wire type. 

201 """ 

202 (tag, _) = _DecodeVarint(tag_bytes, 0) 

203 return wire_format.UnpackTag(tag) 

204 

205 

206# -------------------------------------------------------------------- 

207 

208 

209def _SimpleDecoder(wire_type, decode_value): 

210 """Return a constructor for a decoder for fields of a particular type. 

211 

212 Args: 

213 wire_type: The field's wire type. 

214 decode_value: A function which decodes an individual value, e.g. 

215 _DecodeVarint() 

216 """ 

217 

218 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

219 clear_if_default=False): 

220 if is_packed: 

221 local_DecodeVarint = _DecodeVarint 

222 def DecodePackedField( 

223 buffer, pos, end, message, field_dict, current_depth=0 

224 ): 

225 del current_depth # unused 

226 value = field_dict.get(key) 

227 if value is None: 

228 value = field_dict.setdefault(key, new_default(message)) 

229 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

230 endpoint += pos 

231 if endpoint > end: 

232 raise _DecodeError('Truncated message.') 

233 while pos < endpoint: 

234 (element, pos) = decode_value(buffer, pos) 

235 value.append(element) 

236 if pos > endpoint: 

237 del value[-1] # Discard corrupt value. 

238 raise _DecodeError('Packed element was truncated.') 

239 return pos 

240 

241 return DecodePackedField 

242 elif is_repeated: 

243 tag_bytes = encoder.TagBytes(field_number, wire_type) 

244 tag_len = len(tag_bytes) 

245 def DecodeRepeatedField( 

246 buffer, pos, end, message, field_dict, current_depth=0 

247 ): 

248 del current_depth # unused 

249 value = field_dict.get(key) 

250 if value is None: 

251 value = field_dict.setdefault(key, new_default(message)) 

252 while 1: 

253 (element, new_pos) = decode_value(buffer, pos) 

254 value.append(element) 

255 # Predict that the next tag is another copy of the same repeated 

256 # field. 

257 pos = new_pos + tag_len 

258 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

259 # Prediction failed. Return. 

260 if new_pos > end: 

261 raise _DecodeError('Truncated message.') 

262 return new_pos 

263 

264 return DecodeRepeatedField 

265 else: 

266 

267 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

268 del current_depth # unused 

269 (new_value, pos) = decode_value(buffer, pos) 

270 if pos > end: 

271 raise _DecodeError('Truncated message.') 

272 if clear_if_default and IsDefaultScalarValue(new_value): 

273 field_dict.pop(key, None) 

274 else: 

275 field_dict[key] = new_value 

276 return pos 

277 

278 return DecodeField 

279 

280 return SpecificDecoder 

281 

282 

283def _ModifiedDecoder(wire_type, decode_value, modify_value): 

284 """Like SimpleDecoder but additionally invokes modify_value on every value 

285 before storing it. Usually modify_value is ZigZagDecode. 

286 """ 

287 

288 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

289 # not enough to make a significant difference. 

290 

291 def InnerDecode(buffer, pos): 

292 (result, new_pos) = decode_value(buffer, pos) 

293 return (modify_value(result), new_pos) 

294 return _SimpleDecoder(wire_type, InnerDecode) 

295 

296 

297def _StructPackDecoder(wire_type, format): 

298 """Return a constructor for a decoder for a fixed-width field. 

299 

300 Args: 

301 wire_type: The field's wire type. 

302 format: The format string to pass to struct.unpack(). 

303 """ 

304 

305 value_size = struct.calcsize(format) 

306 local_unpack = struct.unpack 

307 

308 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

309 # not enough to make a significant difference. 

310 

311 # Note that we expect someone up-stack to catch struct.error and convert 

312 # it to _DecodeError -- this way we don't have to set up exception- 

313 # handling blocks every time we parse one value. 

314 

315 def InnerDecode(buffer, pos): 

316 new_pos = pos + value_size 

317 result = local_unpack(format, buffer[pos:new_pos])[0] 

318 return (result, new_pos) 

319 return _SimpleDecoder(wire_type, InnerDecode) 

320 

321 

322def _FloatDecoder(): 

323 """Returns a decoder for a float field. 

324 

325 This code works around a bug in struct.unpack for non-finite 32-bit 

326 floating-point values. 

327 """ 

328 

329 local_unpack = struct.unpack 

330 

331 def InnerDecode(buffer, pos): 

332 """Decode serialized float to a float and new position. 

333 

334 Args: 

335 buffer: memoryview of the serialized bytes 

336 pos: int, position in the memory view to start at. 

337 

338 Returns: 

339 Tuple[float, int] of the deserialized float value and new position 

340 in the serialized data. 

341 """ 

342 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

343 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

344 new_pos = pos + 4 

345 float_bytes = buffer[pos:new_pos].tobytes() 

346 

347 # If this value has all its exponent bits set, then it's non-finite. 

348 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

349 # To avoid that, we parse it specially. 

350 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

351 # If at least one significand bit is set... 

352 if float_bytes[0:3] != b'\x00\x00\x80': 

353 return (math.nan, new_pos) 

354 # If sign bit is set... 

355 if float_bytes[3:4] == b'\xFF': 

356 return (-math.inf, new_pos) 

357 return (math.inf, new_pos) 

358 

359 # Note that we expect someone up-stack to catch struct.error and convert 

360 # it to _DecodeError -- this way we don't have to set up exception- 

361 # handling blocks every time we parse one value. 

362 result = local_unpack('<f', float_bytes)[0] 

363 return (result, new_pos) 

364 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

365 

366 

367def _DoubleDecoder(): 

368 """Returns a decoder for a double field. 

369 

370 This code works around a bug in struct.unpack for not-a-number. 

371 """ 

372 

373 local_unpack = struct.unpack 

374 

375 def InnerDecode(buffer, pos): 

376 """Decode serialized double to a double and new position. 

377 

378 Args: 

379 buffer: memoryview of the serialized bytes. 

380 pos: int, position in the memory view to start at. 

381 

382 Returns: 

383 Tuple[float, int] of the decoded double value and new position 

384 in the serialized data. 

385 """ 

386 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

387 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

388 new_pos = pos + 8 

389 double_bytes = buffer[pos:new_pos].tobytes() 

390 

391 # If this value has all its exponent bits set and at least one significand 

392 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

393 # as inf or -inf. To avoid that, we treat it specially. 

394 if ((double_bytes[7:8] in b'\x7F\xFF') 

395 and (double_bytes[6:7] >= b'\xF0') 

396 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

397 return (math.nan, new_pos) 

398 

399 # Note that we expect someone up-stack to catch struct.error and convert 

400 # it to _DecodeError -- this way we don't have to set up exception- 

401 # handling blocks every time we parse one value. 

402 result = local_unpack('<d', double_bytes)[0] 

403 return (result, new_pos) 

404 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

405 

406 

407def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

408 clear_if_default=False): 

409 """Returns a decoder for enum field.""" 

410 enum_type = key.enum_type 

411 if is_packed: 

412 local_DecodeVarint = _DecodeVarint 

413 def DecodePackedField( 

414 buffer, pos, end, message, field_dict, current_depth=0 

415 ): 

416 """Decode serialized packed enum to its value and a new position. 

417 

418 Args: 

419 buffer: memoryview of the serialized bytes. 

420 pos: int, position in the memory view to start at. 

421 end: int, end position of serialized data 

422 message: Message object to store unknown fields in 

423 field_dict: Map[Descriptor, Any] to store decoded values in. 

424 

425 Returns: 

426 int, new position in serialized data. 

427 """ 

428 del current_depth # unused 

429 value = field_dict.get(key) 

430 if value is None: 

431 value = field_dict.setdefault(key, new_default(message)) 

432 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

433 endpoint += pos 

434 if endpoint > end: 

435 raise _DecodeError('Truncated message.') 

436 while pos < endpoint: 

437 value_start_pos = pos 

438 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

439 # pylint: disable=protected-access 

440 if element in enum_type.values_by_number: 

441 value.append(element) 

442 else: 

443 if not message._unknown_fields: 

444 message._unknown_fields = [] 

445 tag_bytes = encoder.TagBytes(field_number, 

446 wire_format.WIRETYPE_VARINT) 

447 

448 message._unknown_fields.append( 

449 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

450 # pylint: enable=protected-access 

451 if pos > endpoint: 

452 if element in enum_type.values_by_number: 

453 del value[-1] # Discard corrupt value. 

454 else: 

455 del message._unknown_fields[-1] 

456 # pylint: enable=protected-access 

457 raise _DecodeError('Packed element was truncated.') 

458 return pos 

459 

460 return DecodePackedField 

461 elif is_repeated: 

462 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

463 tag_len = len(tag_bytes) 

464 def DecodeRepeatedField( 

465 buffer, pos, end, message, field_dict, current_depth=0 

466 ): 

467 """Decode serialized repeated enum to its value and a new position. 

468 

469 Args: 

470 buffer: memoryview of the serialized bytes. 

471 pos: int, position in the memory view to start at. 

472 end: int, end position of serialized data 

473 message: Message object to store unknown fields in 

474 field_dict: Map[Descriptor, Any] to store decoded values in. 

475 

476 Returns: 

477 int, new position in serialized data. 

478 """ 

479 del current_depth # unused 

480 value = field_dict.get(key) 

481 if value is None: 

482 value = field_dict.setdefault(key, new_default(message)) 

483 while 1: 

484 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

485 # pylint: disable=protected-access 

486 if element in enum_type.values_by_number: 

487 value.append(element) 

488 else: 

489 if not message._unknown_fields: 

490 message._unknown_fields = [] 

491 message._unknown_fields.append( 

492 (tag_bytes, buffer[pos:new_pos].tobytes())) 

493 # pylint: enable=protected-access 

494 # Predict that the next tag is another copy of the same repeated 

495 # field. 

496 pos = new_pos + tag_len 

497 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

498 # Prediction failed. Return. 

499 if new_pos > end: 

500 raise _DecodeError('Truncated message.') 

501 return new_pos 

502 

503 return DecodeRepeatedField 

504 else: 

505 

506 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

507 """Decode serialized repeated enum to its value and a new position. 

508 

509 Args: 

510 buffer: memoryview of the serialized bytes. 

511 pos: int, position in the memory view to start at. 

512 end: int, end position of serialized data 

513 message: Message object to store unknown fields in 

514 field_dict: Map[Descriptor, Any] to store decoded values in. 

515 

516 Returns: 

517 int, new position in serialized data. 

518 """ 

519 del current_depth # unused 

520 value_start_pos = pos 

521 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

522 if pos > end: 

523 raise _DecodeError('Truncated message.') 

524 if clear_if_default and IsDefaultScalarValue(enum_value): 

525 field_dict.pop(key, None) 

526 return pos 

527 # pylint: disable=protected-access 

528 if enum_value in enum_type.values_by_number: 

529 field_dict[key] = enum_value 

530 else: 

531 if not message._unknown_fields: 

532 message._unknown_fields = [] 

533 tag_bytes = encoder.TagBytes(field_number, 

534 wire_format.WIRETYPE_VARINT) 

535 message._unknown_fields.append( 

536 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

537 # pylint: enable=protected-access 

538 return pos 

539 

540 return DecodeField 

541 

542 

543# -------------------------------------------------------------------- 

544 

545 

546Int32Decoder = _SimpleDecoder( 

547 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

548 

549Int64Decoder = _SimpleDecoder( 

550 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

551 

552UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

553UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

554 

555SInt32Decoder = _ModifiedDecoder( 

556 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

557SInt64Decoder = _ModifiedDecoder( 

558 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

559 

560# Note that Python conveniently guarantees that when using the '<' prefix on 

561# formats, they will also have the same size across all platforms (as opposed 

562# to without the prefix, where their sizes depend on the C compiler's basic 

563# type sizes). 

564Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

565Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

566SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

567SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

568FloatDecoder = _FloatDecoder() 

569DoubleDecoder = _DoubleDecoder() 

570 

571BoolDecoder = _ModifiedDecoder( 

572 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

573 

574 

575def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

576 clear_if_default=False): 

577 """Returns a decoder for a string field.""" 

578 

579 local_DecodeVarint = _DecodeVarint 

580 

581 def _ConvertToUnicode(memview): 

582 """Convert byte to unicode.""" 

583 byte_str = memview.tobytes() 

584 try: 

585 value = str(byte_str, 'utf-8') 

586 except UnicodeDecodeError as e: 

587 # add more information to the error message and re-raise it. 

588 e.reason = '%s in field: %s' % (e, key.full_name) 

589 raise 

590 

591 return value 

592 

593 assert not is_packed 

594 if is_repeated: 

595 tag_bytes = encoder.TagBytes(field_number, 

596 wire_format.WIRETYPE_LENGTH_DELIMITED) 

597 tag_len = len(tag_bytes) 

598 def DecodeRepeatedField( 

599 buffer, pos, end, message, field_dict, current_depth=0 

600 ): 

601 del current_depth # unused 

602 value = field_dict.get(key) 

603 if value is None: 

604 value = field_dict.setdefault(key, new_default(message)) 

605 while 1: 

606 (size, pos) = local_DecodeVarint(buffer, pos) 

607 new_pos = pos + size 

608 if new_pos > end: 

609 raise _DecodeError('Truncated string.') 

610 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

611 # Predict that the next tag is another copy of the same repeated field. 

612 pos = new_pos + tag_len 

613 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

614 # Prediction failed. Return. 

615 return new_pos 

616 

617 return DecodeRepeatedField 

618 else: 

619 

620 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

621 del current_depth # unused 

622 (size, pos) = local_DecodeVarint(buffer, pos) 

623 new_pos = pos + size 

624 if new_pos > end: 

625 raise _DecodeError('Truncated string.') 

626 if clear_if_default and IsDefaultScalarValue(size): 

627 field_dict.pop(key, None) 

628 else: 

629 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

630 return new_pos 

631 

632 return DecodeField 

633 

634 

635def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

636 clear_if_default=False): 

637 """Returns a decoder for a bytes field.""" 

638 

639 local_DecodeVarint = _DecodeVarint 

640 

641 assert not is_packed 

642 if is_repeated: 

643 tag_bytes = encoder.TagBytes(field_number, 

644 wire_format.WIRETYPE_LENGTH_DELIMITED) 

645 tag_len = len(tag_bytes) 

646 def DecodeRepeatedField( 

647 buffer, pos, end, message, field_dict, current_depth=0 

648 ): 

649 del current_depth # unused 

650 value = field_dict.get(key) 

651 if value is None: 

652 value = field_dict.setdefault(key, new_default(message)) 

653 while 1: 

654 (size, pos) = local_DecodeVarint(buffer, pos) 

655 new_pos = pos + size 

656 if new_pos > end: 

657 raise _DecodeError('Truncated string.') 

658 value.append(buffer[pos:new_pos].tobytes()) 

659 # Predict that the next tag is another copy of the same repeated field. 

660 pos = new_pos + tag_len 

661 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

662 # Prediction failed. Return. 

663 return new_pos 

664 

665 return DecodeRepeatedField 

666 else: 

667 

668 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

669 del current_depth # unused 

670 (size, pos) = local_DecodeVarint(buffer, pos) 

671 new_pos = pos + size 

672 if new_pos > end: 

673 raise _DecodeError('Truncated string.') 

674 if clear_if_default and IsDefaultScalarValue(size): 

675 field_dict.pop(key, None) 

676 else: 

677 field_dict[key] = buffer[pos:new_pos].tobytes() 

678 return new_pos 

679 

680 return DecodeField 

681 

682 

683def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

684 """Returns a decoder for a group field.""" 

685 

686 end_tag_bytes = encoder.TagBytes(field_number, 

687 wire_format.WIRETYPE_END_GROUP) 

688 end_tag_len = len(end_tag_bytes) 

689 

690 assert not is_packed 

691 if is_repeated: 

692 tag_bytes = encoder.TagBytes(field_number, 

693 wire_format.WIRETYPE_START_GROUP) 

694 tag_len = len(tag_bytes) 

695 def DecodeRepeatedField( 

696 buffer, pos, end, message, field_dict, current_depth=0 

697 ): 

698 value = field_dict.get(key) 

699 if value is None: 

700 value = field_dict.setdefault(key, new_default(message)) 

701 while 1: 

702 value = field_dict.get(key) 

703 if value is None: 

704 value = field_dict.setdefault(key, new_default(message)) 

705 # Read sub-message. 

706 current_depth += 1 

707 if current_depth > _recursion_limit: 

708 raise _DecodeError( 

709 'Error parsing message: too many levels of nesting.' 

710 ) 

711 pos = value.add()._InternalParse(buffer, pos, end, current_depth) 

712 current_depth -= 1 

713 # Read end tag. 

714 new_pos = pos+end_tag_len 

715 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

716 raise _DecodeError('Missing group end tag.') 

717 # Predict that the next tag is another copy of the same repeated field. 

718 pos = new_pos + tag_len 

719 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

720 # Prediction failed. Return. 

721 return new_pos 

722 

723 return DecodeRepeatedField 

724 else: 

725 

726 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

727 value = field_dict.get(key) 

728 if value is None: 

729 value = field_dict.setdefault(key, new_default(message)) 

730 # Read sub-message. 

731 current_depth += 1 

732 if current_depth > _recursion_limit: 

733 raise _DecodeError('Error parsing message: too many levels of nesting.') 

734 pos = value._InternalParse(buffer, pos, end, current_depth) 

735 current_depth -= 1 

736 # Read end tag. 

737 new_pos = pos+end_tag_len 

738 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

739 raise _DecodeError('Missing group end tag.') 

740 return new_pos 

741 

742 return DecodeField 

743 

744 

745def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

746 """Returns a decoder for a message field.""" 

747 

748 local_DecodeVarint = _DecodeVarint 

749 

750 assert not is_packed 

751 if is_repeated: 

752 tag_bytes = encoder.TagBytes(field_number, 

753 wire_format.WIRETYPE_LENGTH_DELIMITED) 

754 tag_len = len(tag_bytes) 

755 def DecodeRepeatedField( 

756 buffer, pos, end, message, field_dict, current_depth=0 

757 ): 

758 value = field_dict.get(key) 

759 if value is None: 

760 value = field_dict.setdefault(key, new_default(message)) 

761 while 1: 

762 # Read length. 

763 (size, pos) = local_DecodeVarint(buffer, pos) 

764 new_pos = pos + size 

765 if new_pos > end: 

766 raise _DecodeError('Truncated message.') 

767 # Read sub-message. 

768 current_depth += 1 

769 if current_depth > _recursion_limit: 

770 raise _DecodeError( 

771 'Error parsing message: too many levels of nesting.' 

772 ) 

773 if ( 

774 value.add()._InternalParse(buffer, pos, new_pos, current_depth) 

775 != new_pos 

776 ): 

777 # The only reason _InternalParse would return early is if it 

778 # encountered an end-group tag. 

779 raise _DecodeError('Unexpected end-group tag.') 

780 current_depth -= 1 

781 # Predict that the next tag is another copy of the same repeated field. 

782 pos = new_pos + tag_len 

783 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

784 # Prediction failed. Return. 

785 return new_pos 

786 

787 return DecodeRepeatedField 

788 else: 

789 

790 def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): 

791 value = field_dict.get(key) 

792 if value is None: 

793 value = field_dict.setdefault(key, new_default(message)) 

794 # Read length. 

795 (size, pos) = local_DecodeVarint(buffer, pos) 

796 new_pos = pos + size 

797 if new_pos > end: 

798 raise _DecodeError('Truncated message.') 

799 # Read sub-message. 

800 current_depth += 1 

801 if current_depth > _recursion_limit: 

802 raise _DecodeError('Error parsing message: too many levels of nesting.') 

803 if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: 

804 # The only reason _InternalParse would return early is if it encountered 

805 # an end-group tag. 

806 raise _DecodeError('Unexpected end-group tag.') 

807 current_depth -= 1 

808 return new_pos 

809 

810 return DecodeField 

811 

812 

813# -------------------------------------------------------------------- 

814 

815MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

816 

817def MessageSetItemDecoder(descriptor): 

818 """Returns a decoder for a MessageSet item. 

819 

820 The parameter is the message Descriptor. 

821 

822 The message set message looks like this: 

823 message MessageSet { 

824 repeated group Item = 1 { 

825 required int32 type_id = 2; 

826 required string message = 3; 

827 } 

828 } 

829 """ 

830 

831 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

832 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

833 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

834 

835 local_ReadTag = ReadTag 

836 local_DecodeVarint = _DecodeVarint 

837 

838 def DecodeItem(buffer, pos, end, message, field_dict): 

839 """Decode serialized message set to its value and new position. 

840 

841 Args: 

842 buffer: memoryview of the serialized bytes. 

843 pos: int, position in the memory view to start at. 

844 end: int, end position of serialized data 

845 message: Message object to store unknown fields in 

846 field_dict: Map[Descriptor, Any] to store decoded values in. 

847 

848 Returns: 

849 int, new position in serialized data. 

850 """ 

851 message_set_item_start = pos 

852 type_id = -1 

853 message_start = -1 

854 message_end = -1 

855 

856 # Technically, type_id and message can appear in any order, so we need 

857 # a little loop here. 

858 while 1: 

859 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

860 if tag_bytes == type_id_tag_bytes: 

861 (type_id, pos) = local_DecodeVarint(buffer, pos) 

862 elif tag_bytes == message_tag_bytes: 

863 (size, message_start) = local_DecodeVarint(buffer, pos) 

864 pos = message_end = message_start + size 

865 elif tag_bytes == item_end_tag_bytes: 

866 break 

867 else: 

868 field_number, wire_type = DecodeTag(tag_bytes) 

869 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 

870 if pos == -1: 

871 raise _DecodeError('Unexpected end-group tag.') 

872 

873 if pos > end: 

874 raise _DecodeError('Truncated message.') 

875 

876 if type_id == -1: 

877 raise _DecodeError('MessageSet item missing type_id.') 

878 if message_start == -1: 

879 raise _DecodeError('MessageSet item missing message.') 

880 

881 extension = message.Extensions._FindExtensionByNumber(type_id) 

882 # pylint: disable=protected-access 

883 if extension is not None: 

884 value = field_dict.get(extension) 

885 if value is None: 

886 message_type = extension.message_type 

887 if not hasattr(message_type, '_concrete_class'): 

888 message_factory.GetMessageClass(message_type) 

889 value = field_dict.setdefault( 

890 extension, message_type._concrete_class()) 

891 if value._InternalParse(buffer, message_start,message_end) != message_end: 

892 # The only reason _InternalParse would return early is if it encountered 

893 # an end-group tag. 

894 raise _DecodeError('Unexpected end-group tag.') 

895 else: 

896 if not message._unknown_fields: 

897 message._unknown_fields = [] 

898 message._unknown_fields.append( 

899 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

900 # pylint: enable=protected-access 

901 

902 return pos 

903 

904 return DecodeItem 

905 

906 

907def UnknownMessageSetItemDecoder(): 

908 """Returns a decoder for a Unknown MessageSet item.""" 

909 

910 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

911 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

912 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

913 

914 def DecodeUnknownItem(buffer): 

915 pos = 0 

916 end = len(buffer) 

917 message_start = -1 

918 message_end = -1 

919 while 1: 

920 (tag_bytes, pos) = ReadTag(buffer, pos) 

921 if tag_bytes == type_id_tag_bytes: 

922 (type_id, pos) = _DecodeVarint(buffer, pos) 

923 elif tag_bytes == message_tag_bytes: 

924 (size, message_start) = _DecodeVarint(buffer, pos) 

925 pos = message_end = message_start + size 

926 elif tag_bytes == item_end_tag_bytes: 

927 break 

928 else: 

929 field_number, wire_type = DecodeTag(tag_bytes) 

930 _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) 

931 if pos == -1: 

932 raise _DecodeError('Unexpected end-group tag.') 

933 

934 if pos > end: 

935 raise _DecodeError('Truncated message.') 

936 

937 if type_id == -1: 

938 raise _DecodeError('MessageSet item missing type_id.') 

939 if message_start == -1: 

940 raise _DecodeError('MessageSet item missing message.') 

941 

942 return (type_id, buffer[message_start:message_end].tobytes()) 

943 

944 return DecodeUnknownItem 

945 

946# -------------------------------------------------------------------- 

947 

948def MapDecoder(field_descriptor, new_default, is_message_map): 

949 """Returns a decoder for a map field.""" 

950 

951 key = field_descriptor 

952 tag_bytes = encoder.TagBytes(field_descriptor.number, 

953 wire_format.WIRETYPE_LENGTH_DELIMITED) 

954 tag_len = len(tag_bytes) 

955 local_DecodeVarint = _DecodeVarint 

956 # Can't read _concrete_class yet; might not be initialized. 

957 message_type = field_descriptor.message_type 

958 

959 def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0): 

960 del current_depth # Unused. 

961 submsg = message_type._concrete_class() 

962 value = field_dict.get(key) 

963 if value is None: 

964 value = field_dict.setdefault(key, new_default(message)) 

965 while 1: 

966 # Read length. 

967 (size, pos) = local_DecodeVarint(buffer, pos) 

968 new_pos = pos + size 

969 if new_pos > end: 

970 raise _DecodeError('Truncated message.') 

971 # Read sub-message. 

972 submsg.Clear() 

973 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 

974 # The only reason _InternalParse would return early is if it 

975 # encountered an end-group tag. 

976 raise _DecodeError('Unexpected end-group tag.') 

977 

978 if is_message_map: 

979 value[submsg.key].CopyFrom(submsg.value) 

980 else: 

981 value[submsg.key] = submsg.value 

982 

983 # Predict that the next tag is another copy of the same repeated field. 

984 pos = new_pos + tag_len 

985 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

986 # Prediction failed. Return. 

987 return new_pos 

988 

989 return DecodeMap 

990 

991 

992def _DecodeFixed64(buffer, pos): 

993 """Decode a fixed64.""" 

994 new_pos = pos + 8 

995 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

996 

997 

998def _DecodeFixed32(buffer, pos): 

999 """Decode a fixed32.""" 

1000 

1001 new_pos = pos + 4 

1002 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

1003DEFAULT_RECURSION_LIMIT = 100 

1004_recursion_limit = DEFAULT_RECURSION_LIMIT 

1005 

1006 

1007def SetRecursionLimit(new_limit): 

1008 global _recursion_limit 

1009 _recursion_limit = new_limit 

1010 

1011 

1012def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0): 

1013 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

1014 

1015 unknown_field_set = containers.UnknownFieldSet() 

1016 while end_pos is None or pos < end_pos: 

1017 (tag_bytes, pos) = ReadTag(buffer, pos) 

1018 (tag, _) = _DecodeVarint(tag_bytes, 0) 

1019 field_number, wire_type = wire_format.UnpackTag(tag) 

1020 if wire_type == wire_format.WIRETYPE_END_GROUP: 

1021 break 

1022 (data, pos) = _DecodeUnknownField( 

1023 buffer, pos, end_pos, field_number, wire_type, current_depth 

1024 ) 

1025 # pylint: disable=protected-access 

1026 unknown_field_set._add(field_number, wire_type, data) 

1027 

1028 return (unknown_field_set, pos) 

1029 

1030 

1031def _DecodeUnknownField( 

1032 buffer, pos, end_pos, field_number, wire_type, current_depth=0 

1033): 

1034 """Decode a unknown field. Returns the UnknownField and new position.""" 

1035 

1036 if wire_type == wire_format.WIRETYPE_VARINT: 

1037 (data, pos) = _DecodeVarint(buffer, pos) 

1038 elif wire_type == wire_format.WIRETYPE_FIXED64: 

1039 (data, pos) = _DecodeFixed64(buffer, pos) 

1040 elif wire_type == wire_format.WIRETYPE_FIXED32: 

1041 (data, pos) = _DecodeFixed32(buffer, pos) 

1042 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

1043 (size, pos) = _DecodeVarint(buffer, pos) 

1044 data = buffer[pos:pos+size].tobytes() 

1045 pos += size 

1046 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

1047 end_tag_bytes = encoder.TagBytes( 

1048 field_number, wire_format.WIRETYPE_END_GROUP 

1049 ) 

1050 current_depth += 1 

1051 if current_depth >= _recursion_limit: 

1052 raise _DecodeError('Error parsing message: too many levels of nesting.') 

1053 data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth) 

1054 current_depth -= 1 

1055 # Check end tag. 

1056 if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes: 

1057 raise _DecodeError('Missing group end tag.') 

1058 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

1059 return (0, -1) 

1060 else: 

1061 raise _DecodeError('Wrong wire type in tag.') 

1062 

1063 if pos > end_pos: 

1064 raise _DecodeError('Truncated message.') 

1065 

1066 return (data, pos)