Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding_base.py: 27%

62 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base Class for TPU Embeddings Mid level APIs.""" 

16 

17import functools 

18from typing import Any, Dict, Iterable, Optional, Union, Text 

19 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.ops import variables as tf_variables 

22from tensorflow.python.tpu import tpu_embedding_v2_utils 

23from tensorflow.python.trackable import autotrackable 

24from tensorflow.python.util import nest 

25 

26 

27class TPUEmbeddingBase(autotrackable.AutoTrackable): 

28 """The TPUEmbedding Base class. 

29 

30 This class only contains the basic logic to check the feature config and table 

31 config for the tpu embedding mid level APIs. 

32 """ 

33 

34 def __init__( 

35 self, 

36 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 

37 optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None): # pylint:disable=protected-access 

38 """Creates the TPUEmbeddingBase object.""" 

39 self._feature_config = feature_config 

40 self._output_shapes = [] 

41 for feature in nest.flatten(feature_config): 

42 self._output_shapes.append(feature.output_shape) 

43 # Set table order here to the order of the first occurrence of the table in 

44 # a feature provided by the user. The order of this struct must be fixed 

45 # to provide the user with deterministic behavior over multiple 

46 # instantiations. 

47 self._table_config = [] 

48 for feature in nest.flatten(feature_config): 

49 if feature.table not in self._table_config: 

50 self._table_config.append(feature.table) 

51 

52 # Ensure tables have unique names. Also error check the optimizer as we 

53 # specifically don't do that in the TableConfig class to allow high level 

54 # APIs that are built on this to use strings/other classes to represent 

55 # optimizers (before they are passed to this class). 

56 table_names = [] 

57 for i, table in enumerate(self._table_config): 

58 if table.optimizer is None: 

59 # TODO(bfontain) Should we allow some sort of optimizer merging here? 

60 table.optimizer = optimizer 

61 if (table.optimizer is not None and 

62 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 

63 raise ValueError("{} is an unsupported optimizer class. Please pass an " 

64 "instance of one of the optimizer classes under " 

65 "tf.tpu.experimental.embedding.".format( 

66 type(table.optimizer))) 

67 if table.name is None: 

68 table.name = "table_{}".format(i) 

69 if table.name in table_names: 

70 raise ValueError("Tables must have a unique name. " 

71 f"Multiple tables with name {table.name} found.") 

72 table_names.append(table.name) 

73 

74 self._built = False 

75 

76 @property 

77 def embedding_tables(self): 

78 """Returns a dict of embedding tables, keyed by `TableConfig`.""" 

79 raise NotImplementedError 

80 

81 def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig, 

82 trainable: bool) -> Dict[Text, tf_variables.Variable]: 

83 """Create all variables including table variables and slot variables.""" 

84 variable_shape = (table.vocabulary_size, table.dim) 

85 

86 def getter(name, shape, dtype, initializer, trainable): 

87 del shape 

88 # _add_variable_with_custom_getter clears the shape sometimes, so we 

89 # take the global shape from outside the getter. 

90 initial_value = functools.partial( 

91 initializer, variable_shape, dtype=dtype) 

92 return tf_variables.Variable( 

93 name=name, 

94 initial_value=initial_value, 

95 shape=variable_shape, 

96 dtype=dtype, 

97 trainable=trainable) 

98 

99 def variable_creator(name, initializer, trainable=True): 

100 # Use add_variable_with_custom_getter here so that we take advantage of 

101 # the checkpoint loading to allow restore before the variables get 

102 # created which avoids double initialization. 

103 return self._add_variable_with_custom_getter( 

104 name=name, 

105 initializer=initializer, 

106 shape=variable_shape, 

107 dtype=dtypes.float32, 

108 getter=getter, 

109 trainable=trainable) 

110 

111 parameters = variable_creator( 

112 table.name, table.initializer, trainable=trainable) 

113 

114 def slot_creator(name, initializer): 

115 return variable_creator(table.name + "/" + name, initializer, False) 

116 

117 if table.optimizer is not None: 

118 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 

119 else: 

120 slot_vars = {} 

121 slot_vars["parameters"] = parameters 

122 return slot_vars 

123 

124 def _create_variables_and_slots(self): 

125 """Create variables and slots variables for TPU embeddings.""" 

126 raise NotImplementedError 

127 

128 def build(self): 

129 """Create variables and slots variables for TPU embeddings.""" 

130 if self._built: 

131 return 

132 self._variables = self._create_variables_and_slots() 

133 self._built = True 

134 

135 def __call__(self, features: Any, weights: Optional[Any] = None) -> Any: 

136 """Call the mid level api to do embedding lookup.""" 

137 if not self._built: 

138 self.build() 

139 return self.embedding_lookup(features, weights) 

140 

141 def embedding_lookup(self, 

142 features: Any, 

143 weights: Optional[Any] = None) -> Any: 

144 """Lookup the embedding table using the input features.""" 

145 raise NotImplementedError