Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

521 statements  

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Code for decoding protocol buffer primitives. 

9 

10This code is very similar to encoder.py -- read the docs for that module first. 

11 

12A "decoder" is a function with the signature: 

13 Decode(buffer, pos, end, message, field_dict) 

14The arguments are: 

15 buffer: The string containing the encoded message. 

16 pos: The current position in the string. 

17 end: The position in the string where the current message ends. May be 

18 less than len(buffer) if we're reading a sub-message. 

19 message: The message object into which we're parsing. 

20 field_dict: message._fields (avoids a hashtable lookup). 

21The decoder reads the field and stores it into field_dict, returning the new 

22buffer position. A decoder for a repeated field may proactively decode all of 

23the elements of that field, if they appear consecutively. 

24 

25Note that decoders may throw any of the following: 

26 IndexError: Indicates a truncated message. 

27 struct.error: Unpacking of a fixed-width field failed. 

28 message.DecodeError: Other errors. 

29 

30Decoders are expected to raise an exception if they are called with pos > end. 

31This allows callers to be lax about bounds checking: it's fineto read past 

32"end" as long as you are sure that someone else will notice and throw an 

33exception later on. 

34 

35Something up the call stack is expected to catch IndexError and struct.error 

36and convert them to message.DecodeError. 

37 

38Decoders are constructed using decoder constructors with the signature: 

39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

40The arguments are: 

41 field_number: The field number of the field we want to decode. 

42 is_repeated: Is the field a repeated field? (bool) 

43 is_packed: Is the field a packed field? (bool) 

44 key: The key to use when looking up the field within field_dict. 

45 (This is actually the FieldDescriptor but nothing in this 

46 file should depend on that.) 

47 new_default: A function which takes a message object as a parameter and 

48 returns a new instance of the default value for this field. 

49 (This is called for repeated fields and sub-messages, when an 

50 instance does not already exist.) 

51 

52As with encoders, we define a decoder constructor for every type of field. 

53Then, for every field of every message class we construct an actual decoder. 

54That decoder goes into a dict indexed by tag, so when we decode a message 

55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

56""" 

57 

58__author__ = 'kenton@google.com (Kenton Varda)' 

59 

60import math 

61import struct 

62 

63from google.protobuf import message 

64from google.protobuf.internal import containers 

65from google.protobuf.internal import encoder 

66from google.protobuf.internal import wire_format 

67 

68 

69# This is not for optimization, but rather to avoid conflicts with local 

70# variables named "message". 

71_DecodeError = message.DecodeError 

72 

73 

74def _VarintDecoder(mask, result_type): 

75 """Return an encoder for a basic varint value (does not include tag). 

76 

77 Decoded values will be bitwise-anded with the given mask before being 

78 returned, e.g. to limit them to 32 bits. The returned decoder does not 

79 take the usual "end" parameter -- the caller is expected to do bounds checking 

80 after the fact (often the caller can defer such checking until later). The 

81 decoder returns a (value, new_pos) pair. 

82 """ 

83 

84 def DecodeVarint(buffer, pos: int=None): 

85 result = 0 

86 shift = 0 

87 while 1: 

88 if pos is None: 

89 # Read from BytesIO 

90 try: 

91 b = buffer.read(1)[0] 

92 except IndexError as e: 

93 if shift == 0: 

94 # End of BytesIO. 

95 return None 

96 else: 

97 raise ValueError('Fail to read varint %s' % str(e)) 

98 else: 

99 b = buffer[pos] 

100 pos += 1 

101 result |= ((b & 0x7f) << shift) 

102 if not (b & 0x80): 

103 result &= mask 

104 result = result_type(result) 

105 return result if pos is None else (result, pos) 

106 shift += 7 

107 if shift >= 64: 

108 raise _DecodeError('Too many bytes when decoding varint.') 

109 

110 return DecodeVarint 

111 

112 

113def _SignedVarintDecoder(bits, result_type): 

114 """Like _VarintDecoder() but decodes signed values.""" 

115 

116 signbit = 1 << (bits - 1) 

117 mask = (1 << bits) - 1 

118 

119 def DecodeVarint(buffer, pos): 

120 result = 0 

121 shift = 0 

122 while 1: 

123 b = buffer[pos] 

124 result |= ((b & 0x7f) << shift) 

125 pos += 1 

126 if not (b & 0x80): 

127 result &= mask 

128 result = (result ^ signbit) - signbit 

129 result = result_type(result) 

130 return (result, pos) 

131 shift += 7 

132 if shift >= 64: 

133 raise _DecodeError('Too many bytes when decoding varint.') 

134 return DecodeVarint 

135 

136# All 32-bit and 64-bit values are represented as int. 

137_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

138_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

139 

140# Use these versions for values which must be limited to 32 bits. 

141_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

142_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

143 

144 

145def ReadTag(buffer, pos): 

146 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

147 

148 We return the raw bytes of the tag rather than decoding them. The raw 

149 bytes can then be used to look up the proper decoder. This effectively allows 

150 us to trade some work that would be done in pure-python (decoding a varint) 

151 for work that is done in C (searching for a byte string in a hash table). 

152 In a low-level language it would be much cheaper to decode the varint and 

153 use that, but not in Python. 

154 

155 Args: 

156 buffer: memoryview object of the encoded bytes 

157 pos: int of the current position to start from 

158 

159 Returns: 

160 Tuple[bytes, int] of the tag data and new position. 

161 """ 

