Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/protobuf/descriptor_pool.py: 15%

469 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:40 +0000

1# Protocol Buffers - Google's data interchange format 

2# Copyright 2008 Google Inc. All rights reserved. 

3# 

4# Use of this source code is governed by a BSD-style 

5# license that can be found in the LICENSE file or at 

6# https://developers.google.com/open-source/licenses/bsd 

7 

8"""Provides DescriptorPool to use as a container for proto2 descriptors. 

9 

10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain 

11a collection of protocol buffer descriptors for use when dynamically creating 

12message types at runtime. 

13 

14For most applications protocol buffers should be used via modules generated by 

15the protocol buffer compiler tool. This should only be used when the type of 

16protocol buffers used in an application or library cannot be predetermined. 

17 

18Below is a straightforward example on how to use this class:: 

19 

20 pool = DescriptorPool() 

21 file_descriptor_protos = [ ... ] 

22 for file_descriptor_proto in file_descriptor_protos: 

23 pool.Add(file_descriptor_proto) 

24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') 

25 

26The message descriptor can be used in conjunction with the message_factory 

27module in order to create a protocol buffer class that can be encoded and 

28decoded. 

29 

30If you want to get a Python class for the specified proto, use the 

31helper functions inside google.protobuf.message_factory 

32directly instead of this class. 

33""" 

34 

35__author__ = 'matthewtoia@google.com (Matt Toia)' 

36 

37import collections 

38import warnings 

39 

40from google.protobuf import descriptor 

41from google.protobuf import descriptor_database 

42from google.protobuf import text_encoding 

43from google.protobuf.internal import python_message 

44 

45_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access 

46 

47 

48def _Deprecated(func): 

49 """Mark functions as deprecated.""" 

50 

51 def NewFunc(*args, **kwargs): 

52 warnings.warn( 

53 'Call to deprecated function %s(). Note: Do add unlinked descriptors ' 

54 'to descriptor_pool is wrong. Please use Add() or AddSerializedFile() ' 

55 'instead. This function will be removed soon.' % func.__name__, 

56 category=DeprecationWarning) 

57 return func(*args, **kwargs) 

58 NewFunc.__name__ = func.__name__ 

59 NewFunc.__doc__ = func.__doc__ 

60 NewFunc.__dict__.update(func.__dict__) 

61 return NewFunc 

62 

63 

64def _NormalizeFullyQualifiedName(name): 

65 """Remove leading period from fully-qualified type name. 

66 

67 Due to b/13860351 in descriptor_database.py, types in the root namespace are 

68 generated with a leading period. This function removes that prefix. 

69 

70 Args: 

71 name (str): The fully-qualified symbol name. 

72 

73 Returns: 

74 str: The normalized fully-qualified symbol name. 

75 """ 

76 return name.lstrip('.') 

77 

78 

79def _OptionsOrNone(descriptor_proto): 

80 """Returns the value of the field `options`, or None if it is not set.""" 

81 if descriptor_proto.HasField('options'): 

82 return descriptor_proto.options 

83 else: 

84 return None 

85 

86 

87def _IsMessageSetExtension(field): 

88 return (field.is_extension and 

89 field.containing_type.has_options and 

90 field.containing_type.GetOptions().message_set_wire_format and 

91 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 

92 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) 

93 

94 

95class DescriptorPool(object): 

96 """A collection of protobufs dynamically constructed by descriptor protos.""" 

97 

98 if _USE_C_DESCRIPTORS: 

99 

100 def __new__(cls, descriptor_db=None): 

101 # pylint: disable=protected-access 

102 return descriptor._message.DescriptorPool(descriptor_db) 

103 

104 def __init__( 

105 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False 

106 ): 

107 """Initializes a Pool of proto buffs. 

108 

109 The descriptor_db argument to the constructor is provided to allow 

110 specialized file descriptor proto lookup code to be triggered on demand. An 

111 example would be an implementation which will read and compile a file 

112 specified in a call to FindFileByName() and not require the call to Add() 

113 at all. Results from this database will be cached internally here as well. 

114 

115 Args: 

116 descriptor_db: A secondary source of file descriptors. 

117 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with 

118 C++. 

119 """ 

120 

121 self._internal_db = descriptor_database.DescriptorDatabase() 

122 self._descriptor_db = descriptor_db 

123 self._descriptors = {} 

124 self._enum_descriptors = {} 

125 self._service_descriptors = {} 

126 self._file_descriptors = {} 

127 self._toplevel_extensions = {} 

128 self._top_enum_values = {} 

129 # We store extensions in two two-level mappings: The first key is the 

130 # descriptor of the message being extended, the second key is the extension 

131 # full name or its tag number. 

132 self._extensions_by_name = collections.defaultdict(dict) 

133 self._extensions_by_number = collections.defaultdict(dict) 

134 

135 def _CheckConflictRegister(self, desc, desc_name, file_name): 

136 """Check if the descriptor name conflicts with another of the same name. 

137 

138 Args: 

139 desc: Descriptor of a message, enum, service, extension or enum value. 

140 desc_name (str): the full name of desc. 

141 file_name (str): The file name of descriptor. 

142 """ 

143 for register, descriptor_type in [ 

144 (self._descriptors, descriptor.Descriptor), 

145 (self._enum_descriptors, descriptor.EnumDescriptor), 

146 (self._service_descriptors, descriptor.ServiceDescriptor), 

147 (self._toplevel_extensions, descriptor.FieldDescriptor), 

148 (self._top_enum_values, descriptor.EnumValueDescriptor)]: 

149 if desc_name in register: 

150 old_desc = register[desc_name] 

151 if isinstance(old_desc, descriptor.EnumValueDescriptor): 

152 old_file = old_desc.type.file.name 

153 else: 

154 old_file = old_desc.file.name 

155 

