Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/internal/python_message.py: 11%

772 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8# This code is meant to work on Python 2.4 and above only. 

9# 

10# TODO: Helpers for verbose, common checks like seeing if a 

11# descriptor's cpp_type is CPPTYPE_MESSAGE. 

12 

13"""Contains a metaclass and helper functions used to create 

14protocol message classes from Descriptor objects at runtime. 

15 

16Recall that a metaclass is the "type" of a class. 

17(A class is to a metaclass what an instance is to a class.) 

18 

19In this case, we use the GeneratedProtocolMessageType metaclass 

20to inject all the useful functionality into the classes 

21output by the protocol compiler at compile-time. 

22 

23The upshot of all this is that the real implementation 

24details for ALL pure-Python protocol buffers are *here in 

25this file*. 

26""" 

27 

28__author__ = 'robinson@google.com (Will Robinson)' 

29 

30from io import BytesIO 

31import struct 

32import sys 

33import warnings 

34import weakref 

35 

36from google.protobuf import descriptor as descriptor_mod 

37from google.protobuf import message as message_mod 

38from google.protobuf import text_format 

39# We use "as" to avoid name collisions with variables. 

40from google.protobuf.internal import api_implementation 

41from google.protobuf.internal import containers 

42from google.protobuf.internal import decoder 

43from google.protobuf.internal import encoder 

44from google.protobuf.internal import enum_type_wrapper 

45from google.protobuf.internal import extension_dict 

46from google.protobuf.internal import message_listener as message_listener_mod 

47from google.protobuf.internal import type_checkers 

48from google.protobuf.internal import well_known_types 

49from google.protobuf.internal import wire_format 

50 

51_FieldDescriptor = descriptor_mod.FieldDescriptor 

52_AnyFullTypeName = 'google.protobuf.Any' 

53_ExtensionDict = extension_dict._ExtensionDict 

54 

55class GeneratedProtocolMessageType(type): 

56 

57 """Metaclass for protocol message classes created at runtime from Descriptors. 

58 

59 We add implementations for all methods described in the Message class. We 

60 also create properties to allow getting/setting all fields in the protocol 

61 message. Finally, we create slots to prevent users from accidentally 

62 "setting" nonexistent fields in the protocol message, which then wouldn't get 

63 serialized / deserialized properly. 

64 

65 The protocol compiler currently uses this metaclass to create protocol 

66 message classes at runtime. Clients can also manually create their own 

67 classes at runtime, as in this example: 

68 

69 mydescriptor = Descriptor(.....) 

70 factory = symbol_database.Default() 

71 factory.pool.AddDescriptor(mydescriptor) 

72 MyProtoClass = factory.GetPrototype(mydescriptor) 

73 myproto_instance = MyProtoClass() 

74 myproto.foo_field = 23 

75 ... 

76 """ 

77 

78 # Must be consistent with the protocol-compiler code in 

79 # proto2/compiler/internal/generator.*. 

80 _DESCRIPTOR_KEY = 'DESCRIPTOR' 

81 

82 def __new__(cls, name, bases, dictionary): 

83 """Custom allocation for runtime-generated class types. 

84 

85 We override __new__ because this is apparently the only place 

86 where we can meaningfully set __slots__ on the class we're creating(?). 

87 (The interplay between metaclasses and slots is not very well-documented). 

88 

89 Args: 

90 name: Name of the class (ignored, but required by the 

91 metaclass protocol). 

92 bases: Base classes of the class we're constructing. 

93 (Should be message.Message). We ignore this field, but 

94 it's required by the metaclass protocol 

95 dictionary: The class dictionary of the class we're 

96 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

97 a Descriptor object describing this protocol message 

98 type. 

99 

100 Returns: 

101 Newly-allocated class. 

102 

103 Raises: 

104 RuntimeError: Generated code only work with python cpp extension. 

105 """ 

106 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

107 

108 if isinstance(descriptor, str): 

109 raise RuntimeError('The generated code only work with python cpp ' 

110 'extension, but it is using pure python runtime.') 

111 

112 # If a concrete class already exists for this descriptor, don't try to 

113 # create another. Doing so will break any messages that already exist with 

114 # the existing class. 

115 # 

116 # The C++ implementation appears to have its own internal `PyMessageFactory` 

117 # to achieve similar results. 

118 # 

119 # This most commonly happens in `text_format.py` when using descriptors from 

120 # a custom pool; it calls symbol_database.Global().getPrototype() on a 

121 # descriptor which already has an existing concrete class. 

122 new_class = getattr(descriptor, '_concrete_class', None) 

123 if new_class: 

124 return new_class 

125 

126 if descriptor.full_name in well_known_types.WKTBASES: 

127 bases += (well_known_types.WKTBASES[descriptor.full_name],) 

128 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 

129 _AddSlots(descriptor, dictionary) 

130 

131 superclass = super(GeneratedProtocolMessageType, cls) 

132 new_class = superclass.__new__(cls, name, bases, dictionary) 

133 return new_class 

134 

135 def __init__(cls, name, bases, dictionary): 

136 """Here we perform the majority of our work on the class. 

137 We add enum getters, an __init__ method, implementations 

138 of all Message methods, and properties for all fields 

139 in the protocol type. 

140 

141 Args: 

142 name: Name of the class (ignored, but required by the 

143 metaclass protocol). 

144 bases: Base classes of the class we're constructing. 

145 (Should be message.Message). We ignore this field, but 

146 it's required by the metaclass protocol 

147 dictionary: The class dictionary of the class we're 

148 constructing. dictionary[_DESCRIPTOR_KEY] must contain 

149 a Descriptor object describing this protocol message 

150 type. 

151 """ 

152 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 

153 

154 # If this is an _existing_ class looked up via `_concrete_class` in the 

155 # __new__ method above, then we don't need to re-initialize anything. 

156 existing_class = getattr(descriptor, '_concrete_class', None) 

157 if existing_class: 

158 assert existing_class is cls, ( 

159 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 

160 % (descriptor.full_name)) 

161 return 

162 

163 cls._message_set_decoders_by_tag = {} 

164 cls._fields_by_tag = {} 

165 if (descriptor.has_options and 

166 descriptor.GetOptions().message_set_wire_format): 

167 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 

168 decoder.MessageSetItemDecoder(descriptor), 

169 None, 

170 ) 

171 

172 # Attach stuff to each FieldDescriptor for quick lookup later on. 

173 for field in descriptor.fields: 

174 _AttachFieldHelpers(cls, field) 

175 

176 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'): 

177 extensions = descriptor.file.pool.FindAllExtensions(descriptor) 

178 for ext in extensions: 

179 _AttachFieldHelpers(cls, ext) 

180 

181 descriptor._concrete_class = cls # pylint: disable=protected-access 

182 _AddEnumValues(descriptor, cls) 

183 _AddInitMethod(descriptor, cls) 

184 _AddPropertiesForFields(descriptor, cls) 

185 _AddPropertiesForExtensions(descriptor, cls) 

186 _AddStaticMethods(cls) 

187 _AddMessageMethods(descriptor, cls) 

188 _AddPrivateHelperMethods(descriptor, cls) 

189 

190 superclass = super(GeneratedProtocolMessageType, cls) 

191 superclass.__init__(name, bases, dictionary) 

192 

