Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 67%

732 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:57 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# https://developers.google.com/protocol-buffers/ 

4# 

5# Redistribution and use in source and binary forms, with or without 

6# modification, are permitted provided that the following conditions are 

7# met: 

8# 

9# * Redistributions of source code must retain the above copyright 

10# notice, this list of conditions and the following disclaimer. 

11# * Redistributions in binary form must reproduce the above 

12# copyright notice, this list of conditions and the following disclaimer 

13# in the documentation and/or other materials provided with the 

14# distribution. 

15# * Neither the name of Google Inc. nor the names of its 

16# contributors may be used to endorse or promote products derived from 

17# this software without specific prior written permission. 

18# 

19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 

20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 

21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 

22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 

23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 

24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 

25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 

26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 

27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 

28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31# This code is meant to work on Python 2.4 and above only. 

32# 

33# TODO(robinson): Helpers for verbose, common checks like seeing if a 

34# descriptor's cpp_type is CPPTYPE_MESSAGE. 

35 

36"""Contains a metaclass and helper functions used to create 

37protocol message classes from Descriptor objects at runtime. 

38 

39Recall that a metaclass is the "type" of a class. 

40(A class is to a metaclass what an instance is to a class.) 

41 

42In this case, we use the GeneratedProtocolMessageType metaclass 

43to inject all the useful functionality into the classes 

44output by the protocol compiler at compile-time. 

45 

46The upshot of all this is that the real implementation 

47details for ALL pure-Python protocol buffers are *here in 

48this file*. 

49""" 

50 

51__author__ = 'robinson@google.com (Will Robinson)' 

52 

53from io import BytesIO 

54import struct 

55import sys 

56import weakref 

57 

58# We use "as" to avoid name collisions with variables. 

59from google.protobuf.internal import api_implementation 

60from google.protobuf.internal import containers 

61from google.protobuf.internal import decoder 

62from google.protobuf.internal import encoder 

63from google.protobuf.internal import enum_type_wrapper 

64from google.protobuf.internal import extension_dict 

65from google.protobuf.internal import message_listener as message_listener_mod 

66from google.protobuf.internal import type_checkers 

67from google.protobuf.internal import well_known_types 

68from google.protobuf.internal import wire_format 

69from google.protobuf import descriptor as descriptor_mod 

70from google.protobuf import message as message_mod 

71from google.protobuf import text_format 

72 

73_FieldDescriptor = descriptor_mod.FieldDescriptor 

74_AnyFullTypeName = 'google.protobuf.Any' 

75_ExtensionDict = extension_dict._ExtensionDict 

76 

77class GeneratedProtocolMessageType(type): 

78 

79 """Metaclass for protocol message classes created at runtime from Descriptors. 

80 

81 We add implementations for all methods described in the Message class. We 

82 also create properties to allow getting/setting all fields in the protocol 

83 message. Finally, we create slots to prevent users from accidentally 

84 "setting" nonexistent fields in the protocol message, which then wouldn't get 

85 serialized / deserialized properly. 

86 

87 The protocol compiler currently uses this metaclass to create protocol 

88 message classes at runtime. Clients can also manually create their own 

89 classes at runtime, as in this example: 

90 

91 mydescriptor = Descriptor(.....) 

92 factory = symbol_database.Default() 

93 factory.pool.AddDescriptor(mydescriptor) 

94 MyProtoClass = factory.GetPrototype(mydescriptor) 

95 myproto_instance = MyProtoClass() 

96 myproto.foo_field = 23 

97 ... 

98 """ 

99 

100 # Must be consistent with the protocol-compiler code in 

101 # proto2/compiler/internal/generator.*. 

102 _DESCRIPTOR_KEY = 'DESCRIPTOR' 

103 

104 def __new__(cls, name, bases, dictionary): 

105 """Custom allocation for runtime-generated class types. 

106 

107 We override __new__ because this is apparently the only place 

108 where we can meaningfully set __slots__ on the class we're creating(?). 

109 (The interplay between metaclasses and slots is not very well-documented). 

110 

111 Args: 

112 name: Name of the class (ignored, but required by the 

113 metaclass protocol). 

114 bases: Base classes of the class we're constructing. 

115 (Should be message.Message). We ignore this field, but 

116 it's required by the metaclass protocol 

117 dictionary: The class dictionary of the class we're 

118 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

119 a Descriptor object describing this protocol message 

120 type. 

121 

122 Returns: 

123 Newly-allocated class. 

124 

125 Raises: 

126 RuntimeError: Generated code only work with python cpp extension. 

127 """ 

128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

129 

130 if isinstance(descriptor, str): 

131 raise RuntimeError('The generated code only work with python cpp ' 

132 'extension, but it is using pure python runtime.') 

133 

134 # If a concrete class already exists for this descriptor, don't try to 

135 # create another. Doing so will break any messages that already exist with 

136 # the existing class. 

137 # 

138 # The C++ implementation appears to have its own internal `PyMessageFactory` 

139 # to achieve similar results. 

140 # 

141 # This most commonly happens in `text_format.py` when using descriptors from 

142 # a custom pool; it calls symbol_database.Global().getPrototype() on a 

143 # descriptor which already has an existing concrete class. 

144 new_class = getattr(descriptor, '_concrete_class', None) 

145 if new_class: 

146 return new_class 

147 

148 if descriptor.full_name in well_known_types.WKTBASES: 

149 bases += (well_known_types.WKTBASES[descriptor.full_name],) 

150 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 

151 _AddSlots(descriptor, dictionary) 

152 

153 superclass = super(GeneratedProtocolMessageType, cls) 

154 new_class = superclass.__new__(cls, name, bases, dictionary) 

155 return new_class 

156 

157 def __init__(cls, name, bases, dictionary): 

158 """Here we perform the majority of our work on the class. 

159 We add enum getters, an __init__ method, implementations 

160 of all Message methods, and properties for all fields 

161 in the protocol type. 

162 

163 Args: 

164 name: Name of the class (ignored, but required by the 

165 metaclass protocol). 

166 bases: Base classes of the class we're constructing. 

167 (Should be message.Message). We ignore this field, but 

168 it's required by the metaclass protocol 

169 dictionary: The class dictionary of the class we're 

170 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

171 a Descriptor object describing this protocol message 

172 type. 

173 """ 

174 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

175 

176 # If this is an _existing_ class looked up via `_concrete_class` in the 

177 # __new__ method above, then we don't need to re-initialize anything. 

178 existing_class = getattr(descriptor, '_concrete_class', None) 

179 if existing_class: 

180 assert existing_class is cls, ( 

181 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 

182 % (descriptor.full_name)) 

183 return 

184 

185 cls._decoders_by_tag = {} 

186 if (descriptor.has_options and 

187 descriptor.GetOptions().message_set_wire_format): 

188 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 

189 decoder.MessageSetItemDecoder(descriptor), None) 

190 

191 # Attach stuff to each FieldDescriptor for quick lookup later on. 

192 for field in descriptor.fields: 

193 _AttachFieldHelpers(cls, field) 

194 