156 if not isinstance(desc, descriptor_type) or ( 

157 old_file != file_name): 

158 error_msg = ('Conflict register for file "' + file_name + 

159 '": ' + desc_name + 

160 ' is already defined in file "' + 

161 old_file + '". Please fix the conflict by adding ' 

162 'package name on the proto file, or use different ' 

163 'name for the duplication.') 

164 if isinstance(desc, descriptor.EnumValueDescriptor): 

165 error_msg += ('\nNote: enum values appear as ' 

166 'siblings of the enum type instead of ' 

167 'children of it.') 

168 

169 raise TypeError(error_msg) 

170 

171 return 

172 

173 def Add(self, file_desc_proto): 

174 """Adds the FileDescriptorProto and its types to this pool. 

175 

176 Args: 

177 file_desc_proto (FileDescriptorProto): The file descriptor to add. 

178 """ 

179 

180 self._internal_db.Add(file_desc_proto) 

181 

182 def AddSerializedFile(self, serialized_file_desc_proto): 

183 """Adds the FileDescriptorProto and its types to this pool. 

184 

185 Args: 

186 serialized_file_desc_proto (bytes): A bytes string, serialization of the 

187 :class:`FileDescriptorProto` to add. 

188 

189 Returns: 

190 FileDescriptor: Descriptor for the added file. 

191 """ 

192 

193 # pylint: disable=g-import-not-at-top 

194 from google.protobuf import descriptor_pb2 

195 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( 

196 serialized_file_desc_proto) 

197 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto) 

198 file_desc.serialized_pb = serialized_file_desc_proto 

199 return file_desc 

200 

201 # Add Descriptor to descriptor pool is deprecated. Please use Add() 

202 # or AddSerializedFile() to add a FileDescriptorProto instead. 

203 @_Deprecated 

204 def AddDescriptor(self, desc): 

205 self._AddDescriptor(desc) 

206 

207 # Never call this method. It is for internal usage only. 

208 def _AddDescriptor(self, desc): 

209 """Adds a Descriptor to the pool, non-recursively. 

210 

211 If the Descriptor contains nested messages or enums, the caller must 

212 explicitly register them. This method also registers the FileDescriptor 

213 associated with the message. 

214 

215 Args: 

216 desc: A Descriptor. 

217 """ 

218 if not isinstance(desc, descriptor.Descriptor): 

219 raise TypeError('Expected instance of descriptor.Descriptor.') 

220 

221 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 

222 

223 self._descriptors[desc.full_name] = desc 

224 self._AddFileDescriptor(desc.file) 

225 

226 # Never call this method. It is for internal usage only. 

227 def _AddEnumDescriptor(self, enum_desc): 

228 """Adds an EnumDescriptor to the pool. 

229 

230 This method also registers the FileDescriptor associated with the enum. 

231 

232 Args: 

233 enum_desc: An EnumDescriptor. 

234 """ 

235 

236 if not isinstance(enum_desc, descriptor.EnumDescriptor): 

237 raise TypeError('Expected instance of descriptor.EnumDescriptor.') 

238 

239 file_name = enum_desc.file.name 

240 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) 

241 self._enum_descriptors[enum_desc.full_name] = enum_desc 

242 

243 # Top enum values need to be indexed. 

244 # Count the number of dots to see whether the enum is toplevel or nested 

245 # in a message. We cannot use enum_desc.containing_type at this stage. 

246 if enum_desc.file.package: 

247 top_level = (enum_desc.full_name.count('.') 

248 - enum_desc.file.package.count('.') == 1) 

249 else: 

250 top_level = enum_desc.full_name.count('.') == 0 

251 if top_level: 

252 file_name = enum_desc.file.name 

253 package = enum_desc.file.package 

254 for enum_value in enum_desc.values: 

255 full_name = _NormalizeFullyQualifiedName( 

256 '.'.join((package, enum_value.name))) 

257 self._CheckConflictRegister(enum_value, full_name, file_name) 

258 self._top_enum_values[full_name] = enum_value 

259 self._AddFileDescriptor(enum_desc.file) 

260 

261 # Add ServiceDescriptor to descriptor pool is deprecated. Please use Add() 

262 # or AddSerializedFile() to add a FileDescriptorProto instead. 

263 @_Deprecated 

264 def AddServiceDescriptor(self, service_desc): 

265 self._AddServiceDescriptor(service_desc) 

266 

267 # Never call this method. It is for internal usage only. 

268 def _AddServiceDescriptor(self, service_desc): 

269 """Adds a ServiceDescriptor to the pool. 

270 

271 Args: 

272 service_desc: A ServiceDescriptor. 

273 """ 

274 

275 if not isinstance(service_desc, descriptor.ServiceDescriptor): 

276 raise TypeError('Expected instance of descriptor.ServiceDescriptor.') 

277 

278 self._CheckConflictRegister(service_desc, service_desc.full_name, 

279 service_desc.file.name) 

280 self._service_descriptors[service_desc.full_name] = service_desc 

281 

282 # Add ExtensionDescriptor to descriptor pool is deprecated. Please use Add() 

283 # or AddSerializedFile() to add a FileDescriptorProto instead. 

284 @_Deprecated 

285 def AddExtensionDescriptor(self, extension): 

286 self._AddExtensionDescriptor(extension) 

287 

288 # Never call this method. It is for internal usage only. 

289 def _AddExtensionDescriptor(self, extension): 

290 """Adds a FieldDescriptor describing an extension to the pool. 

291 

292 Args: 

293 extension: A FieldDescriptor. 

294 

295 Raises: 

296 AssertionError: when another extension with the same number extends the 

297 same message. 

298 TypeError: when the specified extension is not a 

299 descriptor.FieldDescriptor. 

300 """ 

301 if not (isinstance(extension, descriptor.FieldDescriptor) and 

302 extension.is_extension): 

