Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 16%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

514 statements  

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Code for decoding protocol buffer primitives. 

9 

10This code is very similar to encoder.py -- read the docs for that module first. 

11 

12A "decoder" is a function with the signature: 

13 Decode(buffer, pos, end, message, field_dict) 

14The arguments are: 

15 buffer: The string containing the encoded message. 

16 pos: The current position in the string. 

17 end: The position in the string where the current message ends. May be 

18 less than len(buffer) if we're reading a sub-message. 

19 message: The message object into which we're parsing. 

20 field_dict: message._fields (avoids a hashtable lookup). 

21The decoder reads the field and stores it into field_dict, returning the new 

22buffer position. A decoder for a repeated field may proactively decode all of 

23the elements of that field, if they appear consecutively. 

24 

25Note that decoders may throw any of the following: 

26 IndexError: Indicates a truncated message. 

27 struct.error: Unpacking of a fixed-width field failed. 

28 message.DecodeError: Other errors. 

29 

30Decoders are expected to raise an exception if they are called with pos > end. 

31This allows callers to be lax about bounds checking: it's fineto read past 

32"end" as long as you are sure that someone else will notice and throw an 

33exception later on. 

34 

35Something up the call stack is expected to catch IndexError and struct.error 

36and convert them to message.DecodeError. 

37 

38Decoders are constructed using decoder constructors with the signature: 

39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

40The arguments are: 

41 field_number: The field number of the field we want to decode. 

42 is_repeated: Is the field a repeated field? (bool) 

43 is_packed: Is the field a packed field? (bool) 

44 key: The key to use when looking up the field within field_dict. 

45 (This is actually the FieldDescriptor but nothing in this 

46 file should depend on that.) 

47 new_default: A function which takes a message object as a parameter and 

48 returns a new instance of the default value for this field. 

49 (This is called for repeated fields and sub-messages, when an 

50 instance does not already exist.) 

51 

52As with encoders, we define a decoder constructor for every type of field. 

53Then, for every field of every message class we construct an actual decoder. 

54That decoder goes into a dict indexed by tag, so when we decode a message 

55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

56""" 

57 

58__author__ = 'kenton@google.com (Kenton Varda)' 

59 

60import math 

61import struct 

62 

63from google.protobuf.internal import containers 

64from google.protobuf.internal import encoder 

65from google.protobuf.internal import wire_format 

66from google.protobuf import message 

67 

68 

69# This is not for optimization, but rather to avoid conflicts with local 

70# variables named "message". 

71_DecodeError = message.DecodeError 

72 

73 

74def _VarintDecoder(mask, result_type): 

75 """Return an encoder for a basic varint value (does not include tag). 

76 

77 Decoded values will be bitwise-anded with the given mask before being 

78 returned, e.g. to limit them to 32 bits. The returned decoder does not 

79 take the usual "end" parameter -- the caller is expected to do bounds checking 

80 after the fact (often the caller can defer such checking until later). The 

81 decoder returns a (value, new_pos) pair. 

82 """ 

83 

84 def DecodeVarint(buffer, pos): 

85 result = 0 

86 shift = 0 

87 while 1: 

88 b = buffer[pos] 

89 result |= ((b & 0x7f) << shift) 

90 pos += 1 

91 if not (b & 0x80): 

92 result &= mask 

93 result = result_type(result) 

94 return (result, pos) 

95 shift += 7 

96 if shift >= 64: 

97 raise _DecodeError('Too many bytes when decoding varint.') 

98 return DecodeVarint 

99 

100 

101def _SignedVarintDecoder(bits, result_type): 

102 """Like _VarintDecoder() but decodes signed values.""" 

103 

104 signbit = 1 << (bits - 1) 

105 mask = (1 << bits) - 1 

106 

107 def DecodeVarint(buffer, pos): 

108 result = 0 

109 shift = 0 

110 while 1: 

111 b = buffer[pos] 

112 result |= ((b & 0x7f) << shift) 

113 pos += 1 

114 if not (b & 0x80): 

115 result &= mask 

116 result = (result ^ signbit) - signbit 

117 result = result_type(result) 

118 return (result, pos) 

119 shift += 7 

120 if shift >= 64: 

121 raise _DecodeError('Too many bytes when decoding varint.') 

122 return DecodeVarint 

123 

124# All 32-bit and 64-bit values are represented as int. 

125_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

126_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

127 

128# Use these versions for values which must be limited to 32 bits. 

129_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

130_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

131 

132 

133def ReadTag(buffer, pos): 

134 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

135 

136 We return the raw bytes of the tag rather than decoding them. The raw 

137 bytes can then be used to look up the proper decoder. This effectively allows 

138 us to trade some work that would be done in pure-python (decoding a varint) 

139 for work that is done in C (searching for a byte string in a hash table). 

140 In a low-level language it would be much cheaper to decode the varint and 

141 use that, but not in Python. 

142 

143 Args: 

144 buffer: memoryview object of the encoded bytes 

145 pos: int of the current position to start from 

146 

147 Returns: 

148 Tuple[bytes, int] of the tag data and new position. 

149 """ 