195 descriptor._concrete_class = cls # pylint: disable=protected-access 

196 _AddEnumValues(descriptor, cls) 

197 _AddInitMethod(descriptor, cls) 

198 _AddPropertiesForFields(descriptor, cls) 

199 _AddPropertiesForExtensions(descriptor, cls) 

200 _AddStaticMethods(cls) 

201 _AddMessageMethods(descriptor, cls) 

202 _AddPrivateHelperMethods(descriptor, cls) 

203 

204 superclass = super(GeneratedProtocolMessageType, cls) 

205 superclass.__init__(name, bases, dictionary) 

206 

207 

208# Stateless helpers for GeneratedProtocolMessageType below. 

209# Outside clients should not access these directly. 

210# 

211# I opted not to make any of these methods on the metaclass, to make it more 

212# clear that I'm not really using any state there and to keep clients from 

213# thinking that they have direct access to these construction helpers. 

214 

215 

216def _PropertyName(proto_field_name): 

217 """Returns the name of the public property attribute which 

218 clients can use to get and (in some cases) set the value 

219 of a protocol message field. 

220 

221 Args: 

222 proto_field_name: The protocol message field name, exactly 

223 as it appears (or would appear) in a .proto file. 

224 """ 

225 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 

226 # nnorwitz makes my day by writing: 

227 # """ 

228 # FYI. See the keyword module in the stdlib. This could be as simple as: 

229 # 

230 # if keyword.iskeyword(proto_field_name): 

231 # return proto_field_name + "_" 

232 # return proto_field_name 

233 # """ 

234 # Kenton says: The above is a BAD IDEA. People rely on being able to use 

235 # getattr() and setattr() to reflectively manipulate field values. If we 

236 # rename the properties, then every such user has to also make sure to apply 

237 # the same transformation. Note that currently if you name a field "yield", 

238 # you can still access it just fine using getattr/setattr -- it's not even 

239 # that cumbersome to do so. 

240 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 

241 # position. 

242 return proto_field_name 

243 

244 

245def _AddSlots(message_descriptor, dictionary): 

246 """Adds a __slots__ entry to dictionary, containing the names of all valid 

247 attributes for this message type. 

248 

249 Args: 

250 message_descriptor: A Descriptor instance describing this message type. 

251 dictionary: Class dictionary to which we'll add a '__slots__' entry. 

252 """ 

253 dictionary['__slots__'] = ['_cached_byte_size', 

254 '_cached_byte_size_dirty', 

255 '_fields', 

256 '_unknown_fields', 

257 '_unknown_field_set', 

258 '_is_present_in_parent', 

259 '_listener', 

260 '_listener_for_children', 

261 '__weakref__', 

262 '_oneofs'] 

263 

264 

265def _IsMessageSetExtension(field): 

266 return (field.is_extension and 

267 field.containing_type.has_options and 

268 field.containing_type.GetOptions().message_set_wire_format and 

269 field.type == _FieldDescriptor.TYPE_MESSAGE and 

270 field.label == _FieldDescriptor.LABEL_OPTIONAL) 

271 

272 

273def _IsMapField(field): 

274 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 

275 field.message_type.has_options and 

276 field.message_type.GetOptions().map_entry) 

277 

278 

279def _IsMessageMapField(field): 

280 value_type = field.message_type.fields_by_name['value'] 

281 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 

282 

283 

284def _AttachFieldHelpers(cls, field_descriptor): 

285 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 

286 is_map_entry = _IsMapField(field_descriptor) 

287 is_packed = field_descriptor.is_packed 

288 

289 if is_map_entry: 

290 field_encoder = encoder.MapEncoder(field_descriptor) 

291 sizer = encoder.MapSizer(field_descriptor, 

292 _IsMessageMapField(field_descriptor)) 

293 elif _IsMessageSetExtension(field_descriptor): 

294 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 

295 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 

296 else: 

297 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 

298 field_descriptor.number, is_repeated, is_packed) 

299 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 

300 field_descriptor.number, is_repeated, is_packed) 

301 

302 field_descriptor._encoder = field_encoder 

303 field_descriptor._sizer = sizer 

304 field_descriptor._default_constructor = _DefaultValueConstructorForField( 

305 field_descriptor) 

306 

307 def AddDecoder(wiretype, is_packed): 

308 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 

309 decode_type = field_descriptor.type 

310 if (decode_type == _FieldDescriptor.TYPE_ENUM and 

311 not field_descriptor.enum_type.is_closed): 

312 decode_type = _FieldDescriptor.TYPE_INT32 

313 

314 oneof_descriptor = None 

315 if field_descriptor.containing_oneof is not None: 

316 oneof_descriptor = field_descriptor 

317 

318 if is_map_entry: 

319 is_message_map = _IsMessageMapField(field_descriptor) 

320 

321 field_decoder = decoder.MapDecoder( 

322 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 

323 is_message_map) 

324 elif decode_type == _FieldDescriptor.TYPE_STRING: 

325 field_decoder = decoder.StringDecoder( 

326 field_descriptor.number, is_repeated, is_packed, 

327 field_descriptor, field_descriptor._default_constructor, 

328 not field_descriptor.has_presence) 

329 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

330 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

331 field_descriptor.number, is_repeated, is_packed, 

332 field_descriptor, field_descriptor._default_constructor) 

333 else: 

334 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

335 field_descriptor.number, is_repeated, is_packed, 

336 # pylint: disable=protected-access 

337 field_descriptor, field_descriptor._default_constructor, 

338 not field_descriptor.has_presence) 

339 

340 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 

341 

342 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 

343 False) 

344 

345 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 

346 # To support wire compatibility of adding packed = true, add a decoder for 

347 # packed values regardless of the field's options. 

348 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 

349 

350 

351def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 

352 extensions = descriptor.extensions_by_name 

353 for extension_name, extension_field in extensions.items(): 

354 assert extension_name not in dictionary 

355 dictionary[extension_name] = extension_field 

356 

357 

358def _AddEnumValues(descriptor, cls): 

359 """Sets class-level attributes for all enum fields defined in this message. 

360 

361 Also exporting a class-level object that can name enum values. 

362 

363 Args: 

364 descriptor: Descriptor object for this message type. 

365 cls: Class we're constructing for this message type. 

366 """ 

367 for enum_type in descriptor.enum_types: 

368 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 

369 for enum_value in enum_type.values: 

370 setattr(cls, enum_value.name, enum_value.number) 

371 

372 

373def _GetInitializeDefaultForMap(field): 

374 if field.label != _FieldDescriptor.LABEL_REPEATED: 

375 raise ValueError('map_entry set on non-repeated field %s' % ( 

376 field.name)) 

377 fields_by_name = field.message_type.fields_by_name 

378 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 

379 

380 value_field = fields_by_name['value'] 

381 if _IsMessageMapField(field): 

382 def MakeMessageMapDefault(message): 

383 return containers.MessageMap( 

384 message._listener_for_children, value_field.message_type, key_checker, 

385 field.message_type) 

386 return MakeMessageMapDefault 

387 else: 

