Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/decoder.py: 15%

527 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 07:30 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# https://developers.google.com/protocol-buffers/ 

4# 

5# Redistribution and use in source and binary forms, with or without 

6# modification, are permitted provided that the following conditions are 

7# met: 

8# 

9# * Redistributions of source code must retain the above copyright 

10# notice, this list of conditions and the following disclaimer. 

11# * Redistributions in binary form must reproduce the above 

12# copyright notice, this list of conditions and the following disclaimer 

13# in the documentation and/or other materials provided with the 

14# distribution. 

15# * Neither the name of Google Inc. nor the names of its 

16# contributors may be used to endorse or promote products derived from 

17# this software without specific prior written permission. 

18# 

19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 

20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 

21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 

22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 

23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 

24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 

25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 

26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 

27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 

28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31"""Code for decoding protocol buffer primitives. 

32 

33This code is very similar to encoder.py -- read the docs for that module first. 

34 

35A "decoder" is a function with the signature: 

36 Decode(buffer, pos, end, message, field_dict) 

37The arguments are: 

38 buffer: The string containing the encoded message. 

39 pos: The current position in the string. 

40 end: The position in the string where the current message ends. May be 

41 less than len(buffer) if we're reading a sub-message. 

42 message: The message object into which we're parsing. 

43 field_dict: message._fields (avoids a hashtable lookup). 

44The decoder reads the field and stores it into field_dict, returning the new 

45buffer position. A decoder for a repeated field may proactively decode all of 

46the elements of that field, if they appear consecutively. 

47 

48Note that decoders may throw any of the following: 

49 IndexError: Indicates a truncated message. 

50 struct.error: Unpacking of a fixed-width field failed. 

51 message.DecodeError: Other errors. 

52 

53Decoders are expected to raise an exception if they are called with pos > end. 

54This allows callers to be lax about bounds checking: it's fineto read past 

55"end" as long as you are sure that someone else will notice and throw an 

56exception later on. 

57 

58Something up the call stack is expected to catch IndexError and struct.error 

59and convert them to message.DecodeError. 

60 

61Decoders are constructed using decoder constructors with the signature: 

62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 

63The arguments are: 

64 field_number: The field number of the field we want to decode. 

65 is_repeated: Is the field a repeated field? (bool) 

66 is_packed: Is the field a packed field? (bool) 

67 key: The key to use when looking up the field within field_dict. 

68 (This is actually the FieldDescriptor but nothing in this 

69 file should depend on that.) 

70 new_default: A function which takes a message object as a parameter and 

71 returns a new instance of the default value for this field. 

72 (This is called for repeated fields and sub-messages, when an 

73 instance does not already exist.) 

74 

75As with encoders, we define a decoder constructor for every type of field. 

76Then, for every field of every message class we construct an actual decoder. 

77That decoder goes into a dict indexed by tag, so when we decode a message 

78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 

79""" 

80 

81__author__ = 'kenton@google.com (Kenton Varda)' 

82 

83import math 

84import struct 

85 

86from google.protobuf.internal import containers 

87from google.protobuf.internal import encoder 

88from google.protobuf.internal import wire_format 

89from google.protobuf import message 

90 

91 

92# This is not for optimization, but rather to avoid conflicts with local 

93# variables named "message". 

94_DecodeError = message.DecodeError 

95 

96 

97def _VarintDecoder(mask, result_type): 

98 """Return an encoder for a basic varint value (does not include tag). 

99 

100 Decoded values will be bitwise-anded with the given mask before being 

101 returned, e.g. to limit them to 32 bits. The returned decoder does not 

102 take the usual "end" parameter -- the caller is expected to do bounds checking 

103 after the fact (often the caller can defer such checking until later). The 

104 decoder returns a (value, new_pos) pair. 

105 """ 

106 

107 def DecodeVarint(buffer, pos): 

108 result = 0 

109 shift = 0 

110 while 1: 

111 b = buffer[pos] 

112 result |= ((b & 0x7f) << shift) 

113 pos += 1 

114 if not (b & 0x80): 

115 result &= mask 

116 result = result_type(result) 

117 return (result, pos) 

118 shift += 7 

119 if shift >= 64: 

120 raise _DecodeError('Too many bytes when decoding varint.') 

121 return DecodeVarint 

122 

123 

124def _SignedVarintDecoder(bits, result_type): 

125 """Like _VarintDecoder() but decodes signed values.""" 

126 

127 signbit = 1 << (bits - 1) 

128 mask = (1 << bits) - 1 

129 

130 def DecodeVarint(buffer, pos): 

131 result = 0 

132 shift = 0 

133 while 1: 

134 b = buffer[pos] 

135 result |= ((b & 0x7f) << shift) 

