Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 11%

736 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:37 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# https://developers.google.com/protocol-buffers/ 

4# 

5# Redistribution and use in source and binary forms, with or without 

6# modification, are permitted provided that the following conditions are 

7# met: 

8# 

9# * Redistributions of source code must retain the above copyright 

10# notice, this list of conditions and the following disclaimer. 

11# * Redistributions in binary form must reproduce the above 

12# copyright notice, this list of conditions and the following disclaimer 

13# in the documentation and/or other materials provided with the 

14# distribution. 

15# * Neither the name of Google Inc. nor the names of its 

16# contributors may be used to endorse or promote products derived from 

17# this software without specific prior written permission. 

18# 

19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 

20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 

21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 

22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 

23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 

24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 

25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 

26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 

27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 

28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31# This code is meant to work on Python 2.4 and above only. 

32# 

33# TODO(robinson): Helpers for verbose, common checks like seeing if a 

34# descriptor's cpp_type is CPPTYPE_MESSAGE. 

35 

36"""Contains a metaclass and helper functions used to create 

37protocol message classes from Descriptor objects at runtime. 

38 

39Recall that a metaclass is the "type" of a class. 

40(A class is to a metaclass what an instance is to a class.) 

41 

42In this case, we use the GeneratedProtocolMessageType metaclass 

43to inject all the useful functionality into the classes 

44output by the protocol compiler at compile-time. 

45 

46The upshot of all this is that the real implementation 

47details for ALL pure-Python protocol buffers are *here in 

48this file*. 

49""" 

50 

51__author__ = 'robinson@google.com (Will Robinson)' 

52 

53from io import BytesIO 

54import struct 

55import sys 

56import weakref 

57 

58# We use "as" to avoid name collisions with variables. 

59from google.protobuf.internal import api_implementation 

60from google.protobuf.internal import containers 

61from google.protobuf.internal import decoder 

62from google.protobuf.internal import encoder 

63from google.protobuf.internal import enum_type_wrapper 

64from google.protobuf.internal import extension_dict 

65from google.protobuf.internal import message_listener as message_listener_mod 

66from google.protobuf.internal import type_checkers 

67from google.protobuf.internal import well_known_types 

68from google.protobuf.internal import wire_format 

69from google.protobuf import descriptor as descriptor_mod 

70from google.protobuf import message as message_mod 

71from google.protobuf import text_format 

72 

73_FieldDescriptor = descriptor_mod.FieldDescriptor 

74_AnyFullTypeName = 'google.protobuf.Any' 

75_ExtensionDict = extension_dict._ExtensionDict 

76 

77class GeneratedProtocolMessageType(type): 

78 

79 """Metaclass for protocol message classes created at runtime from Descriptors. 

80 

81 We add implementations for all methods described in the Message class. We 

82 also create properties to allow getting/setting all fields in the protocol 

83 message. Finally, we create slots to prevent users from accidentally 

84 "setting" nonexistent fields in the protocol message, which then wouldn't get 

85 serialized / deserialized properly. 

86 

87 The protocol compiler currently uses this metaclass to create protocol 

88 message classes at runtime. Clients can also manually create their own 

89 classes at runtime, as in this example: 

90 

91 mydescriptor = Descriptor(.....) 

92 factory = symbol_database.Default() 

93 factory.pool.AddDescriptor(mydescriptor) 

94 MyProtoClass = factory.GetPrototype(mydescriptor) 

95 myproto_instance = MyProtoClass() 

96 myproto.foo_field = 23 

97 ... 

98 """ 

99 

100 # Must be consistent with the protocol-compiler code in 

101 # proto2/compiler/internal/generator.*. 

102 _DESCRIPTOR_KEY = 'DESCRIPTOR' 

103 

104 def __new__(cls, name, bases, dictionary): 

105 """Custom allocation for runtime-generated class types. 

106 

107 We override __new__ because this is apparently the only place 

108 where we can meaningfully set __slots__ on the class we're creating(?). 

109 (The interplay between metaclasses and slots is not very well-documented). 

110 

111 Args: 

112 name: Name of the class (ignored, but required by the 

113 metaclass protocol). 

114 bases: Base classes of the class we're constructing. 

115 (Should be message.Message). We ignore this field, but 

116 it's required by the metaclass protocol 

117 dictionary: The class dictionary of the class we're 

118 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

119 a Descriptor object describing this protocol message 

120 type. 

121 

122 Returns: 

123 Newly-allocated class. 

124 

125 Raises: 

126 RuntimeError: Generated code only work with python cpp extension. 

127 """ 

128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

129 

130 if isinstance(descriptor, str): 

131 raise RuntimeError('The generated code only work with python cpp ' 

132 'extension, but it is using pure python runtime.') 

133 

134 # If a concrete class already exists for this descriptor, don't try to 

135 # create another. Doing so will break any messages that already exist with 

136 # the existing class. 

137 # 

138 # The C++ implementation appears to have its own internal `PyMessageFactory` 

139 # to achieve similar results. 

140 # 

141 # This most commonly happens in `text_format.py` when using descriptors from 

142 # a custom pool; it calls symbol_database.Global().getPrototype() on a 

143 # descriptor which already has an existing concrete class. 

144 new_class = getattr(descriptor, '_concrete_class', None) 

145 if new_class: 

146 return new_class 

147 

148 if descriptor.full_name in well_known_types.WKTBASES: 

149 bases += (well_known_types.WKTBASES[descriptor.full_name],) 

150 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 

151 _AddSlots(descriptor, dictionary) 

152 

153 superclass = super(GeneratedProtocolMessageType, cls) 

154 new_class = superclass.__new__(cls, name, bases, dictionary) 

155 return new_class 

156 

157 def __init__(cls, name, bases, dictionary): 

158 """Here we perform the majority of our work on the class. 

159 We add enum getters, an __init__ method, implementations 

160 of all Message methods, and properties for all fields 

161 in the protocol type. 

162 

163 Args: 

164 name: Name of the class (ignored, but required by the 

165 metaclass protocol). 

166 bases: Base classes of the class we're constructing. 

167 (Should be message.Message). We ignore this field, but 

168 it's required by the metaclass protocol 

169 dictionary: The class dictionary of the class we're 

170 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

171 a Descriptor object describing this protocol message 

172 type. 

173 """ 

174 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

175 

176 # If this is an _existing_ class looked up via `_concrete_class` in the 

177 # __new__ method above, then we don't need to re-initialize anything. 

178 existing_class = getattr(descriptor, '_concrete_class', None) 

179 if existing_class: 

180 assert existing_class is cls, ( 

181 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 

182 % (descriptor.full_name)) 

183 return 

184 

185 cls._decoders_by_tag = {} 

186 if (descriptor.has_options and 

187 descriptor.GetOptions().message_set_wire_format): 

188 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 

189 decoder.MessageSetItemDecoder(descriptor), None) 

190 

191 # Attach stuff to each FieldDescriptor for quick lookup later on. 

192 for field in descriptor.fields: 

193 _AttachFieldHelpers(cls, field) 

194 

195 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'): 