388 value_checker = type_checkers.GetTypeChecker(value_field) 

389 def MakePrimitiveMapDefault(message): 

390 return containers.ScalarMap( 

391 message._listener_for_children, key_checker, value_checker, 

392 field.message_type) 

393 return MakePrimitiveMapDefault 

394 

395def _DefaultValueConstructorForField(field): 

396 """Returns a function which returns a default value for a field. 

397 

398 Args: 

399 field: FieldDescriptor object for this field. 

400 

401 The returned function has one argument: 

402 message: Message instance containing this field, or a weakref proxy 

403 of same. 

404 

405 That function in turn returns a default value for this field. The default 

406 value may refer back to |message| via a weak reference. 

407 """ 

408 

409 if _IsMapField(field): 

410 return _GetInitializeDefaultForMap(field) 

411 

412 if field.label == _FieldDescriptor.LABEL_REPEATED: 

413 if field.has_default_value and field.default_value != []: 

414 raise ValueError('Repeated field default value not empty list: %s' % ( 

415 field.default_value)) 

416 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

417 # We can't look at _concrete_class yet since it might not have 

418 # been set. (Depends on order in which we initialize the classes). 

419 message_type = field.message_type 

420 def MakeRepeatedMessageDefault(message): 

421 return containers.RepeatedCompositeFieldContainer( 

422 message._listener_for_children, field.message_type) 

423 return MakeRepeatedMessageDefault 

424 else: 

425 type_checker = type_checkers.GetTypeChecker(field) 

426 def MakeRepeatedScalarDefault(message): 

427 return containers.RepeatedScalarFieldContainer( 

428 message._listener_for_children, type_checker) 

429 return MakeRepeatedScalarDefault 

430 

431 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

432 # _concrete_class may not yet be initialized. 

433 message_type = field.message_type 

434 def MakeSubMessageDefault(message): 

435 assert getattr(message_type, '_concrete_class', None), ( 

436 'Uninitialized concrete class found for field %r (message type %r)' 

437 % (field.full_name, message_type.full_name)) 

438 result = message_type._concrete_class() 

439 result._SetListener( 

440 _OneofListener(message, field) 

441 if field.containing_oneof is not None 

442 else message._listener_for_children) 

443 return result 

444 return MakeSubMessageDefault 

445 

446 def MakeScalarDefault(message): 

447 # TODO(protobuf-team): This may be broken since there may not be 

448 # default_value. Combine with has_default_value somehow. 

449 return field.default_value 

450 return MakeScalarDefault 

451 

452 

453def _ReraiseTypeErrorWithFieldName(message_name, field_name): 

454 """Re-raise the currently-handled TypeError with the field name added.""" 

455 exc = sys.exc_info()[1] 

456 if len(exc.args) == 1 and type(exc) is TypeError: 

457 # simple TypeError; add field name to exception message 

458 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 

459 

460 # re-raise possibly-amended exception with original traceback: 

461 raise exc.with_traceback(sys.exc_info()[2]) 

462 

463 

464def _AddInitMethod(message_descriptor, cls): 

465 """Adds an __init__ method to cls.""" 

466 

467 def _GetIntegerEnumValue(enum_type, value): 

468 """Convert a string or integer enum value to an integer. 

469 

470 If the value is a string, it is converted to the enum value in 

471 enum_type with the same name. If the value is not a string, it's 

472 returned as-is. (No conversion or bounds-checking is done.) 

473 """ 

474 if isinstance(value, str): 

475 try: 

476 return enum_type.values_by_name[value].number 

477 except KeyError: 

478 raise ValueError('Enum type %s: unknown label "%s"' % ( 

479 enum_type.full_name, value)) 

480 return value 

481 

482 def init(self, **kwargs): 

483 self._cached_byte_size = 0 

484 self._cached_byte_size_dirty = len(kwargs) > 0 

485 self._fields = {} 

486 # Contains a mapping from oneof field descriptors to the descriptor 

487 # of the currently set field in that oneof field. 

488 self._oneofs = {} 

489 

490 # _unknown_fields is () when empty for efficiency, and will be turned into 

491 # a list if fields are added. 

492 self._unknown_fields = () 

493 # _unknown_field_set is None when empty for efficiency, and will be 

494 # turned into UnknownFieldSet struct if fields are added. 

495 self._unknown_field_set = None # pylint: disable=protected-access 

496 self._is_present_in_parent = False 

497 self._listener = message_listener_mod.NullMessageListener() 

498 self._listener_for_children = _Listener(self) 

499 for field_name, field_value in kwargs.items(): 

500 field = _GetFieldByName(message_descriptor, field_name) 

501 if field is None: 

502 raise TypeError('%s() got an unexpected keyword argument "%s"' % 

503 (message_descriptor.name, field_name)) 

504 if field_value is None: 

505 # field=None is the same as no field at all. 

506 continue 

507 if field.label == _FieldDescriptor.LABEL_REPEATED: 

508 copy = field._default_constructor(self) 

509 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 

510 if _IsMapField(field): 

511 if _IsMessageMapField(field): 

512 for key in field_value: 

513 copy[key].MergeFrom(field_value[key]) 

514 else: 

515 copy.update(field_value) 

516 else: 

517 for val in field_value: 

518 if isinstance(val, dict): 

519 copy.add(**val) 

520 else: 

521 copy.add().MergeFrom(val) 

522 else: # Scalar 

523 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

524 field_value = [_GetIntegerEnumValue(field.enum_type, val) 

525 for val in field_value] 

526 copy.extend(field_value) 

527 self._fields[field] = copy 

528 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

529 copy = field._default_constructor(self) 

530 new_val = field_value 

531 if isinstance(field_value, dict): 

532 new_val = field.message_type._concrete_class(**field_value) 

533 try: 

534 copy.MergeFrom(new_val) 

535 except TypeError: 

536 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

537 self._fields[field] = copy 

538 else: 

539 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

540 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 

541 try: 

542 setattr(self, field_name, field_value) 

543 except TypeError: 

544 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

545 

546 init.__module__ = None 

547 init.__doc__ = None 

548 cls.__init__ = init 

549 

550 

551def _GetFieldByName(message_descriptor, field_name): 

552 """Returns a field descriptor by field name. 

553 

554 Args: 

555 message_descriptor: A Descriptor describing all fields in message. 

556 field_name: The name of the field to retrieve. 

557 Returns: 

558 The field descriptor associated with the field name. 

559 """ 

560 try: 

561 return message_descriptor.fields_by_name[field_name] 

562 except KeyError: 

563 raise ValueError('Protocol message %s has no "%s" field.' % 

564 (message_descriptor.name, field_name)) 

565 

566 

567def _AddPropertiesForFields(descriptor, cls): 

568 """Adds properties for all fields in this protocol message type.""" 

569 for field in descriptor.fields: 

570 _AddPropertiesForField(field, cls) 

571 

572 if descriptor.is_extendable: 

573 # _ExtensionDict is just an adaptor with no state so we allocate a new one 

574 # every time it is accessed. 

575 cls.Extensions = property(lambda self: _ExtensionDict(self)) 

