Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 15%

527 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Code for decoding protocol buffer primitives. 

9 

10This code is very similar to encoder.py -- read the docs for that module first. 

11 

12A "decoder" is a function with the signature: 

13 Decode(buffer, pos, end, message, field_dict) 

14The arguments are: 

15 buffer: The string containing the encoded message. 

16 pos: The current position in the string. 

17 end: The position in the string where the current message ends. May be 

18 less than len(buffer) if we're reading a sub-message. 

19 message: The message object into which we're parsing. 

20 field_dict: message._fields (avoids a hashtable lookup). 

21The decoder reads the field and stores it into field_dict, returning the new 

22buffer position. A decoder for a repeated field may proactively decode all of 

23the elements of that field, if they appear consecutively. 

24 

25Note that decoders may throw any of the following: 

26 IndexError: Indicates a truncated message. 

27 struct.error: Unpacking of a fixed-width field failed. 

28 message.DecodeError: Other errors. 

29 

30Decoders are expected to raise an exception if they are called with pos > end. 

31This allows callers to be lax about bounds checking: it's fineto read past 

32"end" as long as you are sure that someone else will notice and throw an 

33exception later on. 

34 

35Something up the call stack is expected to catch IndexError and struct.error 

36and convert them to message.DecodeError. 

37 

38Decoders are constructed using decoder constructors with the signature: 

39 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

40The arguments are: 

41 field_number: The field number of the field we want to decode. 

42 is_repeated: Is the field a repeated field? (bool) 

43 is_packed: Is the field a packed field? (bool) 

44 key: The key to use when looking up the field within field_dict. 

45 (This is actually the FieldDescriptor but nothing in this 

46 file should depend on that.) 

47 new_default: A function which takes a message object as a parameter and 

48 returns a new instance of the default value for this field. 

49 (This is called for repeated fields and sub-messages, when an 

50 instance does not already exist.) 

51 

52As with encoders, we define a decoder constructor for every type of field. 

53Then, for every field of every message class we construct an actual decoder. 

54That decoder goes into a dict indexed by tag, so when we decode a message 

55we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

56""" 

57 

58__author__ = 'kenton@google.com (Kenton Varda)' 

59 

60import math 

61import struct 

62 

63from google.protobuf.internal import containers 

64from google.protobuf.internal import encoder 

65from google.protobuf.internal import wire_format 

66from google.protobuf import message 

67 

68 

69# This is not for optimization, but rather to avoid conflicts with local 

70# variables named "message". 

71_DecodeError = message.DecodeError 

72 

73 

74def _VarintDecoder(mask, result_type): 

75 """Return an encoder for a basic varint value (does not include tag). 

76 

77 Decoded values will be bitwise-anded with the given mask before being 

78 returned, e.g. to limit them to 32 bits. The returned decoder does not 

79 take the usual "end" parameter -- the caller is expected to do bounds checking 

80 after the fact (often the caller can defer such checking until later). The 

81 decoder returns a (value, new_pos) pair. 

82 """ 

83 

84 def DecodeVarint(buffer, pos): 

85 result = 0 

86 shift = 0 

87 while 1: 

88 b = buffer[pos] 

89 result |= ((b & 0x7f) << shift) 

90 pos += 1 

91 if not (b & 0x80): 

92 result &= mask 

93 result = result_type(result) 

94 return (result, pos) 

95 shift += 7 

96 if shift >= 64: 

97 raise _DecodeError('Too many bytes when decoding varint.') 

98 return DecodeVarint 

99 

100 

101def _SignedVarintDecoder(bits, result_type): 

102 """Like _VarintDecoder() but decodes signed values.""" 

103 

104 signbit = 1 << (bits - 1) 

105 mask = (1 << bits) - 1 

106 

107 def DecodeVarint(buffer, pos): 

108 result = 0 

109 shift = 0 

110 while 1: 

111 b = buffer[pos] 

112 result |= ((b & 0x7f) << shift) 

113 pos += 1 

114 if not (b & 0x80): 

115 result &= mask 

116 result = (result ^ signbit) - signbit 

117 result = result_type(result) 

118 return (result, pos) 

119 shift += 7 

120 if shift >= 64: 

121 raise _DecodeError('Too many bytes when decoding varint.') 

122 return DecodeVarint 

123 

124# All 32-bit and 64-bit values are represented as int. 

125_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

126_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

127 

128# Use these versions for values which must be limited to 32 bits. 

129_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

130_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

131 

132 

133def ReadTag(buffer, pos): 

134 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

135 

136 We return the raw bytes of the tag rather than decoding them. The raw 

137 bytes can then be used to look up the proper decoder. This effectively allows 

138 us to trade some work that would be done in pure-python (decoding a varint) 

139 for work that is done in C (searching for a byte string in a hash table). 

140 In a low-level language it would be much cheaper to decode the varint and 

141 use that, but not in Python. 

142 

143 Args: 

144 buffer: memoryview object of the encoded bytes 

145 pos: int of the current position to start from 

146 

147 Returns: 

148 Tuple[bytes, int] of the tag data and new position. 

149 """ 