136 pos += 1 

137 if not (b & 0x80): 

138 result &= mask 

139 result = (result ^ signbit) - signbit 

140 result = result_type(result) 

141 return (result, pos) 

142 shift += 7 

143 if shift >= 64: 

144 raise _DecodeError('Too many bytes when decoding varint.') 

145 return DecodeVarint 

146 

147# All 32-bit and 64-bit values are represented as int. 

148_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 

149_DecodeSignedVarint = _SignedVarintDecoder(64, int) 

150 

151# Use these versions for values which must be limited to 32 bits. 

152_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 

153_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 

154 

155 

156def ReadTag(buffer, pos): 

157 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 

158 

159 We return the raw bytes of the tag rather than decoding them. The raw 

160 bytes can then be used to look up the proper decoder. This effectively allows 

161 us to trade some work that would be done in pure-python (decoding a varint) 

162 for work that is done in C (searching for a byte string in a hash table). 

163 In a low-level language it would be much cheaper to decode the varint and 

164 use that, but not in Python. 

165 

166 Args: 

167 buffer: memoryview object of the encoded bytes 

168 pos: int of the current position to start from 

169 

170 Returns: 

171 Tuple[bytes, int] of the tag data and new position. 

172 """ 

173 start = pos 

174 while buffer[pos] & 0x80: 

175 pos += 1 

176 pos += 1 

177 

178 tag_bytes = buffer[start:pos].tobytes() 

179 return tag_bytes, pos 

180 

181 

182# -------------------------------------------------------------------- 

183 

184 

185def _SimpleDecoder(wire_type, decode_value): 

186 """Return a constructor for a decoder for fields of a particular type. 

187 

188 Args: 

189 wire_type: The field's wire type. 

190 decode_value: A function which decodes an individual value, e.g. 

191 _DecodeVarint() 

192 """ 

193 

194 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 

195 clear_if_default=False): 

196 if is_packed: 

197 local_DecodeVarint = _DecodeVarint 

198 def DecodePackedField(buffer, pos, end, message, field_dict): 

199 value = field_dict.get(key) 

200 if value is None: 

201 value = field_dict.setdefault(key, new_default(message)) 

202 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

203 endpoint += pos 

204 if endpoint > end: 

205 raise _DecodeError('Truncated message.') 

206 while pos < endpoint: 

207 (element, pos) = decode_value(buffer, pos) 

208 value.append(element) 

209 if pos > endpoint: 

210 del value[-1] # Discard corrupt value. 

211 raise _DecodeError('Packed element was truncated.') 

212 return pos 

213 return DecodePackedField 

214 elif is_repeated: 

215 tag_bytes = encoder.TagBytes(field_number, wire_type) 

216 tag_len = len(tag_bytes) 

217 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

218 value = field_dict.get(key) 

219 if value is None: 

220 value = field_dict.setdefault(key, new_default(message)) 

221 while 1: 

222 (element, new_pos) = decode_value(buffer, pos) 

223 value.append(element) 

224 # Predict that the next tag is another copy of the same repeated 

225 # field. 

226 pos = new_pos + tag_len 

227 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

228 # Prediction failed. Return. 

229 if new_pos > end: 

230 raise _DecodeError('Truncated message.') 

231 return new_pos 

232 return DecodeRepeatedField 

233 else: 

234 def DecodeField(buffer, pos, end, message, field_dict): 

235 (new_value, pos) = decode_value(buffer, pos) 

236 if pos > end: 

237 raise _DecodeError('Truncated message.') 

238 if clear_if_default and not new_value: 

239 field_dict.pop(key, None) 

240 else: 

241 field_dict[key] = new_value 

242 return pos 

243 return DecodeField 

244 

245 return SpecificDecoder 

246 

247 

248def _ModifiedDecoder(wire_type, decode_value, modify_value): 

249 """Like SimpleDecoder but additionally invokes modify_value on every value 

250 before storing it. Usually modify_value is ZigZagDecode. 

251 """ 

252 

253 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

254 # not enough to make a significant difference. 

255 

256 def InnerDecode(buffer, pos): 

257 (result, new_pos) = decode_value(buffer, pos) 

258 return (modify_value(result), new_pos) 

259 return _SimpleDecoder(wire_type, InnerDecode) 

260 

261 

262def _StructPackDecoder(wire_type, format): 

263 """Return a constructor for a decoder for a fixed-width field. 

264 

265 Args: 

266 wire_type: The field's wire type. 

267 format: The format string to pass to struct.unpack(). 