196 extensions = descriptor.file.pool.FindAllExtensions(descriptor) 

197 for ext in extensions: 

198 _AttachFieldHelpers(cls, ext) 

199 

200 descriptor._concrete_class = cls # pylint: disable=protected-access 

201 _AddEnumValues(descriptor, cls) 

202 _AddInitMethod(descriptor, cls) 

203 _AddPropertiesForFields(descriptor, cls) 

204 _AddPropertiesForExtensions(descriptor, cls) 

205 _AddStaticMethods(cls) 

206 _AddMessageMethods(descriptor, cls) 

207 _AddPrivateHelperMethods(descriptor, cls) 

208 

209 superclass = super(GeneratedProtocolMessageType, cls) 

210 superclass.__init__(name, bases, dictionary) 

211 

212 

213# Stateless helpers for GeneratedProtocolMessageType below. 

214# Outside clients should not access these directly. 

215# 

216# I opted not to make any of these methods on the metaclass, to make it more 

217# clear that I'm not really using any state there and to keep clients from 

218# thinking that they have direct access to these construction helpers. 

219 

220 

221def _PropertyName(proto_field_name): 

222 """Returns the name of the public property attribute which 

223 clients can use to get and (in some cases) set the value 

224 of a protocol message field. 

225 

226 Args: 

227 proto_field_name: The protocol message field name, exactly 

228 as it appears (or would appear) in a .proto file. 

229 """ 

230 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 

231 # nnorwitz makes my day by writing: 

232 # """ 

233 # FYI. See the keyword module in the stdlib. This could be as simple as: 

234 # 

235 # if keyword.iskeyword(proto_field_name): 

236 # return proto_field_name + "_" 

237 # return proto_field_name 

238 # """ 

239 # Kenton says: The above is a BAD IDEA. People rely on being able to use 

240 # getattr() and setattr() to reflectively manipulate field values. If we 

241 # rename the properties, then every such user has to also make sure to apply 

242 # the same transformation. Note that currently if you name a field "yield", 

243 # you can still access it just fine using getattr/setattr -- it's not even 

244 # that cumbersome to do so. 

245 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 

246 # position. 

247 return proto_field_name 

248 

249 

250def _AddSlots(message_descriptor, dictionary): 

251 """Adds a __slots__ entry to dictionary, containing the names of all valid 

252 attributes for this message type. 

253 

254 Args: 

255 message_descriptor: A Descriptor instance describing this message type. 

256 dictionary: Class dictionary to which we'll add a '__slots__' entry. 

257 """ 

258 dictionary['__slots__'] = ['_cached_byte_size', 

259 '_cached_byte_size_dirty', 

260 '_fields', 

261 '_unknown_fields', 

262 '_unknown_field_set', 

263 '_is_present_in_parent', 

264 '_listener', 

265 '_listener_for_children', 

266 '__weakref__', 

267 '_oneofs'] 

268 

269 

270def _IsMessageSetExtension(field): 

271 return (field.is_extension and 

272 field.containing_type.has_options and 

273 field.containing_type.GetOptions().message_set_wire_format and 

274 field.type == _FieldDescriptor.TYPE_MESSAGE and 

275 field.label == _FieldDescriptor.LABEL_OPTIONAL) 

276 

277 

278def _IsMapField(field): 

279 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 

280 field.message_type.has_options and 

281 field.message_type.GetOptions().map_entry) 

282 

283 

284def _IsMessageMapField(field): 

285 value_type = field.message_type.fields_by_name['value'] 

286 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 

287 

288 

289def _AttachFieldHelpers(cls, field_descriptor): 

290 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 

291 is_map_entry = _IsMapField(field_descriptor) 

292 is_packed = field_descriptor.is_packed 

293 

294 if is_map_entry: 

295 field_encoder = encoder.MapEncoder(field_descriptor) 

296 sizer = encoder.MapSizer(field_descriptor, 

297 _IsMessageMapField(field_descriptor)) 

298 elif _IsMessageSetExtension(field_descriptor): 

299 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 

300 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 

301 else: 

302 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 

303 field_descriptor.number, is_repeated, is_packed) 

304 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 

305 field_descriptor.number, is_repeated, is_packed) 

306 

307 field_descriptor._encoder = field_encoder 

308 field_descriptor._sizer = sizer 

309 field_descriptor._default_constructor = _DefaultValueConstructorForField( 

310 field_descriptor) 

311 

312 def AddDecoder(wiretype, is_packed): 

313 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 

314 decode_type = field_descriptor.type 

315 if (decode_type == _FieldDescriptor.TYPE_ENUM and 

316 not field_descriptor.enum_type.is_closed): 

317 decode_type = _FieldDescriptor.TYPE_INT32 

318 

319 oneof_descriptor = None 

320 if field_descriptor.containing_oneof is not None: 

321 oneof_descriptor = field_descriptor 

322 

323 if is_map_entry: 

324 is_message_map = _IsMessageMapField(field_descriptor) 

325 

326 field_decoder = decoder.MapDecoder( 

327 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 

328 is_message_map) 

329 elif decode_type == _FieldDescriptor.TYPE_STRING: 

330 field_decoder = decoder.StringDecoder( 

331 field_descriptor.number, is_repeated, is_packed, 

332 field_descriptor, field_descriptor._default_constructor, 

333 not field_descriptor.has_presence) 

334 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

335 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

336 field_descriptor.number, is_repeated, is_packed, 

337 field_descriptor, field_descriptor._default_constructor) 

338 else: 

339 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

340 field_descriptor.number, is_repeated, is_packed, 

341 # pylint: disable=protected-access 

342 field_descriptor, field_descriptor._default_constructor, 

343 not field_descriptor.has_presence) 

344 

345 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 

346 

347 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 

348 False) 

349 

350 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 

351 # To support wire compatibility of adding packed = true, add a decoder for 

352 # packed values regardless of the field's options. 

353 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 

354 

355 

356def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 

357 extensions = descriptor.extensions_by_name 

358 for extension_name, extension_field in extensions.items(): 

359 assert extension_name not in dictionary 

360 dictionary[extension_name] = extension_field 

361 

362 

363def _AddEnumValues(descriptor, cls): 

364 """Sets class-level attributes for all enum fields defined in this message. 

365 

366 Also exporting a class-level object that can name enum values. 

367 

368 Args: 

369 descriptor: Descriptor object for this message type. 

370 cls: Class we're constructing for this message type. 

371 """ 

372 for enum_type in descriptor.enum_types: 

373 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 

374 for enum_value in enum_type.values: 

375 setattr(cls, enum_value.name, enum_value.number) 

376 

377 

378def _GetInitializeDefaultForMap(field): 

379 if field.label != _FieldDescriptor.LABEL_REPEATED: 

380 raise ValueError('map_entry set on non-repeated field %s' % ( 

381 field.name)) 

382 fields_by_name = field.message_type.fields_by_name 

383 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 

384 

385 value_field = fields_by_name['value'] 

386 if _IsMessageMapField(field): 

387 def MakeMessageMapDefault(message): 