193 

194# Stateless helpers for GeneratedProtocolMessageType below. 

195# Outside clients should not access these directly. 

196# 

197# I opted not to make any of these methods on the metaclass, to make it more 

198# clear that I'm not really using any state there and to keep clients from 

199# thinking that they have direct access to these construction helpers. 

200 

201 

202def _PropertyName(proto_field_name): 

203 """Returns the name of the public property attribute which 

204 clients can use to get and (in some cases) set the value 

205 of a protocol message field. 

206 

207 Args: 

208 proto_field_name: The protocol message field name, exactly 

209 as it appears (or would appear) in a .proto file. 

210 """ 

211 # TODO: Escape Python keywords (e.g., yield), and test this support. 

212 # nnorwitz makes my day by writing: 

213 # """ 

214 # FYI. See the keyword module in the stdlib. This could be as simple as: 

215 # 

216 # if keyword.iskeyword(proto_field_name): 

217 # return proto_field_name + "_" 

218 # return proto_field_name 

219 # """ 

220 # Kenton says: The above is a BAD IDEA. People rely on being able to use 

221 # getattr() and setattr() to reflectively manipulate field values. If we 

222 # rename the properties, then every such user has to also make sure to apply 

223 # the same transformation. Note that currently if you name a field "yield", 

224 # you can still access it just fine using getattr/setattr -- it's not even 

225 # that cumbersome to do so. 

226 # TODO: Remove this method entirely if/when everyone agrees with my 

227 # position. 

228 return proto_field_name 

229 

230 

231def _AddSlots(message_descriptor, dictionary): 

232 """Adds a __slots__ entry to dictionary, containing the names of all valid 

233 attributes for this message type. 

234 

235 Args: 

236 message_descriptor: A Descriptor instance describing this message type. 

237 dictionary: Class dictionary to which we'll add a '__slots__' entry. 

238 """ 

239 dictionary['__slots__'] = ['_cached_byte_size', 

240 '_cached_byte_size_dirty', 

241 '_fields', 

242 '_unknown_fields', 

243 '_unknown_field_set', 

244 '_is_present_in_parent', 

245 '_listener', 

246 '_listener_for_children', 

247 '__weakref__', 

248 '_oneofs'] 

249 

250 

251def _IsMessageSetExtension(field): 

252 return (field.is_extension and 

253 field.containing_type.has_options and 

254 field.containing_type.GetOptions().message_set_wire_format and 

255 field.type == _FieldDescriptor.TYPE_MESSAGE and 

256 field.label == _FieldDescriptor.LABEL_OPTIONAL) 

257 

258 

259def _IsMapField(field): 

260 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 

261 field.message_type._is_map_entry) 

262 

263 

264def _IsMessageMapField(field): 

265 value_type = field.message_type.fields_by_name['value'] 

266 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 

267 

268def _AttachFieldHelpers(cls, field_descriptor): 

269 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED 

270 field_descriptor._default_constructor = _DefaultValueConstructorForField( 

271 field_descriptor 

272 ) 

273 

274 def AddFieldByTag(wiretype, is_packed): 

275 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 

276 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed) 

277 

278 AddFieldByTag( 

279 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False 

280 ) 

281 

282 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 

283 # To support wire compatibility of adding packed = true, add a decoder for 

284 # packed values regardless of the field's options. 

285 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 

286 

287 

288def _MaybeAddEncoder(cls, field_descriptor): 

289 if hasattr(field_descriptor, '_encoder'): 

290 return 

291 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 

292 is_map_entry = _IsMapField(field_descriptor) 

293 is_packed = field_descriptor.is_packed 

294 

295 if is_map_entry: 

296 field_encoder = encoder.MapEncoder(field_descriptor) 

297 sizer = encoder.MapSizer(field_descriptor, 

298 _IsMessageMapField(field_descriptor)) 

299 elif _IsMessageSetExtension(field_descriptor): 

300 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 

301 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 

302 else: 

303 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 

304 field_descriptor.number, is_repeated, is_packed) 

305 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 

306 field_descriptor.number, is_repeated, is_packed) 

307 

308 field_descriptor._sizer = sizer 

309 field_descriptor._encoder = field_encoder 

310 

311 

312def _MaybeAddDecoder(cls, field_descriptor): 

313 if hasattr(field_descriptor, '_decoders'): 

314 return 

315 

316 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED 

317 is_map_entry = _IsMapField(field_descriptor) 

318 helper_decoders = {} 

319 

320 def AddDecoder(is_packed): 

321 decode_type = field_descriptor.type 

322 if (decode_type == _FieldDescriptor.TYPE_ENUM and 

323 not field_descriptor.enum_type.is_closed): 

324 decode_type = _FieldDescriptor.TYPE_INT32 

325 

326 oneof_descriptor = None 

327 if field_descriptor.containing_oneof is not None: 

328 oneof_descriptor = field_descriptor 

329 

330 if is_map_entry: 

331 is_message_map = _IsMessageMapField(field_descriptor) 

332 

333 field_decoder = decoder.MapDecoder( 

334 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 

335 is_message_map) 

336 elif decode_type == _FieldDescriptor.TYPE_STRING: 

337 field_decoder = decoder.StringDecoder( 

338 field_descriptor.number, is_repeated, is_packed, 

339 field_descriptor, field_descriptor._default_constructor, 

340 not field_descriptor.has_presence) 

341 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

342 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

343 field_descriptor.number, is_repeated, is_packed, 

344 field_descriptor, field_descriptor._default_constructor) 

345 else: 

346 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 

347 field_descriptor.number, is_repeated, is_packed, 

348 # pylint: disable=protected-access 

349 field_descriptor, field_descriptor._default_constructor, 

350 not field_descriptor.has_presence) 

351 

352 helper_decoders[is_packed] = field_decoder 

353 

354 AddDecoder(False) 

355 

356 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 

357 # To support wire compatibility of adding packed = true, add a decoder for 

358 # packed values regardless of the field's options. 

359 AddDecoder(True) 

360 

361 field_descriptor._decoders = helper_decoders 

362 

363 

364def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 

365 extensions = descriptor.extensions_by_name 

366 for extension_name, extension_field in extensions.items(): 

367 assert extension_name not in dictionary 

368 dictionary[extension_name] = extension_field 

369 

370 

371def _AddEnumValues(descriptor, cls): 

372 """Sets class-level attributes for all enum fields defined in this message. 

373 

374 Also exporting a class-level object that can name enum values. 

375 

376 Args: 

377 descriptor: Descriptor object for this message type. 

378 cls: Class we're constructing for this message type. 

379 """ 

380 for enum_type in descriptor.enum_types: 

381 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 

382 for enum_value in enum_type.values: 

383 setattr(cls, enum_value.name, enum_value.number) 

384 

385 

386def _GetInitializeDefaultForMap(field): 

387 if field.label != _FieldDescriptor.LABEL_REPEATED: 

388 raise ValueError('map_entry set on non-repeated field %s' % ( 

389 field.name)) 

390 fields_by_name = field.message_type.fields_by_name 

391 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 

392 

393 value_field = fields_by_name['value'] 

394 if _IsMessageMapField(field): 

395 def MakeMessageMapDefault(message): 

396 return containers.MessageMap( 

397 message._listener_for_children, value_field.message_type, key_checker, 

398 field.message_type) 