162 start = pos 

163 while buffer[pos] & 0x80: 

164 pos += 1 

165 pos += 1 

166 

167 tag_bytes = buffer[start:pos].tobytes() 

168 return tag_bytes, pos 

169 

170 

171# -------------------------------------------------------------------- 

172 

173 

174def _SimpleDecoder(wire_type, decode_value): 

175 """Return a constructor for a decoder for fields of a particular type. 

176 

177 Args: 

178 wire_type: The field's wire type. 

179 decode_value: A function which decodes an individual value, e.g. 

180 _DecodeVarint() 

181 """ 

182 

183 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

184 clear_if_default=False): 

185 if is_packed: 

186 local_DecodeVarint = _DecodeVarint 

187 def DecodePackedField(buffer, pos, end, message, field_dict): 

188 value = field_dict.get(key) 

189 if value is None: 

190 value = field_dict.setdefault(key, new_default(message)) 

191 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

192 endpoint += pos 

193 if endpoint > end: 

194 raise _DecodeError('Truncated message.') 

195 while pos < endpoint: 

196 (element, pos) = decode_value(buffer, pos) 

197 value.append(element) 

198 if pos > endpoint: 

199 del value[-1] # Discard corrupt value. 

200 raise _DecodeError('Packed element was truncated.') 

201 return pos 

202 return DecodePackedField 

203 elif is_repeated: 

204 tag_bytes = encoder.TagBytes(field_number, wire_type) 

205 tag_len = len(tag_bytes) 

206 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

207 value = field_dict.get(key) 

208 if value is None: 

209 value = field_dict.setdefault(key, new_default(message)) 

210 while 1: 

211 (element, new_pos) = decode_value(buffer, pos) 

212 value.append(element) 

213 # Predict that the next tag is another copy of the same repeated 

214 # field. 

215 pos = new_pos + tag_len 

216 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

217 # Prediction failed. Return. 

218 if new_pos > end: 

219 raise _DecodeError('Truncated message.') 

220 return new_pos 

221 return DecodeRepeatedField 

222 else: 

223 def DecodeField(buffer, pos, end, message, field_dict): 

224 (new_value, pos) = decode_value(buffer, pos) 

225 if pos > end: 

226 raise _DecodeError('Truncated message.') 

227 if clear_if_default and not new_value: 

228 field_dict.pop(key, None) 

229 else: 

230 field_dict[key] = new_value 

231 return pos 

232 return DecodeField 

233 

234 return SpecificDecoder 

235 

236 

237def _ModifiedDecoder(wire_type, decode_value, modify_value): 

238 """Like SimpleDecoder but additionally invokes modify_value on every value 

239 before storing it. Usually modify_value is ZigZagDecode. 

240 """ 

241 

242 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

243 # not enough to make a significant difference. 

244 

245 def InnerDecode(buffer, pos): 

246 (result, new_pos) = decode_value(buffer, pos) 

247 return (modify_value(result), new_pos) 

248 return _SimpleDecoder(wire_type, InnerDecode) 

249 

250 

251def _StructPackDecoder(wire_type, format): 

252 """Return a constructor for a decoder for a fixed-width field. 

253 

254 Args: 

255 wire_type: The field's wire type. 

256 format: The format string to pass to struct.unpack(). 

257 """ 

258 

259 value_size = struct.calcsize(format) 

260 local_unpack = struct.unpack 

261 

262 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