388 return containers.MessageMap( 

389 message._listener_for_children, value_field.message_type, key_checker, 

390 field.message_type) 

391 return MakeMessageMapDefault 

392 else: 

393 value_checker = type_checkers.GetTypeChecker(value_field) 

394 def MakePrimitiveMapDefault(message): 

395 return containers.ScalarMap( 

396 message._listener_for_children, key_checker, value_checker, 

397 field.message_type) 

398 return MakePrimitiveMapDefault 

399 

400def _DefaultValueConstructorForField(field): 

401 """Returns a function which returns a default value for a field. 

402 

403 Args: 

404 field: FieldDescriptor object for this field. 

405 

406 The returned function has one argument: 

407 message: Message instance containing this field, or a weakref proxy 

408 of same. 

409 

410 That function in turn returns a default value for this field. The default 

411 value may refer back to |message| via a weak reference. 

412 """ 

413 

414 if _IsMapField(field): 

415 return _GetInitializeDefaultForMap(field) 

416 

417 if field.label == _FieldDescriptor.LABEL_REPEATED: 

418 if field.has_default_value and field.default_value != []: 

419 raise ValueError('Repeated field default value not empty list: %s' % ( 

420 field.default_value)) 

421 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

422 # We can't look at _concrete_class yet since it might not have 

423 # been set. (Depends on order in which we initialize the classes). 

424 message_type = field.message_type 

425 def MakeRepeatedMessageDefault(message): 

426 return containers.RepeatedCompositeFieldContainer( 

427 message._listener_for_children, field.message_type) 

428 return MakeRepeatedMessageDefault 

429 else: 

430 type_checker = type_checkers.GetTypeChecker(field) 

431 def MakeRepeatedScalarDefault(message): 

432 return containers.RepeatedScalarFieldContainer( 

433 message._listener_for_children, type_checker) 

434 return MakeRepeatedScalarDefault 

435 

436 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

437 message_type = field.message_type 

438 def MakeSubMessageDefault(message): 

439 # _concrete_class may not yet be initialized. 

440 if not hasattr(message_type, '_concrete_class'): 

441 from google.protobuf import message_factory 

442 message_factory.GetMessageClass(message_type) 

443 result = message_type._concrete_class() 

444 result._SetListener( 

445 _OneofListener(message, field) 

446 if field.containing_oneof is not None 

447 else message._listener_for_children) 

448 return result 

449 return MakeSubMessageDefault 

450 

451 def MakeScalarDefault(message): 

452 # TODO(protobuf-team): This may be broken since there may not be 

453 # default_value. Combine with has_default_value somehow. 

454 return field.default_value 

455 return MakeScalarDefault 

456 

457 

458def _ReraiseTypeErrorWithFieldName(message_name, field_name): 

459 """Re-raise the currently-handled TypeError with the field name added.""" 

460 exc = sys.exc_info()[1] 

461 if len(exc.args) == 1 and type(exc) is TypeError: 

462 # simple TypeError; add field name to exception message 

463 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 

464 

465 # re-raise possibly-amended exception with original traceback: 

466 raise exc.with_traceback(sys.exc_info()[2]) 

467 

468 

469def _AddInitMethod(message_descriptor, cls): 

470 """Adds an __init__ method to cls.""" 

471 

472 def _GetIntegerEnumValue(enum_type, value): 

473 """Convert a string or integer enum value to an integer. 

474 

475 If the value is a string, it is converted to the enum value in 

476 enum_type with the same name. If the value is not a string, it's 

477 returned as-is. (No conversion or bounds-checking is done.) 

478 """ 

479 if isinstance(value, str): 

480 try: 

481 return enum_type.values_by_name[value].number 

482 except KeyError: 

483 raise ValueError('Enum type %s: unknown label "%s"' % ( 

484 enum_type.full_name, value)) 

485 return value 

486 

487 def init(self, **kwargs): 

488 self._cached_byte_size = 0 

489 self._cached_byte_size_dirty = len(kwargs) > 0 

490 self._fields = {} 

491 # Contains a mapping from oneof field descriptors to the descriptor 

492 # of the currently set field in that oneof field. 

493 self._oneofs = {} 

494 

495 # _unknown_fields is () when empty for efficiency, and will be turned into 

496 # a list if fields are added. 

497 self._unknown_fields = () 

498 # _unknown_field_set is None when empty for efficiency, and will be 

499 # turned into UnknownFieldSet struct if fields are added. 

500 self._unknown_field_set = None # pylint: disable=protected-access 

501 self._is_present_in_parent = False 

502 self._listener = message_listener_mod.NullMessageListener() 

503 self._listener_for_children = _Listener(self) 

504 for field_name, field_value in kwargs.items(): 

505 field = _GetFieldByName(message_descriptor, field_name) 

506 if field is None: 

507 raise TypeError('%s() got an unexpected keyword argument "%s"' % 

508 (message_descriptor.name, field_name)) 

509 if field_value is None: 

510 # field=None is the same as no field at all. 

511 continue 

512 if field.label == _FieldDescriptor.LABEL_REPEATED: 

513 copy = field._default_constructor(self) 

514 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 

515 if _IsMapField(field): 

516 if _IsMessageMapField(field): 

517 for key in field_value: 

518 copy[key].MergeFrom(field_value[key]) 

519 else: 

520 copy.update(field_value) 

521 else: 

522 for val in field_value: 

523 if isinstance(val, dict): 

524 copy.add(**val) 

525 else: 

526 copy.add().MergeFrom(val) 

527 else: # Scalar 

528 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

529 field_value = [_GetIntegerEnumValue(field.enum_type, val) 

530 for val in field_value] 

531 copy.extend(field_value) 

532 self._fields[field] = copy 

533 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

534 copy = field._default_constructor(self) 

535 new_val = field_value 

536 if isinstance(field_value, dict): 

537 new_val = field.message_type._concrete_class(**field_value) 

538 try: 

539 copy.MergeFrom(new_val) 

540 except TypeError: 

541 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

542 self._fields[field] = copy 

543 else: 

544 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

545 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 

546 try: 

547 setattr(self, field_name, field_value) 

548 except TypeError: 

549 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

550 

551 init.__module__ = None 

552 init.__doc__ = None 

553 cls.__init__ = init 

554 

555 

556def _GetFieldByName(message_descriptor, field_name): 

557 """Returns a field descriptor by field name. 

558 

559 Args: 

560 message_descriptor: A Descriptor describing all fields in message. 

561 field_name: The name of the field to retrieve. 

562 Returns: 

563 The field descriptor associated with the field name. 

564 """ 

565 try: 

566 return message_descriptor.fields_by_name[field_name] 

567 except KeyError: 

568 raise ValueError('Protocol message %s has no "%s" field.' % 

569 (message_descriptor.name, field_name)) 

570 

571 

572def _AddPropertiesForFields(descriptor, cls): 

573 """Adds properties for all fields in this protocol message type.""" 

574 for field in descriptor.fields: 

575 _AddPropertiesForField(field, cls) 

576 

577 if descriptor.is_extendable: 