576 

577 

578def _AddPropertiesForField(field, cls): 

579 """Adds a public property for a protocol message field. 

580 Clients can use this property to get and (in the case 

581 of non-repeated scalar fields) directly set the value 

582 of a protocol message field. 

583 

584 Args: 

585 field: A FieldDescriptor for this field. 

586 cls: The class we're constructing. 

587 """ 

588 # Catch it if we add other types that we should 

589 # handle specially here. 

590 assert _FieldDescriptor.MAX_CPPTYPE == 10 

591 

592 constant_name = field.name.upper() + '_FIELD_NUMBER' 

593 setattr(cls, constant_name, field.number) 

594 

595 if field.label == _FieldDescriptor.LABEL_REPEATED: 

596 _AddPropertiesForRepeatedField(field, cls) 

597 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

598 _AddPropertiesForNonRepeatedCompositeField(field, cls) 

599 else: 

600 _AddPropertiesForNonRepeatedScalarField(field, cls) 

601 

602 

603class _FieldProperty(property): 

604 __slots__ = ('DESCRIPTOR',) 

605 

606 def __init__(self, descriptor, getter, setter, doc): 

607 property.__init__(self, getter, setter, doc=doc) 

608 self.DESCRIPTOR = descriptor 

609 

610 

611def _AddPropertiesForRepeatedField(field, cls): 

612 """Adds a public property for a "repeated" protocol message field. Clients 

613 can use this property to get the value of the field, which will be either a 

614 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 

615 below). 

616 

617 Note that when clients add values to these containers, we perform 

618 type-checking in the case of repeated scalar fields, and we also set any 

619 necessary "has" bits as a side-effect. 

620 

621 Args: 

622 field: A FieldDescriptor for this field. 

623 cls: The class we're constructing. 

624 """ 

625 proto_field_name = field.name 

626 property_name = _PropertyName(proto_field_name) 

627 

628 def getter(self): 

629 field_value = self._fields.get(field) 

630 if field_value is None: 

631 # Construct a new object to represent this field. 

632 field_value = field._default_constructor(self) 

633 

634 # Atomically check if another thread has preempted us and, if not, swap 

635 # in the new object we just created. If someone has preempted us, we 

636 # take that object and discard ours. 

637 # WARNING: We are relying on setdefault() being atomic. This is true 

638 # in CPython but we haven't investigated others. This warning appears 

639 # in several other locations in this file. 

640 field_value = self._fields.setdefault(field, field_value) 

641 return field_value 

642 getter.__module__ = None 

643 getter.__doc__ = 'Getter for %s.' % proto_field_name 

644 

645 # We define a setter just so we can throw an exception with a more 

646 # helpful error message. 

647 def setter(self, new_value): 

648 raise AttributeError('Assignment not allowed to repeated field ' 

649 '"%s" in protocol message object.' % proto_field_name) 

650 

651 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

652 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

653 

654 

655def _AddPropertiesForNonRepeatedScalarField(field, cls): 

656 """Adds a public property for a nonrepeated, scalar protocol message field. 

657 Clients can use this property to get and directly set the value of the field. 

658 Note that when the client sets the value of a field by using this property, 

659 all necessary "has" bits are set as a side-effect, and we also perform 

660 type-checking. 

661 

662 Args: 

663 field: A FieldDescriptor for this field. 

664 cls: The class we're constructing. 

665 """ 

666 proto_field_name = field.name 

667 property_name = _PropertyName(proto_field_name) 

668 type_checker = type_checkers.GetTypeChecker(field) 

669 default_value = field.default_value 

670 

671 def getter(self): 

672 # TODO(protobuf-team): This may be broken since there may not be 

673 # default_value. Combine with has_default_value somehow. 

674 return self._fields.get(field, default_value) 

675 getter.__module__ = None 

676 getter.__doc__ = 'Getter for %s.' % proto_field_name 

677 

678 def field_setter(self, new_value): 

679 # pylint: disable=protected-access 

680 # Testing the value for truthiness captures all of the proto3 defaults 

681 # (0, 0.0, enum 0, and False). 

682 try: 

683 new_value = type_checker.CheckValue(new_value) 

684 except TypeError as e: 

685 raise TypeError( 

686 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 

687 if not field.has_presence and not new_value: 

688 self._fields.pop(field, None) 

689 else: 

690 self._fields[field] = new_value 

691 # Check _cached_byte_size_dirty inline to improve performance, since scalar 

692 # setters are called frequently. 

693 if not self._cached_byte_size_dirty: 

694 self._Modified() 

695 

696 if field.containing_oneof: 

697 def setter(self, new_value): 

698 field_setter(self, new_value) 

699 self._UpdateOneofState(field) 

700 else: 

701 setter = field_setter 

702 

703 setter.__module__ = None 

704 setter.__doc__ = 'Setter for %s.' % proto_field_name 

705 

706 # Add a property to encapsulate the getter/setter. 

707 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

708 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

709 

710 

711def _AddPropertiesForNonRepeatedCompositeField(field, cls): 

712 """Adds a public property for a nonrepeated, composite protocol message field. 

713 A composite field is a "group" or "message" field. 

714 

715 Clients can use this property to get the value of the field, but cannot 

716 assign to the property directly. 

717 

718 Args: 

719 field: A FieldDescriptor for this field. 

720 cls: The class we're constructing. 

721 """ 

722 # TODO(robinson): Remove duplication with similar method 

723 # for non-repeated scalars. 

724 proto_field_name = field.name 

725 property_name = _PropertyName(proto_field_name) 

726 

727 def getter(self): 

728 field_value = self._fields.get(field) 

729 if field_value is None: 

730 # Construct a new object to represent this field. 

731 field_value = field._default_constructor(self) 

732 

733 # Atomically check if another thread has preempted us and, if not, swap 

734 # in the new object we just created. If someone has preempted us, we 

735 # take that object and discard ours. 

736 # WARNING: We are relying on setdefault() being atomic. This is true 

737 # in CPython but we haven't investigated others. This warning appears 

738 # in several other locations in this file. 

739 field_value = self._fields.setdefault(field, field_value) 

740 return field_value 

741 getter.__module__ = None 

742 getter.__doc__ = 'Getter for %s.' % proto_field_name 

743 

744 # We define a setter just so we can throw an exception with a more 

745 # helpful error message. 

746 def setter(self, new_value): 

747 raise AttributeError('Assignment not allowed to composite field ' 

748 '"%s" in protocol message object.' % proto_field_name) 

749 

750 # Add a property to encapsulate the getter. 

751 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

752 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

753 

754 

755def _AddPropertiesForExtensions(descriptor, cls): 

756 """Adds properties for all fields in this protocol message type.""" 

757 extensions = descriptor.extensions_by_name 

758 for extension_name, extension_field in extensions.items(): 

759 constant_name = extension_name.upper() + '_FIELD_NUMBER' 

760 setattr(cls, constant_name, extension_field.number) 

761 

762 # TODO(amauryfa): Migrate all users of these attributes to functions like 

763 # pool.FindExtensionByNumber(descriptor). 

764 if descriptor.file is not None: 

765 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 

766 pool = descriptor.file.pool 

767 cls._extensions_by_number = pool._extensions_by_number[descriptor] 

768 cls._extensions_by_name = pool._extensions_by_name[descriptor] 

769 

770def _AddStaticMethods(cls): 

771 # TODO(robinson): This probably needs to be thread-safe(?) 

772 def RegisterExtension(field_descriptor): 

773 field_descriptor.containing_type = cls.DESCRIPTOR 

774 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 

775 # pylint: disable=protected-access 

776 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor) 