263 # not enough to make a significant difference. 

264 

265 # Note that we expect someone up-stack to catch struct.error and convert 

266 # it to _DecodeError -- this way we don't have to set up exception- 

267 # handling blocks every time we parse one value. 

268 

269 def InnerDecode(buffer, pos): 

270 new_pos = pos + value_size 

271 result = local_unpack(format, buffer[pos:new_pos])[0] 

272 return (result, new_pos) 

273 return _SimpleDecoder(wire_type, InnerDecode) 

274 

275 

276def _FloatDecoder(): 

277 """Returns a decoder for a float field. 

278 

279 This code works around a bug in struct.unpack for non-finite 32-bit 

280 floating-point values. 

281 """ 

282 

283 local_unpack = struct.unpack 

284 

285 def InnerDecode(buffer, pos): 

286 """Decode serialized float to a float and new position. 

287 

288 Args: 

289 buffer: memoryview of the serialized bytes 

290 pos: int, position in the memory view to start at. 

291 

292 Returns: 

293 Tuple[float, int] of the deserialized float value and new position 

294 in the serialized data. 

295 """ 

296 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

297 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

298 new_pos = pos + 4 

299 float_bytes = buffer[pos:new_pos].tobytes() 

300 

301 # If this value has all its exponent bits set, then it's non-finite. 

302 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

303 # To avoid that, we parse it specially. 

304 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

305 # If at least one significand bit is set... 

306 if float_bytes[0:3] != b'\x00\x00\x80': 

307 return (math.nan, new_pos) 

308 # If sign bit is set... 

309 if float_bytes[3:4] == b'\xFF': 

310 return (-math.inf, new_pos) 

311 return (math.inf, new_pos) 

312 

313 # Note that we expect someone up-stack to catch struct.error and convert 

314 # it to _DecodeError -- this way we don't have to set up exception- 

315 # handling blocks every time we parse one value. 

316 result = local_unpack('<f', float_bytes)[0] 

317 return (result, new_pos) 

318 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

319 

320 

321def _DoubleDecoder(): 

322 """Returns a decoder for a double field. 

323 

324 This code works around a bug in struct.unpack for not-a-number. 

325 """ 

326 

327 local_unpack = struct.unpack 

328 

329 def InnerDecode(buffer, pos): 

330 """Decode serialized double to a double and new position. 

331 

332 Args: 

333 buffer: memoryview of the serialized bytes. 

334 pos: int, position in the memory view to start at. 

335 

336 Returns: 

337 Tuple[float, int] of the decoded double value and new position 

338 in the serialized data. 

339 """ 

340 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

341 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

342 new_pos = pos + 8 

343 double_bytes = buffer[pos:new_pos].tobytes() 

344 

345 # If this value has all its exponent bits set and at least one significand 

346 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

347 # as inf or -inf. To avoid that, we treat it specially. 

348 if ((double_bytes[7:8] in b'\x7F\xFF') 

349 and (double_bytes[6:7] >= b'\xF0') 

350 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

351 return (math.nan, new_pos) 

352 

353 # Note that we expect someone up-stack to catch struct.error and convert 

354 # it to _DecodeError -- this way we don't have to set up exception- 

355 # handling blocks every time we parse one value. 

356 result = local_unpack('<d', double_bytes)[0] 

357 return (result, new_pos) 

358 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

359 

360 

361def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

362 clear_if_default=False): 

363 """Returns a decoder for enum field.""" 

364 enum_type = key.enum_type 

365 if is_packed: 

366 local_DecodeVarint = _DecodeVarint 

367 def DecodePackedField(buffer, pos, end, message, field_dict): 

368 """Decode serialized packed enum to its value and a new position. 

369 

370 Args: 

371 buffer: memoryview of the serialized bytes. 

372 pos: int, position in the memory view to start at. 

373 end: int, end position of serialized data 

374 message: Message object to store unknown fields in 

375 field_dict: Map[Descriptor, Any] to store decoded values in. 

376 

377 Returns: 

378 int, new position in serialized data. 

379 """ 

380 value = field_dict.get(key) 

381 if value is None: 

382 value = field_dict.setdefault(key, new_default(message)) 

383 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

384 endpoint += pos 

385 if endpoint > end: 

386 raise _DecodeError('Truncated message.') 

387 while pos < endpoint: 

388 value_start_pos = pos 

389 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

390 # pylint: disable=protected-access 

391 if element in enum_type.values_by_number: 