578 # _ExtensionDict is just an adaptor with no state so we allocate a new one 

579 # every time it is accessed. 

580 cls.Extensions = property(lambda self: _ExtensionDict(self)) 

581 

582 

583def _AddPropertiesForField(field, cls): 

584 """Adds a public property for a protocol message field. 

585 Clients can use this property to get and (in the case 

586 of non-repeated scalar fields) directly set the value 

587 of a protocol message field. 

588 

589 Args: 

590 field: A FieldDescriptor for this field. 

591 cls: The class we're constructing. 

592 """ 

593 # Catch it if we add other types that we should 

594 # handle specially here. 

595 assert _FieldDescriptor.MAX_CPPTYPE == 10 

596 

597 constant_name = field.name.upper() + '_FIELD_NUMBER' 

598 setattr(cls, constant_name, field.number) 

599 

600 if field.label == _FieldDescriptor.LABEL_REPEATED: 

601 _AddPropertiesForRepeatedField(field, cls) 

602 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

603 _AddPropertiesForNonRepeatedCompositeField(field, cls) 

604 else: 

605 _AddPropertiesForNonRepeatedScalarField(field, cls) 

606 

607 

608class _FieldProperty(property): 

609 __slots__ = ('DESCRIPTOR',) 

610 

611 def __init__(self, descriptor, getter, setter, doc): 

612 property.__init__(self, getter, setter, doc=doc) 

613 self.DESCRIPTOR = descriptor 

614 

615 

616def _AddPropertiesForRepeatedField(field, cls): 

617 """Adds a public property for a "repeated" protocol message field. Clients 

618 can use this property to get the value of the field, which will be either a 

619 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 

620 below). 

621 

622 Note that when clients add values to these containers, we perform 

623 type-checking in the case of repeated scalar fields, and we also set any 

624 necessary "has" bits as a side-effect. 

625 

626 Args: 

627 field: A FieldDescriptor for this field. 

628 cls: The class we're constructing. 

629 """ 

630 proto_field_name = field.name 

631 property_name = _PropertyName(proto_field_name) 

632 

633 def getter(self): 

634 field_value = self._fields.get(field) 

635 if field_value is None: 

636 # Construct a new object to represent this field. 

637 field_value = field._default_constructor(self) 

638 

639 # Atomically check if another thread has preempted us and, if not, swap 

640 # in the new object we just created. If someone has preempted us, we 

641 # take that object and discard ours. 

642 # WARNING: We are relying on setdefault() being atomic. This is true 

643 # in CPython but we haven't investigated others. This warning appears 

644 # in several other locations in this file. 

645 field_value = self._fields.setdefault(field, field_value) 

646 return field_value 

647 getter.__module__ = None 

648 getter.__doc__ = 'Getter for %s.' % proto_field_name 

649 

650 # We define a setter just so we can throw an exception with a more 

651 # helpful error message. 

652 def setter(self, new_value): 

653 raise AttributeError('Assignment not allowed to repeated field ' 

654 '"%s" in protocol message object.' % proto_field_name) 

655 

656 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

657 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

658 

659 

660def _AddPropertiesForNonRepeatedScalarField(field, cls): 

661 """Adds a public property for a nonrepeated, scalar protocol message field. 

662 Clients can use this property to get and directly set the value of the field. 

663 Note that when the client sets the value of a field by using this property, 

664 all necessary "has" bits are set as a side-effect, and we also perform 

665 type-checking. 

666 

667 Args: 

668 field: A FieldDescriptor for this field. 

669 cls: The class we're constructing. 

670 """ 

671 proto_field_name = field.name 

672 property_name = _PropertyName(proto_field_name) 

673 type_checker = type_checkers.GetTypeChecker(field) 

674 default_value = field.default_value 

675 

676 def getter(self): 

677 # TODO(protobuf-team): This may be broken since there may not be 

678 # default_value. Combine with has_default_value somehow. 

679 return self._fields.get(field, default_value) 

680 getter.__module__ = None 

681 getter.__doc__ = 'Getter for %s.' % proto_field_name 

682 

683 def field_setter(self, new_value): 

684 # pylint: disable=protected-access 

685 # Testing the value for truthiness captures all of the proto3 defaults 

686 # (0, 0.0, enum 0, and False). 

687 try: 

688 new_value = type_checker.CheckValue(new_value) 

689 except TypeError as e: 

690 raise TypeError( 

691 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 

692 if not field.has_presence and not new_value: 

693 self._fields.pop(field, None) 

694 else: 

695 self._fields[field] = new_value 

696 # Check _cached_byte_size_dirty inline to improve performance, since scalar 

697 # setters are called frequently. 

698 if not self._cached_byte_size_dirty: 

699 self._Modified() 

700 

701 if field.containing_oneof: 

702 def setter(self, new_value): 

703 field_setter(self, new_value) 

704 self._UpdateOneofState(field) 

705 else: 

706 setter = field_setter 

707 

708 setter.__module__ = None 

709 setter.__doc__ = 'Setter for %s.' % proto_field_name 

710 

711 # Add a property to encapsulate the getter/setter. 

712 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

713 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

714 

715 

716def _AddPropertiesForNonRepeatedCompositeField(field, cls): 

717 """Adds a public property for a nonrepeated, composite protocol message field. 

718 A composite field is a "group" or "message" field. 

719 

720 Clients can use this property to get the value of the field, but cannot 

721 assign to the property directly. 

722 

723 Args: 

724 field: A FieldDescriptor for this field. 

725 cls: The class we're constructing. 

726 """ 

727 # TODO(robinson): Remove duplication with similar method 

728 # for non-repeated scalars. 

729 proto_field_name = field.name 

730 property_name = _PropertyName(proto_field_name) 

731 

732 def getter(self): 

733 field_value = self._fields.get(field) 

734 if field_value is None: 

735 # Construct a new object to represent this field. 

736 field_value = field._default_constructor(self) 

737 

738 # Atomically check if another thread has preempted us and, if not, swap 

739 # in the new object we just created. If someone has preempted us, we 

740 # take that object and discard ours. 

741 # WARNING: We are relying on setdefault() being atomic. This is true 

742 # in CPython but we haven't investigated others. This warning appears 

743 # in several other locations in this file. 

744 field_value = self._fields.setdefault(field, field_value) 

745 return field_value 

746 getter.__module__ = None 

747 getter.__doc__ = 'Getter for %s.' % proto_field_name 

748 

749 # We define a setter just so we can throw an exception with a more 

750 # helpful error message. 

751 def setter(self, new_value): 

752 raise AttributeError('Assignment not allowed to composite field ' 

753 '"%s" in protocol message object.' % proto_field_name) 

754 

755 # Add a property to encapsulate the getter. 

756 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

757 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

758 

759 

760def _AddPropertiesForExtensions(descriptor, cls): 

761 """Adds properties for all fields in this protocol message type.""" 

762 extensions = descriptor.extensions_by_name 

763 for extension_name, extension_field in extensions.items(): 

764 constant_name = extension_name.upper() + '_FIELD_NUMBER' 

765 setattr(cls, constant_name, extension_field.number) 

766 

767 # TODO(amauryfa): Migrate all users of these attributes to functions like 

768 # pool.FindExtensionByNumber(descriptor). 

769 if descriptor.file is not None: 

770 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 

771 pool = descriptor.file.pool 

772 

773def _AddStaticMethods(cls): 

774 # TODO(robinson): This probably needs to be thread-safe(?) 

775 def RegisterExtension(field_descriptor): 

776 field_descriptor.containing_type = cls.DESCRIPTOR 

777 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 

778 # pylint: disable=protected-access 

779 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor) 