777 _AttachFieldHelpers(cls, field_descriptor) 

778 cls.RegisterExtension = staticmethod(RegisterExtension) 

779 

780 def FromString(s): 

781 message = cls() 

782 message.MergeFromString(s) 

783 return message 

784 cls.FromString = staticmethod(FromString) 

785 

786 

787def _IsPresent(item): 

788 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 

789 value should be included in the list returned by ListFields().""" 

790 

791 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 

792 return bool(item[1]) 

793 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

794 return item[1]._is_present_in_parent 

795 else: 

796 return True 

797 

798 

799def _AddListFieldsMethod(message_descriptor, cls): 

800 """Helper for _AddMessageMethods().""" 

801 

802 def ListFields(self): 

803 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 

804 all_fields.sort(key = lambda item: item[0].number) 

805 return all_fields 

806 

807 cls.ListFields = ListFields 

808 

809 

810def _AddHasFieldMethod(message_descriptor, cls): 

811 """Helper for _AddMessageMethods().""" 

812 

813 hassable_fields = {} 

814 for field in message_descriptor.fields: 

815 if field.label == _FieldDescriptor.LABEL_REPEATED: 

816 continue 

817 # For proto3, only submessages and fields inside a oneof have presence. 

818 if not field.has_presence: 

819 continue 

820 hassable_fields[field.name] = field 

821 

822 # Has methods are supported for oneof descriptors. 

823 for oneof in message_descriptor.oneofs: 

824 hassable_fields[oneof.name] = oneof 

825 

826 def HasField(self, field_name): 

827 try: 

828 field = hassable_fields[field_name] 

829 except KeyError as exc: 

830 raise ValueError('Protocol message %s has no non-repeated field "%s" ' 

831 'nor has presence is not available for this field.' % ( 

832 message_descriptor.full_name, field_name)) from exc 

833 

834 if isinstance(field, descriptor_mod.OneofDescriptor): 

835 try: 

836 return HasField(self, self._oneofs[field].name) 

837 except KeyError: 

838 return False 

839 else: 

840 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

841 value = self._fields.get(field) 

842 return value is not None and value._is_present_in_parent 

843 else: 

844 return field in self._fields 

845 

846 cls.HasField = HasField 

847 

848 

849def _AddClearFieldMethod(message_descriptor, cls): 

850 """Helper for _AddMessageMethods().""" 

851 def ClearField(self, field_name): 

852 try: 

853 field = message_descriptor.fields_by_name[field_name] 

854 except KeyError: 

855 try: 

856 field = message_descriptor.oneofs_by_name[field_name] 

857 if field in self._oneofs: 

858 field = self._oneofs[field] 

859 else: 

860 return 

861 except KeyError: 

862 raise ValueError('Protocol message %s has no "%s" field.' % 

863 (message_descriptor.name, field_name)) 

864 

865 if field in self._fields: 

866 # To match the C++ implementation, we need to invalidate iterators 

867 # for map fields when ClearField() happens. 

868 if hasattr(self._fields[field], 'InvalidateIterators'): 

869 self._fields[field].InvalidateIterators() 

870 

871 # Note: If the field is a sub-message, its listener will still point 

872 # at us. That's fine, because the worst than can happen is that it 

873 # will call _Modified() and invalidate our byte size. Big deal. 

874 del self._fields[field] 

875 

876 if self._oneofs.get(field.containing_oneof, None) is field: 

877 del self._oneofs[field.containing_oneof] 

878 

879 # Always call _Modified() -- even if nothing was changed, this is 

880 # a mutating method, and thus calling it should cause the field to become 

881 # present in the parent message. 

882 self._Modified() 

883 

884 cls.ClearField = ClearField 

885 

886 

887def _AddClearExtensionMethod(cls): 

888 """Helper for _AddMessageMethods().""" 

889 def ClearExtension(self, field_descriptor): 

890 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

891 

892 # Similar to ClearField(), above. 

893 if field_descriptor in self._fields: 

894 del self._fields[field_descriptor] 

895 self._Modified() 

896 cls.ClearExtension = ClearExtension 

897 

898 

899def _AddHasExtensionMethod(cls): 

900 """Helper for _AddMessageMethods().""" 

901 def HasExtension(self, field_descriptor): 

902 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

903 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: 

904 raise KeyError('"%s" is repeated.' % field_descriptor.full_name) 

905 

906 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

907 value = self._fields.get(field_descriptor) 

908 return value is not None and value._is_present_in_parent 

909 else: 

910 return field_descriptor in self._fields 

911 cls.HasExtension = HasExtension 

912 

913def _InternalUnpackAny(msg): 

914 """Unpacks Any message and returns the unpacked message. 

915 

916 This internal method is different from public Any Unpack method which takes 

917 the target message as argument. _InternalUnpackAny method does not have 

918 target message type and need to find the message type in descriptor pool. 

919 

920 Args: 

921 msg: An Any message to be unpacked. 

922 

923 Returns: 

924 The unpacked message. 

925 """ 

926 # TODO(amauryfa): Don't use the factory of generated messages. 

927 # To make Any work with custom factories, use the message factory of the 

928 # parent message. 

929 # pylint: disable=g-import-not-at-top 

930 from google.protobuf import symbol_database 

931 factory = symbol_database.Default() 

932 

933 type_url = msg.type_url 

934 

935 if not type_url: 

936 return None 

937 

938 # TODO(haberman): For now we just strip the hostname. Better logic will be 

939 # required. 

940 type_name = type_url.split('/')[-1] 

941 descriptor = factory.pool.FindMessageTypeByName(type_name) 

942 

943 if descriptor is None: 

944 return None 

945 

946 message_class = factory.GetPrototype(descriptor) 

947 message = message_class() 

948 

949 message.ParseFromString(msg.value) 

950 return message 

951 

952 

953def _AddEqualsMethod(message_descriptor, cls): 

954 """Helper for _AddMessageMethods().""" 

955 def __eq__(self, other): 

956 if (not isinstance(other, message_mod.Message) or 

957 other.DESCRIPTOR != self.DESCRIPTOR): 

958 return False 

959 

960 if self is other: 

961 return True 

962 

963 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 

964 any_a = _InternalUnpackAny(self) 

965 any_b = _InternalUnpackAny(other) 

966 if any_a and any_b: 

967 return any_a == any_b 

968 

969 if not self.ListFields() == other.ListFields(): 

970 return False 

971 

972 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, 

973 # then use it for the comparison. 

974 unknown_fields = list(self._unknown_fields) 

975 unknown_fields.sort() 

976 other_unknown_fields = list(other._unknown_fields) 

977 other_unknown_fields.sort() 

978 return unknown_fields == other_unknown_fields 

979 

980 cls.__eq__ = __eq__ 

981 

982 

983def _AddStrMethod(message_descriptor, cls): 

984 """Helper for _AddMessageMethods().""" 

985 def __str__(self): 

986 return text_format.MessageToString(self) 

987 cls.__str__ = __str__ 

988 

989 

990def _AddReprMethod(message_descriptor, cls): 

991 """Helper for _AddMessageMethods().""" 

992 def __repr__(self): 

993 return text_format.MessageToString(self) 

994 cls.__repr__ = __repr__ 

995 

996 

997def _AddUnicodeMethod(unused_message_descriptor, cls): 

998 """Helper for _AddMessageMethods().""" 

999 

1000 def __unicode__(self): 

1001 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 

1002 cls.__unicode__ = __unicode__ 

1003 

1004 

1005def _BytesForNonRepeatedElement(value, field_number, field_type): 

1006 """Returns the number of bytes needed to serialize a non-repeated element. 