303 raise TypeError('Expected an extension descriptor.') 

304 

305 if extension.extension_scope is None: 

306 self._CheckConflictRegister( 

307 extension, extension.full_name, extension.file.name) 

308 self._toplevel_extensions[extension.full_name] = extension 

309 

310 try: 

311 existing_desc = self._extensions_by_number[ 

312 extension.containing_type][extension.number] 

313 except KeyError: 

314 pass 

315 else: 

316 if extension is not existing_desc: 

317 raise AssertionError( 

318 'Extensions "%s" and "%s" both try to extend message type "%s" ' 

319 'with field number %d.' % 

320 (extension.full_name, existing_desc.full_name, 

321 extension.containing_type.full_name, extension.number)) 

322 

323 self._extensions_by_number[extension.containing_type][ 

324 extension.number] = extension 

325 self._extensions_by_name[extension.containing_type][ 

326 extension.full_name] = extension 

327 

328 # Also register MessageSet extensions with the type name. 

329 if _IsMessageSetExtension(extension): 

330 self._extensions_by_name[extension.containing_type][ 

331 extension.message_type.full_name] = extension 

332 

333 if hasattr(extension.containing_type, '_concrete_class'): 

334 python_message._AttachFieldHelpers( 

335 extension.containing_type._concrete_class, extension) 

336 

337 @_Deprecated 

338 def AddFileDescriptor(self, file_desc): 

339 self._InternalAddFileDescriptor(file_desc) 

340 

341 # Never call this method. It is for internal usage only. 

342 def _InternalAddFileDescriptor(self, file_desc): 

343 """Adds a FileDescriptor to the pool, non-recursively. 

344 

345 If the FileDescriptor contains messages or enums, the caller must explicitly 

346 register them. 

347 

348 Args: 

349 file_desc: A FileDescriptor. 

350 """ 

351 

352 self._AddFileDescriptor(file_desc) 

353 

354 def _AddFileDescriptor(self, file_desc): 

355 """Adds a FileDescriptor to the pool, non-recursively. 

356 

357 If the FileDescriptor contains messages or enums, the caller must explicitly 

358 register them. 

359 

360 Args: 

361 file_desc: A FileDescriptor. 

362 """ 

363 

364 if not isinstance(file_desc, descriptor.FileDescriptor): 

365 raise TypeError('Expected instance of descriptor.FileDescriptor.') 

366 self._file_descriptors[file_desc.name] = file_desc 

367 

368 def FindFileByName(self, file_name): 

369 """Gets a FileDescriptor by file name. 

370 

371 Args: 

372 file_name (str): The path to the file to get a descriptor for. 

373 

374 Returns: 

375 FileDescriptor: The descriptor for the named file. 

376 

377 Raises: 

378 KeyError: if the file cannot be found in the pool. 

379 """ 

380 

381 try: 

382 return self._file_descriptors[file_name] 

383 except KeyError: 

384 pass 

385 

386 try: 

387 file_proto = self._internal_db.FindFileByName(file_name) 

388 except KeyError as error: 

389 if self._descriptor_db: 

390 file_proto = self._descriptor_db.FindFileByName(file_name) 

391 else: 

392 raise error 

393 if not file_proto: 

394 raise KeyError('Cannot find a file named %s' % file_name) 

395 return self._ConvertFileProtoToFileDescriptor(file_proto) 

396 

397 def FindFileContainingSymbol(self, symbol): 

398 """Gets the FileDescriptor for the file containing the specified symbol. 

399 

400 Args: 

401 symbol (str): The name of the symbol to search for. 

402 

403 Returns: 

404 FileDescriptor: Descriptor for the file that contains the specified 

405 symbol. 

406 

407 Raises: 

408 KeyError: if the file cannot be found in the pool. 

409 """ 

410 

411 symbol = _NormalizeFullyQualifiedName(symbol) 

412 try: 

413 return self._InternalFindFileContainingSymbol(symbol) 

414 except KeyError: 

415 pass 

416 

417 try: 

418 # Try fallback database. Build and find again if possible. 

419 self._FindFileContainingSymbolInDb(symbol) 

420 return self._InternalFindFileContainingSymbol(symbol) 

421 except KeyError: 

422 raise KeyError('Cannot find a file containing %s' % symbol) 

423 

424 def _InternalFindFileContainingSymbol(self, symbol): 

425 """Gets the already built FileDescriptor containing the specified symbol. 

426 

427 Args: 

428 symbol (str): The name of the symbol to search for. 

429 

430 Returns: 

431 FileDescriptor: Descriptor for the file that contains the specified 

432 symbol. 

433 

434 Raises: 

435 KeyError: if the file cannot be found in the pool. 

436 """ 

437 try: 

438 return self._descriptors[symbol].file 

439 except KeyError: 

440 pass 

441 

442 try: 

443 return self._enum_descriptors[symbol].file 

444 except KeyError: 

445 pass 

446 

447 try: 

448 return self._service_descriptors[symbol].file 

449 except KeyError: 

450 pass 

451 

452 try: 

453 return self._top_enum_values[symbol].type.file 

454 except KeyError: 

455 pass 

456 

457 try: 

458 return self._toplevel_extensions[symbol].file 

459 except KeyError: 

460 pass 

461 

462 # Try fields, enum values and nested extensions inside a message. 

463 top_name, _, sub_name = symbol.rpartition('.') 

464 try: 

465 message = self.FindMessageTypeByName(top_name) 

466 assert (sub_name in message.extensions_by_name or 

467 sub_name in message.fields_by_name or 

468 sub_name in message.enum_values_by_name) 

469 return message.file 

470 except (KeyError, AssertionError): 

471 raise KeyError('Cannot find a file containing %s' % symbol) 

472 

473 def FindMessageTypeByName(self, full_name): 