399 return MakeMessageMapDefault 

400 else: 

401 value_checker = type_checkers.GetTypeChecker(value_field) 

402 def MakePrimitiveMapDefault(message): 

403 return containers.ScalarMap( 

404 message._listener_for_children, key_checker, value_checker, 

405 field.message_type) 

406 return MakePrimitiveMapDefault 

407 

408def _DefaultValueConstructorForField(field): 

409 """Returns a function which returns a default value for a field. 

410 

411 Args: 

412 field: FieldDescriptor object for this field. 

413 

414 The returned function has one argument: 

415 message: Message instance containing this field, or a weakref proxy 

416 of same. 

417 

418 That function in turn returns a default value for this field. The default 

419 value may refer back to |message| via a weak reference. 

420 """ 

421 

422 if _IsMapField(field): 

423 return _GetInitializeDefaultForMap(field) 

424 

425 if field.label == _FieldDescriptor.LABEL_REPEATED: 

426 if field.has_default_value and field.default_value != []: 

427 raise ValueError('Repeated field default value not empty list: %s' % ( 

428 field.default_value)) 

429 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

430 # We can't look at _concrete_class yet since it might not have 

431 # been set. (Depends on order in which we initialize the classes). 

432 message_type = field.message_type 

433 def MakeRepeatedMessageDefault(message): 

434 return containers.RepeatedCompositeFieldContainer( 

435 message._listener_for_children, field.message_type) 

436 return MakeRepeatedMessageDefault 

437 else: 

438 type_checker = type_checkers.GetTypeChecker(field) 

439 def MakeRepeatedScalarDefault(message): 

440 return containers.RepeatedScalarFieldContainer( 

441 message._listener_for_children, type_checker) 

442 return MakeRepeatedScalarDefault 

443 

444 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

445 message_type = field.message_type 

446 def MakeSubMessageDefault(message): 

447 # _concrete_class may not yet be initialized. 

448 if not hasattr(message_type, '_concrete_class'): 

449 from google.protobuf import message_factory 

450 message_factory.GetMessageClass(message_type) 

451 result = message_type._concrete_class() 

452 result._SetListener( 

453 _OneofListener(message, field) 

454 if field.containing_oneof is not None 

455 else message._listener_for_children) 

456 return result 

457 return MakeSubMessageDefault 

458 

459 def MakeScalarDefault(message): 

460 # TODO: This may be broken since there may not be 

461 # default_value. Combine with has_default_value somehow. 

462 return field.default_value 

463 return MakeScalarDefault 

464 

465 

466def _ReraiseTypeErrorWithFieldName(message_name, field_name): 

467 """Re-raise the currently-handled TypeError with the field name added.""" 

468 exc = sys.exc_info()[1] 

469 if len(exc.args) == 1 and type(exc) is TypeError: 

470 # simple TypeError; add field name to exception message 

471 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 

472 

473 # re-raise possibly-amended exception with original traceback: 

474 raise exc.with_traceback(sys.exc_info()[2]) 

475 

476 

477def _AddInitMethod(message_descriptor, cls): 

478 """Adds an __init__ method to cls.""" 

479 

480 def _GetIntegerEnumValue(enum_type, value): 

481 """Convert a string or integer enum value to an integer. 

482 

483 If the value is a string, it is converted to the enum value in 

484 enum_type with the same name. If the value is not a string, it's 

485 returned as-is. (No conversion or bounds-checking is done.) 

486 """ 

487 if isinstance(value, str): 

488 try: 

489 return enum_type.values_by_name[value].number 

490 except KeyError: 

491 raise ValueError('Enum type %s: unknown label "%s"' % ( 

492 enum_type.full_name, value)) 

493 return value 

494 

495 def init(self, **kwargs): 

496 self._cached_byte_size = 0 

497 self._cached_byte_size_dirty = len(kwargs) > 0 

498 self._fields = {} 

499 # Contains a mapping from oneof field descriptors to the descriptor 

500 # of the currently set field in that oneof field. 

501 self._oneofs = {} 

502 

503 # _unknown_fields is () when empty for efficiency, and will be turned into 

504 # a list if fields are added. 

505 self._unknown_fields = () 

506 # _unknown_field_set is None when empty for efficiency, and will be 

507 # turned into UnknownFieldSet struct if fields are added. 

508 self._unknown_field_set = None # pylint: disable=protected-access 

509 self._is_present_in_parent = False 

510 self._listener = message_listener_mod.NullMessageListener() 

511 self._listener_for_children = _Listener(self) 

512 for field_name, field_value in kwargs.items(): 

513 field = _GetFieldByName(message_descriptor, field_name) 

514 if field is None: 

515 raise TypeError('%s() got an unexpected keyword argument "%s"' % 

516 (message_descriptor.name, field_name)) 

517 if field_value is None: 

518 # field=None is the same as no field at all. 

519 continue 

520 if field.label == _FieldDescriptor.LABEL_REPEATED: 

521 copy = field._default_constructor(self) 

522 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 

523 if _IsMapField(field): 

524 if _IsMessageMapField(field): 

525 for key in field_value: 

526 copy[key].MergeFrom(field_value[key]) 

527 else: 

528 copy.update(field_value) 

529 else: 

530 for val in field_value: 

531 if isinstance(val, dict): 

532 copy.add(**val) 

533 else: 

534 copy.add().MergeFrom(val) 

535 else: # Scalar 

536 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

537 field_value = [_GetIntegerEnumValue(field.enum_type, val) 

538 for val in field_value] 

539 copy.extend(field_value) 

540 self._fields[field] = copy 

541 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

542 copy = field._default_constructor(self) 

543 new_val = field_value 

544 if isinstance(field_value, dict): 

545 new_val = field.message_type._concrete_class(**field_value) 

546 try: 

547 copy.MergeFrom(new_val) 

548 except TypeError: 

549 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

550 self._fields[field] = copy 

551 else: 

552 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 

553 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 

554 try: 

555 setattr(self, field_name, field_value) 

556 except TypeError: 

557 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 

558 

559 init.__module__ = None 

560 init.__doc__ = None 

561 cls.__init__ = init 

562 

563 

564def _GetFieldByName(message_descriptor, field_name): 

565 """Returns a field descriptor by field name. 

566 

567 Args: 

568 message_descriptor: A Descriptor describing all fields in message. 

569 field_name: The name of the field to retrieve. 

570 Returns: 

571 The field descriptor associated with the field name. 

572 """ 

573 try: 

574 return message_descriptor.fields_by_name[field_name] 

575 except KeyError: 

576 raise ValueError('Protocol message %s has no "%s" field.' % 

577 (message_descriptor.name, field_name)) 

578 

579 

580def _AddPropertiesForFields(descriptor, cls): 

581 """Adds properties for all fields in this protocol message type.""" 

582 for field in descriptor.fields: 

583 _AddPropertiesForField(field, cls) 

584 

585 if descriptor.is_extendable: 

586 # _ExtensionDict is just an adaptor with no state so we allocate a new one 

587 # every time it is accessed. 

588 cls.Extensions = property(lambda self: _ExtensionDict(self)) 

589 

590 

591def _AddPropertiesForField(field, cls): 