268 """ 

269 

270 value_size = struct.calcsize(format) 

271 local_unpack = struct.unpack 

272 

273 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 

274 # not enough to make a significant difference. 

275 

276 # Note that we expect someone up-stack to catch struct.error and convert 

277 # it to _DecodeError -- this way we don't have to set up exception- 

278 # handling blocks every time we parse one value. 

279 

280 def InnerDecode(buffer, pos): 

281 new_pos = pos + value_size 

282 result = local_unpack(format, buffer[pos:new_pos])[0] 

283 return (result, new_pos) 

284 return _SimpleDecoder(wire_type, InnerDecode) 

285 

286 

287def _FloatDecoder(): 

288 """Returns a decoder for a float field. 

289 

290 This code works around a bug in struct.unpack for non-finite 32-bit 

291 floating-point values. 

292 """ 

293 

294 local_unpack = struct.unpack 

295 

296 def InnerDecode(buffer, pos): 

297 """Decode serialized float to a float and new position. 

298 

299 Args: 

300 buffer: memoryview of the serialized bytes 

301 pos: int, position in the memory view to start at. 

302 

303 Returns: 

304 Tuple[float, int] of the deserialized float value and new position 

305 in the serialized data. 

306 """ 

307 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 

308 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 

309 new_pos = pos + 4 

310 float_bytes = buffer[pos:new_pos].tobytes() 

311 

312 # If this value has all its exponent bits set, then it's non-finite. 

313 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 

314 # To avoid that, we parse it specially. 

315 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 

316 # If at least one significand bit is set... 

317 if float_bytes[0:3] != b'\x00\x00\x80': 

318 return (math.nan, new_pos) 

319 # If sign bit is set... 

320 if float_bytes[3:4] == b'\xFF': 

321 return (-math.inf, new_pos) 

322 return (math.inf, new_pos) 

323 

324 # Note that we expect someone up-stack to catch struct.error and convert 

325 # it to _DecodeError -- this way we don't have to set up exception- 

326 # handling blocks every time we parse one value. 

327 result = local_unpack('<f', float_bytes)[0] 

328 return (result, new_pos) 

329 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 

330 

331 

332def _DoubleDecoder(): 

333 """Returns a decoder for a double field. 

334 

335 This code works around a bug in struct.unpack for not-a-number. 

336 """ 

337 

338 local_unpack = struct.unpack 

339 

340 def InnerDecode(buffer, pos): 

341 """Decode serialized double to a double and new position. 

342 

343 Args: 

344 buffer: memoryview of the serialized bytes. 

345 pos: int, position in the memory view to start at. 

346 

347 Returns: 

348 Tuple[float, int] of the decoded double value and new position 

349 in the serialized data. 

350 """ 

351 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 

352 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 

353 new_pos = pos + 8 

354 double_bytes = buffer[pos:new_pos].tobytes() 

355 

356 # If this value has all its exponent bits set and at least one significand 

357 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 

358 # as inf or -inf. To avoid that, we treat it specially. 

359 if ((double_bytes[7:8] in b'\x7F\xFF') 

360 and (double_bytes[6:7] >= b'\xF0') 

361 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 

362 return (math.nan, new_pos) 

363 

364 # Note that we expect someone up-stack to catch struct.error and convert 

365 # it to _DecodeError -- this way we don't have to set up exception- 

366 # handling blocks every time we parse one value. 

367 result = local_unpack('<d', double_bytes)[0] 

368 return (result, new_pos) 

369 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 

370 

371 

372def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 

373 clear_if_default=False): 

374 """Returns a decoder for enum field.""" 

375 enum_type = key.enum_type 

376 if is_packed: 

377 local_DecodeVarint = _DecodeVarint 

378 def DecodePackedField(buffer, pos, end, message, field_dict): 

379 """Decode serialized packed enum to its value and a new position. 

380 

381 Args: 

382 buffer: memoryview of the serialized bytes. 

383 pos: int, position in the memory view to start at. 

384 end: int, end position of serialized data 

385 message: Message object to store unknown fields in 

386 field_dict: Map[Descriptor, Any] to store decoded values in. 

387 

388 Returns: 

389 int, new position in serialized data. 