474 """Loads the named descriptor from the pool. 

475 

476 Args: 

477 full_name (str): The full name of the descriptor to load. 

478 

479 Returns: 

480 Descriptor: The descriptor for the named type. 

481 

482 Raises: 

483 KeyError: if the message cannot be found in the pool. 

484 """ 

485 

486 full_name = _NormalizeFullyQualifiedName(full_name) 

487 if full_name not in self._descriptors: 

488 self._FindFileContainingSymbolInDb(full_name) 

489 return self._descriptors[full_name] 

490 

491 def FindEnumTypeByName(self, full_name): 

492 """Loads the named enum descriptor from the pool. 

493 

494 Args: 

495 full_name (str): The full name of the enum descriptor to load. 

496 

497 Returns: 

498 EnumDescriptor: The enum descriptor for the named type. 

499 

500 Raises: 

501 KeyError: if the enum cannot be found in the pool. 

502 """ 

503 

504 full_name = _NormalizeFullyQualifiedName(full_name) 

505 if full_name not in self._enum_descriptors: 

506 self._FindFileContainingSymbolInDb(full_name) 

507 return self._enum_descriptors[full_name] 

508 

509 def FindFieldByName(self, full_name): 

510 """Loads the named field descriptor from the pool. 

511 

512 Args: 

513 full_name (str): The full name of the field descriptor to load. 

514 

515 Returns: 

516 FieldDescriptor: The field descriptor for the named field. 

517 

518 Raises: 

519 KeyError: if the field cannot be found in the pool. 

520 """ 

521 full_name = _NormalizeFullyQualifiedName(full_name) 

522 message_name, _, field_name = full_name.rpartition('.') 

523 message_descriptor = self.FindMessageTypeByName(message_name) 

524 return message_descriptor.fields_by_name[field_name] 

525 

526 def FindOneofByName(self, full_name): 

527 """Loads the named oneof descriptor from the pool. 

528 

529 Args: 

530 full_name (str): The full name of the oneof descriptor to load. 

531 

532 Returns: 

533 OneofDescriptor: The oneof descriptor for the named oneof. 

534 

535 Raises: 

536 KeyError: if the oneof cannot be found in the pool. 

537 """ 

538 full_name = _NormalizeFullyQualifiedName(full_name) 

539 message_name, _, oneof_name = full_name.rpartition('.') 

540 message_descriptor = self.FindMessageTypeByName(message_name) 

541 return message_descriptor.oneofs_by_name[oneof_name] 

542 

543 def FindExtensionByName(self, full_name): 

544 """Loads the named extension descriptor from the pool. 

545 

546 Args: 

547 full_name (str): The full name of the extension descriptor to load. 

548 

549 Returns: 

550 FieldDescriptor: The field descriptor for the named extension. 

551 

552 Raises: 

553 KeyError: if the extension cannot be found in the pool. 

554 """ 

555 full_name = _NormalizeFullyQualifiedName(full_name) 

556 try: 

557 # The proto compiler does not give any link between the FileDescriptor 

558 # and top-level extensions unless the FileDescriptorProto is added to 

559 # the DescriptorDatabase, but this can impact memory usage. 

560 # So we registered these extensions by name explicitly. 

561 return self._toplevel_extensions[full_name] 

562 except KeyError: 

563 pass 

564 message_name, _, extension_name = full_name.rpartition('.') 

565 try: 

566 # Most extensions are nested inside a message. 

567 scope = self.FindMessageTypeByName(message_name) 

568 except KeyError: 

569 # Some extensions are defined at file scope. 

570 scope = self._FindFileContainingSymbolInDb(full_name) 

571 return scope.extensions_by_name[extension_name] 

572 

573 def FindExtensionByNumber(self, message_descriptor, number): 

574 """Gets the extension of the specified message with the specified number. 

575 

576 Extensions have to be registered to this pool by calling :func:`Add` or 

577 :func:`AddExtensionDescriptor`. 

578 

579 Args: 

580 message_descriptor (Descriptor): descriptor of the extended message. 

581 number (int): Number of the extension field. 

582 

583 Returns: 

584 FieldDescriptor: The descriptor for the extension. 

585 

586 Raises: 

587 KeyError: when no extension with the given number is known for the 

588 specified message. 

589 """ 

590 try: 

591 return self._extensions_by_number[message_descriptor][number] 

592 except KeyError: 

593 self._TryLoadExtensionFromDB(message_descriptor, number) 

594 return self._extensions_by_number[message_descriptor][number] 

595 

596 def FindAllExtensions(self, message_descriptor): 

597 """Gets all the known extensions of a given message. 

598 

599 Extensions have to be registered to this pool by build related 

600 :func:`Add` or :func:`AddExtensionDescriptor`. 

601 

602 Args: 

603 message_descriptor (Descriptor): Descriptor of the extended message. 

604 

605 Returns: 

606 list[FieldDescriptor]: Field descriptors describing the extensions. 

607 """ 

608 # Fallback to descriptor db if FindAllExtensionNumbers is provided. 

609 if self._descriptor_db and hasattr( 

610 self._descriptor_db, 'FindAllExtensionNumbers'): 

611 full_name = message_descriptor.full_name 

612 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) 

613 for number in all_numbers: 

614 if number in self._extensions_by_number[message_descriptor]: 

615 continue 

616 self._TryLoadExtensionFromDB(message_descriptor, number) 

617 

618 return list(self._extensions_by_number[message_descriptor].values()) 

619 

620 def _TryLoadExtensionFromDB(self, message_descriptor, number): 

621 """Try to Load extensions from descriptor db. 

622 

623 Args: 

624 message_descriptor: descriptor of the extended message. 

625 number: the extension number that needs to be loaded. 

626 """ 

627 if not self._descriptor_db: 

628 return 