150 start = pos 

151 while buffer[pos] & 0x80: 

152 pos += 1 

153 pos += 1 

154 

155 tag_bytes = buffer[start:pos].tobytes() 

156 return tag_bytes, pos 

157 

158 

159# -------------------------------------------------------------------- 

160 

161 

162def _SimpleDecoder(wire_type, decode_value): 

163 """Return a constructor for a decoder for fields of a particular type. 

164 

165 Args: 

166 wire_type: The field's wire type. 

167 decode_value: A function which decodes an individual value, e.g. 

168 _DecodeVarint() 

169 """ 

170 

171 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

172 clear_if_default=False): 

173 if is_packed: 

174 local_DecodeVarint = _DecodeVarint 

175 def DecodePackedField(buffer, pos, end, message, field_dict): 

176 value = field_dict.get(key) 

177 if value is None: 

178 value = field_dict.setdefault(key, new_default(message)) 

179 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

180 endpoint += pos 

181 if endpoint > end: 

182 raise _DecodeError('Truncated message.') 

183 while pos < endpoint: 

184 (element, pos) = decode_value(buffer, pos) 

185 value.append(element) 

186 if pos > endpoint: 

187 del value[-1] # Discard corrupt value. 

188 raise _DecodeError('Packed element was truncated.') 

189 return pos 

190 return DecodePackedField 

191 elif is_repeated: 

192 tag_bytes = encoder.TagBytes(field_number, wire_type) 

193 tag_len = len(tag_bytes) 

194 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

195 value = field_dict.get(key) 

196 if value is None: 

197 value = field_dict.setdefault(key, new_default(message)) 

198 while 1: 

199 (element, new_pos) = decode_value(buffer, pos) 

200 value.append(element) 

201 # Predict that the next tag is another copy of the same repeated 

202 # field. 

203 pos = new_pos + tag_len 

204 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

205 # Prediction failed. Return. 

206 if new_pos > end: 

207 raise _DecodeError('Truncated message.') 

208 return new_pos 

209 return DecodeRepeatedField 

210 else: 

211 def DecodeField(buffer, pos, end, message, field_dict): 

212 (new_value, pos) = decode_value(buffer, pos) 

213 if pos > end: 

214 raise _DecodeError('Truncated message.') 

215 if clear_if_default and not new_value: 

216 field_dict.pop(key, None) 

217 else: 

218 field_dict[key] = new_value 

219 return pos 

220 return DecodeField 

221 

222 return SpecificDecoder 

223 

224 

225def _ModifiedDecoder(wire_type, decode_value, modify_value): 

226 """Like SimpleDecoder but additionally invokes modify_value on every value 

227 before storing it. Usually modify_value is ZigZagDecode. 

228 """ 

229 

230 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

231 # not enough to make a significant difference. 

232 

233 def InnerDecode(buffer, pos): 

234 (result, new_pos) = decode_value(buffer, pos) 

235 return (modify_value(result), new_pos) 

236 return _SimpleDecoder(wire_type, InnerDecode) 

237 

238 

239def _StructPackDecoder(wire_type, format): 

240 """Return a constructor for a decoder for a fixed-width field. 

241 

242 Args: 

243 wire_type: The field's wire type. 

244 format: The format string to pass to struct.unpack(). 

245 """ 

246 

247 value_size = struct.calcsize(format) 

248 local_unpack = struct.unpack 

249 

250 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

251 # not enough to make a significant difference. 

252 

253 # Note that we expect someone up-stack to catch struct.error and convert 

254 # it to _DecodeError -- this way we don't have to set up exception- 

255 # handling blocks every time we parse one value. 

256 