1007 The returned byte count includes space for tag information and any 

1008 other additional space associated with serializing value. 

1009 

1010 Args: 

1011 value: Value we're serializing. 

1012 field_number: Field number of this value. (Since the field number 

1013 is stored as part of a varint-encoded tag, this has an impact 

1014 on the total bytes required to serialize the value). 

1015 field_type: The type of the field. One of the TYPE_* constants 

1016 within FieldDescriptor. 

1017 """ 

1018 try: 

1019 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 

1020 return fn(field_number, value) 

1021 except KeyError: 

1022 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 

1023 

1024 

1025def _AddByteSizeMethod(message_descriptor, cls): 

1026 """Helper for _AddMessageMethods().""" 

1027 

1028 def ByteSize(self): 

1029 if not self._cached_byte_size_dirty: 

1030 return self._cached_byte_size 

1031 

1032 size = 0 

1033 descriptor = self.DESCRIPTOR 

1034 if descriptor.GetOptions().map_entry: 

1035 # Fields of map entry should always be serialized. 

1036 size = descriptor.fields_by_name['key']._sizer(self.key) 

1037 size += descriptor.fields_by_name['value']._sizer(self.value) 

1038 else: 

1039 for field_descriptor, field_value in self.ListFields(): 

1040 size += field_descriptor._sizer(field_value) 

1041 for tag_bytes, value_bytes in self._unknown_fields: 

1042 size += len(tag_bytes) + len(value_bytes) 

1043 

1044 self._cached_byte_size = size 

1045 self._cached_byte_size_dirty = False 

1046 self._listener_for_children.dirty = False 

1047 return size 

1048 

1049 cls.ByteSize = ByteSize 

1050 

1051 

1052def _AddSerializeToStringMethod(message_descriptor, cls): 

1053 """Helper for _AddMessageMethods().""" 

1054 

1055 def SerializeToString(self, **kwargs): 

1056 # Check if the message has all of its required fields set. 

1057 if not self.IsInitialized(): 

1058 raise message_mod.EncodeError( 

1059 'Message %s is missing required fields: %s' % ( 

1060 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 

1061 return self.SerializePartialToString(**kwargs) 

1062 cls.SerializeToString = SerializeToString 

1063 

1064 

1065def _AddSerializePartialToStringMethod(message_descriptor, cls): 

1066 """Helper for _AddMessageMethods().""" 

1067 

1068 def SerializePartialToString(self, **kwargs): 

1069 out = BytesIO() 

1070 self._InternalSerialize(out.write, **kwargs) 

1071 return out.getvalue() 

1072 cls.SerializePartialToString = SerializePartialToString 

1073 

1074 def InternalSerialize(self, write_bytes, deterministic=None): 

1075 if deterministic is None: 

1076 deterministic = ( 

1077 api_implementation.IsPythonDefaultSerializationDeterministic()) 

1078 else: 

1079 deterministic = bool(deterministic) 

1080 

1081 descriptor = self.DESCRIPTOR 

1082 if descriptor.GetOptions().map_entry: 

1083 # Fields of map entry should always be serialized. 

1084 descriptor.fields_by_name['key']._encoder( 

1085 write_bytes, self.key, deterministic) 

1086 descriptor.fields_by_name['value']._encoder( 

1087 write_bytes, self.value, deterministic) 

1088 else: 

1089 for field_descriptor, field_value in self.ListFields(): 

1090 field_descriptor._encoder(write_bytes, field_value, deterministic) 

1091 for tag_bytes, value_bytes in self._unknown_fields: 

1092 write_bytes(tag_bytes) 

1093 write_bytes(value_bytes) 

1094 cls._InternalSerialize = InternalSerialize 

1095 

1096 

1097def _AddMergeFromStringMethod(message_descriptor, cls): 

1098 """Helper for _AddMessageMethods().""" 

1099 def MergeFromString(self, serialized): 

1100 serialized = memoryview(serialized) 

1101 length = len(serialized) 

1102 try: 

1103 if self._InternalParse(serialized, 0, length) != length: 

1104 # The only reason _InternalParse would return early is if it 

1105 # encountered an end-group tag. 

1106 raise message_mod.DecodeError('Unexpected end-group tag.') 

1107 except (IndexError, TypeError): 

1108 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 

1109 raise message_mod.DecodeError('Truncated message.') 

1110 except struct.error as e: 

1111 raise message_mod.DecodeError(e) 

1112 return length # Return this for legacy reasons. 

1113 cls.MergeFromString = MergeFromString 

1114 

1115 local_ReadTag = decoder.ReadTag 

1116 local_SkipField = decoder.SkipField 

1117 decoders_by_tag = cls._decoders_by_tag 

1118 

1119 def InternalParse(self, buffer, pos, end): 

1120 """Create a message from serialized bytes. 

1121 

1122 Args: 

1123 self: Message, instance of the proto message object. 

1124 buffer: memoryview of the serialized data. 

1125 pos: int, position to start in the serialized data. 

1126 end: int, end position of the serialized data. 

1127 

1128 Returns: 

1129 Message object. 