629 # Only supported when FindFileContainingExtension is provided. 

630 if not hasattr( 

631 self._descriptor_db, 'FindFileContainingExtension'): 

632 return 

633 

634 full_name = message_descriptor.full_name 

635 file_proto = self._descriptor_db.FindFileContainingExtension( 

636 full_name, number) 

637 

638 if file_proto is None: 

639 return 

640 

641 try: 

642 self._ConvertFileProtoToFileDescriptor(file_proto) 

643 except: 

644 warn_msg = ('Unable to load proto file %s for extension number %d.' % 

645 (file_proto.name, number)) 

646 warnings.warn(warn_msg, RuntimeWarning) 

647 

648 def FindServiceByName(self, full_name): 

649 """Loads the named service descriptor from the pool. 

650 

651 Args: 

652 full_name (str): The full name of the service descriptor to load. 

653 

654 Returns: 

655 ServiceDescriptor: The service descriptor for the named service. 

656 

657 Raises: 

658 KeyError: if the service cannot be found in the pool. 

659 """ 

660 full_name = _NormalizeFullyQualifiedName(full_name) 

661 if full_name not in self._service_descriptors: 

662 self._FindFileContainingSymbolInDb(full_name) 

663 return self._service_descriptors[full_name] 

664 

665 def FindMethodByName(self, full_name): 

666 """Loads the named service method descriptor from the pool. 

667 

668 Args: 

669 full_name (str): The full name of the method descriptor to load. 

670 

671 Returns: 

672 MethodDescriptor: The method descriptor for the service method. 

673 

674 Raises: 

675 KeyError: if the method cannot be found in the pool. 

676 """ 

677 full_name = _NormalizeFullyQualifiedName(full_name) 

678 service_name, _, method_name = full_name.rpartition('.') 

679 service_descriptor = self.FindServiceByName(service_name) 

680 return service_descriptor.methods_by_name[method_name] 

681 

682 def _FindFileContainingSymbolInDb(self, symbol): 

683 """Finds the file in descriptor DB containing the specified symbol. 

684 

685 Args: 

686 symbol (str): The name of the symbol to search for. 

687 

688 Returns: 

689 FileDescriptor: The file that contains the specified symbol. 

690 

691 Raises: 

692 KeyError: if the file cannot be found in the descriptor database. 

693 """ 

694 try: 

695 file_proto = self._internal_db.FindFileContainingSymbol(symbol) 

696 except KeyError as error: 

697 if self._descriptor_db: 

698 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) 

699 else: 

700 raise error 

701 if not file_proto: 

702 raise KeyError('Cannot find a file containing %s' % symbol) 

703 return self._ConvertFileProtoToFileDescriptor(file_proto) 

704 

705 def _ConvertFileProtoToFileDescriptor(self, file_proto): 

706 """Creates a FileDescriptor from a proto or returns a cached copy. 

707 

708 This method also has the side effect of loading all the symbols found in 

709 the file into the appropriate dictionaries in the pool. 

710 

711 Args: 

712 file_proto: The proto to convert. 

713 

714 Returns: 

715 A FileDescriptor matching the passed in proto. 

716 """ 

717 if file_proto.name not in self._file_descriptors: 

718 built_deps = list(self._GetDeps(file_proto.dependency)) 

719 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] 

720 public_deps = [direct_deps[i] for i in file_proto.public_dependency] 

721 

722 file_descriptor = descriptor.FileDescriptor( 

723 pool=self, 

724 name=file_proto.name, 

725 package=file_proto.package, 

726 syntax=file_proto.syntax, 

727 options=_OptionsOrNone(file_proto), 

728 serialized_pb=file_proto.SerializeToString(), 

729 dependencies=direct_deps, 

730 public_dependencies=public_deps, 

731 # pylint: disable=protected-access 

732 create_key=descriptor._internal_create_key) 

733 scope = {} 

734 

735 # This loop extracts all the message and enum types from all the 

736 # dependencies of the file_proto. This is necessary to create the 

737 # scope of available message types when defining the passed in 

738 # file proto. 

739 for dependency in built_deps: 

740 scope.update(self._ExtractSymbols( 

741 dependency.message_types_by_name.values())) 

742 scope.update((_PrefixWithDot(enum.full_name), enum) 

743 for enum in dependency.enum_types_by_name.values()) 

744 

745 for message_type in file_proto.message_type: 

746 message_desc = self._ConvertMessageDescriptor( 

747 message_type, file_proto.package, file_descriptor, scope, 

748 file_proto.syntax) 

749 file_descriptor.message_types_by_name[message_desc.name] = ( 

750 message_desc) 

751 

752 for enum_type in file_proto.enum_type: 

753 file_descriptor.enum_types_by_name[enum_type.name] = ( 

754 self._ConvertEnumDescriptor(enum_type, file_proto.package, 

755 file_descriptor, None, scope, True)) 

756 

757 for index, extension_proto in enumerate(file_proto.extension): 

758 extension_desc = self._MakeFieldDescriptor( 

759 extension_proto, file_proto.package, index, file_descriptor, 

760 is_extension=True) 

761 extension_desc.containing_type = self._GetTypeFromScope( 

762 file_descriptor.package, extension_proto.extendee, scope) 

763 self._SetFieldType(extension_proto, extension_desc, 

764 file_descriptor.package, scope) 

765 file_descriptor.extensions_by_name[extension_desc.name] = ( 

766 extension_desc) 

767 

768 for desc_proto in file_proto.message_type: 

769 self._SetAllFieldTypes(file_proto.package, desc_proto, scope) 

770 

771 if file_proto.package: 

772 desc_proto_prefix = _PrefixWithDot(file_proto.package) 

773 else: 

774 desc_proto_prefix = '' 

775 

776 for desc_proto in file_proto.message_type: 