390 """ 

391 value = field_dict.get(key) 

392 if value is None: 

393 value = field_dict.setdefault(key, new_default(message)) 

394 (endpoint, pos) = local_DecodeVarint(buffer, pos) 

395 endpoint += pos 

396 if endpoint > end: 

397 raise _DecodeError('Truncated message.') 

398 while pos < endpoint: 

399 value_start_pos = pos 

400 (element, pos) = _DecodeSignedVarint32(buffer, pos) 

401 # pylint: disable=protected-access 

402 if element in enum_type.values_by_number: 

403 value.append(element) 

404 else: 

405 if not message._unknown_fields: 

406 message._unknown_fields = [] 

407 tag_bytes = encoder.TagBytes(field_number, 

408 wire_format.WIRETYPE_VARINT) 

409 

410 message._unknown_fields.append( 

411 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

412 if message._unknown_field_set is None: 

413 message._unknown_field_set = containers.UnknownFieldSet() 

414 message._unknown_field_set._add( 

415 field_number, wire_format.WIRETYPE_VARINT, element) 

416 # pylint: enable=protected-access 

417 if pos > endpoint: 

418 if element in enum_type.values_by_number: 

419 del value[-1] # Discard corrupt value. 

420 else: 

421 del message._unknown_fields[-1] 

422 # pylint: disable=protected-access 

423 del message._unknown_field_set._values[-1] 

424 # pylint: enable=protected-access 

425 raise _DecodeError('Packed element was truncated.') 

426 return pos 

427 return DecodePackedField 

428 elif is_repeated: 

429 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 

430 tag_len = len(tag_bytes) 

431 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

432 """Decode serialized repeated enum to its value and a new position. 

433 

434 Args: 

435 buffer: memoryview of the serialized bytes. 

436 pos: int, position in the memory view to start at. 

437 end: int, end position of serialized data 

438 message: Message object to store unknown fields in 

439 field_dict: Map[Descriptor, Any] to store decoded values in. 

440 

441 Returns: 

442 int, new position in serialized data. 

443 """ 

444 value = field_dict.get(key) 

445 if value is None: 

446 value = field_dict.setdefault(key, new_default(message)) 

447 while 1: 

448 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 

449 # pylint: disable=protected-access 

450 if element in enum_type.values_by_number: 

451 value.append(element) 

452 else: 

453 if not message._unknown_fields: 

454 message._unknown_fields = [] 

455 message._unknown_fields.append( 

456 (tag_bytes, buffer[pos:new_pos].tobytes())) 

457 if message._unknown_field_set is None: 

458 message._unknown_field_set = containers.UnknownFieldSet() 

459 message._unknown_field_set._add( 

460 field_number, wire_format.WIRETYPE_VARINT, element) 

461 # pylint: enable=protected-access 

462 # Predict that the next tag is another copy of the same repeated 

463 # field. 

464 pos = new_pos + tag_len 

465 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 

466 # Prediction failed. Return. 

467 if new_pos > end: 

468 raise _DecodeError('Truncated message.') 

469 return new_pos 

470 return DecodeRepeatedField 

471 else: 

472 def DecodeField(buffer, pos, end, message, field_dict): 

473 """Decode serialized repeated enum to its value and a new position. 

474 

475 Args: 

476 buffer: memoryview of the serialized bytes. 

477 pos: int, position in the memory view to start at. 

478 end: int, end position of serialized data 

479 message: Message object to store unknown fields in 

480 field_dict: Map[Descriptor, Any] to store decoded values in. 

481 

482 Returns: 

483 int, new position in serialized data. 