1130 """ 

1131 # Guard against internal misuse, since this function is called internally 

1132 # quite extensively, and its easy to accidentally pass bytes. 

1133 assert isinstance(buffer, memoryview) 

1134 self._Modified() 

1135 field_dict = self._fields 

1136 # pylint: disable=protected-access 

1137 unknown_field_set = self._unknown_field_set 

1138 while pos != end: 

1139 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 

1140 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 

1141 if field_decoder is None: 

1142 if not self._unknown_fields: # pylint: disable=protected-access 

1143 self._unknown_fields = [] # pylint: disable=protected-access 

1144 if unknown_field_set is None: 

1145 # pylint: disable=protected-access 

1146 self._unknown_field_set = containers.UnknownFieldSet() 

1147 # pylint: disable=protected-access 

1148 unknown_field_set = self._unknown_field_set 

1149 # pylint: disable=protected-access 

1150 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 

1151 field_number, wire_type = wire_format.UnpackTag(tag) 

1152 if field_number == 0: 

1153 raise message_mod.DecodeError('Field number 0 is illegal.') 

1154 # TODO(jieluo): remove old_pos. 

1155 old_pos = new_pos 

1156 (data, new_pos) = decoder._DecodeUnknownField( 

1157 buffer, new_pos, wire_type) # pylint: disable=protected-access 

1158 if new_pos == -1: 

1159 return pos 

1160 # pylint: disable=protected-access 

1161 unknown_field_set._add(field_number, wire_type, data) 

1162 # TODO(jieluo): remove _unknown_fields. 

1163 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 

1164 if new_pos == -1: 

1165 return pos 

1166 self._unknown_fields.append( 

1167 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 

1168 pos = new_pos 

1169 else: 

1170 pos = field_decoder(buffer, new_pos, end, self, field_dict) 

1171 if field_desc: 

1172 self._UpdateOneofState(field_desc) 

1173 return pos 

1174 cls._InternalParse = InternalParse 

1175 

1176 

1177def _AddIsInitializedMethod(message_descriptor, cls): 

1178 """Adds the IsInitialized and FindInitializationError methods to the 

1179 protocol message class.""" 

1180 

1181 required_fields = [field for field in message_descriptor.fields 

1182 if field.label == _FieldDescriptor.LABEL_REQUIRED] 

1183 

1184 def IsInitialized(self, errors=None): 

1185 """Checks if all required fields of a message are set. 

1186 

1187 Args: 

1188 errors: A list which, if provided, will be populated with the field 

1189 paths of all missing required fields. 

1190 

1191 Returns: 

1192 True iff the specified message has all required fields set. 

1193 """ 

1194 

1195 # Performance is critical so we avoid HasField() and ListFields(). 

1196 

1197 for field in required_fields: 

1198 if (field not in self._fields or 

1199 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 

1200 not self._fields[field]._is_present_in_parent)): 

1201 if errors is not None: 

1202 errors.extend(self.FindInitializationErrors()) 

1203 return False 

1204 

1205 for field, value in list(self._fields.items()): # dict can change size! 

1206 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1207 if field.label == _FieldDescriptor.LABEL_REPEATED: 

1208 if (field.message_type.has_options and 

1209 field.message_type.GetOptions().map_entry): 

1210 continue 

1211 for element in value: 

1212 if not element.IsInitialized(): 

1213 if errors is not None: 

1214 errors.extend(self.FindInitializationErrors()) 

1215 return False 

1216 elif value._is_present_in_parent and not value.IsInitialized(): 

1217 if errors is not None: 

1218 errors.extend(self.FindInitializationErrors()) 

1219 return False 

1220 

1221 return True 

1222 

1223 cls.IsInitialized = IsInitialized 

1224 

1225 def FindInitializationErrors(self): 

1226 """Finds required fields which are not initialized. 

1227 

1228 Returns: 

1229 A list of strings. Each string is a path to an uninitialized field from 

1230 the top-level message, e.g. "foo.bar[5].baz". 

1231 """ 

1232 

1233 errors = [] # simplify things 

1234 

1235 for field in required_fields: 

1236 if not self.HasField(field.name): 

1237 errors.append(field.name) 

1238 

1239 for field, value in self.ListFields(): 

1240 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1241 if field.is_extension: 

1242 name = '(%s)' % field.full_name 

1243 else: 

1244 name = field.name 

1245 

1246 if _IsMapField(field): 

1247 if _IsMessageMapField(field): 

1248 for key in value: 

1249 element = value[key] 

1250 prefix = '%s[%s].' % (name, key) 

1251 sub_errors = element.FindInitializationErrors() 

1252 errors += [prefix + error for error in sub_errors] 

1253 else: 

1254 # ScalarMaps can't have any initialization errors. 

1255 pass 

1256 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1257 for i in range(len(value)): 

1258 element = value[i] 

1259 prefix = '%s[%d].' % (name, i) 

1260 sub_errors = element.FindInitializationErrors() 

1261 errors += [prefix + error for error in sub_errors] 

1262 else: 

1263 prefix = name + '.' 

1264 sub_errors = value.FindInitializationErrors() 

1265 errors += [prefix + error for error in sub_errors] 

1266 

1267 return errors 

1268 

1269 cls.FindInitializationErrors = FindInitializationErrors 

1270 

1271 

1272def _FullyQualifiedClassName(klass): 

1273 module = klass.__module__ 

1274 name = getattr(klass, '__qualname__', klass.__name__) 

1275 if module in (None, 'builtins', '__builtin__'): 

1276 return name 

1277 return module + '.' + name 

1278 

1279 

1280def _AddMergeFromMethod(cls): 

1281 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 

1282 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 

1283 

1284 def MergeFrom(self, msg): 

1285 if not isinstance(msg, cls): 

1286 raise TypeError( 

1287 'Parameter to MergeFrom() must be instance of same class: ' 

1288 'expected %s got %s.' % (_FullyQualifiedClassName(cls), 

1289 _FullyQualifiedClassName(msg.__class__))) 

1290 

1291 assert msg is not self 

1292 self._Modified() 

1293 

1294 fields = self._fields 

1295 

1296 for field, value in msg._fields.items(): 

1297 if field.label == LABEL_REPEATED: 

1298 field_value = fields.get(field) 

1299 if field_value is None: 

1300 # Construct a new object to represent this field. 

1301 field_value = field._default_constructor(self) 

1302 fields[field] = field_value 

1303 field_value.MergeFrom(value) 

1304 elif field.cpp_type == CPPTYPE_MESSAGE: 

1305 if value._is_present_in_parent: 

1306 field_value = fields.get(field) 

1307 if field_value is None: 

1308 # Construct a new object to represent this field. 

1309 field_value = field._default_constructor(self) 

1310 fields[field] = field_value 

1311 field_value.MergeFrom(value) 

1312 else: 

1313 self._fields[field] = value 

1314 if field.containing_oneof: 

1315 self._UpdateOneofState(field) 

1316 

1317 if msg._unknown_fields: 

1318 if not self._unknown_fields: 

1319 self._unknown_fields = [] 

1320 self._unknown_fields.extend(msg._unknown_fields) 

1321 # pylint: disable=protected-access 

1322 if self._unknown_field_set is None: 

1323 self._unknown_field_set = containers.UnknownFieldSet() 

1324 self._unknown_field_set._extend(msg._unknown_field_set) 

1325 

1326 cls.MergeFrom = MergeFrom 

1327 

1328 

1329def _AddWhichOneofMethod(message_descriptor, cls): 

1330 def WhichOneof(self, oneof_name): 

1331 """Returns the name of the currently set field inside a oneof, or None.""" 

1332 try: 

1333 field = message_descriptor.oneofs_by_name[oneof_name] 

1334 except KeyError: 

1335 raise ValueError( 

1336 'Protocol message has no oneof "%s" field.' % oneof_name) 

1337 

1338 nested_field = self._oneofs.get(field, None) 

1339 if nested_field is not None and self.HasField(nested_field.name): 

1340 return nested_field.name 

1341 else: 

1342 return None 

1343 

1344 cls.WhichOneof = WhichOneof 

1345 

1346 

1347def _Clear(self): 

1348 # Clear fields. 

1349 self._fields = {} 

1350 self._unknown_fields = () 

1351 # pylint: disable=protected-access 

1352 if self._unknown_field_set is not None: 

1353 self._unknown_field_set._clear() 

1354 self._unknown_field_set = None 

1355 

1356 self._oneofs = {} 

1357 self._Modified() 

1358 

1359 

1360def _UnknownFields(self): 

1361 if self._unknown_field_set is None: # pylint: disable=protected-access 

1362 # pylint: disable=protected-access 

1363 self._unknown_field_set = containers.UnknownFieldSet() 

1364 return self._unknown_field_set # pylint: disable=protected-access 

1365 

1366 

1367def _DiscardUnknownFields(self): 

1368 self._unknown_fields = [] 

1369 self._unknown_field_set = None # pylint: disable=protected-access 

1370 for field, value in self.ListFields(): 

1371 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1372 if _IsMapField(field): 

1373 if _IsMessageMapField(field): 

1374 for key in value: 

1375 value[key].DiscardUnknownFields() 

1376 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1377 for sub_message in value: 

1378 sub_message.DiscardUnknownFields() 

1379 else: 

1380 value.DiscardUnknownFields() 

1381 

1382 

1383def _SetListener(self, listener): 

1384 if listener is None: 

1385 self._listener = message_listener_mod.NullMessageListener() 

1386 else: 

1387 self._listener = listener 

1388 

1389 

1390def _AddMessageMethods(message_descriptor, cls): 

1391 """Adds implementations of all Message methods to cls.""" 

1392 _AddListFieldsMethod(message_descriptor, cls) 

1393 _AddHasFieldMethod(message_descriptor, cls) 

1394 _AddClearFieldMethod(message_descriptor, cls) 

1395 if message_descriptor.is_extendable: 

1396 _AddClearExtensionMethod(cls) 

1397 _AddHasExtensionMethod(cls) 

1398 _AddEqualsMethod(message_descriptor, cls) 

1399 _AddStrMethod(message_descriptor, cls) 

1400 _AddReprMethod(message_descriptor, cls) 

1401 _AddUnicodeMethod(message_descriptor, cls) 

1402 _AddByteSizeMethod(message_descriptor, cls) 

1403 _AddSerializeToStringMethod(message_descriptor, cls) 

1404 _AddSerializePartialToStringMethod(message_descriptor, cls) 

1405 _AddMergeFromStringMethod(message_descriptor, cls) 

1406 _AddIsInitializedMethod(message_descriptor, cls) 

1407 _AddMergeFromMethod(cls) 

1408 _AddWhichOneofMethod(message_descriptor, cls) 

1409 # Adds methods which do not depend on cls. 

1410 cls.Clear = _Clear 

1411 cls.UnknownFields = _UnknownFields 

1412 cls.DiscardUnknownFields = _DiscardUnknownFields 

1413 cls._SetListener = _SetListener 

1414 

1415 

1416def _AddPrivateHelperMethods(message_descriptor, cls): 

1417 """Adds implementation of private helper methods to cls.""" 

1418 

1419 def Modified(self): 

1420 """Sets the _cached_byte_size_dirty bit to true, 