592 """Adds a public property for a protocol message field. 

593 Clients can use this property to get and (in the case 

594 of non-repeated scalar fields) directly set the value 

595 of a protocol message field. 

596 

597 Args: 

598 field: A FieldDescriptor for this field. 

599 cls: The class we're constructing. 

600 """ 

601 # Catch it if we add other types that we should 

602 # handle specially here. 

603 assert _FieldDescriptor.MAX_CPPTYPE == 10 

604 

605 constant_name = field.name.upper() + '_FIELD_NUMBER' 

606 setattr(cls, constant_name, field.number) 

607 

608 if field.label == _FieldDescriptor.LABEL_REPEATED: 

609 _AddPropertiesForRepeatedField(field, cls) 

610 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

611 _AddPropertiesForNonRepeatedCompositeField(field, cls) 

612 else: 

613 _AddPropertiesForNonRepeatedScalarField(field, cls) 

614 

615 

616class _FieldProperty(property): 

617 __slots__ = ('DESCRIPTOR',) 

618 

619 def __init__(self, descriptor, getter, setter, doc): 

620 property.__init__(self, getter, setter, doc=doc) 

621 self.DESCRIPTOR = descriptor 

622 

623 

624def _AddPropertiesForRepeatedField(field, cls): 

625 """Adds a public property for a "repeated" protocol message field. Clients 

626 can use this property to get the value of the field, which will be either a 

627 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 

628 below). 

629 

630 Note that when clients add values to these containers, we perform 

631 type-checking in the case of repeated scalar fields, and we also set any 

632 necessary "has" bits as a side-effect. 

633 

634 Args: 

635 field: A FieldDescriptor for this field. 

636 cls: The class we're constructing. 

637 """ 

638 proto_field_name = field.name 

639 property_name = _PropertyName(proto_field_name) 

640 

641 def getter(self): 

642 field_value = self._fields.get(field) 

643 if field_value is None: 

644 # Construct a new object to represent this field. 

645 field_value = field._default_constructor(self) 

646 

647 # Atomically check if another thread has preempted us and, if not, swap 

648 # in the new object we just created. If someone has preempted us, we 

649 # take that object and discard ours. 

650 # WARNING: We are relying on setdefault() being atomic. This is true 

651 # in CPython but we haven't investigated others. This warning appears 

652 # in several other locations in this file. 

653 field_value = self._fields.setdefault(field, field_value) 

654 return field_value 

655 getter.__module__ = None 

656 getter.__doc__ = 'Getter for %s.' % proto_field_name 

657 

658 # We define a setter just so we can throw an exception with a more 

659 # helpful error message. 

660 def setter(self, new_value): 

661 raise AttributeError('Assignment not allowed to repeated field ' 

662 '"%s" in protocol message object.' % proto_field_name) 

663 

664 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

665 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

666 

667 

668def _AddPropertiesForNonRepeatedScalarField(field, cls): 

669 """Adds a public property for a nonrepeated, scalar protocol message field. 

670 Clients can use this property to get and directly set the value of the field. 

671 Note that when the client sets the value of a field by using this property, 

672 all necessary "has" bits are set as a side-effect, and we also perform 

673 type-checking. 

674 

675 Args: 

676 field: A FieldDescriptor for this field. 

677 cls: The class we're constructing. 

678 """ 

679 proto_field_name = field.name 

680 property_name = _PropertyName(proto_field_name) 

681 type_checker = type_checkers.GetTypeChecker(field) 

682 default_value = field.default_value 

683 

684 def getter(self): 

685 # TODO: This may be broken since there may not be 

686 # default_value. Combine with has_default_value somehow. 

687 return self._fields.get(field, default_value) 

688 getter.__module__ = None 

689 getter.__doc__ = 'Getter for %s.' % proto_field_name 

690 

691 def field_setter(self, new_value): 

692 # pylint: disable=protected-access 

693 # Testing the value for truthiness captures all of the proto3 defaults 

694 # (0, 0.0, enum 0, and False). 

695 try: 

696 new_value = type_checker.CheckValue(new_value) 

697 except TypeError as e: 