257 def InnerDecode(buffer, pos): 

258 new_pos = pos + value_size 

259 result = local_unpack(format, buffer[pos:new_pos])[0] 

260 return (result, new_pos) 

261 return _SimpleDecoder(wire_type, InnerDecode) 

262 

263 

264def _FloatDecoder(): 

265 """Returns a decoder for a float field. 

266 

267 This code works around a bug in struct.unpack for non-finite 32-bit 

268 floating-point values. 

269 """ 

270 

271 local_unpack = struct.unpack 

272 

273 def InnerDecode(buffer, pos): 

274 """Decode serialized float to a float and new position. 

275 

276 Args: 

277 buffer: memoryview of the serialized bytes 

278 pos: int, position in the memory view to start at. 

279 

280 Returns: 

281 Tuple[float, int] of the deserialized float value and new position 

282 in the serialized data. 

283 """ 

284 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

285 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

286 new_pos = pos + 4 

287 float_bytes = buffer[pos:new_pos].tobytes() 

288 

289 # If this value has all its exponent bits set, then it's non-finite. 

290 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

291 # To avoid that, we parse it specially. 

292 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

293 # If at least one significand bit is set... 

294 if float_bytes[0:3] != b'\x00\x00\x80': 

295 return (math.nan, new_pos) 

296 # If sign bit is set... 

297 if float_bytes[3:4] == b'\xFF': 

298 return (-math.inf, new_pos) 

299 return (math.inf, new_pos) 

300 

301 # Note that we expect someone up-stack to catch struct.error and convert 

302 # it to _DecodeError -- this way we don't have to set up exception- 

303 # handling blocks every time we parse one value. 

304 result = local_unpack('<f', float_bytes)[0] 

305 return (result, new_pos) 

306 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

307 

308 

309def _DoubleDecoder(): 

310 """Returns a decoder for a double field. 

311 

312 This code works around a bug in struct.unpack for not-a-number. 

313 """ 

314 

315 local_unpack = struct.unpack 

316 

317 def InnerDecode(buffer, pos): 

318 """Decode serialized double to a double and new position. 

319 

320 Args: 

321 buffer: memoryview of the serialized bytes. 

322 pos: int, position in the memory view to start at. 

323 

324 Returns: 

325 Tuple[float, int] of the decoded double value and new position 

326 in the serialized data. 

327 """ 

328 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

329 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

330 new_pos = pos + 8 

331 double_bytes = buffer[pos:new_pos].tobytes() 

332 

333 # If this value has all its exponent bits set and at least one significand 

334 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

335 # as inf or -inf. To avoid that, we treat it specially. 

336 if ((double_bytes[7:8] in b'\x7F\xFF') 

337 and (double_bytes[6:7] >= b'\xF0') 

338 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

339 return (math.nan, new_pos) 

340 

341 # Note that we expect someone up-stack to catch struct.error and convert 

342 # it to _DecodeError -- this way we don't have to set up exception- 

343 # handling blocks every time we parse one value. 

344 result = local_unpack('<d', double_bytes)[0] 

345 return (result, new_pos) 

346 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

347 

348 

349def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

350 clear_if_default=False): 

351 """Returns a decoder for enum field.""" 

352 enum_type = key.enum_type 

353 if is_packed: 

354 local_DecodeVarint = _DecodeVarint 

355 def DecodePackedField(buffer, pos, end, message, field_dict): 

356 """Decode serialized packed enum to its value and a new position. 

357 

358 Args: 

359 buffer: memoryview of the serialized bytes. 

360 pos: int, position in the memory view to start at. 

361 end: int, end position of serialized data 

362 message: Message object to store unknown fields in 

363 field_dict: Map[Descriptor, Any] to store decoded values in. 

364 

365 Returns: 

366 int, new position in serialized data. 

367 """ 

368 value = field_dict.get(key) 

369 if value is None: 

370 value = field_dict.setdefault(key, new_default(message)) 

371 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

372 endpoint += pos 

373 if endpoint > end: 

374 raise _DecodeError('Truncated message.') 

375 while pos < endpoint: 

376 value_start_pos = pos 

377 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

378 # pylint: disable=protected-access 

379 if element in enum_type.values_by_number: 

380 value.append(element) 

381 else: 

382 if not message._unknown_fields: 

383 message._unknown_fields = [] 

384 tag_bytes = encoder.TagBytes(field_number, 

385 wire_format.WIRETYPE_VARINT) 