484 """ 

485 value_start_pos = pos 

486 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 

487 if pos > end: 

488 raise _DecodeError('Truncated message.') 

489 if clear_if_default and not enum_value: 

490 field_dict.pop(key, None) 

491 return pos 

492 # pylint: disable=protected-access 

493 if enum_value in enum_type.values_by_number: 

494 field_dict[key] = enum_value 

495 else: 

496 if not message._unknown_fields: 

497 message._unknown_fields = [] 

498 tag_bytes = encoder.TagBytes(field_number, 

499 wire_format.WIRETYPE_VARINT) 

500 message._unknown_fields.append( 

501 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 

502 if message._unknown_field_set is None: 

503 message._unknown_field_set = containers.UnknownFieldSet() 

504 message._unknown_field_set._add( 

505 field_number, wire_format.WIRETYPE_VARINT, enum_value) 

506 # pylint: enable=protected-access 

507 return pos 

508 return DecodeField 

509 

510 

511# -------------------------------------------------------------------- 

512 

513 

514Int32Decoder = _SimpleDecoder( 

515 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 

516 

517Int64Decoder = _SimpleDecoder( 

518 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 

519 

520UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 

521UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 

522 

523SInt32Decoder = _ModifiedDecoder( 

524 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 

525SInt64Decoder = _ModifiedDecoder( 

526 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 

527 

528# Note that Python conveniently guarantees that when using the '<' prefix on 

529# formats, they will also have the same size across all platforms (as opposed 

530# to without the prefix, where their sizes depend on the C compiler's basic 

531# type sizes). 

532Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 

533Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 

534SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 

535SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 

536FloatDecoder = _FloatDecoder() 

537DoubleDecoder = _DoubleDecoder() 

538 

539BoolDecoder = _ModifiedDecoder( 

540 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 

541 

542 

543def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 

544 clear_if_default=False): 

545 """Returns a decoder for a string field.""" 

546 

547 local_DecodeVarint = _DecodeVarint 

548 

549 def _ConvertToUnicode(memview): 

550 """Convert byte to unicode.""" 

551 byte_str = memview.tobytes() 

552 try: 

553 value = str(byte_str, 'utf-8') 

554 except UnicodeDecodeError as e: 

555 # add more information to the error message and re-raise it. 

556 e.reason = '%s in field: %s' % (e, key.full_name) 

557 raise 

558 

559 return value 

560 

561 assert not is_packed 

562 if is_repeated: 

563 tag_bytes = encoder.TagBytes(field_number, 

564 wire_format.WIRETYPE_LENGTH_DELIMITED) 

565 tag_len = len(tag_bytes) 

566 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

567 value = field_dict.get(key) 

568 if value is None: 

569 value = field_dict.setdefault(key, new_default(message)) 

570 while 1: 

571 (size, pos) = local_DecodeVarint(buffer, pos) 

572 new_pos = pos + size 

573 if new_pos > end: 

574 raise _DecodeError('Truncated string.') 

575 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 

576 # Predict that the next tag is another copy of the same repeated field. 

577 pos = new_pos + tag_len 

578 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

579 # Prediction failed. Return. 

580 return new_pos 

581 return DecodeRepeatedField 

582 else: 

583 def DecodeField(buffer, pos, end, message, field_dict): 

584 (size, pos) = local_DecodeVarint(buffer, pos) 

585 new_pos = pos + size 

586 if new_pos > end: 

587 raise _DecodeError('Truncated string.') 

588 if clear_if_default and not size: 

589 field_dict.pop(key, None) 

590 else: 

591 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 

592 return new_pos 

593 return DecodeField 

594 

595 

596def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 

597 clear_if_default=False): 

598 """Returns a decoder for a bytes field.""" 

599 

600 local_DecodeVarint = _DecodeVarint 

601 

602 assert not is_packed 

603 if is_repeated: 

604 tag_bytes = encoder.TagBytes(field_number, 

605 wire_format.WIRETYPE_LENGTH_DELIMITED) 

606 tag_len = len(tag_bytes) 

607 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

608 value = field_dict.get(key) 

609 if value is None: 

610 value = field_dict.setdefault(key, new_default(message)) 

611 while 1: 

612 (size, pos) = local_DecodeVarint(buffer, pos) 

613 new_pos = pos + size 

614 if new_pos > end: 

615 raise _DecodeError('Truncated string.') 

616 value.append(buffer[pos:new_pos].tobytes()) 

617 # Predict that the next tag is another copy of the same repeated field. 

618 pos = new_pos + tag_len 

619 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

620 # Prediction failed. Return. 

621 return new_pos 

622 return DecodeRepeatedField 

623 else: 

624 def DecodeField(buffer, pos, end, message, field_dict): 

625 (size, pos) = local_DecodeVarint(buffer, pos) 

626 new_pos = pos + size 

627 if new_pos > end: 

628 raise _DecodeError('Truncated string.') 

629 if clear_if_default and not size: 

630 field_dict.pop(key, None) 

631 else: 

632 field_dict[key] = buffer[pos:new_pos].tobytes() 

633 return new_pos 

634 return DecodeField 

635 

636 

637def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 

638 """Returns a decoder for a group field.""" 

639 

640 end_tag_bytes = encoder.TagBytes(field_number, 

641 wire_format.WIRETYPE_END_GROUP) 

642 end_tag_len = len(end_tag_bytes) 

643 

644 assert not is_packed 

645 if is_repeated: 

646 tag_bytes = encoder.TagBytes(field_number, 

647 wire_format.WIRETYPE_START_GROUP) 

648 tag_len = len(tag_bytes) 

649 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

650 value = field_dict.get(key) 

651 if value is None: 

652 value = field_dict.setdefault(key, new_default(message)) 

653 while 1: 

654 value = field_dict.get(key) 

655 if value is None: 

656 value = field_dict.setdefault(key, new_default(message)) 

657 # Read sub-message. 

658 pos = value.add()._InternalParse(buffer, pos, end) 

659 # Read end tag. 

660 new_pos = pos+end_tag_len 

661 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

662 raise _DecodeError('Missing group end tag.') 

663 # Predict that the next tag is another copy of the same repeated field. 

664 pos = new_pos + tag_len 

665 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

666 # Prediction failed. Return. 

667 return new_pos 

668 return DecodeRepeatedField 

669 else: 

670 def DecodeField(buffer, pos, end, message, field_dict): 

671 value = field_dict.get(key) 

672 if value is None: 

673 value = field_dict.setdefault(key, new_default(message)) 

674 # Read sub-message. 

675 pos = value._InternalParse(buffer, pos, end) 

676 # Read end tag. 

677 new_pos = pos+end_tag_len 

678 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 

679 raise _DecodeError('Missing group end tag.') 

680 return new_pos 

681 return DecodeField 

682 

683 

684def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 

685 """Returns a decoder for a message field.""" 

686 

687 local_DecodeVarint = _DecodeVarint 

688 

689 assert not is_packed 

690 if is_repeated: 

691 tag_bytes = encoder.TagBytes(field_number, 

692 wire_format.WIRETYPE_LENGTH_DELIMITED) 

693 tag_len = len(tag_bytes) 

694 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 

695 value = field_dict.get(key) 

696 if value is None: 

697 value = field_dict.setdefault(key, new_default(message)) 

698 while 1: 

699 # Read length. 

700 (size, pos) = local_DecodeVarint(buffer, pos) 

701 new_pos = pos + size 

702 if new_pos > end: 

703 raise _DecodeError('Truncated message.') 

704 # Read sub-message. 

705 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 

706 # The only reason _InternalParse would return early is if it 

707 # encountered an end-group tag. 

708 raise _DecodeError('Unexpected end-group tag.') 

709 # Predict that the next tag is another copy of the same repeated field. 

710 pos = new_pos + tag_len 

711 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

712 # Prediction failed. Return. 

713 return new_pos 

714 return DecodeRepeatedField 

715 else: 

716 def DecodeField(buffer, pos, end, message, field_dict): 

717 value = field_dict.get(key) 

718 if value is None: 

719 value = field_dict.setdefault(key, new_default(message)) 

720 # Read length. 

721 (size, pos) = local_DecodeVarint(buffer, pos) 

722 new_pos = pos + size 

723 if new_pos > end: 

724 raise _DecodeError('Truncated message.') 

725 # Read sub-message. 

726 if value._InternalParse(buffer, pos, new_pos) != new_pos: 

727 # The only reason _InternalParse would return early is if it encountered 

728 # an end-group tag. 

729 raise _DecodeError('Unexpected end-group tag.') 

730 return new_pos 

731 return DecodeField 

732 

733 

734# -------------------------------------------------------------------- 

735 

736MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 

737 

738def MessageSetItemDecoder(descriptor): 

739 """Returns a decoder for a MessageSet item. 