1421 and propagates this to our listener iff this was a state change. 

1422 """ 

1423 

1424 # Note: Some callers check _cached_byte_size_dirty before calling 

1425 # _Modified() as an extra optimization. So, if this method is ever 

1426 # changed such that it does stuff even when _cached_byte_size_dirty is 

1427 # already true, the callers need to be updated. 

1428 if not self._cached_byte_size_dirty: 

1429 self._cached_byte_size_dirty = True 

1430 self._listener_for_children.dirty = True 

1431 self._is_present_in_parent = True 

1432 self._listener.Modified() 

1433 

1434 def _UpdateOneofState(self, field): 

1435 """Sets field as the active field in its containing oneof. 

1436 

1437 Will also delete currently active field in the oneof, if it is different 

1438 from the argument. Does not mark the message as modified. 

1439 """ 

1440 other_field = self._oneofs.setdefault(field.containing_oneof, field) 

1441 if other_field is not field: 

1442 del self._fields[other_field] 

1443 self._oneofs[field.containing_oneof] = field 

1444 

1445 cls._Modified = Modified 

1446 cls.SetInParent = Modified 

1447 cls._UpdateOneofState = _UpdateOneofState 

1448 

1449 

1450class _Listener(object): 

1451 

1452 """MessageListener implementation that a parent message registers with its 

1453 child message. 

1454 

1455 In order to support semantics like: 

1456 

1457 foo.bar.baz.moo = 23 

1458 assert foo.HasField('bar') 

1459 

1460 ...child objects must have back references to their parents. 

1461 This helper class is at the heart of this support. 

1462 """ 

1463 

1464 def __init__(self, parent_message): 

1465 """Args: 

1466 parent_message: The message whose _Modified() method we should call when 

1467 we receive Modified() messages. 

1468 """ 

1469 # This listener establishes a back reference from a child (contained) object 

1470 # to its parent (containing) object. We make this a weak reference to avoid 

1471 # creating cyclic garbage when the client finishes with the 'parent' object 

1472 # in the tree. 

1473 if isinstance(parent_message, weakref.ProxyType): 

1474 self._parent_message_weakref = parent_message 

1475 else: 

1476 self._parent_message_weakref = weakref.proxy(parent_message) 

1477 

1478 # As an optimization, we also indicate directly on the listener whether 

1479 # or not the parent message is dirty. This way we can avoid traversing 

1480 # up the tree in the common case. 

1481 self.dirty = False 

1482 

1483 def Modified(self): 

1484 if self.dirty: 

1485 return 

1486 try: 

1487 # Propagate the signal to our parents iff this is the first field set. 

1488 self._parent_message_weakref._Modified() 

1489 except ReferenceError: 

1490 # We can get here if a client has kept a reference to a child object, 

1491 # and is now setting a field on it, but the child's parent has been 

1492 # garbage-collected. This is not an error. 

1493 pass 

1494 

1495 

1496class _OneofListener(_Listener): 

1497 """Special listener implementation for setting composite oneof fields.""" 

1498 

1499 def __init__(self, parent_message, field): 

1500 """Args: 

1501 parent_message: The message whose _Modified() method we should call when 

1502 we receive Modified() messages. 

1503 field: The descriptor of the field being set in the parent message. 

1504 """ 

1505 super(_OneofListener, self).__init__(parent_message) 

1506 self._field = field 

1507 

1508 def Modified(self): 

1509 """Also updates the state of the containing oneof in the parent message.""" 

1510 try: 

1511 self._parent_message_weakref._UpdateOneofState(self._field) 

1512 super(_OneofListener, self).Modified() 

1513 except ReferenceError: 

1514 pass