777 desc = self._GetTypeFromScope( 

778 desc_proto_prefix, desc_proto.name, scope) 

779 file_descriptor.message_types_by_name[desc_proto.name] = desc 

780 

781 for index, service_proto in enumerate(file_proto.service): 

782 file_descriptor.services_by_name[service_proto.name] = ( 

783 self._MakeServiceDescriptor(service_proto, index, scope, 

784 file_proto.package, file_descriptor)) 

785 

786 self._file_descriptors[file_proto.name] = file_descriptor 

787 

788 # Add extensions to the pool 

789 def AddExtensionForNested(message_type): 

790 for nested in message_type.nested_types: 

791 AddExtensionForNested(nested) 

792 for extension in message_type.extensions: 

793 self._AddExtensionDescriptor(extension) 

794 

795 file_desc = self._file_descriptors[file_proto.name] 

796 for extension in file_desc.extensions_by_name.values(): 

797 self._AddExtensionDescriptor(extension) 

798 for message_type in file_desc.message_types_by_name.values(): 

799 AddExtensionForNested(message_type) 

800 

801 return file_desc 

802 

803 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, 

804 scope=None, syntax=None): 

805 """Adds the proto to the pool in the specified package. 

806 

807 Args: 

808 desc_proto: The descriptor_pb2.DescriptorProto protobuf message. 

809 package: The package the proto should be located in. 

810 file_desc: The file containing this message. 

811 scope: Dict mapping short and full symbols to message and enum types. 

812 syntax: string indicating syntax of the file ("proto2" or "proto3") 

813 

814 Returns: 

815 The added descriptor. 

816 """ 

817 

818 if package: 

819 desc_name = '.'.join((package, desc_proto.name)) 

820 else: 

821 desc_name = desc_proto.name 

822 

823 if file_desc is None: 

824 file_name = None 

825 else: 

826 file_name = file_desc.name 

827 

828 if scope is None: 

829 scope = {} 

830 

831 nested = [ 

832 self._ConvertMessageDescriptor( 

833 nested, desc_name, file_desc, scope, syntax) 

834 for nested in desc_proto.nested_type] 

835 enums = [ 

836 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, 

837 scope, False) 

838 for enum in desc_proto.enum_type] 

839 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) 

840 for index, field in enumerate(desc_proto.field)] 

841 extensions = [ 

842 self._MakeFieldDescriptor(extension, desc_name, index, file_desc, 

843 is_extension=True) 

844 for index, extension in enumerate(desc_proto.extension)] 

845 oneofs = [ 

846 # pylint: disable=g-complex-comprehension 

847 descriptor.OneofDescriptor( 

848 desc.name, 

849 '.'.join((desc_name, desc.name)), 

850 index, 

851 None, 

852 [], 

853 _OptionsOrNone(desc), 

854 # pylint: disable=protected-access 

855 create_key=descriptor._internal_create_key) 

856 for index, desc in enumerate(desc_proto.oneof_decl) 

857 ] 

858 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] 

859 if extension_ranges: 

860 is_extendable = True 

861 else: 

862 is_extendable = False 

863 desc = descriptor.Descriptor( 

864 name=desc_proto.name, 

865 full_name=desc_name, 

866 filename=file_name, 

867 containing_type=None, 

868 fields=fields, 

869 oneofs=oneofs, 

870 nested_types=nested, 

871 enum_types=enums, 

872 extensions=extensions, 

873 options=_OptionsOrNone(desc_proto), 

874 is_extendable=is_extendable, 

875 extension_ranges=extension_ranges, 

876 file=file_desc, 

877 serialized_start=None, 

878 serialized_end=None, 

879 syntax=syntax, 

880 is_map_entry=desc_proto.options.map_entry, 

881 # pylint: disable=protected-access 

882 create_key=descriptor._internal_create_key) 

883 for nested in desc.nested_types: 

884 nested.containing_type = desc 

885 for enum in desc.enum_types: 

886 enum.containing_type = desc 

887 for field_index, field_desc in enumerate(desc_proto.field): 

888 if field_desc.HasField('oneof_index'): 

889 oneof_index = field_desc.oneof_index 

890 oneofs[oneof_index].fields.append(fields[field_index]) 

891 fields[field_index].containing_oneof = oneofs[oneof_index] 

892 

893 scope[_PrefixWithDot(desc_name)] = desc 

894 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 

895 self._descriptors[desc_name] = desc 

896 return desc 

897 

898 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, 

899 containing_type=None, scope=None, top_level=False): 

900 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. 

901 

902 Args: 

903 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. 

904 package: Optional package name for the new message EnumDescriptor. 

905 file_desc: The file containing the enum descriptor. 

906 containing_type: The type containing this enum. 

907 scope: Scope containing available types. 

908 top_level: If True, the enum is a top level symbol. If False, the enum 

909 is defined inside a message. 

910 

911 Returns: 

912 The added descriptor 

913 """ 

914 

915 if package: 

916 enum_name = '.'.join((package, enum_proto.name)) 

917 else: 

918 enum_name = enum_proto.name 

919 

920 if file_desc is None: 

921 file_name = None 

922 else: 

923 file_name = file_desc.name 

924 

925 values = [self._MakeEnumValueDescriptor(value, index) 

926 for index, value in enumerate(enum_proto.value)] 

927 desc = descriptor.EnumDescriptor(name=enum_proto.name, 

928 full_name=enum_name, 

929 filename=file_name, 

930 file=file_desc, 

931 values=values, 

932 containing_type=containing_type, 

933 options=_OptionsOrNone(enum_proto), 

934 # pylint: disable=protected-access 

935 create_key=descriptor._internal_create_key) 

936 scope['.%s' % enum_name] = desc 

937 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 

938 self._enum_descriptors[enum_name] = desc 

939 

940 # Add top level enum values. 

941 if top_level: 

942 for value in values: 

943 full_name = _NormalizeFullyQualifiedName( 

944 '.'.join((package, value.name))) 

945 self._CheckConflictRegister(value, full_name, file_name) 

946 self._top_enum_values[full_name] = value 

947 

948 return desc 

949 

950 def _MakeFieldDescriptor(self, field_proto, message_name, index, 

951 file_desc, is_extension=False): 

952 """Creates a field descriptor from a FieldDescriptorProto. 