386 

387 message._unknown_fields.append( 

388 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

389 # pylint: enable=protected-access 

390 if pos > endpoint: 

391 if element in enum_type.values_by_number: 

392 del value[-1] # Discard corrupt value. 

393 else: 

394 del message._unknown_fields[-1] 

395 # pylint: enable=protected-access 

396 raise _DecodeError('Packed element was truncated.') 

397 return pos 

398 return DecodePackedField 

399 elif is_repeated: 

400 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

401 tag_len = len(tag_bytes) 

402 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

403 """Decode serialized repeated enum to its value and a new position. 

404 

405 Args: 

406 buffer: memoryview of the serialized bytes. 

407 pos: int, position in the memory view to start at. 

408 end: int, end position of serialized data 

409 message: Message object to store unknown fields in 

410 field_dict: Map[Descriptor, Any] to store decoded values in. 

411 

412 Returns: 

413 int, new position in serialized data. 

414 """ 

415 value = field_dict.get(key) 

416 if value is None: 

417 value = field_dict.setdefault(key, new_default(message)) 

418 while 1: 

419 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

420 # pylint: disable=protected-access 

421 if element in enum_type.values_by_number: 

422 value.append(element) 

423 else: 

424 if not message._unknown_fields: 

425 message._unknown_fields = [] 

426 message._unknown_fields.append( 

427 (tag_bytes, buffer[pos:new_pos].tobytes())) 

428 # pylint: enable=protected-access 

429 # Predict that the next tag is another copy of the same repeated 

430 # field. 

431 pos = new_pos + tag_len 

432 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

433 # Prediction failed. Return. 

434 if new_pos > end: 

435 raise _DecodeError('Truncated message.') 

436 return new_pos 

437 return DecodeRepeatedField 

438 else: 

439 def DecodeField(buffer, pos, end, message, field_dict): 

440 """Decode serialized repeated enum to its value and a new position. 

441 

442 Args: 

443 buffer: memoryview of the serialized bytes. 

444 pos: int, position in the memory view to start at. 

445 end: int, end position of serialized data 

446 message: Message object to store unknown fields in 

447 field_dict: Map[Descriptor, Any] to store decoded values in. 

448 

449 Returns: 

450 int, new position in serialized data. 

451 """ 

452 value_start_pos = pos 

453 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

454 if pos > end: 

455 raise _DecodeError('Truncated message.') 

456 if clear_if_default and not enum_value: 

457 field_dict.pop(key, None) 

458 return pos 

459 # pylint: disable=protected-access 

460 if enum_value in enum_type.values_by_number: 

461 field_dict[key] = enum_value 

462 else: 

463 if not message._unknown_fields: 

464 message._unknown_fields = [] 

465 tag_bytes = encoder.TagBytes(field_number, 

466 wire_format.WIRETYPE_VARINT) 

467 message._unknown_fields.append( 

468 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

469 # pylint: enable=protected-access 

470 return pos 

471 return DecodeField 

472 

473 

474# -------------------------------------------------------------------- 

475 

476 

477Int32Decoder = _SimpleDecoder( 

478 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

479 

480Int64Decoder = _SimpleDecoder( 

481 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

482 

483UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

484UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

485 

486SInt32Decoder = _ModifiedDecoder( 

487 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

488SInt64Decoder = _ModifiedDecoder( 

489 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

490 

491# Note that Python conveniently guarantees that when using the '<' prefix on 

492# formats, they will also have the same size across all platforms (as opposed 

493# to without the prefix, where their sizes depend on the C compiler's basic 

494# type sizes). 

495Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

496Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

497SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

498SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

499FloatDecoder = _FloatDecoder() 

500DoubleDecoder = _DoubleDecoder() 

501 

502BoolDecoder = _ModifiedDecoder( 

503 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

504 

505 

506def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

507 clear_if_default=False): 

508 """Returns a decoder for a string field.""" 

509 

510 local_DecodeVarint = _DecodeVarint 

511 

512 def _ConvertToUnicode(memview): 

513 """Convert byte to unicode.""" 

514 byte_str = memview.tobytes() 

515 try: 

516 value = str(byte_str, 'utf-8') 

517 except UnicodeDecodeError as e: 

518 # add more information to the error message and re-raise it. 

519 e.reason = '%s in field: %s' % (e, key.full_name) 

520 raise 

521 

522 return value 

523 

524 assert not is_packed 

525 if is_repeated: 

526 tag_bytes = encoder.TagBytes(field_number, 

527 wire_format.WIRETYPE_LENGTH_DELIMITED) 

528 tag_len = len(tag_bytes) 

529 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

530 value = field_dict.get(key) 

531 if value is None: 

532 value = field_dict.setdefault(key, new_default(message)) 

533 while 1: 

534 (size, pos) = local_DecodeVarint(buffer, pos) 

535 new_pos = pos + size 

536 if new_pos > end: 

537 raise _DecodeError('Truncated string.') 

538 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

539 # Predict that the next tag is another copy of the same repeated field. 

540 pos = new_pos + tag_len 

541 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

542 # Prediction failed. Return. 

543 return new_pos 

544 return DecodeRepeatedField 

545 else: 

546 def DecodeField(buffer, pos, end, message, field_dict): 

547 (size, pos) = local_DecodeVarint(buffer, pos) 

548 new_pos = pos + size 

549 if new_pos > end: 

550 raise _DecodeError('Truncated string.') 

551 if clear_if_default and not size: 

552 field_dict.pop(key, None) 

553 else: 

554 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

555 return new_pos 

556 return DecodeField 

557 

558 

559def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

560 clear_if_default=False): 

561 """Returns a decoder for a bytes field.""" 

562 

563 local_DecodeVarint = _DecodeVarint 

564 

565 assert not is_packed 

566 if is_repeated: 

567 tag_bytes = encoder.TagBytes(field_number, 

568 wire_format.WIRETYPE_LENGTH_DELIMITED) 

569 tag_len = len(tag_bytes) 

570 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

571 value = field_dict.get(key) 

572 if value is None: 

573 value = field_dict.setdefault(key, new_default(message)) 

574 while 1: 

575 (size, pos) = local_DecodeVarint(buffer, pos) 

576 new_pos = pos + size 

577 if new_pos > end: 

578 raise _DecodeError('Truncated string.') 

579 value.append(buffer[pos:new_pos].tobytes()) 

580 # Predict that the next tag is another copy of the same repeated field. 

581 pos = new_pos + tag_len 

582 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

583 # Prediction failed. Return. 

584 return new_pos 

585 return DecodeRepeatedField 

586 else: 

587 def DecodeField(buffer, pos, end, message, field_dict): 

588 (size, pos) = local_DecodeVarint(buffer, pos) 

589 new_pos = pos + size 

590 if new_pos > end: 

591 raise _DecodeError('Truncated string.') 

592 if clear_if_default and not size: 

593 field_dict.pop(key, None) 

594 else: 

595 field_dict[key] = buffer[pos:new_pos].tobytes() 

596 return new_pos 

597 return DecodeField 

598 

599 

600def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

601 """Returns a decoder for a group field.""" 

602 

603 end_tag_bytes = encoder.TagBytes(field_number, 

604 wire_format.WIRETYPE_END_GROUP) 

605 end_tag_len = len(end_tag_bytes) 

606 

607 assert not is_packed 

608 if is_repeated: 

609 tag_bytes = encoder.TagBytes(field_number, 

610 wire_format.WIRETYPE_START_GROUP) 

611 tag_len = len(tag_bytes) 

612 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

613 value = field_dict.get(key) 

614 if value is None: 

615 value = field_dict.setdefault(key, new_default(message)) 

616 while 1: 

617 value = field_dict.get(key) 

618 if value is None: 

619 value = field_dict.setdefault(key, new_default(message)) 

620 # Read sub-message. 

621 pos = value.add()._InternalParse(buffer, pos, end) 

622 # Read end tag. 

623 new_pos = pos+end_tag_len 

624 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

625 raise _DecodeError('Missing group end tag.') 

626 # Predict that the next tag is another copy of the same repeated field. 

627 pos = new_pos + tag_len 

628 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

629 # Prediction failed. Return. 

630 return new_pos 

631 return DecodeRepeatedField 

632 else: 

633 def DecodeField(buffer, pos, end, message, field_dict): 

634 value = field_dict.get(key) 

635 if value is None: 

636 value = field_dict.setdefault(key, new_default(message)) 

637 # Read sub-message. 

638 pos = value._InternalParse(buffer, pos, end) 

639 # Read end tag. 

640 new_pos = pos+end_tag_len 

641 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

642 raise _DecodeError('Missing group end tag.') 

643 return new_pos 

644 return DecodeField 

645 

646 

647def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

648 """Returns a decoder for a message field.""" 

649 

650 local_DecodeVarint = _DecodeVarint 

651 

652 assert not is_packed 

653 if is_repeated: 

654 tag_bytes = encoder.TagBytes(field_number, 

655 wire_format.WIRETYPE_LENGTH_DELIMITED) 

656 tag_len = len(tag_bytes) 

657 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

658 value = field_dict.get(key) 

659 if value is None: 

660 value = field_dict.setdefault(key, new_default(message)) 

661 while 1: 

662 # Read length. 

663 (size, pos) = local_DecodeVarint(buffer, pos) 

664 new_pos = pos + size 

665 if new_pos > end: 

666 raise _DecodeError('Truncated message.') 

667 # Read sub-message. 

668 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 

669 # The only reason _InternalParse would return early is if it 

670 # encountered an end-group tag. 

671 raise _DecodeError('Unexpected end-group tag.') 

672 # Predict that the next tag is another copy of the same repeated field. 

673 pos = new_pos + tag_len 

674 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

675 # Prediction failed. Return. 

676 return new_pos 

677 return DecodeRepeatedField 

678 else: 

679 def DecodeField(buffer, pos, end, message, field_dict): 

680 value = field_dict.get(key) 

681 if value is None: 

682 value = field_dict.setdefault(key, new_default(message)) 

683 # Read length. 

684 (size, pos) = local_DecodeVarint(buffer, pos) 

685 new_pos = pos + size 

686 if new_pos > end: 

687 raise _DecodeError('Truncated message.') 

688 # Read sub-message. 

689 if value._InternalParse(buffer, pos, new_pos) != new_pos: 

690 # The only reason _InternalParse would return early is if it encountered 

691 # an end-group tag. 

692 raise _DecodeError('Unexpected end-group tag.') 

693 return new_pos 

694 return DecodeField 

695 

696 

697# -------------------------------------------------------------------- 

698 

699MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

700 

701def MessageSetItemDecoder(descriptor): 

702 """Returns a decoder for a MessageSet item. 

703 

704 The parameter is the message Descriptor. 

705 

706 The message set message looks like this: 

707 message MessageSet { 

708 repeated group Item = 1 { 

709 required int32 type_id = 2; 

710 required string message = 3; 

711 } 

712 } 

713 """ 

714 

715 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

716 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

717 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

718 

719 local_ReadTag = ReadTag 

720 local_DecodeVarint = _DecodeVarint 

721 local_SkipField = SkipField 

722 

723 def DecodeItem(buffer, pos, end, message, field_dict): 

724 """Decode serialized message set to its value and new position. 

725 

726 Args: 

727 buffer: memoryview of the serialized bytes. 

728 pos: int, position in the memory view to start at. 

729 end: int, end position of serialized data 

730 message: Message object to store unknown fields in 

731 field_dict: Map[Descriptor, Any] to store decoded values in. 

732 

733 Returns: 

734 int, new position in serialized data. 

735 """ 

736 message_set_item_start = pos 

737 type_id = -1 

738 message_start = -1 

739 message_end = -1 

740 

741 # Technically, type_id and message can appear in any order, so we need 

742 # a little loop here. 

743 while 1: 

744 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

745 if tag_bytes == type_id_tag_bytes: 

746 (type_id, pos) = local_DecodeVarint(buffer, pos) 

747 elif tag_bytes == message_tag_bytes: 

748 (size, message_start) = local_DecodeVarint(buffer, pos) 

749 pos = message_end = message_start + size 

750 elif tag_bytes == item_end_tag_bytes: 

751 break 

752 else: 

753 pos = SkipField(buffer, pos, end, tag_bytes) 

754 if pos == -1: 

755 raise _DecodeError('Missing group end tag.') 

756 

757 if pos > end: 

758 raise _DecodeError('Truncated message.') 

759 

760 if type_id == -1: 

761 raise _DecodeError('MessageSet item missing type_id.') 

762 if message_start == -1: 

763 raise _DecodeError('MessageSet item missing message.') 

764 

765 extension = message.Extensions._FindExtensionByNumber(type_id) 

766 # pylint: disable=protected-access 

767 if extension is not None: 

768 value = field_dict.get(extension) 

769 if value is None: 

770 message_type = extension.message_type 

771 if not hasattr(message_type, '_concrete_class'): 

772 message_factory.GetMessageClass(message_type) 

773 value = field_dict.setdefault( 

774 extension, message_type._concrete_class()) 

775 if value._InternalParse(buffer, message_start,message_end) != message_end: 

776 # The only reason _InternalParse would return early is if it encountered 

777 # an end-group tag. 

778 raise _DecodeError('Unexpected end-group tag.') 

779 else: 

780 if not message._unknown_fields: 

781 message._unknown_fields = [] 

782 message._unknown_fields.append( 

783 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

784 # pylint: enable=protected-access 

785 

786 return pos 

787 

788 return DecodeItem 

789 

790 

791def UnknownMessageSetItemDecoder(): 

792 """Returns a decoder for a Unknown MessageSet item.""" 

793 

794 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

795 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

796 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

797 

798 def DecodeUnknownItem(buffer): 

799 pos = 0 

800 end = len(buffer) 

801 message_start = -1 

802 message_end = -1 

803 while 1: 

804 (tag_bytes, pos) = ReadTag(buffer, pos) 

805 if tag_bytes == type_id_tag_bytes: 

806 (type_id, pos) = _DecodeVarint(buffer, pos) 

807 elif tag_bytes == message_tag_bytes: 

808 (size, message_start) = _DecodeVarint(buffer, pos) 

809 pos = message_end = message_start + size 

810 elif tag_bytes == item_end_tag_bytes: 

811 break 

812 else: 

813 pos = SkipField(buffer, pos, end, tag_bytes) 

814 if pos == -1: 

815 raise _DecodeError('Missing group end tag.') 

816 

817 if pos > end: 

818 raise _DecodeError('Truncated message.') 

819 

820 if type_id == -1: 

821 raise _DecodeError('MessageSet item missing type_id.') 

822 if message_start == -1: 

823 raise _DecodeError('MessageSet item missing message.') 

824 

825 return (type_id, buffer[message_start:message_end].tobytes()) 

826 

827 return DecodeUnknownItem 

828 

829# -------------------------------------------------------------------- 

830 

831def MapDecoder(field_descriptor, new_default, is_message_map): 

832 """Returns a decoder for a map field.""" 

833 

834 key = field_descriptor 

835 tag_bytes = encoder.TagBytes(field_descriptor.number, 

836 wire_format.WIRETYPE_LENGTH_DELIMITED) 

837 tag_len = len(tag_bytes) 

838 local_DecodeVarint = _DecodeVarint 

839 # Can't read _concrete_class yet; might not be initialized. 

840 message_type = field_descriptor.message_type 

841 

842 def DecodeMap(buffer, pos, end, message, field_dict): 

843 submsg = message_type._concrete_class() 

844 value = field_dict.get(key) 

845 if value is None: 

846 value = field_dict.setdefault(key, new_default(message)) 

847 while 1: 

848 # Read length. 

849 (size, pos) = local_DecodeVarint(buffer, pos) 

850 new_pos = pos + size 

851 if new_pos > end: 

852 raise _DecodeError('Truncated message.') 

853 # Read sub-message. 

854 submsg.Clear() 

855 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 

856 # The only reason _InternalParse would return early is if it 

857 # encountered an end-group tag. 

858 raise _DecodeError('Unexpected end-group tag.') 

859 

860 if is_message_map: 

861 value[submsg.key].CopyFrom(submsg.value) 

862 else: 

863 value[submsg.key] = submsg.value 

864 

865 # Predict that the next tag is another copy of the same repeated field. 

866 pos = new_pos + tag_len 

867 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

868 # Prediction failed. Return. 

869 return new_pos 

870 

871 return DecodeMap 

872 

873# -------------------------------------------------------------------- 

874# Optimization is not as heavy here because calls to SkipField() are rare, 

875# except for handling end-group tags. 

876 

877def _SkipVarint(buffer, pos, end): 

878 """Skip a varint value. Returns the new position.""" 

879 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 

880 # With this code, ord(b'') raises TypeError. Both are handled in 

881 # python_message.py to generate a 'Truncated message' error. 

882 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 

883 pos += 1 

884 pos += 1 

885 if pos > end: 

886 raise _DecodeError('Truncated message.') 

887 return pos 

888 

889def _SkipFixed64(buffer, pos, end): 

890 """Skip a fixed64 value. Returns the new position.""" 

891 

892 pos += 8 

893 if pos > end: 

894 raise _DecodeError('Truncated message.') 

895 return pos 

896 

897 

898def _DecodeFixed64(buffer, pos): 

899 """Decode a fixed64.""" 

900 new_pos = pos + 8 

901 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

902 

903 

904def _SkipLengthDelimited(buffer, pos, end): 

905 """Skip a length-delimited value. Returns the new position.""" 

906 

907 (size, pos) = _DecodeVarint(buffer, pos) 

908 pos += size 

909 if pos > end: 

910 raise _DecodeError('Truncated message.') 

911 return pos 

912 

913 

914def _SkipGroup(buffer, pos, end): 

915 """Skip sub-group. Returns the new position.""" 

916 

917 while 1: 

918 (tag_bytes, pos) = ReadTag(buffer, pos) 

919 new_pos = SkipField(buffer, pos, end, tag_bytes) 

920 if new_pos == -1: 

921 return pos 

922 pos = new_pos 

923 

924 

925def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 

926 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

927 

928 unknown_field_set = containers.UnknownFieldSet() 

929 while end_pos is None or pos < end_pos: 

930 (tag_bytes, pos) = ReadTag(buffer, pos) 

931 (tag, _) = _DecodeVarint(tag_bytes, 0) 

932 field_number, wire_type = wire_format.UnpackTag(tag) 

933 if wire_type == wire_format.WIRETYPE_END_GROUP: 

934 break 

935 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 

936 # pylint: disable=protected-access 

937 unknown_field_set._add(field_number, wire_type, data) 

938 

939 return (unknown_field_set, pos) 

940 

941 

942def _DecodeUnknownField(buffer, pos, wire_type): 

943 """Decode a unknown field. Returns the UnknownField and new position.""" 

944 

945 if wire_type == wire_format.WIRETYPE_VARINT: 

946 (data, pos) = _DecodeVarint(buffer, pos) 

947 elif wire_type == wire_format.WIRETYPE_FIXED64: 

948 (data, pos) = _DecodeFixed64(buffer, pos) 

949 elif wire_type == wire_format.WIRETYPE_FIXED32: 

950 (data, pos) = _DecodeFixed32(buffer, pos) 

951 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

952 (size, pos) = _DecodeVarint(buffer, pos) 

953 data = buffer[pos:pos+size].tobytes() 

954 pos += size 

955 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

956 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 

957 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

958 return (0, -1) 

959 else: 

960 raise _DecodeError('Wrong wire type in tag.') 

961 

962 return (data, pos) 

963 

964 

965def _EndGroup(buffer, pos, end): 

966 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 

967 

968 return -1 

969 

970 

971def _SkipFixed32(buffer, pos, end): 

972 """Skip a fixed32 value. Returns the new position.""" 

973 

974 pos += 4 

975 if pos > end: 

976 raise _DecodeError('Truncated message.') 

977 return pos 

978 

979 

980def _DecodeFixed32(buffer, pos): 

981 """Decode a fixed32.""" 

982 

983 new_pos = pos + 4 

984 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

985 

986 

987def _RaiseInvalidWireType(buffer, pos, end): 

988 """Skip function for unknown wire types. Raises an exception.""" 

989 

990 raise _DecodeError('Tag had invalid wire type.') 

991 

992def _FieldSkipper(): 

993 """Constructs the SkipField function.""" 

994 

995 WIRETYPE_TO_SKIPPER = [ 

996 _SkipVarint, 

997 _SkipFixed64, 

998 _SkipLengthDelimited, 

999 _SkipGroup, 

1000 _EndGroup, 

1001 _SkipFixed32, 

1002 _RaiseInvalidWireType, 

1003 _RaiseInvalidWireType, 

1004 ] 

1005 

1006 wiretype_mask = wire_format.TAG_TYPE_MASK 

1007 

1008 def SkipField(buffer, pos, end, tag_bytes): 

1009 """Skips a field with the specified tag. 

1010 

1011 |pos| should point to the byte immediately after the tag. 

1012 

1013 Returns: 

1014 The new position (after the tag value), or -1 if the tag is an end-group 

1015 tag (in which case the calling loop should break). 

1016 """ 

1017 

1018 # The wire type is always in the first byte since varints are little-endian. 

1019 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 

1020 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 

1021 

1022 return SkipField 

1023 

1024SkipField = _FieldSkipper()