780 _AttachFieldHelpers(cls, field_descriptor) 

781 cls.RegisterExtension = staticmethod(RegisterExtension) 

782 

783 def FromString(s): 

784 message = cls() 

785 message.MergeFromString(s) 

786 return message 

787 cls.FromString = staticmethod(FromString) 

788 

789 

790def _IsPresent(item): 

791 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 

792 value should be included in the list returned by ListFields().""" 

793 

794 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 

795 return bool(item[1]) 

796 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

797 return item[1]._is_present_in_parent 

798 else: 

799 return True 

800 

801 

802def _AddListFieldsMethod(message_descriptor, cls): 

803 """Helper for _AddMessageMethods().""" 

804 

805 def ListFields(self): 

806 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 

807 all_fields.sort(key = lambda item: item[0].number) 

808 return all_fields 

809 

810 cls.ListFields = ListFields 

811 

812 

813def _AddHasFieldMethod(message_descriptor, cls): 

814 """Helper for _AddMessageMethods().""" 

815 

816 hassable_fields = {} 

817 for field in message_descriptor.fields: 

818 if field.label == _FieldDescriptor.LABEL_REPEATED: 

819 continue 

820 # For proto3, only submessages and fields inside a oneof have presence. 

821 if not field.has_presence: 

822 continue 

823 hassable_fields[field.name] = field 

824 

825 # Has methods are supported for oneof descriptors. 

826 for oneof in message_descriptor.oneofs: 

827 hassable_fields[oneof.name] = oneof 

828 

829 def HasField(self, field_name): 

830 try: 

831 field = hassable_fields[field_name] 

832 except KeyError as exc: 

833 raise ValueError('Protocol message %s has no non-repeated field "%s" ' 

834 'nor has presence is not available for this field.' % ( 

835 message_descriptor.full_name, field_name)) from exc 

836 

837 if isinstance(field, descriptor_mod.OneofDescriptor): 

838 try: 

839 return HasField(self, self._oneofs[field].name) 

840 except KeyError: 

841 return False 

842 else: 

843 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

844 value = self._fields.get(field) 

845 return value is not None and value._is_present_in_parent 

846 else: 

847 return field in self._fields 

848 

849 cls.HasField = HasField 

850 

851 

852def _AddClearFieldMethod(message_descriptor, cls): 

853 """Helper for _AddMessageMethods().""" 

854 def ClearField(self, field_name): 

855 try: 

856 field = message_descriptor.fields_by_name[field_name] 

857 except KeyError: 

858 try: 

859 field = message_descriptor.oneofs_by_name[field_name] 

860 if field in self._oneofs: 

861 field = self._oneofs[field] 

862 else: 

863 return 

864 except KeyError: 

865 raise ValueError('Protocol message %s has no "%s" field.' % 

866 (message_descriptor.name, field_name)) 

867 

868 if field in self._fields: 

869 # To match the C++ implementation, we need to invalidate iterators 

870 # for map fields when ClearField() happens. 

871 if hasattr(self._fields[field], 'InvalidateIterators'): 

872 self._fields[field].InvalidateIterators() 

873 

874 # Note: If the field is a sub-message, its listener will still point 

875 # at us. That's fine, because the worst than can happen is that it 

876 # will call _Modified() and invalidate our byte size. Big deal. 

877 del self._fields[field] 

878 

879 if self._oneofs.get(field.containing_oneof, None) is field: 

880 del self._oneofs[field.containing_oneof] 

881 

882 # Always call _Modified() -- even if nothing was changed, this is 

883 # a mutating method, and thus calling it should cause the field to become 

884 # present in the parent message. 

885 self._Modified() 

886 

887 cls.ClearField = ClearField 

888 

889 

890def _AddClearExtensionMethod(cls): 

891 """Helper for _AddMessageMethods().""" 

892 def ClearExtension(self, field_descriptor): 

893 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

894 

895 # Similar to ClearField(), above. 

896 if field_descriptor in self._fields: 

897 del self._fields[field_descriptor] 

898 self._Modified() 

899 cls.ClearExtension = ClearExtension 

900 

901 

902def _AddHasExtensionMethod(cls): 

903 """Helper for _AddMessageMethods().""" 

904 def HasExtension(self, field_descriptor): 

905 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

906 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: 

907 raise KeyError('"%s" is repeated.' % field_descriptor.full_name) 

908 

909 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

910 value = self._fields.get(field_descriptor) 

911 return value is not None and value._is_present_in_parent 

912 else: 

913 return field_descriptor in self._fields 

914 cls.HasExtension = HasExtension 

915 

916def _InternalUnpackAny(msg): 

917 """Unpacks Any message and returns the unpacked message. 

918 

919 This internal method is different from public Any Unpack method which takes 

920 the target message as argument. _InternalUnpackAny method does not have 

921 target message type and need to find the message type in descriptor pool. 

922 

923 Args: 

924 msg: An Any message to be unpacked. 

925 

926 Returns: 

927 The unpacked message. 

928 """ 

929 # TODO(amauryfa): Don't use the factory of generated messages. 

930 # To make Any work with custom factories, use the message factory of the 

931 # parent message. 

932 # pylint: disable=g-import-not-at-top 

933 from google.protobuf import symbol_database 

934 factory = symbol_database.Default() 

935 

936 type_url = msg.type_url 

937 

938 if not type_url: 

939 return None 

940 

941 # TODO(haberman): For now we just strip the hostname. Better logic will be 

942 # required. 

943 type_name = type_url.split('/')[-1] 

944 descriptor = factory.pool.FindMessageTypeByName(type_name) 

945 

946 if descriptor is None: 

947 return None 

948 

949 message_class = factory.GetPrototype(descriptor) 

950 message = message_class() 

951 

952 message.ParseFromString(msg.value) 

953 return message 

954 

955 

956def _AddEqualsMethod(message_descriptor, cls): 

957 """Helper for _AddMessageMethods().""" 

958 def __eq__(self, other): 

959 if (not isinstance(other, message_mod.Message) or 

960 other.DESCRIPTOR != self.DESCRIPTOR): 

961 return False 

962 

963 if self is other: 

964 return True 

965 

966 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 

967 any_a = _InternalUnpackAny(self) 

968 any_b = _InternalUnpackAny(other) 

969 if any_a and any_b: 

970 return any_a == any_b 

971 

972 if not self.ListFields() == other.ListFields(): 

973 return False 

974 

975 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, 

976 # then use it for the comparison. 

977 unknown_fields = list(self._unknown_fields) 

978 unknown_fields.sort() 

979 other_unknown_fields = list(other._unknown_fields) 

980 other_unknown_fields.sort() 

981 return unknown_fields == other_unknown_fields 

982 

983 cls.__eq__ = __eq__ 

984 

985 

986def _AddStrMethod(message_descriptor, cls): 

987 """Helper for _AddMessageMethods().""" 

988 def __str__(self): 

989 return text_format.MessageToString(self) 

990 cls.__str__ = __str__ 

991 

992 

993def _AddReprMethod(message_descriptor, cls): 

994 """Helper for _AddMessageMethods().""" 

995 def __repr__(self): 

996 return text_format.MessageToString(self) 

997 cls.__repr__ = __repr__ 

998 

999 

1000def _AddUnicodeMethod(unused_message_descriptor, cls): 

1001 """Helper for _AddMessageMethods().""" 

1002 

1003 def __unicode__(self): 

1004 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 

1005 cls.__unicode__ = __unicode__ 

1006 

1007 

1008def _BytesForNonRepeatedElement(value, field_number, field_type): 

1009 """Returns the number of bytes needed to serialize a non-repeated element. 