150 start = pos 

151 while buffer[pos] & 0x80: 

152 pos += 1 

153 pos += 1 

154 

155 tag_bytes = buffer[start:pos].tobytes() 

156 return tag_bytes, pos 

157 

158 

159# -------------------------------------------------------------------- 

160 

161 

162def _SimpleDecoder(wire_type, decode_value): 

163 """Return a constructor for a decoder for fields of a particular type. 

164 

165 Args: 

166 wire_type: The field's wire type. 

167 decode_value: A function which decodes an individual value, e.g. 

168 _DecodeVarint() 

169 """ 

170 

171 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

172 clear_if_default=False): 

173 if is_packed: 

174 local_DecodeVarint = _DecodeVarint 

175 def DecodePackedField(buffer, pos, end, message, field_dict): 

176 value = field_dict.get(key) 

177 if value is None: 

178 value = field_dict.setdefault(key, new_default(message)) 

179 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

180 endpoint += pos 

181 if endpoint > end: 

182 raise _DecodeError('Truncated message.') 

183 while pos < endpoint: 

184 (element, pos) = decode_value(buffer, pos) 

185 value.append(element) 

186 if pos > endpoint: 

187 del value[-1] # Discard corrupt value. 

188 raise _DecodeError('Packed element was truncated.') 

189 return pos 

190 return DecodePackedField 

191 elif is_repeated: 

192 tag_bytes = encoder.TagBytes(field_number, wire_type) 

193 tag_len = len(tag_bytes) 

194 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

195 value = field_dict.get(key) 

196 if value is None: 

197 value = field_dict.setdefault(key, new_default(message)) 

198 while 1: 

199 (element, new_pos) = decode_value(buffer, pos) 

200 value.append(element) 

201 # Predict that the next tag is another copy of the same repeated 

202 # field. 

203 pos = new_pos + tag_len 

204 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

205 # Prediction failed. Return. 

206 if new_pos > end: 

207 raise _DecodeError('Truncated message.') 

208 return new_pos 

209 return DecodeRepeatedField 

210 else: 

211 def DecodeField(buffer, pos, end, message, field_dict): 

212 (new_value, pos) = decode_value(buffer, pos) 

213 if pos > end: 

214 raise _DecodeError('Truncated message.') 

215 if clear_if_default and not new_value: 

216 field_dict.pop(key, None) 

217 else: 

218 field_dict[key] = new_value 

219 return pos 

220 return DecodeField 

221 

222 return SpecificDecoder 

223 

224 

225def _ModifiedDecoder(wire_type, decode_value, modify_value): 

226 """Like SimpleDecoder but additionally invokes modify_value on every value 

227 before storing it. Usually modify_value is ZigZagDecode. 

228 """ 

229 

230 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

231 # not enough to make a significant difference. 

232 

233 def InnerDecode(buffer, pos): 

234 (result, new_pos) = decode_value(buffer, pos) 

235 return (modify_value(result), new_pos) 

236 return _SimpleDecoder(wire_type, InnerDecode) 

237 

238 

239def _StructPackDecoder(wire_type, format): 

240 """Return a constructor for a decoder for a fixed-width field. 

241 

242 Args: 

243 wire_type: The field's wire type. 

244 format: The format string to pass to struct.unpack(). 

245 """ 

246 

247 value_size = struct.calcsize(format) 

248 local_unpack = struct.unpack 

249 

250 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

251 # not enough to make a significant difference. 

252 

253 # Note that we expect someone up-stack to catch struct.error and convert 

254 # it to _DecodeError -- this way we don't have to set up exception- 

255 # handling blocks every time we parse one value. 

256 

257 def InnerDecode(buffer, pos): 

258 new_pos = pos + value_size 

259 result = local_unpack(format, buffer[pos:new_pos])[0] 

260 return (result, new_pos) 

261 return _SimpleDecoder(wire_type, InnerDecode) 

262 

263 

264def _FloatDecoder(): 