953 

954 For message and enum type fields, this method will do a look up 

955 in the pool for the appropriate descriptor for that type. If it 

956 is unavailable, it will fall back to the _source function to 

957 create it. If this type is still unavailable, construction will 

958 fail. 

959 

960 Args: 

961 field_proto: The proto describing the field. 

962 message_name: The name of the containing message. 

963 index: Index of the field 

964 file_desc: The file containing the field descriptor. 

965 is_extension: Indication that this field is for an extension. 

966 

967 Returns: 

968 An initialized FieldDescriptor object 

969 """ 

970 

971 if message_name: 

972 full_name = '.'.join((message_name, field_proto.name)) 

973 else: 

974 full_name = field_proto.name 

975 

976 if field_proto.json_name: 

977 json_name = field_proto.json_name 

978 else: 

979 json_name = None 

980 

981 return descriptor.FieldDescriptor( 

982 name=field_proto.name, 

983 full_name=full_name, 

984 index=index, 

985 number=field_proto.number, 

986 type=field_proto.type, 

987 cpp_type=None, 

988 message_type=None, 

989 enum_type=None, 

990 containing_type=None, 

991 label=field_proto.label, 

992 has_default_value=False, 

993 default_value=None, 

994 is_extension=is_extension, 

995 extension_scope=None, 

996 options=_OptionsOrNone(field_proto), 

997 json_name=json_name, 

998 file=file_desc, 

999 # pylint: disable=protected-access 

1000 create_key=descriptor._internal_create_key) 

1001 

1002 def _SetAllFieldTypes(self, package, desc_proto, scope): 

1003 """Sets all the descriptor's fields's types. 

1004 

1005 This method also sets the containing types on any extensions. 

1006 

1007 Args: 

1008 package: The current package of desc_proto. 

1009 desc_proto: The message descriptor to update. 

1010 scope: Enclosing scope of available types. 

1011 """ 

1012 

1013 package = _PrefixWithDot(package) 

1014 

1015 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) 

1016 

1017 if package == '.': 

1018 nested_package = _PrefixWithDot(desc_proto.name) 

1019 else: 

1020 nested_package = '.'.join([package, desc_proto.name]) 

1021 

1022 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): 

1023 self._SetFieldType(field_proto, field_desc, nested_package, scope) 

1024 

1025 for extension_proto, extension_desc in ( 

1026 zip(desc_proto.extension, main_desc.extensions)): 

1027 extension_desc.containing_type = self._GetTypeFromScope( 

1028 nested_package, extension_proto.extendee, scope) 

1029 self._SetFieldType(extension_proto, extension_desc, nested_package, scope) 

1030 

1031 for nested_type in desc_proto.nested_type: 

1032 self._SetAllFieldTypes(nested_package, nested_type, scope) 

1033 

1034 def _SetFieldType(self, field_proto, field_desc, package, scope): 

1035 """Sets the field's type, cpp_type, message_type and enum_type. 

1036 

1037 Args: 

1038 field_proto: Data about the field in proto format. 

1039 field_desc: The descriptor to modify. 

1040 package: The package the field's container is in. 

1041 scope: Enclosing scope of available types. 

1042 """ 

1043 if field_proto.type_name: 

1044 desc = self._GetTypeFromScope(package, field_proto.type_name, scope) 

1045 else: 

1046 desc = None 

1047 

1048 if not field_proto.HasField('type'): 

1049 if isinstance(desc, descriptor.Descriptor): 

1050 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE 

1051 else: 

1052 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM 

1053 

1054 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( 

1055 field_proto.type) 

1056 

1057 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE 

1058 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): 

1059 field_desc.message_type = desc 

1060 

1061 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 

1062 field_desc.enum_type = desc 

1063 

1064 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: 

1065 field_desc.has_default_value = False 

1066 field_desc.default_value = [] 

1067 elif field_proto.HasField('default_value'): 

1068 field_desc.has_default_value = True 

1069 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 

1070 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 

1071 field_desc.default_value = float(field_proto.default_value) 

1072 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 

1073 field_desc.default_value = field_proto.default_value 

1074 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 

1075 field_desc.default_value = field_proto.default_value.lower() == 'true' 

1076 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 

1077 field_desc.default_value = field_desc.enum_type.values_by_name[ 

1078 field_proto.default_value].number 

1079 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 

1080 field_desc.default_value = text_encoding.CUnescape( 

1081 field_proto.default_value) 

1082 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 

1083 field_desc.default_value = None 

1084 else: 

1085 # All other types are of the "int" type. 

1086 field_desc.default_value = int(field_proto.default_value) 

1087 else: 

1088 field_desc.has_default_value = False 

1089 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 

1090 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 

1091 field_desc.default_value = 0.0 

1092 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 

1093 field_desc.default_value = u'' 

1094 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 

1095 field_desc.default_value = False 

1096 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 

1097 field_desc.default_value = field_desc.enum_type.values[0].number 

1098 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 

1099 field_desc.default_value = b'' 

1100 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 

1101 field_desc.default_value = None 

1102 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP: 

1103 field_desc.default_value = None 

1104 else: 

1105 # All other types are of the "int" type. 

1106 field_desc.default_value = 0 

1107 

1108 field_desc.type = field_proto.type 

1109 

1110 def _MakeEnumValueDescriptor(self, value_proto, index): 

1111 """Creates a enum value descriptor object from a enum value proto. 