698 raise TypeError( 

699 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 

700 if not field.has_presence and not new_value: 

701 self._fields.pop(field, None) 

702 else: 

703 self._fields[field] = new_value 

704 # Check _cached_byte_size_dirty inline to improve performance, since scalar 

705 # setters are called frequently. 

706 if not self._cached_byte_size_dirty: 

707 self._Modified() 

708 

709 if field.containing_oneof: 

710 def setter(self, new_value): 

711 field_setter(self, new_value) 

712 self._UpdateOneofState(field) 

713 else: 

714 setter = field_setter 

715 

716 setter.__module__ = None 

717 setter.__doc__ = 'Setter for %s.' % proto_field_name 

718 

719 # Add a property to encapsulate the getter/setter. 

720 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

721 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

722 

723 

724def _AddPropertiesForNonRepeatedCompositeField(field, cls): 

725 """Adds a public property for a nonrepeated, composite protocol message field. 

726 A composite field is a "group" or "message" field. 

727 

728 Clients can use this property to get the value of the field, but cannot 

729 assign to the property directly. 

730 

731 Args: 

732 field: A FieldDescriptor for this field. 

733 cls: The class we're constructing. 

734 """ 

735 # TODO: Remove duplication with similar method 

736 # for non-repeated scalars. 

737 proto_field_name = field.name 

738 property_name = _PropertyName(proto_field_name) 

739 

740 def getter(self): 

741 field_value = self._fields.get(field) 

742 if field_value is None: 

743 # Construct a new object to represent this field. 

744 field_value = field._default_constructor(self) 

745 

746 # Atomically check if another thread has preempted us and, if not, swap 

747 # in the new object we just created. If someone has preempted us, we 

748 # take that object and discard ours. 

749 # WARNING: We are relying on setdefault() being atomic. This is true 

750 # in CPython but we haven't investigated others. This warning appears 

751 # in several other locations in this file. 

752 field_value = self._fields.setdefault(field, field_value) 

753 return field_value 

754 getter.__module__ = None 

755 getter.__doc__ = 'Getter for %s.' % proto_field_name 

756 

757 # We define a setter just so we can throw an exception with a more 

758 # helpful error message. 

759 def setter(self, new_value): 

760 raise AttributeError('Assignment not allowed to composite field ' 

761 '"%s" in protocol message object.' % proto_field_name) 

762 

763 # Add a property to encapsulate the getter. 

764 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 

765 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 

766 

767 

768def _AddPropertiesForExtensions(descriptor, cls): 

769 """Adds properties for all fields in this protocol message type.""" 

770 extensions = descriptor.extensions_by_name 

771 for extension_name, extension_field in extensions.items(): 

772 constant_name = extension_name.upper() + '_FIELD_NUMBER' 

773 setattr(cls, constant_name, extension_field.number) 

774 

775 # TODO: Migrate all users of these attributes to functions like 

776 # pool.FindExtensionByNumber(descriptor). 

777 if descriptor.file is not None: 

778 # TODO: Use cls.MESSAGE_FACTORY.pool when available. 

779 pool = descriptor.file.pool 

780 

781def _AddStaticMethods(cls): 

782 # TODO: This probably needs to be thread-safe(?) 

783 def RegisterExtension(field_descriptor): 

784 field_descriptor.containing_type = cls.DESCRIPTOR 

785 # TODO: Use cls.MESSAGE_FACTORY.pool when available. 

786 # pylint: disable=protected-access 

787 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(field_descriptor) 

788 _AttachFieldHelpers(cls, field_descriptor) 

789 cls.RegisterExtension = staticmethod(RegisterExtension) 

790 

791 def FromString(s): 

792 message = cls() 

793 message.MergeFromString(s) 

794 return message 

795 cls.FromString = staticmethod(FromString) 

796 

797 

798def _IsPresent(item): 

799 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 

800 value should be included in the list returned by ListFields().""" 

801 

802 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 

803 return bool(item[1]) 

804 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

805 return item[1]._is_present_in_parent 

806 else: 

807 return True 

808 

809 

810def _AddListFieldsMethod(message_descriptor, cls): 

811 """Helper for _AddMessageMethods().""" 

812 

813 def ListFields(self): 

814 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 

815 all_fields.sort(key = lambda item: item[0].number) 

816 return all_fields 

817 

818 cls.ListFields = ListFields 

819 

820 

821def _AddHasFieldMethod(message_descriptor, cls): 

822 """Helper for _AddMessageMethods().""" 

823 

824 hassable_fields = {} 

825 for field in message_descriptor.fields: 

826 if field.label == _FieldDescriptor.LABEL_REPEATED: 

827 continue 

828 # For proto3, only submessages and fields inside a oneof have presence. 

829 if not field.has_presence: 

830 continue 

831 hassable_fields[field.name] = field 

832 

833 # Has methods are supported for oneof descriptors. 

834 for oneof in message_descriptor.oneofs: 

835 hassable_fields[oneof.name] = oneof 

836 

837 def HasField(self, field_name): 

838 try: 

839 field = hassable_fields[field_name] 

840 except KeyError as exc: 

841 raise ValueError('Protocol message %s has no non-repeated field "%s" ' 

842 'nor has presence is not available for this field.' % ( 

843 message_descriptor.full_name, field_name)) from exc 

844 

845 if isinstance(field, descriptor_mod.OneofDescriptor): 

846 try: 

847 return HasField(self, self._oneofs[field].name) 

848 except KeyError: 

849 return False 

850 else: 

851 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

852 value = self._fields.get(field) 

853 return value is not None and value._is_present_in_parent 

854 else: 

855 return field in self._fields 

856 

857 cls.HasField = HasField 

858 

859 

860def _AddClearFieldMethod(message_descriptor, cls): 

861 """Helper for _AddMessageMethods().""" 

862 def ClearField(self, field_name): 

863 try: 

864 field = message_descriptor.fields_by_name[field_name] 

865 except KeyError: 

866 try: 

867 field = message_descriptor.oneofs_by_name[field_name] 

868 if field in self._oneofs: 

869 field = self._oneofs[field] 

870 else: 

871 return 

872 except KeyError: 

873 raise ValueError('Protocol message %s has no "%s" field.' % 

874 (message_descriptor.name, field_name)) 

875 

876 if field in self._fields: 

877 # To match the C++ implementation, we need to invalidate iterators 

878 # for map fields when ClearField() happens. 

879 if hasattr(self._fields[field], 'InvalidateIterators'): 

880 self._fields[field].InvalidateIterators() 

881 

882 # Note: If the field is a sub-message, its listener will still point 

883 # at us. That's fine, because the worst than can happen is that it 

884 # will call _Modified() and invalidate our byte size. Big deal. 

885 del self._fields[field] 

886 

887 if self._oneofs.get(field.containing_oneof, None) is field: 

888 del self._oneofs[field.containing_oneof] 

889 

890 # Always call _Modified() -- even if nothing was changed, this is 

891 # a mutating method, and thus calling it should cause the field to become 

892 # present in the parent message. 

893 self._Modified() 

894 

895 cls.ClearField = ClearField 

896 

897 

898def _AddClearExtensionMethod(cls): 

899 """Helper for _AddMessageMethods().""" 

900 def ClearExtension(self, field_descriptor): 

901 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

902 

903 # Similar to ClearField(), above. 

904 if field_descriptor in self._fields: 

905 del self._fields[field_descriptor] 

906 self._Modified() 

907 cls.ClearExtension = ClearExtension 

908 

909 

910def _AddHasExtensionMethod(cls): 

911 """Helper for _AddMessageMethods().""" 

912 def HasExtension(self, field_descriptor): 

913 extension_dict._VerifyExtensionHandle(self, field_descriptor) 

914 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: 

915 raise KeyError('"%s" is repeated.' % field_descriptor.full_name) 

916 

917 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

918 value = self._fields.get(field_descriptor) 

919 return value is not None and value._is_present_in_parent 

920 else: 

921 return field_descriptor in self._fields 

922 cls.HasExtension = HasExtension 

923 

924def _InternalUnpackAny(msg): 

925 """Unpacks Any message and returns the unpacked message. 

926 

927 This internal method is different from public Any Unpack method which takes 

928 the target message as argument. _InternalUnpackAny method does not have 

929 target message type and need to find the message type in descriptor pool. 

930 

931 Args: 

932 msg: An Any message to be unpacked. 

933 

934 Returns: 

935 The unpacked message. 

936 """ 

937 # TODO: Don't use the factory of generated messages. 

938 # To make Any work with custom factories, use the message factory of the 

939 # parent message. 

940 # pylint: disable=g-import-not-at-top 

941 from google.protobuf import symbol_database 

942 factory = symbol_database.Default() 

943 

944 type_url = msg.type_url 

945 

946 if not type_url: 

947 return None 

948 

949 # TODO: For now we just strip the hostname. Better logic will be 

950 # required. 

951 type_name = type_url.split('/')[-1] 

952 descriptor = factory.pool.FindMessageTypeByName(type_name) 

953 

954 if descriptor is None: 

955 return None 

956 

957 message_class = factory.GetPrototype(descriptor) 

958 message = message_class() 

959 

960 message.ParseFromString(msg.value) 

961 return message 

962 

963 

964def _AddEqualsMethod(message_descriptor, cls): 

965 """Helper for _AddMessageMethods().""" 

966 def __eq__(self, other): 

967 if (not isinstance(other, message_mod.Message) or 

968 other.DESCRIPTOR != self.DESCRIPTOR): 

969 return NotImplemented 

970 

971 if self is other: 

972 return True 

973 

974 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 

975 any_a = _InternalUnpackAny(self) 

976 any_b = _InternalUnpackAny(other) 

977 if any_a and any_b: 

978 return any_a == any_b 

979 

980 if not self.ListFields() == other.ListFields(): 

981 return False 

982 

983 # TODO: Fix UnknownFieldSet to consider MessageSet extensions, 

984 # then use it for the comparison. 

985 unknown_fields = list(self._unknown_fields) 

986 unknown_fields.sort() 

987 other_unknown_fields = list(other._unknown_fields) 

988 other_unknown_fields.sort() 

989 return unknown_fields == other_unknown_fields 

990 

991 cls.__eq__ = __eq__ 

992 

993 

994def _AddStrMethod(message_descriptor, cls): 

995 """Helper for _AddMessageMethods().""" 

996 def __str__(self): 

997 return text_format.MessageToString(self) 

998 cls.__str__ = __str__ 

999 

1000 

1001def _AddReprMethod(message_descriptor, cls): 

1002 """Helper for _AddMessageMethods().""" 

1003 def __repr__(self): 

1004 return text_format.MessageToString(self) 

1005 cls.__repr__ = __repr__ 

1006 

1007 

1008def _AddUnicodeMethod(unused_message_descriptor, cls): 

1009 """Helper for _AddMessageMethods().""" 

1010 

1011 def __unicode__(self): 

1012 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 

1013 cls.__unicode__ = __unicode__ 

1014 

1015 

1016def _BytesForNonRepeatedElement(value, field_number, field_type): 

1017 """Returns the number of bytes needed to serialize a non-repeated element. 

1018 The returned byte count includes space for tag information and any 

1019 other additional space associated with serializing value. 

1020 

1021 Args: 

1022 value: Value we're serializing. 

1023 field_number: Field number of this value. (Since the field number 

1024 is stored as part of a varint-encoded tag, this has an impact 

1025 on the total bytes required to serialize the value). 

1026 field_type: The type of the field. One of the TYPE_* constants 

1027 within FieldDescriptor. 

1028 """ 

1029 try: 

1030 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 

1031 return fn(field_number, value) 

1032 except KeyError: 

1033 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 

1034 

1035 

1036def _AddByteSizeMethod(message_descriptor, cls): 

1037 """Helper for _AddMessageMethods().""" 

1038 

1039 def ByteSize(self): 

1040 if not self._cached_byte_size_dirty: 

1041 return self._cached_byte_size 

1042 

1043 size = 0 

1044 descriptor = self.DESCRIPTOR 

1045 if descriptor._is_map_entry: 

1046 # Fields of map entry should always be serialized. 

1047 key_field = descriptor.fields_by_name['key'] 

1048 _MaybeAddEncoder(cls, key_field) 

1049 size = key_field._sizer(self.key) 

1050 value_field = descriptor.fields_by_name['value'] 

1051 _MaybeAddEncoder(cls, value_field) 

1052 size += value_field._sizer(self.value) 

1053 else: 

1054 for field_descriptor, field_value in self.ListFields(): 

1055 _MaybeAddEncoder(cls, field_descriptor) 

1056 size += field_descriptor._sizer(field_value) 

1057 for tag_bytes, value_bytes in self._unknown_fields: 

1058 size += len(tag_bytes) + len(value_bytes) 

1059 

1060 self._cached_byte_size = size 

1061 self._cached_byte_size_dirty = False 

1062 self._listener_for_children.dirty = False 

1063 return size 

1064 

1065 cls.ByteSize = ByteSize 

1066 

1067 

1068def _AddSerializeToStringMethod(message_descriptor, cls): 

1069 """Helper for _AddMessageMethods().""" 

1070 

1071 def SerializeToString(self, **kwargs): 

1072 # Check if the message has all of its required fields set. 

1073 if not self.IsInitialized(): 

1074 raise message_mod.EncodeError( 

1075 'Message %s is missing required fields: %s' % ( 

1076 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 

1077 return self.SerializePartialToString(**kwargs) 

1078 cls.SerializeToString = SerializeToString 

1079 

1080 

1081def _AddSerializePartialToStringMethod(message_descriptor, cls): 

1082 """Helper for _AddMessageMethods().""" 

1083 

1084 def SerializePartialToString(self, **kwargs): 

1085 out = BytesIO() 

1086 self._InternalSerialize(out.write, **kwargs) 

1087 return out.getvalue() 

1088 cls.SerializePartialToString = SerializePartialToString 

1089 

1090 def InternalSerialize(self, write_bytes, deterministic=None): 

1091 if deterministic is None: 

1092 deterministic = ( 

1093 api_implementation.IsPythonDefaultSerializationDeterministic()) 

1094 else: 

1095 deterministic = bool(deterministic) 

1096 

1097 descriptor = self.DESCRIPTOR 

1098 if descriptor._is_map_entry: 

1099 # Fields of map entry should always be serialized. 

1100 key_field = descriptor.fields_by_name['key'] 

1101 _MaybeAddEncoder(cls, key_field) 

1102 key_field._encoder(write_bytes, self.key, deterministic) 

1103 value_field = descriptor.fields_by_name['value'] 

1104 _MaybeAddEncoder(cls, value_field) 

1105 value_field._encoder(write_bytes, self.value, deterministic) 

1106 else: 

1107 for field_descriptor, field_value in self.ListFields(): 

1108 _MaybeAddEncoder(cls, field_descriptor) 

1109 field_descriptor._encoder(write_bytes, field_value, deterministic) 

1110 for tag_bytes, value_bytes in self._unknown_fields: 

1111 write_bytes(tag_bytes) 

1112 write_bytes(value_bytes) 

1113 cls._InternalSerialize = InternalSerialize 

1114 

1115 

1116def _AddMergeFromStringMethod(message_descriptor, cls): 

1117 """Helper for _AddMessageMethods().""" 

1118 def MergeFromString(self, serialized): 

1119 serialized = memoryview(serialized) 

1120 length = len(serialized) 

1121 try: 

1122 if self._InternalParse(serialized, 0, length) != length: 

1123 # The only reason _InternalParse would return early is if it 

1124 # encountered an end-group tag. 

1125 raise message_mod.DecodeError('Unexpected end-group tag.') 

1126 except (IndexError, TypeError): 

1127 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 

1128 raise message_mod.DecodeError('Truncated message.') 

1129 except struct.error as e: 

1130 raise message_mod.DecodeError(e) 

1131 return length # Return this for legacy reasons. 

1132 cls.MergeFromString = MergeFromString 

1133 

1134 local_ReadTag = decoder.ReadTag 

1135 local_SkipField = decoder.SkipField 

1136 fields_by_tag = cls._fields_by_tag 

1137 message_set_decoders_by_tag = cls._message_set_decoders_by_tag 

1138 

1139 def InternalParse(self, buffer, pos, end): 

1140 """Create a message from serialized bytes. 

1141 

1142 Args: 

1143 self: Message, instance of the proto message object. 

1144 buffer: memoryview of the serialized data. 

1145 pos: int, position to start in the serialized data. 

1146 end: int, end position of the serialized data. 

1147 

1148 Returns: 

1149 Message object. 

1150 """ 

1151 # Guard against internal misuse, since this function is called internally 

1152 # quite extensively, and its easy to accidentally pass bytes. 

1153 assert isinstance(buffer, memoryview) 

1154 self._Modified() 

1155 field_dict = self._fields 

1156 # pylint: disable=protected-access 

1157 unknown_field_set = self._unknown_field_set 

1158 while pos != end: 

1159 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 

1160 field_decoder, field_des = message_set_decoders_by_tag.get( 

1161 tag_bytes, (None, None) 

1162 ) 

1163 if field_decoder: 

1164 pos = field_decoder(buffer, new_pos, end, self, field_dict) 

1165 continue 

1166 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None)) 

1167 if field_des is None: 

1168 if not self._unknown_fields: # pylint: disable=protected-access 

1169 self._unknown_fields = [] # pylint: disable=protected-access 

1170 if unknown_field_set is None: 

1171 # pylint: disable=protected-access 

1172 self._unknown_field_set = containers.UnknownFieldSet() 

1173 # pylint: disable=protected-access 

1174 unknown_field_set = self._unknown_field_set 

1175 # pylint: disable=protected-access 

1176 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 

1177 field_number, wire_type = wire_format.UnpackTag(tag) 

1178 if field_number == 0: 

1179 raise message_mod.DecodeError('Field number 0 is illegal.') 

1180 # TODO: remove old_pos. 

1181 old_pos = new_pos 

1182 (data, new_pos) = decoder._DecodeUnknownField( 

1183 buffer, new_pos, wire_type) # pylint: disable=protected-access 

1184 if new_pos == -1: 

1185 return pos 

1186 # pylint: disable=protected-access 

1187 unknown_field_set._add(field_number, wire_type, data) 

1188 # TODO: remove _unknown_fields. 

1189 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 

1190 if new_pos == -1: 

1191 return pos 

1192 self._unknown_fields.append( 

1193 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 

1194 pos = new_pos 

1195 else: 

1196 _MaybeAddDecoder(cls, field_des) 

1197 field_decoder = field_des._decoders[is_packed] 

1198 pos = field_decoder(buffer, new_pos, end, self, field_dict) 

1199 if field_des.containing_oneof: 

1200 self._UpdateOneofState(field_des) 

1201 return pos 

1202 cls._InternalParse = InternalParse 

1203 

1204 

1205def _AddIsInitializedMethod(message_descriptor, cls): 

1206 """Adds the IsInitialized and FindInitializationError methods to the 

1207 protocol message class.""" 

1208 

1209 required_fields = [field for field in message_descriptor.fields 

1210 if field.label == _FieldDescriptor.LABEL_REQUIRED] 

1211 

1212 def IsInitialized(self, errors=None): 

1213 """Checks if all required fields of a message are set. 

1214 

1215 Args: 

1216 errors: A list which, if provided, will be populated with the field 

1217 paths of all missing required fields. 

1218 

1219 Returns: 

1220 True iff the specified message has all required fields set. 

1221 """ 

1222 

1223 # Performance is critical so we avoid HasField() and ListFields(). 

1224 

1225 for field in required_fields: 

1226 if (field not in self._fields or 

1227 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 

1228 not self._fields[field]._is_present_in_parent)): 

1229 if errors is not None: 

1230 errors.extend(self.FindInitializationErrors()) 

1231 return False 

1232 

1233 for field, value in list(self._fields.items()): # dict can change size! 

1234 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1235 if field.label == _FieldDescriptor.LABEL_REPEATED: 

1236 if (field.message_type._is_map_entry): 

1237 continue 

1238 for element in value: 

1239 if not element.IsInitialized(): 

1240 if errors is not None: 

1241 errors.extend(self.FindInitializationErrors()) 

1242 return False 

1243 elif value._is_present_in_parent and not value.IsInitialized(): 

1244 if errors is not None: 

1245 errors.extend(self.FindInitializationErrors()) 

1246 return False 

1247 

1248 return True 

1249 

1250 cls.IsInitialized = IsInitialized 

1251 

1252 def FindInitializationErrors(self): 

1253 """Finds required fields which are not initialized. 

1254 

1255 Returns: 

1256 A list of strings. Each string is a path to an uninitialized field from 

1257 the top-level message, e.g. "foo.bar[5].baz". 

1258 """ 

1259 

1260 errors = [] # simplify things 

1261 

1262 for field in required_fields: 

1263 if not self.HasField(field.name): 

1264 errors.append(field.name) 

1265 

1266 for field, value in self.ListFields(): 

1267 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1268 if field.is_extension: 

1269 name = '(%s)' % field.full_name 

1270 else: 

1271 name = field.name 

1272 

1273 if _IsMapField(field): 

1274 if _IsMessageMapField(field): 

1275 for key in value: 

1276 element = value[key] 

1277 prefix = '%s[%s].' % (name, key) 

1278 sub_errors = element.FindInitializationErrors() 

1279 errors += [prefix + error for error in sub_errors] 

1280 else: 

1281 # ScalarMaps can't have any initialization errors. 

1282 pass 

1283 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1284 for i in range(len(value)): 

1285 element = value[i] 

1286 prefix = '%s[%d].' % (name, i) 

1287 sub_errors = element.FindInitializationErrors() 

1288 errors += [prefix + error for error in sub_errors] 

1289 else: 

1290 prefix = name + '.' 

1291 sub_errors = value.FindInitializationErrors() 

1292 errors += [prefix + error for error in sub_errors] 

1293 

1294 return errors 

1295 

1296 cls.FindInitializationErrors = FindInitializationErrors 

1297 

1298 

1299def _FullyQualifiedClassName(klass): 

1300 module = klass.__module__ 

1301 name = getattr(klass, '__qualname__', klass.__name__) 

1302 if module in (None, 'builtins', '__builtin__'): 

1303 return name 

1304 return module + '.' + name 

1305 

1306 

1307def _AddMergeFromMethod(cls): 

1308 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 

1309 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 

1310 

1311 def MergeFrom(self, msg): 

1312 if not isinstance(msg, cls): 

1313 raise TypeError( 

1314 'Parameter to MergeFrom() must be instance of same class: ' 

1315 'expected %s got %s.' % (_FullyQualifiedClassName(cls), 

1316 _FullyQualifiedClassName(msg.__class__))) 

1317 

1318 assert msg is not self 

1319 self._Modified() 

1320 

1321 fields = self._fields 

1322 

1323 for field, value in msg._fields.items(): 

1324 if field.label == LABEL_REPEATED: 

1325 field_value = fields.get(field) 

1326 if field_value is None: 

1327 # Construct a new object to represent this field. 

1328 field_value = field._default_constructor(self) 

1329 fields[field] = field_value 

1330 field_value.MergeFrom(value) 

1331 elif field.cpp_type == CPPTYPE_MESSAGE: 

1332 if value._is_present_in_parent: 

1333 field_value = fields.get(field) 

1334 if field_value is None: 

1335 # Construct a new object to represent this field. 

1336 field_value = field._default_constructor(self) 

1337 fields[field] = field_value 

1338 field_value.MergeFrom(value) 

1339 else: 

1340 self._fields[field] = value 

1341 if field.containing_oneof: 

1342 self._UpdateOneofState(field) 

1343 

1344 if msg._unknown_fields: 

1345 if not self._unknown_fields: 

1346 self._unknown_fields = [] 

1347 self._unknown_fields.extend(msg._unknown_fields) 

1348 # pylint: disable=protected-access 

1349 if self._unknown_field_set is None: 

1350 self._unknown_field_set = containers.UnknownFieldSet() 

1351 self._unknown_field_set._extend(msg._unknown_field_set) 

1352 

1353 cls.MergeFrom = MergeFrom 

1354 

1355 

1356def _AddWhichOneofMethod(message_descriptor, cls): 

1357 def WhichOneof(self, oneof_name): 

1358 """Returns the name of the currently set field inside a oneof, or None.""" 

1359 try: 

1360 field = message_descriptor.oneofs_by_name[oneof_name] 

1361 except KeyError: 

1362 raise ValueError( 

1363 'Protocol message has no oneof "%s" field.' % oneof_name) 

1364 

1365 nested_field = self._oneofs.get(field, None) 

1366 if nested_field is not None and self.HasField(nested_field.name): 

1367 return nested_field.name 

1368 else: 

1369 return None 

1370 

1371 cls.WhichOneof = WhichOneof 

1372 

1373 

1374def _Clear(self): 

1375 # Clear fields. 

1376 self._fields = {} 

1377 self._unknown_fields = () 

1378 # pylint: disable=protected-access 

1379 if self._unknown_field_set is not None: 

1380 self._unknown_field_set._clear() 

1381 self._unknown_field_set = None 

1382 

1383 self._oneofs = {} 

1384 self._Modified() 

1385 

1386 

1387def _UnknownFields(self): 

1388 warnings.warn( 

1389 'message.UnknownFields() is deprecated. Please use the add one ' 

1390 'feature unknown_fields.UnknownFieldSet(message) in ' 

1391 'unknown_fields.py instead.' 

1392 ) 

1393 if self._unknown_field_set is None: # pylint: disable=protected-access 

1394 # pylint: disable=protected-access 

1395 self._unknown_field_set = containers.UnknownFieldSet() 

1396 return self._unknown_field_set # pylint: disable=protected-access 

1397 

1398 

1399def _DiscardUnknownFields(self): 

1400 self._unknown_fields = [] 

1401 self._unknown_field_set = None # pylint: disable=protected-access 

1402 for field, value in self.ListFields(): 

1403 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 

1404 if _IsMapField(field): 

1405 if _IsMessageMapField(field): 

1406 for key in value: 

1407 value[key].DiscardUnknownFields() 

1408 elif field.label == _FieldDescriptor.LABEL_REPEATED: 

1409 for sub_message in value: 

1410 sub_message.DiscardUnknownFields() 

1411 else: 

1412 value.DiscardUnknownFields() 

1413 

1414 

1415def _SetListener(self, listener): 

1416 if listener is None: 

1417 self._listener = message_listener_mod.NullMessageListener() 

1418 else: 

1419 self._listener = listener 

1420 

1421 

1422def _AddMessageMethods(message_descriptor, cls): 

1423 """Adds implementations of all Message methods to cls.""" 

1424 _AddListFieldsMethod(message_descriptor, cls) 

1425 _AddHasFieldMethod(message_descriptor, cls) 

1426 _AddClearFieldMethod(message_descriptor, cls) 

1427 if message_descriptor.is_extendable: 

1428 _AddClearExtensionMethod(cls) 

1429 _AddHasExtensionMethod(cls) 

1430 _AddEqualsMethod(message_descriptor, cls) 

1431 _AddStrMethod(message_descriptor, cls) 

1432 _AddReprMethod(message_descriptor, cls) 

1433 _AddUnicodeMethod(message_descriptor, cls) 

1434 _AddByteSizeMethod(message_descriptor, cls) 

1435 _AddSerializeToStringMethod(message_descriptor, cls) 

1436 _AddSerializePartialToStringMethod(message_descriptor, cls) 

1437 _AddMergeFromStringMethod(message_descriptor, cls) 

1438 _AddIsInitializedMethod(message_descriptor, cls) 

1439 _AddMergeFromMethod(cls) 

1440 _AddWhichOneofMethod(message_descriptor, cls) 

1441 # Adds methods which do not depend on cls. 

1442 cls.Clear = _Clear 

1443 cls.UnknownFields = _UnknownFields 

1444 cls.DiscardUnknownFields = _DiscardUnknownFields 

1445 cls._SetListener = _SetListener 

1446 

1447 

1448def _AddPrivateHelperMethods(message_descriptor, cls): 

1449 """Adds implementation of private helper methods to cls.""" 

1450 

1451 def Modified(self): 

1452 """Sets the _cached_byte_size_dirty bit to true, 

1453 and propagates this to our listener iff this was a state change. 

1454 """ 

1455 

1456 # Note: Some callers check _cached_byte_size_dirty before calling 

1457 # _Modified() as an extra optimization. So, if this method is ever 

1458 # changed such that it does stuff even when _cached_byte_size_dirty is 

1459 # already true, the callers need to be updated. 

1460 if not self._cached_byte_size_dirty: 

1461 self._cached_byte_size_dirty = True 

1462 self._listener_for_children.dirty = True 

1463 self._is_present_in_parent = True 

1464 self._listener.Modified() 

1465 

1466 def _UpdateOneofState(self, field): 

1467 """Sets field as the active field in its containing oneof. 

1468 

1469 Will also delete currently active field in the oneof, if it is different 

1470 from the argument. Does not mark the message as modified. 

1471 """ 

1472 other_field = self._oneofs.setdefault(field.containing_oneof, field) 

1473 if other_field is not field: 

1474 del self._fields[other_field] 

1475 self._oneofs[field.containing_oneof] = field 

1476 

1477 cls._Modified = Modified 

1478 cls.SetInParent = Modified 

1479 cls._UpdateOneofState = _UpdateOneofState 

1480 

1481 

1482class _Listener(object): 

1483 

1484 """MessageListener implementation that a parent message registers with its 

1485 child message. 

1486 

1487 In order to support semantics like: 

1488 

1489 foo.bar.baz.moo = 23 

1490 assert foo.HasField('bar') 

1491 

1492 ...child objects must have back references to their parents. 

1493 This helper class is at the heart of this support. 

1494 """ 

1495 

1496 def __init__(self, parent_message): 

1497 """Args: 

1498 parent_message: The message whose _Modified() method we should call when 

1499 we receive Modified() messages. 

1500 """ 

1501 # This listener establishes a back reference from a child (contained) object 

1502 # to its parent (containing) object. We make this a weak reference to avoid 

1503 # creating cyclic garbage when the client finishes with the 'parent' object 

1504 # in the tree. 

1505 if isinstance(parent_message, weakref.ProxyType): 

1506 self._parent_message_weakref = parent_message 

1507 else: 

1508 self._parent_message_weakref = weakref.proxy(parent_message) 

1509 

1510 # As an optimization, we also indicate directly on the listener whether 

1511 # or not the parent message is dirty. This way we can avoid traversing 

1512 # up the tree in the common case. 

1513 self.dirty = False 

1514 

1515 def Modified(self): 

1516 if self.dirty: 

1517 return 

1518 try: 

1519 # Propagate the signal to our parents iff this is the first field set. 

1520 self._parent_message_weakref._Modified() 

1521 except ReferenceError: 

1522 # We can get here if a client has kept a reference to a child object, 

1523 # and is now setting a field on it, but the child's parent has been 

1524 # garbage-collected. This is not an error. 

1525 pass 

1526 

1527 

1528class _OneofListener(_Listener): 

1529 """Special listener implementation for setting composite oneof fields.""" 

1530 

1531 def __init__(self, parent_message, field): 

1532 """Args: 

1533 parent_message: The message whose _Modified() method we should call when 

1534 we receive Modified() messages. 

1535 field: The descriptor of the field being set in the parent message. 

1536 """ 

1537 super(_OneofListener, self).__init__(parent_message) 

1538 self._field = field 

1539 

1540 def Modified(self): 

1541 """Also updates the state of the containing oneof in the parent message.""" 

1542 try: 

1543 self._parent_message_weakref._UpdateOneofState(self._field) 

1544 super(_OneofListener, self).Modified() 

1545 except ReferenceError: 

1546 pass