1# Protocol Buffers - Google's data interchange format 
    2# Copyright 2008 Google Inc.  All rights reserved. 
    3# 
    4# Use of this source code is governed by a BSD-style 
    5# license that can be found in the LICENSE file or at 
    6# https://developers.google.com/open-source/licenses/bsd 
    7 
    8"""Contains FieldMask class.""" 
    9 
    10from google.protobuf.descriptor import FieldDescriptor 
    11 
    12 
    13class FieldMask(object): 
    14  """Class for FieldMask message type.""" 
    15 
    16  __slots__ = () 
    17 
    18  def ToJsonString(self): 
    19    """Converts FieldMask to string according to proto3 JSON spec.""" 
    20    camelcase_paths = [] 
    21    for path in self.paths: 
    22      camelcase_paths.append(_SnakeCaseToCamelCase(path)) 
    23    return ','.join(camelcase_paths) 
    24 
    25  def FromJsonString(self, value): 
    26    """Converts string to FieldMask according to proto3 JSON spec.""" 
    27    if not isinstance(value, str): 
    28      raise ValueError('FieldMask JSON value not a string: {!r}'.format(value)) 
    29    self.Clear() 
    30    if value: 
    31      for path in value.split(','): 
    32        self.paths.append(_CamelCaseToSnakeCase(path)) 
    33 
    34  def IsValidForDescriptor(self, message_descriptor): 
    35    """Checks whether the FieldMask is valid for Message Descriptor.""" 
    36    for path in self.paths: 
    37      if not _IsValidPath(message_descriptor, path): 
    38        return False 
    39    return True 
    40 
    41  def AllFieldsFromDescriptor(self, message_descriptor): 
    42    """Gets all direct fields of Message Descriptor to FieldMask.""" 
    43    self.Clear() 
    44    for field in message_descriptor.fields: 
    45      self.paths.append(field.name) 
    46 
    47  def CanonicalFormFromMask(self, mask): 
    48    """Converts a FieldMask to the canonical form. 
    49 
    50    Removes paths that are covered by another path. For example, 
    51    "foo.bar" is covered by "foo" and will be removed if "foo" 
    52    is also in the FieldMask. Then sorts all paths in alphabetical order. 
    53 
    54    Args: 
    55      mask: The original FieldMask to be converted. 
    56    """ 
    57    tree = _FieldMaskTree(mask) 
    58    tree.ToFieldMask(self) 
    59 
    60  def Union(self, mask1, mask2): 
    61    """Merges mask1 and mask2 into this FieldMask.""" 
    62    _CheckFieldMaskMessage(mask1) 
    63    _CheckFieldMaskMessage(mask2) 
    64    tree = _FieldMaskTree(mask1) 
    65    tree.MergeFromFieldMask(mask2) 
    66    tree.ToFieldMask(self) 
    67 
    68  def Intersect(self, mask1, mask2): 
    69    """Intersects mask1 and mask2 into this FieldMask.""" 
    70    _CheckFieldMaskMessage(mask1) 
    71    _CheckFieldMaskMessage(mask2) 
    72    tree = _FieldMaskTree(mask1) 
    73    intersection = _FieldMaskTree() 
    74    for path in mask2.paths: 
    75      tree.IntersectPath(path, intersection) 
    76    intersection.ToFieldMask(self) 
    77 
    78  def MergeMessage( 
    79      self, source, destination, 
    80      replace_message_field=False, replace_repeated_field=False): 
    81    """Merges fields specified in FieldMask from source to destination. 
    82 
    83    Args: 
    84      source: Source message. 
    85      destination: The destination message to be merged into. 
    86      replace_message_field: Replace message field if True. Merge message 
    87          field if False. 
    88      replace_repeated_field: Replace repeated field if True. Append 
    89          elements of repeated field if False. 
    90    """ 
    91    tree = _FieldMaskTree(self) 
    92    tree.MergeMessage( 
    93        source, destination, replace_message_field, replace_repeated_field) 
    94 
    95 
    96def _IsValidPath(message_descriptor, path): 
    97  """Checks whether the path is valid for Message Descriptor.""" 
    98  parts = path.split('.') 
    99  last = parts.pop() 
    100  for name in parts: 
    101    field = message_descriptor.fields_by_name.get(name) 
    102    if (field is None or 
    103        field.is_repeated or 
    104        field.type != FieldDescriptor.TYPE_MESSAGE): 
    105      return False 
    106    message_descriptor = field.message_type 
    107  return last in message_descriptor.fields_by_name 
    108 
    109 
    110def _CheckFieldMaskMessage(message): 
    111  """Raises ValueError if message is not a FieldMask.""" 
    112  message_descriptor = message.DESCRIPTOR 
    113  if (message_descriptor.name != 'FieldMask' or 
    114      message_descriptor.file.name != 'google/protobuf/field_mask.proto'): 
    115    raise ValueError('Message {0} is not a FieldMask.'.format( 
    116        message_descriptor.full_name)) 
    117 
    118 
    119def _SnakeCaseToCamelCase(path_name): 
    120  """Converts a path name from snake_case to camelCase.""" 
    121  result = [] 
    122  after_underscore = False 
    123  for c in path_name: 
    124    if c.isupper(): 
    125      raise ValueError( 
    126          'Fail to print FieldMask to Json string: Path name ' 
    127          '{0} must not contain uppercase letters.'.format(path_name)) 
    128    if after_underscore: 
    129      if c.islower(): 
    130        result.append(c.upper()) 
    131        after_underscore = False 
    132      else: 
    133        raise ValueError( 
    134            'Fail to print FieldMask to Json string: The ' 
    135            'character after a "_" must be a lowercase letter ' 
    136            'in path name {0}.'.format(path_name)) 
    137    elif c == '_': 
    138      after_underscore = True 
    139    else: 
    140      result += c 
    141 
    142  if after_underscore: 
    143    raise ValueError('Fail to print FieldMask to Json string: Trailing "_" ' 
    144                     'in path name {0}.'.format(path_name)) 
    145  return ''.join(result) 
    146 
    147 
    148def _CamelCaseToSnakeCase(path_name): 
    149  """Converts a field name from camelCase to snake_case.""" 
    150  result = [] 
    151  for c in path_name: 
    152    if c == '_': 
    153      raise ValueError('Fail to parse FieldMask: Path name ' 
    154                       '{0} must not contain "_"s.'.format(path_name)) 
    155    if c.isupper(): 
    156      result += '_' 
    157      result += c.lower() 
    158    else: 
    159      result += c 
    160  return ''.join(result) 
    161 
    162 
    163class _FieldMaskTree(object): 
    164  """Represents a FieldMask in a tree structure. 
    165 
    166  For example, given a FieldMask "foo.bar,foo.baz,bar.baz", 
    167  the FieldMaskTree will be: 
    168      [_root] -+- foo -+- bar 
    169            |       | 
    170            |       +- baz 
    171            | 
    172            +- bar --- baz 
    173  In the tree, each leaf node represents a field path. 
    174  """ 
    175 
    176  __slots__ = ('_root',) 
    177 
    178  def __init__(self, field_mask=None): 
    179    """Initializes the tree by FieldMask.""" 
    180    self._root = {} 
    181    if field_mask: 
    182      self.MergeFromFieldMask(field_mask) 
    183 
    184  def MergeFromFieldMask(self, field_mask): 
    185    """Merges a FieldMask to the tree.""" 
    186    for path in field_mask.paths: 
    187      self.AddPath(path) 
    188 
    189  def AddPath(self, path): 
    190    """Adds a field path into the tree. 
    191 
    192    If the field path to add is a sub-path of an existing field path 
    193    in the tree (i.e., a leaf node), it means the tree already matches 
    194    the given path so nothing will be added to the tree. If the path 
    195    matches an existing non-leaf node in the tree, that non-leaf node 
    196    will be turned into a leaf node with all its children removed because 
    197    the path matches all the node's children. Otherwise, a new path will 
    198    be added. 
    199 
    200    Args: 
    201      path: The field path to add. 
    202    """ 
    203    node = self._root 
    204    for name in path.split('.'): 
    205      if name not in node: 
    206        node[name] = {} 
    207      elif not node[name]: 
    208        # Pre-existing empty node implies we already have this entire tree. 
    209        return 
    210      node = node[name] 
    211    # Remove any sub-trees we might have had. 
    212    node.clear() 
    213 
    214  def ToFieldMask(self, field_mask): 
    215    """Converts the tree to a FieldMask.""" 
    216    field_mask.Clear() 
    217    _AddFieldPaths(self._root, '', field_mask) 
    218 
    219  def IntersectPath(self, path, intersection): 
    220    """Calculates the intersection part of a field path with this tree. 
    221 
    222    Args: 
    223      path: The field path to calculates. 
    224      intersection: The out tree to record the intersection part. 
    225    """ 
    226    node = self._root 
    227    for name in path.split('.'): 
    228      if name not in node: 
    229        return 
    230      elif not node[name]: 
    231        intersection.AddPath(path) 
    232        return 
    233      node = node[name] 
    234    intersection.AddLeafNodes(path, node) 
    235 
    236  def AddLeafNodes(self, prefix, node): 
    237    """Adds leaf nodes begin with prefix to this tree.""" 
    238    if not node: 
    239      self.AddPath(prefix) 
    240    for name in node: 
    241      child_path = prefix + '.' + name 
    242      self.AddLeafNodes(child_path, node[name]) 
    243 
    244  def MergeMessage( 
    245      self, source, destination, 
    246      replace_message, replace_repeated): 
    247    """Merge all fields specified by this tree from source to destination.""" 
    248    _MergeMessage( 
    249        self._root, source, destination, replace_message, replace_repeated) 
    250 
    251 
    252def _StrConvert(value): 
    253  """Converts value to str if it is not.""" 
    254  # This file is imported by c extension and some methods like ClearField 
    255  # requires string for the field name. py2/py3 has different text 
    256  # type and may use unicode. 
    257  if not isinstance(value, str): 
    258    return value.encode('utf-8') 
    259  return value 
    260 
    261 
    262def _MergeMessage( 
    263    node, source, destination, replace_message, replace_repeated): 
    264  """Merge all fields specified by a sub-tree from source to destination.""" 
    265  source_descriptor = source.DESCRIPTOR 
    266  for name in node: 
    267    child = node[name] 
    268    field = source_descriptor.fields_by_name[name] 
    269    if field is None: 
    270      raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( 
    271          name, source_descriptor.full_name)) 
    272    if child: 
    273      # Sub-paths are only allowed for singular message fields. 
    274      if (field.is_repeated or 
    275          field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): 
    276        raise ValueError('Error: Field {0} in message {1} is not a singular ' 
    277                         'message field and cannot have sub-fields.'.format( 
    278                             name, source_descriptor.full_name)) 
    279      if source.HasField(name): 
    280        _MergeMessage( 
    281            child, getattr(source, name), getattr(destination, name), 
    282            replace_message, replace_repeated) 
    283      continue 
    284    if field.is_repeated: 
    285      if replace_repeated: 
    286        destination.ClearField(_StrConvert(name)) 
    287      repeated_source = getattr(source, name) 
    288      repeated_destination = getattr(destination, name) 
    289      repeated_destination.MergeFrom(repeated_source) 
    290    else: 
    291      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: 
    292        if replace_message: 
    293          destination.ClearField(_StrConvert(name)) 
    294        if source.HasField(name): 
    295          getattr(destination, name).MergeFrom(getattr(source, name)) 
    296      elif not field.has_presence or source.HasField(name): 
    297        setattr(destination, name, getattr(source, name)) 
    298      else: 
    299        destination.ClearField(_StrConvert(name)) 
    300 
    301 
    302def _AddFieldPaths(node, prefix, field_mask): 
    303  """Adds the field paths descended from node to field_mask.""" 
    304  if not node and prefix: 
    305    field_mask.paths.append(prefix) 
    306    return 
    307  for name in sorted(node): 
    308    if prefix: 
    309      child_path = prefix + '.' + name 
    310    else: 
    311      child_path = name 
    312    _AddFieldPaths(node[name], child_path, field_mask)