1# Protocol Buffers - Google's data interchange format 
    2# Copyright 2008 Google Inc.  All rights reserved. 
    3# 
    4# Use of this source code is governed by a BSD-style 
    5# license that can be found in the LICENSE file or at 
    6# https://developers.google.com/open-source/licenses/bsd 
    7 
    8"""A database of Python protocol buffer generated symbols. 
    9 
    10SymbolDatabase is the MessageFactory for messages generated at compile time, 
    11and makes it easy to create new instances of a registered type, given only the 
    12type's protocol buffer symbol name. 
    13 
    14Example usage:: 
    15 
    16  db = symbol_database.SymbolDatabase() 
    17 
    18  # Register symbols of interest, from one or multiple files. 
    19  db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR) 
    20  db.RegisterMessage(my_proto_pb2.MyMessage) 
    21  db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR) 
    22 
    23  # The database can be used as a MessageFactory, to generate types based on 
    24  # their name: 
    25  types = db.GetMessages(['my_proto.proto']) 
    26  my_message_instance = types['MyMessage']() 
    27 
    28  # The database's underlying descriptor pool can be queried, so it's not 
    29  # necessary to know a type's filename to be able to generate it: 
    30  filename = db.pool.FindFileContainingSymbol('MyMessage') 
    31  my_message_instance = db.GetMessages([filename])['MyMessage']() 
    32 
    33  # This functionality is also provided directly via a convenience method: 
    34  my_message_instance = db.GetSymbol('MyMessage')() 
    35""" 
    36 
    37import warnings 
    38 
    39from google.protobuf.internal import api_implementation 
    40from google.protobuf import descriptor_pool 
    41from google.protobuf import message_factory 
    42 
    43 
    44class SymbolDatabase(): 
    45  """A database of Python generated symbols.""" 
    46 
    47  # local cache of registered classes. 
    48  _classes = {} 
    49 
    50  def __init__(self, pool=None): 
    51    """Initializes a new SymbolDatabase.""" 
    52    self.pool = pool or descriptor_pool.DescriptorPool() 
    53 
    54  def RegisterMessage(self, message): 
    55    """Registers the given message type in the local database. 
    56 
    57    Calls to GetSymbol() and GetMessages() will return messages registered here. 
    58 
    59    Args: 
    60      message: A :class:`google.protobuf.message.Message` subclass (or 
    61        instance); its descriptor will be registered. 
    62 
    63    Returns: 
    64      The provided message. 
    65    """ 
    66 
    67    desc = message.DESCRIPTOR 
    68    self._classes[desc] = message 
    69    self.RegisterMessageDescriptor(desc) 
    70    return message 
    71 
    72  def RegisterMessageDescriptor(self, message_descriptor): 
    73    """Registers the given message descriptor in the local database. 
    74 
    75    Args: 
    76      message_descriptor (Descriptor): the message descriptor to add. 
    77    """ 
    78    if api_implementation.Type() == 'python': 
    79      # pylint: disable=protected-access 
    80      self.pool._AddDescriptor(message_descriptor) 
    81 
    82  def RegisterEnumDescriptor(self, enum_descriptor): 
    83    """Registers the given enum descriptor in the local database. 
    84 
    85    Args: 
    86      enum_descriptor (EnumDescriptor): The enum descriptor to register. 
    87 
    88    Returns: 
    89      EnumDescriptor: The provided descriptor. 
    90    """ 
    91    if api_implementation.Type() == 'python': 
    92      # pylint: disable=protected-access 
    93      self.pool._AddEnumDescriptor(enum_descriptor) 
    94    return enum_descriptor 
    95 
    96  def RegisterServiceDescriptor(self, service_descriptor): 
    97    """Registers the given service descriptor in the local database. 
    98 
    99    Args: 
    100      service_descriptor (ServiceDescriptor): the service descriptor to 
    101        register. 
    102    """ 
    103    if api_implementation.Type() == 'python': 
    104      # pylint: disable=protected-access 
    105      self.pool._AddServiceDescriptor(service_descriptor) 
    106 
    107  def RegisterFileDescriptor(self, file_descriptor): 
    108    """Registers the given file descriptor in the local database. 
    109 
    110    Args: 
    111      file_descriptor (FileDescriptor): The file descriptor to register. 
    112    """ 
    113    if api_implementation.Type() == 'python': 
    114      # pylint: disable=protected-access 
    115      self.pool._InternalAddFileDescriptor(file_descriptor) 
    116 
    117  def GetSymbol(self, symbol): 
    118    """Tries to find a symbol in the local database. 
    119 
    120    Currently, this method only returns message.Message instances, however, if 
    121    may be extended in future to support other symbol types. 
    122 
    123    Args: 
    124      symbol (str): a protocol buffer symbol. 
    125 
    126    Returns: 
    127      A Python class corresponding to the symbol. 
    128 
    129    Raises: 
    130      KeyError: if the symbol could not be found. 
    131    """ 
    132 
    133    return self._classes[self.pool.FindMessageTypeByName(symbol)] 
    134 
    135  def GetMessages(self, files): 
    136    # TODO: Fix the differences with MessageFactory. 
    137    """Gets all registered messages from a specified file. 
    138 
    139    Only messages already created and registered will be returned; (this is the 
    140    case for imported _pb2 modules) 
    141    But unlike MessageFactory, this version also returns already defined nested 
    142    messages, but does not register any message extensions. 
    143 
    144    Args: 
    145      files (list[str]): The file names to extract messages from. 
    146 
    147    Returns: 
    148      A dictionary mapping proto names to the message classes. 
    149 
    150    Raises: 
    151      KeyError: if a file could not be found. 
    152    """ 
    153 
    154    def _GetAllMessages(desc): 
    155      """Walk a message Descriptor and recursively yields all message names.""" 
    156      yield desc 
    157      for msg_desc in desc.nested_types: 
    158        for nested_desc in _GetAllMessages(msg_desc): 
    159          yield nested_desc 
    160 
    161    result = {} 
    162    for file_name in files: 
    163      file_desc = self.pool.FindFileByName(file_name) 
    164      for msg_desc in file_desc.message_types_by_name.values(): 
    165        for desc in _GetAllMessages(msg_desc): 
    166          try: 
    167            result[desc.full_name] = self._classes[desc] 
    168          except KeyError: 
    169            # This descriptor has no registered class, skip it. 
    170            pass 
    171    return result 
    172 
    173 
    174_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) 
    175 
    176 
    177def Default(): 
    178  """Returns the default SymbolDatabase.""" 
    179  return _DEFAULT