740 

741 The parameter is the message Descriptor. 

742 

743 The message set message looks like this: 

744 message MessageSet { 

745 repeated group Item = 1 { 

746 required int32 type_id = 2; 

747 required string message = 3; 

748 } 

749 } 

750 """ 

751 

752 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

753 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

754 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

755 

756 local_ReadTag = ReadTag 

757 local_DecodeVarint = _DecodeVarint 

758 local_SkipField = SkipField 

759 

760 def DecodeItem(buffer, pos, end, message, field_dict): 

761 """Decode serialized message set to its value and new position. 

762 

763 Args: 

764 buffer: memoryview of the serialized bytes. 

765 pos: int, position in the memory view to start at. 

766 end: int, end position of serialized data 

767 message: Message object to store unknown fields in 

768 field_dict: Map[Descriptor, Any] to store decoded values in. 

769 

770 Returns: 

771 int, new position in serialized data. 

772 """ 

773 message_set_item_start = pos 

774 type_id = -1 

775 message_start = -1 

776 message_end = -1 

777 

778 # Technically, type_id and message can appear in any order, so we need 

779 # a little loop here. 

780 while 1: 

781 (tag_bytes, pos) = local_ReadTag(buffer, pos) 

782 if tag_bytes == type_id_tag_bytes: 

783 (type_id, pos) = local_DecodeVarint(buffer, pos) 

784 elif tag_bytes == message_tag_bytes: 

785 (size, message_start) = local_DecodeVarint(buffer, pos) 

786 pos = message_end = message_start + size 

787 elif tag_bytes == item_end_tag_bytes: 

788 break 

789 else: 

790 pos = SkipField(buffer, pos, end, tag_bytes) 

791 if pos == -1: 

792 raise _DecodeError('Missing group end tag.') 

793 

794 if pos > end: 

795 raise _DecodeError('Truncated message.') 

796 

797 if type_id == -1: 

798 raise _DecodeError('MessageSet item missing type_id.') 

799 if message_start == -1: 

800 raise _DecodeError('MessageSet item missing message.') 

801 

802 extension = message.Extensions._FindExtensionByNumber(type_id) 

803 # pylint: disable=protected-access 

804 if extension is not None: 

805 value = field_dict.get(extension) 

806 if value is None: 

807 message_type = extension.message_type 

808 if not hasattr(message_type, '_concrete_class'): 

809 message_factory.GetMessageClass(message_type) 

810 value = field_dict.setdefault( 

811 extension, message_type._concrete_class()) 

812 if value._InternalParse(buffer, message_start,message_end) != message_end: 

813 # The only reason _InternalParse would return early is if it encountered 

814 # an end-group tag. 

815 raise _DecodeError('Unexpected end-group tag.') 

816 else: 

817 if not message._unknown_fields: 

818 message._unknown_fields = [] 

819 message._unknown_fields.append( 

820 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 

821 if message._unknown_field_set is None: 

822 message._unknown_field_set = containers.UnknownFieldSet() 

823 message._unknown_field_set._add( 

824 type_id, 

825 wire_format.WIRETYPE_LENGTH_DELIMITED, 

826 buffer[message_start:message_end].tobytes()) 

827 # pylint: enable=protected-access 

828 

829 return pos 

830 

831 return DecodeItem 

832 

833 

834def UnknownMessageSetItemDecoder(): 

835 """Returns a decoder for a Unknown MessageSet item.""" 

836 

837 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 

838 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 

839 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 

840 

841 def DecodeUnknownItem(buffer): 

842 pos = 0 

843 end = len(buffer) 

844 message_start = -1 

845 message_end = -1 

846 while 1: 

847 (tag_bytes, pos) = ReadTag(buffer, pos) 

848 if tag_bytes == type_id_tag_bytes: 

849 (type_id, pos) = _DecodeVarint(buffer, pos) 

850 elif tag_bytes == message_tag_bytes: 

851 (size, message_start) = _DecodeVarint(buffer, pos) 

852 pos = message_end = message_start + size 

853 elif tag_bytes == item_end_tag_bytes: 

854 break 

855 else: 

856 pos = SkipField(buffer, pos, end, tag_bytes) 

857 if pos == -1: 

858 raise _DecodeError('Missing group end tag.') 

859 

860 if pos > end: 

861 raise _DecodeError('Truncated message.') 

862 

863 if type_id == -1: 

864 raise _DecodeError('MessageSet item missing type_id.') 

865 if message_start == -1: 

866 raise _DecodeError('MessageSet item missing message.') 

867 

868 return (type_id, buffer[message_start:message_end].tobytes()) 

869 

870 return DecodeUnknownItem 

871 

872# -------------------------------------------------------------------- 

873 

874def MapDecoder(field_descriptor, new_default, is_message_map): 

875 """Returns a decoder for a map field.""" 

876 

877 key = field_descriptor 

878 tag_bytes = encoder.TagBytes(field_descriptor.number, 

879 wire_format.WIRETYPE_LENGTH_DELIMITED) 

880 tag_len = len(tag_bytes) 

881 local_DecodeVarint = _DecodeVarint 

882 # Can't read _concrete_class yet; might not be initialized. 

883 message_type = field_descriptor.message_type 

884 

885 def DecodeMap(buffer, pos, end, message, field_dict): 

886 submsg = message_type._concrete_class() 

887 value = field_dict.get(key) 

888 if value is None: 

889 value = field_dict.setdefault(key, new_default(message)) 

890 while 1: 

891 # Read length. 

892 (size, pos) = local_DecodeVarint(buffer, pos) 

893 new_pos = pos + size 

894 if new_pos > end: 

895 raise _DecodeError('Truncated message.') 

896 # Read sub-message. 

897 submsg.Clear() 

898 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 

899 # The only reason _InternalParse would return early is if it 

900 # encountered an end-group tag. 

901 raise _DecodeError('Unexpected end-group tag.') 

902 

903 if is_message_map: 

904 value[submsg.key].CopyFrom(submsg.value) 

905 else: 

906 value[submsg.key] = submsg.value 

907 

908 # Predict that the next tag is another copy of the same repeated field. 

909 pos = new_pos + tag_len 

910 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 

911 # Prediction failed. Return. 

912 return new_pos 

913 

914 return DecodeMap 

915 

916# -------------------------------------------------------------------- 

917# Optimization is not as heavy here because calls to SkipField() are rare, 

918# except for handling end-group tags. 

919 

920def _SkipVarint(buffer, pos, end): 

921 """Skip a varint value. Returns the new position.""" 

922 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 

923 # With this code, ord(b'') raises TypeError. Both are handled in 

924 # python_message.py to generate a 'Truncated message' error. 

925 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 

926 pos += 1 

927 pos += 1 

928 if pos > end: 

929 raise _DecodeError('Truncated message.') 

930 return pos 

931 

932def _SkipFixed64(buffer, pos, end): 

933 """Skip a fixed64 value. Returns the new position.""" 

934 

935 pos += 8 

936 if pos > end: 

937 raise _DecodeError('Truncated message.') 

938 return pos 

939 

940 

941def _DecodeFixed64(buffer, pos): 

942 """Decode a fixed64.""" 

943 new_pos = pos + 8 

944 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 

945 

946 

947def _SkipLengthDelimited(buffer, pos, end): 

948 """Skip a length-delimited value. Returns the new position.""" 

949 

950 (size, pos) = _DecodeVarint(buffer, pos) 

951 pos += size 

952 if pos > end: 

953 raise _DecodeError('Truncated message.') 

954 return pos 

955 

956 

957def _SkipGroup(buffer, pos, end): 

958 """Skip sub-group. Returns the new position.""" 

959 

960 while 1: 

961 (tag_bytes, pos) = ReadTag(buffer, pos) 

962 new_pos = SkipField(buffer, pos, end, tag_bytes) 

963 if new_pos == -1: 

964 return pos 

965 pos = new_pos 

966 

967 

968def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 

969 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 

970 

971 unknown_field_set = containers.UnknownFieldSet() 

972 while end_pos is None or pos < end_pos: 

973 (tag_bytes, pos) = ReadTag(buffer, pos) 

974 (tag, _) = _DecodeVarint(tag_bytes, 0) 

975 field_number, wire_type = wire_format.UnpackTag(tag) 

976 if wire_type == wire_format.WIRETYPE_END_GROUP: 

977 break 

978 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 

979 # pylint: disable=protected-access 

980 unknown_field_set._add(field_number, wire_type, data) 

981 

982 return (unknown_field_set, pos) 

983 

984 

985def _DecodeUnknownField(buffer, pos, wire_type): 

986 """Decode a unknown field. Returns the UnknownField and new position.""" 

987 

988 if wire_type == wire_format.WIRETYPE_VARINT: 

989 (data, pos) = _DecodeVarint(buffer, pos) 

990 elif wire_type == wire_format.WIRETYPE_FIXED64: 

991 (data, pos) = _DecodeFixed64(buffer, pos) 

992 elif wire_type == wire_format.WIRETYPE_FIXED32: 

993 (data, pos) = _DecodeFixed32(buffer, pos) 

994 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 

995 (size, pos) = _DecodeVarint(buffer, pos) 

996 data = buffer[pos:pos+size].tobytes() 

997 pos += size 

998 elif wire_type == wire_format.WIRETYPE_START_GROUP: 

999 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 

1000 elif wire_type == wire_format.WIRETYPE_END_GROUP: 

1001 return (0, -1) 

1002 else: 

1003 raise _DecodeError('Wrong wire type in tag.') 

1004 

1005 return (data, pos) 

1006 

1007 

1008def _EndGroup(buffer, pos, end): 

1009 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 

1010 

1011 return -1 

1012 

1013 

1014def _SkipFixed32(buffer, pos, end): 

1015 """Skip a fixed32 value. Returns the new position.""" 

1016 

1017 pos += 4 

1018 if pos > end: 

1019 raise _DecodeError('Truncated message.') 

1020 return pos 

1021 

1022 

1023def _DecodeFixed32(buffer, pos): 

1024 """Decode a fixed32.""" 

1025 

1026 new_pos = pos + 4 

1027 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 

1028 

1029 

1030def _RaiseInvalidWireType(buffer, pos, end): 

1031 """Skip function for unknown wire types. Raises an exception.""" 

1032 

1033 raise _DecodeError('Tag had invalid wire type.') 

1034 

1035def _FieldSkipper(): 

1036 """Constructs the SkipField function.""" 

1037 

1038 WIRETYPE_TO_SKIPPER = [ 

1039 _SkipVarint, 

1040 _SkipFixed64, 

1041 _SkipLengthDelimited, 

1042 _SkipGroup, 

1043 _EndGroup, 

1044 _SkipFixed32, 

1045 _RaiseInvalidWireType, 

1046 _RaiseInvalidWireType, 

1047 ] 

1048 

1049 wiretype_mask = wire_format.TAG_TYPE_MASK 

1050 

1051 def SkipField(buffer, pos, end, tag_bytes): 

1052 """Skips a field with the specified tag. 

1053 

1054 |pos| should point to the byte immediately after the tag. 

1055 

1056 Returns: 

1057 The new position (after the tag value), or -1 if the tag is an end-group 

1058 tag (in which case the calling loop should break). 

1059 """ 

1060 

1061 # The wire type is always in the first byte since varints are little-endian. 

1062 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 

1063 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 

1064 

1065 return SkipField 

1066 

1067SkipField = _FieldSkipper()