1010 The returned byte count includes space for tag information and any 

1011 other additional space associated with serializing value. 

1012 

1013 Args: 

1014 value: Value we're serializing. 

1015 field_number: Field number of this value. (Since the field number 

1016 is stored as part of a varint-encoded tag, this has an impact 

1017 on the total bytes required to serialize the value). 

1018 field_type: The type of the field. One of the TYPE_* constants 

1019 within FieldDescriptor. 

1020 """ 

1021 try: 

1022 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 

1023 return fn(field_number, value) 

1024 except KeyError: 

1025 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 

1026 

1027 

1028def _AddByteSizeMethod(message_descriptor, cls): 

1029 """Helper for _AddMessageMethods().""" 

1030 

1031 def ByteSize(self): 

1032 if not self._cached_byte_size_dirty: 

1033 return self._cached_byte_size 

1034 

1035 size = 0 

1036 descriptor = self.DESCRIPTOR 

1037 if descriptor.GetOptions().map_entry: 

1038 # Fields of map entry should always be serialized. 

1039 size = descriptor.fields_by_name['key']._sizer(self.key) 

1040 size += descriptor.fields_by_name['value']._sizer(self.value) 

1041 else: 

1042 for field_descriptor, field_value in self.ListFields(): 

1043 size += field_descriptor._sizer(field_value) 

1044 for tag_bytes, value_bytes in self._unknown_fields: 

1045 size += len(tag_bytes) + len(value_bytes) 

1046 

1047 self._cached_byte_size = size 

1048 self._cached_byte_size_dirty = False 

1049 self._listener_for_children.dirty = False 

1050 return size 

1051 

1052 cls.ByteSize = ByteSize 

1053 

1054 

1055def _AddSerializeToStringMethod(message_descriptor, cls): 

1056 """Helper for _AddMessageMethods().""" 

1057 

1058 def SerializeToString(self, **kwargs): 

1059 # Check if the message has all of its required fields set. 

1060 if not self.IsInitialized(): 

1061 raise message_mod.EncodeError( 

1062 'Message %s is missing required fields: %s' % ( 

1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 

1064 return self.SerializePartialToString(**kwargs) 

1065 cls.SerializeToString = SerializeToString 

1066 

1067 

1068def _AddSerializePartialToStringMethod(message_descriptor, cls): 

1069 """Helper for _AddMessageMethods().""" 

1070 

1071 def SerializePartialToString(self, **kwargs): 

1072 out = BytesIO() 

1073 self._InternalSerialize(out.write, **kwargs) 

1074 return out.getvalue() 

1075 cls.SerializePartialToString = SerializePartialToString 

1076 

1077 def InternalSerialize(self, write_bytes, deterministic=None): 

1078 if deterministic is None: 

1079 deterministic = ( 

1080 api_implementation.IsPythonDefaultSerializationDeterministic()) 

1081 else: 

1082 deterministic = bool(deterministic) 

1083 

1084 descriptor = self.DESCRIPTOR 

1085 if descriptor.GetOptions().map_entry: 

1086 # Fields of map entry should always be serialized. 

1087 descriptor.fields_by_name['key']._encoder( 

1088 write_bytes, self.key, deterministic) 

1089 descriptor.fields_by_name['value']._encoder( 

1090 write_bytes, self.value, deterministic) 

1091 else: 

1092 for field_descriptor, field_value in self.ListFields(): 

1093 field_descriptor._encoder(write_bytes, field_value, deterministic) 

1094 for tag_bytes, value_bytes in self._unknown_fields: 

1095 write_bytes(tag_bytes) 

1096 write_bytes(value_bytes) 

1097 cls._InternalSerialize = InternalSerialize 

1098 

1099 

1100def _AddMergeFromStringMethod(message_descriptor, cls): 

1101 """Helper for _AddMessageMethods().""" 

1102 def MergeFromString(self, serialized): 

1103 serialized = memoryview(serialized) 

1104 length = len(serialized) 

1105 try: 

1106 if self._InternalParse(serialized, 0, length) != length: 

1107 # The only reason _InternalParse would return early is if it 

1108 # encountered an end-group tag. 

1109 raise message_mod.DecodeError('Unexpected end-group tag.') 

1110 except (IndexError, TypeError): 

1111 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 

1112 raise message_mod.DecodeError('Truncated message.') 

1113 except struct.error as e: 

1114 raise message_mod.DecodeError(e) 

1115 return length # Return this for legacy reasons. 

1116 cls.MergeFromString = MergeFromString 

1117 

1118 local_ReadTag = decoder.ReadTag 

1119 local_SkipField = decoder.SkipField 

1120 decoders_by_tag = cls._decoders_by_tag 

1121 

1122 def InternalParse(self, buffer, pos, end): 

1123 """Create a message from serialized bytes. 

1124 

1125 Args: 

1126 self: Message, instance of the proto message object. 

1127 buffer: memoryview of the serialized data. 

1128 pos: int, position to start in the serialized data. 

1129 end: int, end position of the serialized data. 

1130 

1131 Returns: 

1132 Message object. 