265 """Returns a decoder for a float field. 

266 

267 This code works around a bug in struct.unpack for non-finite 32-bit 

268 floating-point values. 

269 """ 

270 

271 local_unpack = struct.unpack 

272 

273 def InnerDecode(buffer, pos): 

274 """Decode serialized float to a float and new position. 

275 

276 Args: 

277 buffer: memoryview of the serialized bytes 

278 pos: int, position in the memory view to start at. 

279 

280 Returns: 

281 Tuple[float, int] of the deserialized float value and new position 

282 in the serialized data. 

283 """ 

284 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

285 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

286 new_pos = pos + 4 

287 float_bytes = buffer[pos:new_pos].tobytes() 

288 

289 # If this value has all its exponent bits set, then it's non-finite. 

290 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

291 # To avoid that, we parse it specially. 

292 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

293 # If at least one significand bit is set... 

294 if float_bytes[0:3] != b'\x00\x00\x80': 

295 return (math.nan, new_pos) 

296 # If sign bit is set... 

297 if float_bytes[3:4] == b'\xFF': 

298 return (-math.inf, new_pos) 

299 return (math.inf, new_pos) 

300 

301 # Note that we expect someone up-stack to catch struct.error and convert 

302 # it to _DecodeError -- this way we don't have to set up exception- 

303 # handling blocks every time we parse one value. 

304 result = local_unpack('<f', float_bytes)[0] 

305 return (result, new_pos) 

306 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

307 

308 

309def _DoubleDecoder(): 

310 """Returns a decoder for a double field. 

311 

312 This code works around a bug in struct.unpack for not-a-number. 

313 """ 

314 

315 local_unpack = struct.unpack 

316 

317 def InnerDecode(buffer, pos): 

318 """Decode serialized double to a double and new position. 

319 

320 Args: 

321 buffer: memoryview of the serialized bytes. 

322 pos: int, position in the memory view to start at. 

323 

324 Returns: 

325 Tuple[float, int] of the decoded double value and new position 

326 in the serialized data. 

327 """ 

328 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

329 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

330 new_pos = pos + 8 

331 double_bytes = buffer[pos:new_pos].tobytes() 

332 

333 # If this value has all its exponent bits set and at least one significand 

334 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

335 # as inf or -inf. To avoid that, we treat it specially. 

336 if ((double_bytes[7:8] in b'\x7F\xFF') 

337 and (double_bytes[6:7] >= b'\xF0') 

338 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

339 return (math.nan, new_pos) 

340 

341 # Note that we expect someone up-stack to catch struct.error and convert 

342 # it to _DecodeError -- this way we don't have to set up exception- 

343 # handling blocks every time we parse one value. 

344 result = local_unpack('<d', double_bytes)[0] 

345 return (result, new_pos) 

346 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

347 

348 

349def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

350 clear_if_default=False): 

351 """Returns a decoder for enum field.""" 

352 enum_type = key.enum_type 

353 if is_packed: 

354 local_DecodeVarint = _DecodeVarint 

355 def DecodePackedField(buffer, pos, end, message, field_dict): 

356 """Decode serialized packed enum to its value and a new position. 

357 

358 Args: 

359 buffer: memoryview of the serialized bytes. 

360 pos: int, position in the memory view to start at. 

361 end: int, end position of serialized data 

362 message: Message object to store unknown fields in 

363 field_dict: Map[Descriptor, Any] to store decoded values in. 

364 

365 Returns: 

366 int, new position in serialized data. 

367 """ 

368 value = field_dict.get(key) 

369 if value is None: 

370 value = field_dict.setdefault(key, new_default(message)) 

371 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

372 endpoint += pos 

373 if endpoint > end: 

374 raise _DecodeError('Truncated message.') 

375 while pos < endpoint: 

376 value_start_pos = pos 

377 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

378 # pylint: disable=protected-access 

379 if element in enum_type.values_by_number: 

380 value.append(element) 

381 else: 

382 if not message._unknown_fields: 

383 message._unknown_fields = [] 

384 tag_bytes = encoder.TagBytes(field_number, 

385 wire_format.WIRETYPE_VARINT) 

386 