1112 

1113 Args: 

1114 value_proto: The proto describing the enum value. 

1115 index: The index of the enum value. 

1116 

1117 Returns: 

1118 An initialized EnumValueDescriptor object. 

1119 """ 

1120 

1121 return descriptor.EnumValueDescriptor( 

1122 name=value_proto.name, 

1123 index=index, 

1124 number=value_proto.number, 

1125 options=_OptionsOrNone(value_proto), 

1126 type=None, 

1127 # pylint: disable=protected-access 

1128 create_key=descriptor._internal_create_key) 

1129 

1130 def _MakeServiceDescriptor(self, service_proto, service_index, scope, 

1131 package, file_desc): 

1132 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. 

1133 

1134 Args: 

1135 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. 

1136 service_index: The index of the service in the File. 

1137 scope: Dict mapping short and full symbols to message and enum types. 

1138 package: Optional package name for the new message EnumDescriptor. 

1139 file_desc: The file containing the service descriptor. 

1140 

1141 Returns: 

1142 The added descriptor. 

1143 """ 

1144 

1145 if package: 

1146 service_name = '.'.join((package, service_proto.name)) 

1147 else: 

1148 service_name = service_proto.name 

1149 

1150 methods = [self._MakeMethodDescriptor(method_proto, service_name, package, 

1151 scope, index) 

1152 for index, method_proto in enumerate(service_proto.method)] 

1153 desc = descriptor.ServiceDescriptor( 

1154 name=service_proto.name, 

1155 full_name=service_name, 

1156 index=service_index, 

1157 methods=methods, 

1158 options=_OptionsOrNone(service_proto), 

1159 file=file_desc, 

1160 # pylint: disable=protected-access 

1161 create_key=descriptor._internal_create_key) 

1162 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 

1163 self._service_descriptors[service_name] = desc 

1164 return desc 

1165 

1166 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, 

1167 index): 

1168 """Creates a method descriptor from a MethodDescriptorProto. 

1169 

1170 Args: 

1171 method_proto: The proto describing the method. 

1172 service_name: The name of the containing service. 

1173 package: Optional package name to look up for types. 

1174 scope: Scope containing available types. 

1175 index: Index of the method in the service. 

1176 

1177 Returns: 

1178 An initialized MethodDescriptor object. 

1179 """ 

1180 full_name = '.'.join((service_name, method_proto.name)) 

1181 input_type = self._GetTypeFromScope( 

1182 package, method_proto.input_type, scope) 

1183 output_type = self._GetTypeFromScope( 

1184 package, method_proto.output_type, scope) 

1185 return descriptor.MethodDescriptor( 

1186 name=method_proto.name, 

1187 full_name=full_name, 

1188 index=index, 

1189 containing_service=None, 

1190 input_type=input_type, 

1191 output_type=output_type, 

1192 client_streaming=method_proto.client_streaming, 

1193 server_streaming=method_proto.server_streaming, 

1194 options=_OptionsOrNone(method_proto), 

1195 # pylint: disable=protected-access 

1196 create_key=descriptor._internal_create_key) 

1197 

1198 def _ExtractSymbols(self, descriptors): 

1199 """Pulls out all the symbols from descriptor protos. 

1200 

1201 Args: 

1202 descriptors: The messages to extract descriptors from. 

1203 Yields: 

1204 A two element tuple of the type name and descriptor object. 

1205 """ 

1206 

1207 for desc in descriptors: 

1208 yield (_PrefixWithDot(desc.full_name), desc) 

1209 for symbol in self._ExtractSymbols(desc.nested_types): 

1210 yield symbol 

1211 for enum in desc.enum_types: 

1212 yield (_PrefixWithDot(enum.full_name), enum) 

1213 

1214 def _GetDeps(self, dependencies, visited=None): 

1215 """Recursively finds dependencies for file protos. 

1216 

1217 Args: 

1218 dependencies: The names of the files being depended on. 

1219 visited: The names of files already found. 

1220 

1221 Yields: 

1222 Each direct and indirect dependency. 

1223 """ 

1224 

1225 visited = visited or set() 

1226 for dependency in dependencies: 

1227 if dependency not in visited: 

1228 visited.add(dependency) 

1229 dep_desc = self.FindFileByName(dependency) 

1230 yield dep_desc 

1231 public_files = [d.name for d in dep_desc.public_dependencies] 

1232 yield from self._GetDeps(public_files, visited) 

1233 

1234 def _GetTypeFromScope(self, package, type_name, scope): 

1235 """Finds a given type name in the current scope. 

1236 

1237 Args: 

1238 package: The package the proto should be located in. 

1239 type_name: The name of the type to be found in the scope. 

1240 scope: Dict mapping short and full symbols to message and enum types. 

1241 

1242 Returns: 

1243 The descriptor for the requested type. 

1244 """ 

1245 if type_name not in scope: 

1246 components = _PrefixWithDot(package).split('.') 

1247 while components: 

1248 possible_match = '.'.join(components + [type_name]) 

1249 if possible_match in scope: 

1250 type_name = possible_match 

1251 break 

1252 else: 

1253 components.pop(-1) 

1254 return scope[type_name] 

1255 

1256 

1257def _PrefixWithDot(name): 

1258 return name if name.startswith('.') else '.%s' % name 

1259 

1260 

1261if _USE_C_DESCRIPTORS: 

1262 # TODO: This pool could be constructed from Python code, when we 

1263 # support a flag like 'use_cpp_generated_pool=True'. 

1264 # pylint: disable=protected-access 

1265 _DEFAULT = descriptor._message.default_pool 

1266else: 

1267 _DEFAULT = DescriptorPool() 

1268 

1269 

1270def Default(): 

1271 return _DEFAULT