1133 """ 

1134 # Guard against internal misuse, since this function is called internally 

1135 # quite extensively, and its easy to accidentally pass bytes. 

1136 assert isinstance(buffer, memoryview) 

1137 self._Modified() 

1138 field_dict = self._fields 

1139 # pylint: disable=protected-access 

1140 unknown_field_set = self._unknown_field_set 

1141 while pos != end: 

1142 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 

1143 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 

1144 if field_decoder is None: 

1145 if not self._unknown_fields: # pylint: disable=protected-access 

1146 self._unknown_fields = [] # pylint: disable=protected-access 

1147 if unknown_field_set is None: 

1148 # pylint: disable=protected-access 

1149 self._unknown_field_set = containers.UnknownFieldSet() 

1150 # pylint: disable=protected-access 

1151 unknown_field_set = self._unknown_field_set 

1152 # pylint: disable=protected-access 

1153 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 

1154 field_number, wire_type = wire_format.UnpackTag(tag) 

1155 if field_number == 0: 

1156 raise message_mod.DecodeError('Field number 0 is illegal.') 

1157 # TODO(jieluo): remove old_pos. 

1158 old_pos = new_pos 

1159 (data, new_pos) = decoder._DecodeUnknownField( 

1160 buffer, new_pos, wire_type) # pylint: disable=protected-access 

1161 if new_pos == -1: 

1162 return pos 

1163 # pylint: disable=protected-access 

1164 unknown_field_set._add(field_number, wire_type, data) 

1165 # TODO(jieluo): remove _unknown_fields. 

1166 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 

1167 if new_pos == -1: 

1168 return pos 

1169 self._unknown_fields.append( 

1170 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 

1171 pos = new_pos 

1172 else: 

1173 pos = field_decoder(buffer, new_pos, end, self, field_dict) 

1174 if field_desc: 

1175 self._UpdateOneofState(field_desc) 

1176 return pos 

1177 cls._InternalParse = InternalParse 

1178 

1179 

1180def _AddIsInitializedMethod(message_descriptor, cls): 

1181 """Adds the IsInitialized and FindInitializationError methods to the 

1182 protocol message class.""" 

1183 

1184 required_fields = [field for field in message_descriptor.fields 

1185 if field.label == _FieldDescriptor.LABEL_REQUIRED] 

1186 

1187 def IsInitialized(self, errors=None): 

1188 """Checks if all required fields of a message are set. 

1189 

1190 Args: 

1191 errors: A list which, if provided, will be populated with the field 

1192 paths of all missing required fields. 

1193 

1194 Returns: 

1195 True iff the specified message has all required fields set. 

1196 """ 

1197 

1198 # Performance is critical so we avoid HasField() and ListFields(). 

1199 

1200 for field in required_fields: 

1201 if (field not in self._fields or 

1202 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 

1203 not self._fields[field]._is_present_in_parent)): 

1204 if errors is not None: 

1205 errors.extend(self.FindInitializationErrors()) 

1206 return False 

1207 

1208 for field, value in list(self._fields.items()): # dict can change size! 

1209 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1210 if field.label == _FieldDescriptor.LABEL_REPEATED: 

1211 if (field.message_type.has_options and 

1212 field.message_type.GetOptions().map_entry): 

1213 continue 

1214 for element in value: 

1215 if not element.IsInitialized(): 

1216 if errors is not None: 

1217 errors.extend(self.FindInitializationErrors()) 

1218 return False 

1219 elif value._is_present_in_parent and not value.IsInitialized(): 

1220 if errors is not None: 

1221 errors.extend(self.FindInitializationErrors()) 

1222 return False 

1223 

1224 return True 

1225 

1226 cls.IsInitialized = IsInitialized 

1227 

1228 def FindInitializationErrors(self): 

1229 """Finds required fields which are not initialized. 

1230 

1231 Returns: 

1232 A list of strings. Each string is a path to an uninitialized field from 

1233 the top-level message, e.g. "foo.bar[5].baz". 

1234 """ 

1235 

1236 errors = [] # simplify things 

1237 

1238 for field in required_fields: 

1239 if not self.HasField(field.name): 

1240 errors.append(field.name) 

1241 

1242 for field, value in self.ListFields(): 

1243 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1244 if field.is_extension: 

1245 name = '(%s)' % field.full_name 

1246 else: 

1247 name = field.name 

1248 

1249 if _IsMapField(field): 

1250 if _IsMessageMapField(field): 

1251 for key in value: 

1252 element = value[key] 

1253 prefix = '%s[%s].' % (name, key) 

1254 sub_errors = element.FindInitializationErrors() 

1255 errors += [prefix + error for error in sub_errors] 

1256 else: 

1257 # ScalarMaps can't have any initialization errors. 

1258 pass 

1259 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1260 for i in range(len(value)): 

1261 element = value[i] 

1262 prefix = '%s[%d].' % (name, i) 

1263 sub_errors = element.FindInitializationErrors() 

1264 errors += [prefix + error for error in sub_errors] 

1265 else: 

1266 prefix = name + '.' 

1267 sub_errors = value.FindInitializationErrors() 

1268 errors += [prefix + error for error in sub_errors] 

1269 

1270 return errors 

1271 

1272 cls.FindInitializationErrors = FindInitializationErrors 

1273 

1274 

1275def _FullyQualifiedClassName(klass): 

1276 module = klass.__module__ 

1277 name = getattr(klass, '__qualname__', klass.__name__) 

1278 if module in (None, 'builtins', '__builtin__'): 

1279 return name 

1280 return module + '.' + name 

1281 

1282 

1283def _AddMergeFromMethod(cls): 

1284 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 

1285 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 

1286 

1287 def MergeFrom(self, msg): 

1288 if not isinstance(msg, cls): 

1289 raise TypeError( 

1290 'Parameter to MergeFrom() must be instance of same class: ' 

1291 'expected %s got %s.' % (_FullyQualifiedClassName(cls), 

1292 _FullyQualifiedClassName(msg.__class__))) 

1293 

1294 assert msg is not self 

1295 self._Modified() 

1296 

1297 fields = self._fields 

1298 

1299 for field, value in msg._fields.items(): 

1300 if field.label == LABEL_REPEATED: 

1301 field_value = fields.get(field) 

1302 if field_value is None: 

1303 # Construct a new object to represent this field. 

1304 field_value = field._default_constructor(self) 

1305 fields[field] = field_value 

1306 field_value.MergeFrom(value) 

1307 elif field.cpp_type == CPPTYPE_MESSAGE: 

1308 if value._is_present_in_parent: 

1309 field_value = fields.get(field) 

1310 if field_value is None: 

1311 # Construct a new object to represent this field. 

1312 field_value = field._default_constructor(self) 

1313 fields[field] = field_value 

1314 field_value.MergeFrom(value) 

1315 else: 

1316 self._fields[field] = value 

1317 if field.containing_oneof: 

1318 self._UpdateOneofState(field) 

1319 

1320 if msg._unknown_fields: 

1321 if not self._unknown_fields: 

1322 self._unknown_fields = [] 

1323 self._unknown_fields.extend(msg._unknown_fields) 

1324 # pylint: disable=protected-access 

1325 if self._unknown_field_set is None: 

1326 self._unknown_field_set = containers.UnknownFieldSet() 

1327 self._unknown_field_set._extend(msg._unknown_field_set) 

1328 

1329 cls.MergeFrom = MergeFrom 

1330 

1331 

1332def _AddWhichOneofMethod(message_descriptor, cls): 

1333 def WhichOneof(self, oneof_name): 

1334 """Returns the name of the currently set field inside a oneof, or None.""" 

1335 try: 

1336 field = message_descriptor.oneofs_by_name[oneof_name] 

1337 except KeyError: 

1338 raise ValueError( 

1339 'Protocol message has no oneof "%s" field.' % oneof_name) 

1340 

1341 nested_field = self._oneofs.get(field, None) 

1342 if nested_field is not None and self.HasField(nested_field.name): 

1343 return nested_field.name 

1344 else: 

1345 return None 

1346 

1347 cls.WhichOneof = WhichOneof 

1348 

1349 

1350def _Clear(self): 

1351 # Clear fields. 

1352 self._fields = {} 

1353 self._unknown_fields = () 

1354 # pylint: disable=protected-access 

1355 if self._unknown_field_set is not None: 

1356 self._unknown_field_set._clear() 

1357 self._unknown_field_set = None 

1358 

1359 self._oneofs = {} 

1360 self._Modified() 

1361 

1362 

1363def _UnknownFields(self): 

1364 if self._unknown_field_set is None: # pylint: disable=protected-access 

1365 # pylint: disable=protected-access 

1366 self._unknown_field_set = containers.UnknownFieldSet() 

1367 return self._unknown_field_set # pylint: disable=protected-access 

1368 

1369 

1370def _DiscardUnknownFields(self): 

1371 self._unknown_fields = [] 

1372 self._unknown_field_set = None # pylint: disable=protected-access 

1373 for field, value in self.ListFields(): 

1374 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1375 if _IsMapField(field): 

1376 if _IsMessageMapField(field): 

1377 for key in value: 

1378 value[key].DiscardUnknownFields() 

1379 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1380 for sub_message in value: 

1381 sub_message.DiscardUnknownFields() 

1382 else: 

1383 value.DiscardUnknownFields() 

1384 

1385 

1386def _SetListener(self, listener): 

1387 if listener is None: 

1388 self._listener = message_listener_mod.NullMessageListener() 

1389 else: 

1390 self._listener = listener 

1391 

1392 

1393def _AddMessageMethods(message_descriptor, cls): 

1394 """Adds implementations of all Message methods to cls.""" 

1395 _AddListFieldsMethod(message_descriptor, cls) 

1396 _AddHasFieldMethod(message_descriptor, cls) 

1397 _AddClearFieldMethod(message_descriptor, cls) 

1398 if message_descriptor.is_extendable: 

1399 _AddClearExtensionMethod(cls) 

1400 _AddHasExtensionMethod(cls) 

1401 _AddEqualsMethod(message_descriptor, cls) 

1402 _AddStrMethod(message_descriptor, cls) 

1403 _AddReprMethod(message_descriptor, cls) 

1404 _AddUnicodeMethod(message_descriptor, cls) 

1405 _AddByteSizeMethod(message_descriptor, cls) 

1406 _AddSerializeToStringMethod(message_descriptor, cls) 

1407 _AddSerializePartialToStringMethod(message_descriptor, cls) 

1408 _AddMergeFromStringMethod(message_descriptor, cls) 

1409 _AddIsInitializedMethod(message_descriptor, cls) 

1410 _AddMergeFromMethod(cls) 

1411 _AddWhichOneofMethod(message_descriptor, cls) 

1412 # Adds methods which do not depend on cls. 

1413 cls.Clear = _Clear 

1414 cls.UnknownFields = _UnknownFields 

1415 cls.DiscardUnknownFields = _DiscardUnknownFields 

1416 cls._SetListener = _SetListener 

1417 

1418 

1419def _AddPrivateHelperMethods(message_descriptor, cls): 

1420 """Adds implementation of private helper methods to cls.""" 

1421 

1422 def Modified(self): 

1423 """Sets the _cached_byte_size_dirty bit to true, 