387 message._unknown_fields.append( 

388 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

389 if message._unknown_field_set is None: 

390 message._unknown_field_set = containers.UnknownFieldSet() 

391 message._unknown_field_set._add( 

392 field_number, wire_format.WIRETYPE_VARINT, element) 

393 # pylint: enable=protected-access 

394 if pos > endpoint: 

395 if element in enum_type.values_by_number: 

396 del value[-1] # Discard corrupt value. 

397 else: 

398 del message._unknown_fields[-1] 

399 # pylint: disable=protected-access 

400 del message._unknown_field_set._values[-1] 

401 # pylint: enable=protected-access 

402 raise _DecodeError('Packed element was truncated.') 

403 return pos 

404 return DecodePackedField 

405 elif is_repeated: 

406 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

407 tag_len = len(tag_bytes) 

408 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

409 """Decode serialized repeated enum to its value and a new position. 

410 

411 Args: 

412 buffer: memoryview of the serialized bytes. 

413 pos: int, position in the memory view to start at. 

414 end: int, end position of serialized data 

415 message: Message object to store unknown fields in 

416 field_dict: Map[Descriptor, Any] to store decoded values in. 

417 

418 Returns: 

419 int, new position in serialized data. 

420 """ 

421 value = field_dict.get(key) 

422 if value is None: 

423 value = field_dict.setdefault(key, new_default(message)) 

424 while 1: 

425 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

426 # pylint: disable=protected-access 

427 if element in enum_type.values_by_number: 

428 value.append(element) 

429 else: 

430 if not message._unknown_fields: 

431 message._unknown_fields = [] 

432 message._unknown_fields.append( 

433 (tag_bytes, buffer[pos:new_pos].tobytes())) 

434 if message._unknown_field_set is None: 

435 message._unknown_field_set = containers.UnknownFieldSet() 

436 message._unknown_field_set._add( 

437 field_number, wire_format.WIRETYPE_VARINT, element) 

438 # pylint: enable=protected-access 

439 # Predict that the next tag is another copy of the same repeated 

440 # field. 

441 pos = new_pos + tag_len 

442 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

443 # Prediction failed. Return. 

444 if new_pos > end: 

445 raise _DecodeError('Truncated message.') 

446 return new_pos 

447 return DecodeRepeatedField 

448 else: 

449 def DecodeField(buffer, pos, end, message, field_dict): 

450 """Decode serialized repeated enum to its value and a new position. 

451 

452 Args: 

453 buffer: memoryview of the serialized bytes. 

454 pos: int, position in the memory view to start at. 

455 end: int, end position of serialized data 

456 message: Message object to store unknown fields in 

457 field_dict: Map[Descriptor, Any] to store decoded values in. 

458 

459 Returns: 

460 int, new position in serialized data. 

461 """ 

462 value_start_pos = pos 

463 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

464 if pos > end: 

465 raise _DecodeError('Truncated message.') 

466 if clear_if_default and not enum_value: 

467 field_dict.pop(key, None) 

468 return pos 

469 # pylint: disable=protected-access 

470 if enum_value in enum_type.values_by_number: 

471 field_dict[key] = enum_value 

472 else: 

473 if not message._unknown_fields: 

474 message._unknown_fields = [] 

475 tag_bytes = encoder.TagBytes(field_number, 

476 wire_format.WIRETYPE_VARINT) 

477 message._unknown_fields.append( 

478 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

479 if message._unknown_field_set is None: 

480 message._unknown_field_set = containers.UnknownFieldSet() 

481 message._unknown_field_set._add( 

482 field_number, wire_format.WIRETYPE_VARINT, enum_value) 

483 # pylint: enable=protected-access 

484 return pos 

485 return DecodeField 

486 

487 

488# -------------------------------------------------------------------- 

489 

490 

491Int32Decoder = _SimpleDecoder( 

492 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

493 

494Int64Decoder = _SimpleDecoder( 

495 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

496 

497UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

498UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

499 

500SInt32Decoder = _ModifiedDecoder( 

501 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

502SInt64Decoder = _ModifiedDecoder( 

503 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

504 

505# Note that Python conveniently guarantees that when using the '<' prefix on 

506# formats, they will also have the same size across all platforms (as opposed 

507# to without the prefix, where their sizes depend on the C compiler's basic 

508# type sizes). 

509Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

510Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

511SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

512SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

513FloatDecoder = _FloatDecoder() 

514DoubleDecoder = _DoubleDecoder() 

515 

516BoolDecoder = _ModifiedDecoder( 

517 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

518 

519 

520def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

521 clear_if_default=False): 

522 """Returns a decoder for a string field.""" 

523 

524 local_DecodeVarint = _DecodeVarint 

525 

526 def _ConvertToUnicode(memview): 

527 """Convert byte to unicode.""" 

528 byte_str = memview.tobytes() 

529 try: 

530 value = str(byte_str, 'utf-8') 

531 except UnicodeDecodeError as e: 

532 # add more information to the error message and re-raise it. 

533 e.reason = '%s in field: %s' % (e, key.full_name) 

534 raise 

535 

536 return value 

537 

538 assert not is_packed 

539 if is_repeated: 

540 tag_bytes = encoder.TagBytes(field_number, 

541 wire_format.WIRETYPE_LENGTH_DELIMITED) 

542 tag_len = len(tag_bytes) 

543 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

544 value = field_dict.get(key) 

545 if value is None: 

546 value = field_dict.setdefault(key, new_default(message)) 

547 while 1: 

548 (size, pos) = local_DecodeVarint(buffer, pos) 

549 new_pos = pos + size 

550 if new_pos > end: 

551 raise _DecodeError('Truncated string.') 

552 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

553 # Predict that the next tag is another copy of the same repeated field. 

554 pos = new_pos + tag_len 

555 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

556 # Prediction failed. Return. 

557 return new_pos 

558 return DecodeRepeatedField 

559 else: 

560 def DecodeField(buffer, pos, end, message, field_dict): 

561 (size, pos) = local_DecodeVarint(buffer, pos) 

562 new_pos = pos + size 

563 if new_pos > end: 

564 raise _DecodeError('Truncated string.') 

565 if clear_if_default and not size: 

566 field_dict.pop(key, None) 

567 else: 

568 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

569 return new_pos 

570 return DecodeField 

571 

572 

573def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

574 clear_if_default=False): 

575 """Returns a decoder for a bytes field.""" 

576 

577 local_DecodeVarint = _DecodeVarint 

578 

579 assert not is_packed 

580 if is_repeated: 

581 tag_bytes = encoder.TagBytes(field_number, 

582 wire_format.WIRETYPE_LENGTH_DELIMITED) 

583 tag_len = len(tag_bytes) 

584 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

585 value = field_dict.get(key) 

586 if value is None: 

587 value = field_dict.setdefault(key, new_default(message)) 

588 while 1: 

589 (size, pos) = local_DecodeVarint(buffer, pos) 

590 new_pos = pos + size 

591 if new_pos > end: 

592 raise _DecodeError('Truncated string.') 

593 value.append(buffer[pos:new_pos].tobytes()) 

594 # Predict that the next tag is another copy of the same repeated field. 

595 pos = new_pos + tag_len 

596 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

597 # Prediction failed. Return. 

598 return new_pos 

599 return DecodeRepeatedField 

600 else: 

601 def DecodeField(buffer, pos, end, message, field_dict): 

602 (size, pos) = local_DecodeVarint(buffer, pos) 

603 new_pos = pos + size 

604 if new_pos > end: 

605 raise _DecodeError('Truncated string.') 

606 if clear_if_default and not size: 

607 field_dict.pop(key, None) 

608 else: 

609 field_dict[key] = buffer[pos:new_pos].tobytes() 

610 return new_pos 

611 return DecodeField 

612 

613 

614def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

615 """Returns a decoder for a group field.""" 

616 

617 end_tag_bytes = encoder.TagBytes(field_number, 

618 wire_format.WIRETYPE_END_GROUP) 

619 end_tag_len = len(end_tag_bytes) 

620 

621 assert not is_packed 

622 if is_repeated: 

623 tag_bytes = encoder.TagBytes(field_number, 

624 wire_format.WIRETYPE_START_GROUP) 

625 tag_len = len(tag_bytes) 

626 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

627 value = field_dict.get(key) 

628 if value is None: 

629 value = field_dict.setdefault(key, new_default(message)) 

630 while 1: 

631 value = field_dict.get(key) 

632 if value is None: 

633 value = field_dict.setdefault(key, new_default(message)) 

634 # Read sub-message. 

635 pos = value.add()._InternalParse(buffer, pos, end) 

636 # Read end tag. 

637 new_pos = pos+end_tag_len 

638 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

639 raise _DecodeError('Missing group end tag.') 

640 # Predict that the next tag is another copy of the same repeated field. 

641 pos = new_pos + tag_len 

642 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

643 # Prediction failed. Return. 

644 return new_pos 

645 return DecodeRepeatedField 

646 else: 

647 def DecodeField(buffer, pos, end, message, field_dict): 

648 value = field_dict.get(key) 

649 if value is None: 

650 value = field_dict.setdefault(key, new_default(message)) 

651 # Read sub-message. 

652 pos = value._InternalParse(buffer, pos, end) 

653 # Read end tag. 

654 new_pos = pos+end_tag_len 

655 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

656 raise _DecodeError('Missing group end tag.') 

657 return new_pos 

658 return DecodeField 

659 

660 

661def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

662 """Returns a decoder for a message field.""" 

663 

664 local_DecodeVarint = _DecodeVarint 

665 

666 assert not is_packed 

667 if is_repeated: 

668 tag_bytes = encoder.TagBytes(field_number, 

669 wire_format.WIRETYPE_LENGTH_DELIMITED) 

670 tag_len = len(tag_bytes) 

671 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

672 value = field_dict.get(key) 

673 if value is None: 

674 value = field_dict.setdefault(key, new_default(message)) 

675 while 1: 

676 # Read length. 

677 (size, pos) = local_DecodeVarint(buffer, pos) 

678 new_pos = pos + size 

679 if new_pos > end: 

680 raise _DecodeError('Truncated message.') 

681 # Read sub-message. 

682 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 

683 # The only reason _InternalParse would return early is if it 

684 # encountered an end-group tag. 

685 raise _DecodeError('Unexpected end-group tag.') 

686 # Predict that the next tag is another copy of the same repeated field. 

687 pos = new_pos + tag_len 

688 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

689 # Prediction failed. Return. 

690 return new_pos 

691 return DecodeRepeatedField 

692 else: 

693 def DecodeField(buffer, pos, end, message, field_dict): 

694 value = field_dict.get(key) 

695 if value is None: 

696 value = field_dict.setdefault(key, new_default(message)) 

697 # Read length. 

698 (size, pos) = local_DecodeVarint(buffer, pos) 

699 new_pos = pos + size 

700 if new_pos > end: 

701 raise _DecodeError('Truncated message.') 

702 # Read sub-message. 

703 if value._InternalParse(buffer, pos, new_pos) != new_pos: 

704 # The only reason _InternalParse would return early is if it encountered 

705 # an end-group tag. 

706 raise _DecodeError('Unexpected end-group tag.') 

707 return new_pos 

708 return DecodeField 

709 

710 

711# -------------------------------------------------------------------- 

712 

713MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

714 

715def MessageSetItemDecoder(descriptor): 

716 """Returns a decoder for a MessageSet item. 

717 

718 The parameter is the message Descriptor. 

719 

720 The message set message looks like this: 

721 message MessageSet { 

722 repeated group Item = 1 { 

723 required int32 type_id = 2; 

724 required string message = 3; 

725 } 

726 } 

727 """ 

728 

729 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

730 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

731 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

732 

733 local_ReadTag = ReadTag 

734 local_DecodeVarint = _DecodeVarint 

735 local_SkipField = SkipField 

736 

737 def DecodeItem(buffer, pos, end, message, field_dict): 

738 """Decode serialized message set to its value and new position. 

739 

740 Args: 

741 buffer: memoryview of the serialized bytes. 

742 pos: int, position in the memory view to start at. 

743 end: int, end position of serialized data 

744 message: Message object to store unknown fields in 

745 field_dict: Map[Descriptor, Any] to store decoded values in. 

746 

747 Returns: 

748 int, new position in serialized data. 

749 """ 

750 message_set_item_start = pos 

751 type_id = -1 

752 message_start = -1 

753 message_end = -1 

754 

755 # Technically, type_id and message can appear in any order, so we need 

756 # a little loop here. 

757 while 1: 

758 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

759 if tag_bytes == type_id_tag_bytes: 

760 (type_id, pos) = local_DecodeVarint(buffer, pos) 

761 elif tag_bytes == message_tag_bytes: 

762 (size, message_start) = local_DecodeVarint(buffer, pos) 

763 pos = message_end = message_start + size 

764 elif tag_bytes == item_end_tag_bytes: 

765 break 

766 else: 

767 pos = SkipField(buffer, pos, end, tag_bytes) 

768 if pos == -1: 

769 raise _DecodeError('Missing group end tag.') 

770 

771 if pos > end: 

772 raise _DecodeError('Truncated message.') 

773 

774 if type_id == -1: 

775 raise _DecodeError('MessageSet item missing type_id.') 

776 if message_start == -1: 

777 raise _DecodeError('MessageSet item missing message.') 

778 

779 extension = message.Extensions._FindExtensionByNumber(type_id) 

780 # pylint: disable=protected-access 

781 if extension is not None: 

782 value = field_dict.get(extension) 

783 if value is None: 

784 message_type = extension.message_type 

785 if not hasattr(message_type, '_concrete_class'): 

786 message_factory.GetMessageClass(message_type) 

787 value = field_dict.setdefault( 

788 extension, message_type._concrete_class()) 

789 if value._InternalParse(buffer, message_start,message_end) != message_end: 

790 # The only reason _InternalParse would return early is if it encountered 

791 # an end-group tag. 

792 raise _DecodeError('Unexpected end-group tag.') 

793 else: 

794 if not message._unknown_fields: 

795 message._unknown_fields = [] 

796 message._unknown_fields.append( 

797 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

798 if message._unknown_field_set is None: 

799 message._unknown_field_set = containers.UnknownFieldSet() 

800 message._unknown_field_set._add( 

801 type_id, 

802 wire_format.WIRETYPE_LENGTH_DELIMITED, 

803 buffer[message_start:message_end].tobytes()) 

804 # pylint: enable=protected-access 

805 

806 return pos 

807 

808 return DecodeItem 

809 

810 

811def UnknownMessageSetItemDecoder(): 

812 """Returns a decoder for a Unknown MessageSet item.""" 

813 

814 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

815 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

816 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

817 

818 def DecodeUnknownItem(buffer): 

819 pos = 0 

820 end = len(buffer) 

821 message_start = -1 

822 message_end = -1 

823 while 1: 

824 (tag_bytes, pos) = ReadTag(buffer, pos) 

825 if tag_bytes == type_id_tag_bytes: 

826 (type_id, pos) = _DecodeVarint(buffer, pos) 

827 elif tag_bytes == message_tag_bytes: 

828 (size, message_start) = _DecodeVarint(buffer, pos) 

829 pos = message_end = message_start + size 

830 elif tag_bytes == item_end_tag_bytes: 

831 break 

832 else: 

833 pos = SkipField(buffer, pos, end, tag_bytes) 

834 if pos == -1: 

835 raise _DecodeError('Missing group end tag.') 

836 

837 if pos > end: 

838 raise _DecodeError('Truncated message.') 

839 

840 if type_id == -1: 

841 raise _DecodeError('MessageSet item missing type_id.') 

842 if message_start == -1: 

843 raise _DecodeError('MessageSet item missing message.') 

844 

845 return (type_id, buffer[message_start:message_end].tobytes()) 

846 

847 return DecodeUnknownItem 

848 

849# -------------------------------------------------------------------- 

850 

851def MapDecoder(field_descriptor, new_default, is_message_map): 

852 """Returns a decoder for a map field.""" 

853 

854 key = field_descriptor 

855 tag_bytes = encoder.TagBytes(field_descriptor.number, 

856 wire_format.WIRETYPE_LENGTH_DELIMITED) 

857 tag_len = len(tag_bytes) 

858 local_DecodeVarint = _DecodeVarint 

859 # Can't read _concrete_class yet; might not be initialized. 

860 message_type = field_descriptor.message_type 

861 

862 def DecodeMap(buffer, pos, end, message, field_dict): 

863 submsg = message_type._concrete_class() 

864 value = field_dict.get(key) 

865 if value is None: 

866 value = field_dict.setdefault(key, new_default(message)) 

867 while 1: 

868 # Read length. 

869 (size, pos) = local_DecodeVarint(buffer, pos) 

870 new_pos = pos + size 

871 if new_pos > end: 

872 raise _DecodeError('Truncated message.') 

873 # Read sub-message. 

874 submsg.Clear() 

875 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 

876 # The only reason _InternalParse would return early is if it 

877 # encountered an end-group tag. 

878 raise _DecodeError('Unexpected end-group tag.') 

879 

880 if is_message_map: 

881 value[submsg.key].CopyFrom(submsg.value) 

882 else: 

883 value[submsg.key] = submsg.value 

884 

885 # Predict that the next tag is another copy of the same repeated field. 

886 pos = new_pos + tag_len 

887 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

888 # Prediction failed. Return. 

889 return new_pos 

890 

891 return DecodeMap 

892 

893# -------------------------------------------------------------------- 

894# Optimization is not as heavy here because calls to SkipField() are rare, 

895# except for handling end-group tags. 

896 

897def _SkipVarint(buffer, pos, end): 

898 """Skip a varint value. Returns the new position.""" 

899 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 

900 # With this code, ord(b'') raises TypeError. Both are handled in 

901 # python_message.py to generate a 'Truncated message' error. 

902 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 

903 pos += 1 

904 pos += 1 

905 if pos > end: 

906 raise _DecodeError('Truncated message.') 

907 return pos 

908 

909def _SkipFixed64(buffer, pos, end): 

910 """Skip a fixed64 value. Returns the new position.""" 

911 

912 pos += 8 

913 if pos > end: 

914 raise _DecodeError('Truncated message.') 

915 return pos 

916 

917 

918def _DecodeFixed64(buffer, pos): 

919 """Decode a fixed64.""" 

920 new_pos = pos + 8 

921 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

922 

923 

924def _SkipLengthDelimited(buffer, pos, end): 

925 """Skip a length-delimited value. Returns the new position.""" 

926 

927 (size, pos) = _DecodeVarint(buffer, pos) 

928 pos += size 

929 if pos > end: 

930 raise _DecodeError('Truncated message.') 

931 return pos 

932 

933 

934def _SkipGroup(buffer, pos, end): 

935 """Skip sub-group. Returns the new position.""" 

936 

937 while 1: 

938 (tag_bytes, pos) = ReadTag(buffer, pos) 

939 new_pos = SkipField(buffer, pos, end, tag_bytes) 

940 if new_pos == -1: 

941 return pos 

942 pos = new_pos 

943 

944 

945def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 

946 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

947 

948 unknown_field_set = containers.UnknownFieldSet() 

949 while end_pos is None or pos < end_pos: 

950 (tag_bytes, pos) = ReadTag(buffer, pos) 

951 (tag, _) = _DecodeVarint(tag_bytes, 0) 

952 field_number, wire_type = wire_format.UnpackTag(tag) 

953 if wire_type == wire_format.WIRETYPE_END_GROUP: 

954 break 

955 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 

956 # pylint: disable=protected-access 

957 unknown_field_set._add(field_number, wire_type, data) 

958 

959 return (unknown_field_set, pos) 

960 

961 

962def _DecodeUnknownField(buffer, pos, wire_type): 

963 """Decode a unknown field. Returns the UnknownField and new position.""" 

964 

965 if wire_type == wire_format.WIRETYPE_VARINT: 

966 (data, pos) = _DecodeVarint(buffer, pos) 

967 elif wire_type == wire_format.WIRETYPE_FIXED64: 

968 (data, pos) = _DecodeFixed64(buffer, pos) 

969 elif wire_type == wire_format.WIRETYPE_FIXED32: 

970 (data, pos) = _DecodeFixed32(buffer, pos) 

971 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

972 (size, pos) = _DecodeVarint(buffer, pos) 

973 data = buffer[pos:pos+size].tobytes() 

974 pos += size 

975 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

976 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 

977 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

978 return (0, -1) 

979 else: 

980 raise _DecodeError('Wrong wire type in tag.') 

981 

982 return (data, pos) 

983 

984 

985def _EndGroup(buffer, pos, end): 

986 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 

987 

988 return -1 

989 

990 

991def _SkipFixed32(buffer, pos, end): 

992 """Skip a fixed32 value. Returns the new position.""" 

993 

994 pos += 4 

995 if pos > end: 

996 raise _DecodeError('Truncated message.') 

997 return pos 

998 

999 

1000def _DecodeFixed32(buffer, pos): 

1001 """Decode a fixed32.""" 

1002 

1003 new_pos = pos + 4 

1004 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

1005 

1006 

1007def _RaiseInvalidWireType(buffer, pos, end): 

1008 """Skip function for unknown wire types. Raises an exception.""" 

1009 

1010 raise _DecodeError('Tag had invalid wire type.') 

1011 

1012def _FieldSkipper(): 

1013 """Constructs the SkipField function.""" 

1014 

1015 WIRETYPE_TO_SKIPPER = [ 

1016 _SkipVarint, 

1017 _SkipFixed64, 

1018 _SkipLengthDelimited, 

1019 _SkipGroup, 

1020 _EndGroup, 

1021 _SkipFixed32, 

1022 _RaiseInvalidWireType, 

1023 _RaiseInvalidWireType, 

1024 ] 

1025 

1026 wiretype_mask = wire_format.TAG_TYPE_MASK 

1027 

1028 def SkipField(buffer, pos, end, tag_bytes): 

1029 """Skips a field with the specified tag. 

1030 

1031 |pos| should point to the byte immediately after the tag. 

1032 

1033 Returns: 

1034 The new position (after the tag value), or -1 if the tag is an end-group 

1035 tag (in which case the calling loop should break). 

1036 """ 

1037 

1038 # The wire type is always in the first byte since varints are little-endian. 

1039 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 

1040 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 

1041 

1042 return SkipField 

1043 

1044SkipField = _FieldSkipper()