392 value.append(element) 

393 else: 

394 if not message._unknown_fields: 

395 message._unknown_fields = [] 

396 tag_bytes = encoder.TagBytes(field_number, 

397 wire_format.WIRETYPE_VARINT) 

398 

399 message._unknown_fields.append( 

400 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

401 # pylint: enable=protected-access 

402 if pos > endpoint: 

403 if element in enum_type.values_by_number: 

404 del value[-1] # Discard corrupt value. 

405 else: 

406 del message._unknown_fields[-1] 

407 # pylint: enable=protected-access 

408 raise _DecodeError('Packed element was truncated.') 

409 return pos 

410 return DecodePackedField 

411 elif is_repeated: 

412 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

413 tag_len = len(tag_bytes) 

414 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

415 """Decode serialized repeated enum to its value and a new position. 

416 

417 Args: 

418 buffer: memoryview of the serialized bytes. 

419 pos: int, position in the memory view to start at. 

420 end: int, end position of serialized data 

421 message: Message object to store unknown fields in 

422 field_dict: Map[Descriptor, Any] to store decoded values in. 

423 

424 Returns: 

425 int, new position in serialized data. 

426 """ 

427 value = field_dict.get(key) 

428 if value is None: 

429 value = field_dict.setdefault(key, new_default(message)) 

430 while 1: 

431 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

432 # pylint: disable=protected-access 

433 if element in enum_type.values_by_number: 

434 value.append(element) 

435 else: 

436 if not message._unknown_fields: 

437 message._unknown_fields = [] 

438 message._unknown_fields.append( 

439 (tag_bytes, buffer[pos:new_pos].tobytes())) 

440 # pylint: enable=protected-access 

441 # Predict that the next tag is another copy of the same repeated 

442 # field. 

443 pos = new_pos + tag_len 

444 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

445 # Prediction failed. Return. 

446 if new_pos > end: 

447 raise _DecodeError('Truncated message.') 

448 return new_pos 

449 return DecodeRepeatedField 

450 else: 

451 def DecodeField(buffer, pos, end, message, field_dict): 

452 """Decode serialized repeated enum to its value and a new position. 

453 

454 Args: 

455 buffer: memoryview of the serialized bytes. 

456 pos: int, position in the memory view to start at. 

457 end: int, end position of serialized data 

458 message: Message object to store unknown fields in 

459 field_dict: Map[Descriptor, Any] to store decoded values in. 

460 

461 Returns: 

462 int, new position in serialized data. 

463 """ 

464 value_start_pos = pos 

465 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

466 if pos > end: 

467 raise _DecodeError('Truncated message.') 

468 if clear_if_default and not enum_value: 

469 field_dict.pop(key, None) 

470 return pos 

471 # pylint: disable=protected-access 

472 if enum_value in enum_type.values_by_number: 

473 field_dict[key] = enum_value 

474 else: 

475 if not message._unknown_fields: 

476 message._unknown_fields = [] 

477 tag_bytes = encoder.TagBytes(field_number, 

478 wire_format.WIRETYPE_VARINT) 

479 message._unknown_fields.append( 

480 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

481 # pylint: enable=protected-access 

482 return pos 

483 return DecodeField 

484 

485 

486# -------------------------------------------------------------------- 

487 

488 

489Int32Decoder = _SimpleDecoder( 

490 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

491 

492Int64Decoder = _SimpleDecoder( 

493 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

494 

495UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

496UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

497 

498SInt32Decoder = _ModifiedDecoder( 

499 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

500SInt64Decoder = _ModifiedDecoder( 

501 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

502 

503# Note that Python conveniently guarantees that when using the '<' prefix on 

504# formats, they will also have the same size across all platforms (as opposed 

505# to without the prefix, where their sizes depend on the C compiler's basic 

506# type sizes). 

507Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

508Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

509SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

510SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

511FloatDecoder = _FloatDecoder() 

512DoubleDecoder = _DoubleDecoder() 

513 

514BoolDecoder = _ModifiedDecoder( 

515 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

516 

517 

518def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

519 clear_if_default=False): 

520 """Returns a decoder for a string field.""" 

521 

522 local_DecodeVarint = _DecodeVarint 

523 

524 def _ConvertToUnicode(memview): 

525 """Convert byte to unicode.""" 

526 byte_str = memview.tobytes() 

527 try: 

528 value = str(byte_str, 'utf-8') 

529 except UnicodeDecodeError as e: 

530 # add more information to the error message and re-raise it. 

531 e.reason = '%s in field: %s' % (e, key.full_name) 

532 raise 

533 

534 return value 

535 

536 assert not is_packed 

537 if is_repeated: 

538 tag_bytes = encoder.TagBytes(field_number, 

539 wire_format.WIRETYPE_LENGTH_DELIMITED) 

540 tag_len = len(tag_bytes) 

541 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

542 value = field_dict.get(key) 

543 if value is None: 

544 value = field_dict.setdefault(key, new_default(message)) 

545 while 1: 

546 (size, pos) = local_DecodeVarint(buffer, pos) 

547 new_pos = pos + size 

548 if new_pos > end: 

549 raise _DecodeError('Truncated string.') 

550 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

551 # Predict that the next tag is another copy of the same repeated field. 

552 pos = new_pos + tag_len 

553 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

554 # Prediction failed. Return. 

555 return new_pos 

556 return DecodeRepeatedField 

557 else: 

558 def DecodeField(buffer, pos, end, message, field_dict): 

559 (size, pos) = local_DecodeVarint(buffer, pos) 

560 new_pos = pos + size 

561 if new_pos > end: 

562 raise _DecodeError('Truncated string.') 

563 if clear_if_default and not size: 

564 field_dict.pop(key, None) 

565 else: 

566 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

567 return new_pos 

568 return DecodeField 

569 

570 

571def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

572 clear_if_default=False): 

573 """Returns a decoder for a bytes field.""" 

574 

575 local_DecodeVarint = _DecodeVarint 

576 

577 assert not is_packed 

578 if is_repeated: 

579 tag_bytes = encoder.TagBytes(field_number, 

580 wire_format.WIRETYPE_LENGTH_DELIMITED) 

581 tag_len = len(tag_bytes) 

582 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

583 value = field_dict.get(key) 

584 if value is None: 

585 value = field_dict.setdefault(key, new_default(message)) 

586 while 1: 

587 (size, pos) = local_DecodeVarint(buffer, pos) 

588 new_pos = pos + size 

589 if new_pos > end: 

590 raise _DecodeError('Truncated string.') 

591 value.append(buffer[pos:new_pos].tobytes()) 

592 # Predict that the next tag is another copy of the same repeated field. 

593 pos = new_pos + tag_len 

594 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

595 # Prediction failed. Return. 

596 return new_pos 

597 return DecodeRepeatedField 

598 else: 

599 def DecodeField(buffer, pos, end, message, field_dict): 

600 (size, pos) = local_DecodeVarint(buffer, pos) 

601 new_pos = pos + size 

602 if new_pos > end: 

603 raise _DecodeError('Truncated string.') 

604 if clear_if_default and not size: 

605 field_dict.pop(key, None) 

606 else: 

607 field_dict[key] = buffer[pos:new_pos].tobytes() 

608 return new_pos 

609 return DecodeField 

610 

611 

612def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

613 """Returns a decoder for a group field.""" 

614 

615 end_tag_bytes = encoder.TagBytes(field_number, 

616 wire_format.WIRETYPE_END_GROUP) 

617 end_tag_len = len(end_tag_bytes) 

618 

619 assert not is_packed 

620 if is_repeated: 

621 tag_bytes = encoder.TagBytes(field_number, 

622 wire_format.WIRETYPE_START_GROUP) 

623 tag_len = len(tag_bytes) 

624 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

625 value = field_dict.get(key) 

626 if value is None: 

627 value = field_dict.setdefault(key, new_default(message)) 

628 while 1: 

629 value = field_dict.get(key) 

630 if value is None: 

631 value = field_dict.setdefault(key, new_default(message)) 

632 # Read sub-message. 

633 pos = value.add()._InternalParse(buffer, pos, end) 

634 # Read end tag. 

635 new_pos = pos+end_tag_len 

636 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

637 raise _DecodeError('Missing group end tag.') 

638 # Predict that the next tag is another copy of the same repeated field. 

639 pos = new_pos + tag_len 

640 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

641 # Prediction failed. Return. 

642 return new_pos 

643 return DecodeRepeatedField 

644 else: 

645 def DecodeField(buffer, pos, end, message, field_dict): 

646 value = field_dict.get(key) 

647 if value is None: 

648 value = field_dict.setdefault(key, new_default(message)) 

649 # Read sub-message. 

650 pos = value._InternalParse(buffer, pos, end) 

651 # Read end tag. 

652 new_pos = pos+end_tag_len 

653 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

654 raise _DecodeError('Missing group end tag.') 

655 return new_pos 

656 return DecodeField 

657 

658 

659def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

660 """Returns a decoder for a message field.""" 

661 

662 local_DecodeVarint = _DecodeVarint 

663 

664 assert not is_packed 

665 if is_repeated: 

666 tag_bytes = encoder.TagBytes(field_number, 

667 wire_format.WIRETYPE_LENGTH_DELIMITED) 

668 tag_len = len(tag_bytes) 

669 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

670 value = field_dict.get(key) 

671 if value is None: 

672 value = field_dict.setdefault(key, new_default(message)) 

673 while 1: 

674 # Read length. 

675 (size, pos) = local_DecodeVarint(buffer, pos) 

676 new_pos = pos + size 

677 if new_pos > end: 

678 raise _DecodeError('Truncated message.') 

679 # Read sub-message. 

680 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 

681 # The only reason _InternalParse would return early is if it 

682 # encountered an end-group tag. 

683 raise _DecodeError('Unexpected end-group tag.') 

684 # Predict that the next tag is another copy of the same repeated field. 

685 pos = new_pos + tag_len 

686 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

687 # Prediction failed. Return. 

688 return new_pos 

689 return DecodeRepeatedField 

690 else: 

691 def DecodeField(buffer, pos, end, message, field_dict): 

692 value = field_dict.get(key) 

693 if value is None: 

694 value = field_dict.setdefault(key, new_default(message)) 

695 # Read length. 

696 (size, pos) = local_DecodeVarint(buffer, pos) 

697 new_pos = pos + size 

698 if new_pos > end: 

699 raise _DecodeError('Truncated message.') 

700 # Read sub-message. 

701 if value._InternalParse(buffer, pos, new_pos) != new_pos: 

702 # The only reason _InternalParse would return early is if it encountered 

703 # an end-group tag. 

704 raise _DecodeError('Unexpected end-group tag.') 

705 return new_pos 

706 return DecodeField 

707 

708 

709# -------------------------------------------------------------------- 

710 

711MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

712 

713def MessageSetItemDecoder(descriptor): 

714 """Returns a decoder for a MessageSet item. 

715 

716 The parameter is the message Descriptor. 

717 

718 The message set message looks like this: 

719 message MessageSet { 

720 repeated group Item = 1 { 

721 required int32 type_id = 2; 

722 required string message = 3; 

723 } 

724 } 

725 """ 

726 

727 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

728 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

729 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

730 

731 local_ReadTag = ReadTag 

732 local_DecodeVarint = _DecodeVarint 

733 local_SkipField = SkipField 

734 

735 def DecodeItem(buffer, pos, end, message, field_dict): 

736 """Decode serialized message set to its value and new position. 

737 

738 Args: 

739 buffer: memoryview of the serialized bytes. 

740 pos: int, position in the memory view to start at. 

741 end: int, end position of serialized data 

742 message: Message object to store unknown fields in 

743 field_dict: Map[Descriptor, Any] to store decoded values in. 

744 

745 Returns: 

746 int, new position in serialized data. 

747 """ 

748 message_set_item_start = pos 

749 type_id = -1 

750 message_start = -1 

751 message_end = -1 

752 

753 # Technically, type_id and message can appear in any order, so we need 

754 # a little loop here. 

755 while 1: 

756 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

757 if tag_bytes == type_id_tag_bytes: 

758 (type_id, pos) = local_DecodeVarint(buffer, pos) 

759 elif tag_bytes == message_tag_bytes: 

760 (size, message_start) = local_DecodeVarint(buffer, pos) 

761 pos = message_end = message_start + size 

762 elif tag_bytes == item_end_tag_bytes: 

763 break 

764 else: 

765 pos = SkipField(buffer, pos, end, tag_bytes) 

766 if pos == -1: 

767 raise _DecodeError('Missing group end tag.') 

768 

769 if pos > end: 

770 raise _DecodeError('Truncated message.') 

771 

772 if type_id == -1: 

773 raise _DecodeError('MessageSet item missing type_id.') 

774 if message_start == -1: 

775 raise _DecodeError('MessageSet item missing message.') 

776 

777 extension = message.Extensions._FindExtensionByNumber(type_id) 

778 # pylint: disable=protected-access 

779 if extension is not None: 

780 value = field_dict.get(extension) 

781 if value is None: 

782 message_type = extension.message_type 

783 if not hasattr(message_type, '_concrete_class'): 

784 message_factory.GetMessageClass(message_type) 

785 value = field_dict.setdefault( 

786 extension, message_type._concrete_class()) 

787 if value._InternalParse(buffer, message_start,message_end) != message_end: 

788 # The only reason _InternalParse would return early is if it encountered 

789 # an end-group tag. 

790 raise _DecodeError('Unexpected end-group tag.') 

791 else: 

792 if not message._unknown_fields: 

793 message._unknown_fields = [] 

794 message._unknown_fields.append( 

795 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

796 # pylint: enable=protected-access 

797 

798 return pos 

799 

800 return DecodeItem 

801 

802 

803def UnknownMessageSetItemDecoder(): 

804 """Returns a decoder for a Unknown MessageSet item.""" 

805 

806 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

807 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

808 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

809 

810 def DecodeUnknownItem(buffer): 

811 pos = 0 

812 end = len(buffer) 

813 message_start = -1 

814 message_end = -1 

815 while 1: 

816 (tag_bytes, pos) = ReadTag(buffer, pos) 

817 if tag_bytes == type_id_tag_bytes: 

818 (type_id, pos) = _DecodeVarint(buffer, pos) 

819 elif tag_bytes == message_tag_bytes: 

820 (size, message_start) = _DecodeVarint(buffer, pos) 

821 pos = message_end = message_start + size 

822 elif tag_bytes == item_end_tag_bytes: 

823 break 

824 else: 

825 pos = SkipField(buffer, pos, end, tag_bytes) 

826 if pos == -1: 

827 raise _DecodeError('Missing group end tag.') 

828 

829 if pos > end: 

830 raise _DecodeError('Truncated message.') 

831 

832 if type_id == -1: 

833 raise _DecodeError('MessageSet item missing type_id.') 

834 if message_start == -1: 

835 raise _DecodeError('MessageSet item missing message.') 

836 

837 return (type_id, buffer[message_start:message_end].tobytes()) 

838 

839 return DecodeUnknownItem 

840 

841# -------------------------------------------------------------------- 

842 

843def MapDecoder(field_descriptor, new_default, is_message_map): 

844 """Returns a decoder for a map field.""" 

845 

846 key = field_descriptor 

847 tag_bytes = encoder.TagBytes(field_descriptor.number, 

848 wire_format.WIRETYPE_LENGTH_DELIMITED) 

849 tag_len = len(tag_bytes) 

850 local_DecodeVarint = _DecodeVarint 

851 # Can't read _concrete_class yet; might not be initialized. 

852 message_type = field_descriptor.message_type 

853 

854 def DecodeMap(buffer, pos, end, message, field_dict): 

855 submsg = message_type._concrete_class() 

856 value = field_dict.get(key) 

857 if value is None: 

858 value = field_dict.setdefault(key, new_default(message)) 

859 while 1: 

860 # Read length. 

861 (size, pos) = local_DecodeVarint(buffer, pos) 

862 new_pos = pos + size 

863 if new_pos > end: 

864 raise _DecodeError('Truncated message.') 

865 # Read sub-message. 

866 submsg.Clear() 

867 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 

868 # The only reason _InternalParse would return early is if it 

869 # encountered an end-group tag. 

870 raise _DecodeError('Unexpected end-group tag.') 

871 

872 if is_message_map: 

873 value[submsg.key].CopyFrom(submsg.value) 

874 else: 

875 value[submsg.key] = submsg.value 

876 

877 # Predict that the next tag is another copy of the same repeated field. 

878 pos = new_pos + tag_len 

879 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

880 # Prediction failed. Return. 

881 return new_pos 

882 

883 return DecodeMap 

884 

885# -------------------------------------------------------------------- 

886# Optimization is not as heavy here because calls to SkipField() are rare, 

887# except for handling end-group tags. 

888 

889def _SkipVarint(buffer, pos, end): 

890 """Skip a varint value. Returns the new position.""" 

891 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 

892 # With this code, ord(b'') raises TypeError. Both are handled in 

893 # python_message.py to generate a 'Truncated message' error. 

894 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 

895 pos += 1 

896 pos += 1 

897 if pos > end: 

898 raise _DecodeError('Truncated message.') 

899 return pos 

900 

901def _SkipFixed64(buffer, pos, end): 

902 """Skip a fixed64 value. Returns the new position.""" 

903 

904 pos += 8 

905 if pos > end: 

906 raise _DecodeError('Truncated message.') 

907 return pos 

908 

909 

910def _DecodeFixed64(buffer, pos): 

911 """Decode a fixed64.""" 

912 new_pos = pos + 8 

913 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

914 

915 

916def _SkipLengthDelimited(buffer, pos, end): 

917 """Skip a length-delimited value. Returns the new position.""" 

918 

919 (size, pos) = _DecodeVarint(buffer, pos) 

920 pos += size 

921 if pos > end: 

922 raise _DecodeError('Truncated message.') 

923 return pos 

924 

925 

926def _SkipGroup(buffer, pos, end): 

927 """Skip sub-group. Returns the new position.""" 

928 

929 while 1: 

930 (tag_bytes, pos) = ReadTag(buffer, pos) 

931 new_pos = SkipField(buffer, pos, end, tag_bytes) 

932 if new_pos == -1: 

933 return pos 

934 pos = new_pos 

935 

936 

937def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 

938 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

939 

940 unknown_field_set = containers.UnknownFieldSet() 

941 while end_pos is None or pos < end_pos: 

942 (tag_bytes, pos) = ReadTag(buffer, pos) 

943 (tag, _) = _DecodeVarint(tag_bytes, 0) 

944 field_number, wire_type = wire_format.UnpackTag(tag) 

945 if wire_type == wire_format.WIRETYPE_END_GROUP: 

946 break 

947 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 

948 # pylint: disable=protected-access 

949 unknown_field_set._add(field_number, wire_type, data) 

950 

951 return (unknown_field_set, pos) 

952 

953 

954def _DecodeUnknownField(buffer, pos, wire_type): 

955 """Decode a unknown field. Returns the UnknownField and new position.""" 

956 

957 if wire_type == wire_format.WIRETYPE_VARINT: 

958 (data, pos) = _DecodeVarint(buffer, pos) 

959 elif wire_type == wire_format.WIRETYPE_FIXED64: 

960 (data, pos) = _DecodeFixed64(buffer, pos) 

961 elif wire_type == wire_format.WIRETYPE_FIXED32: 

962 (data, pos) = _DecodeFixed32(buffer, pos) 

963 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

964 (size, pos) = _DecodeVarint(buffer, pos) 

965 data = buffer[pos:pos+size].tobytes() 

966 pos += size 

967 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

968 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 

969 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

970 return (0, -1) 

971 else: 

972 raise _DecodeError('Wrong wire type in tag.') 

973 

974 return (data, pos) 

975 

976 

977def _EndGroup(buffer, pos, end): 

978 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 

979 

980 return -1 

981 

982 

983def _SkipFixed32(buffer, pos, end): 

984 """Skip a fixed32 value. Returns the new position.""" 

985 

986 pos += 4 

987 if pos > end: 

988 raise _DecodeError('Truncated message.') 

989 return pos 

990 

991 

992def _DecodeFixed32(buffer, pos): 

993 """Decode a fixed32.""" 

994 

995 new_pos = pos + 4 

996 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

997 

998 

999def _RaiseInvalidWireType(buffer, pos, end): 

1000 """Skip function for unknown wire types. Raises an exception.""" 

1001 

1002 raise _DecodeError('Tag had invalid wire type.') 

1003 

1004def _FieldSkipper(): 

1005 """Constructs the SkipField function.""" 

1006 

1007 WIRETYPE_TO_SKIPPER = [ 

1008 _SkipVarint, 

1009 _SkipFixed64, 

1010 _SkipLengthDelimited, 

1011 _SkipGroup, 

1012 _EndGroup, 

1013 _SkipFixed32, 

1014 _RaiseInvalidWireType, 

1015 _RaiseInvalidWireType, 

1016 ] 

1017 

1018 wiretype_mask = wire_format.TAG_TYPE_MASK 

1019 

1020 def SkipField(buffer, pos, end, tag_bytes): 

1021 """Skips a field with the specified tag. 

1022 

1023 |pos| should point to the byte immediately after the tag. 

1024 

1025 Returns: 

1026 The new position (after the tag value), or -1 if the tag is an end-group 

1027 tag (in which case the calling loop should break). 

1028 """ 

1029 

1030 # The wire type is always in the first byte since varints are little-endian. 

1031 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 

1032 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 

1033 

1034 return SkipField 

1035 

1036SkipField = _FieldSkipper()