1424 and propagates this to our listener iff this was a state change. 

1425 """ 

1426 

1427 # Note: Some callers check _cached_byte_size_dirty before calling 

1428 # _Modified() as an extra optimization. So, if this method is ever 

1429 # changed such that it does stuff even when _cached_byte_size_dirty is 

1430 # already true, the callers need to be updated. 

1431 if not self._cached_byte_size_dirty: 

1432 self._cached_byte_size_dirty = True 

1433 self._listener_for_children.dirty = True 

1434 self._is_present_in_parent = True 

1435 self._listener.Modified() 

1436 

1437 def _UpdateOneofState(self, field): 

1438 """Sets field as the active field in its containing oneof. 

1439 

1440 Will also delete currently active field in the oneof, if it is different 

1441 from the argument. Does not mark the message as modified. 

1442 """ 

1443 other_field = self._oneofs.setdefault(field.containing_oneof, field) 

1444 if other_field is not field: 

1445 del self._fields[other_field] 

1446 self._oneofs[field.containing_oneof] = field 

1447 

1448 cls._Modified = Modified 

1449 cls.SetInParent = Modified 

1450 cls._UpdateOneofState = _UpdateOneofState 

1451 

1452 

1453class _Listener(object): 

1454 

1455 """MessageListener implementation that a parent message registers with its 

1456 child message. 

1457 

1458 In order to support semantics like: 

1459 

1460 foo.bar.baz.moo = 23 

1461 assert foo.HasField('bar') 

1462 

1463 ...child objects must have back references to their parents. 

1464 This helper class is at the heart of this support. 

1465 """ 

1466 

1467 def __init__(self, parent_message): 

1468 """Args: 

1469 parent_message: The message whose _Modified() method we should call when 

1470 we receive Modified() messages. 

1471 """ 

1472 # This listener establishes a back reference from a child (contained) object 

1473 # to its parent (containing) object. We make this a weak reference to avoid 

1474 # creating cyclic garbage when the client finishes with the 'parent' object 

1475 # in the tree. 

1476 if isinstance(parent_message, weakref.ProxyType): 

1477 self._parent_message_weakref = parent_message 

1478 else: 

1479 self._parent_message_weakref = weakref.proxy(parent_message) 

1480 

1481 # As an optimization, we also indicate directly on the listener whether 

1482 # or not the parent message is dirty. This way we can avoid traversing 

1483 # up the tree in the common case. 

1484 self.dirty = False 

1485 

1486 def Modified(self): 

1487 if self.dirty: 

1488 return 

1489 try: 

1490 # Propagate the signal to our parents iff this is the first field set. 

1491 self._parent_message_weakref._Modified() 

1492 except ReferenceError: 

1493 # We can get here if a client has kept a reference to a child object, 

1494 # and is now setting a field on it, but the child's parent has been 

1495 # garbage-collected. This is not an error. 

1496 pass 

1497 

1498 

1499class _OneofListener(_Listener): 

1500 """Special listener implementation for setting composite oneof fields.""" 

1501 

1502 def __init__(self, parent_message, field): 

1503 """Args: 

1504 parent_message: The message whose _Modified() method we should call when 

1505 we receive Modified() messages. 

1506 field: The descriptor of the field being set in the parent message. 

1507 """ 

1508 super(_OneofListener, self).__init__(parent_message) 

1509 self._field = field 

1510 

1511 def Modified(self): 

1512 """Also updates the state of the containing oneof in the parent message.""" 

1513 try: 

1514 self._parent_message_weakref._UpdateOneofState(self._field) 

1515 super(_OneofListener, self).Modified() 

1516 except ReferenceError: 

1517 pass