Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding_base.py: 27%
62 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base Class for TPU Embeddings Mid level APIs."""
17import functools
18from typing import Any, Dict, Iterable, Optional, Union, Text
20from tensorflow.python.framework import dtypes
21from tensorflow.python.ops import variables as tf_variables
22from tensorflow.python.tpu import tpu_embedding_v2_utils
23from tensorflow.python.trackable import autotrackable
24from tensorflow.python.util import nest
27class TPUEmbeddingBase(autotrackable.AutoTrackable):
28 """The TPUEmbedding Base class.
30 This class only contains the basic logic to check the feature config and table
31 config for the tpu embedding mid level APIs.
32 """
34 def __init__(
35 self,
36 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
37 optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None): # pylint:disable=protected-access
38 """Creates the TPUEmbeddingBase object."""
39 self._feature_config = feature_config
40 self._output_shapes = []
41 for feature in nest.flatten(feature_config):
42 self._output_shapes.append(feature.output_shape)
43 # Set table order here to the order of the first occurrence of the table in
44 # a feature provided by the user. The order of this struct must be fixed
45 # to provide the user with deterministic behavior over multiple
46 # instantiations.
47 self._table_config = []
48 for feature in nest.flatten(feature_config):
49 if feature.table not in self._table_config:
50 self._table_config.append(feature.table)
52 # Ensure tables have unique names. Also error check the optimizer as we
53 # specifically don't do that in the TableConfig class to allow high level
54 # APIs that are built on this to use strings/other classes to represent
55 # optimizers (before they are passed to this class).
56 table_names = []
57 for i, table in enumerate(self._table_config):
58 if table.optimizer is None:
59 # TODO(bfontain) Should we allow some sort of optimizer merging here?
60 table.optimizer = optimizer
61 if (table.optimizer is not None and
62 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access
63 raise ValueError("{} is an unsupported optimizer class. Please pass an "
64 "instance of one of the optimizer classes under "
65 "tf.tpu.experimental.embedding.".format(
66 type(table.optimizer)))
67 if table.name is None:
68 table.name = "table_{}".format(i)
69 if table.name in table_names:
70 raise ValueError("Tables must have a unique name. "
71 f"Multiple tables with name {table.name} found.")
72 table_names.append(table.name)
74 self._built = False
76 @property
77 def embedding_tables(self):
78 """Returns a dict of embedding tables, keyed by `TableConfig`."""
79 raise NotImplementedError
81 def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig,
82 trainable: bool) -> Dict[Text, tf_variables.Variable]:
83 """Create all variables including table variables and slot variables."""
84 variable_shape = (table.vocabulary_size, table.dim)
86 def getter(name, shape, dtype, initializer, trainable):
87 del shape
88 # _add_variable_with_custom_getter clears the shape sometimes, so we
89 # take the global shape from outside the getter.
90 initial_value = functools.partial(
91 initializer, variable_shape, dtype=dtype)
92 return tf_variables.Variable(
93 name=name,
94 initial_value=initial_value,
95 shape=variable_shape,
96 dtype=dtype,
97 trainable=trainable)
99 def variable_creator(name, initializer, trainable=True):
100 # Use add_variable_with_custom_getter here so that we take advantage of
101 # the checkpoint loading to allow restore before the variables get
102 # created which avoids double initialization.
103 return self._add_variable_with_custom_getter(
104 name=name,
105 initializer=initializer,
106 shape=variable_shape,
107 dtype=dtypes.float32,
108 getter=getter,
109 trainable=trainable)
111 parameters = variable_creator(
112 table.name, table.initializer, trainable=trainable)
114 def slot_creator(name, initializer):
115 return variable_creator(table.name + "/" + name, initializer, False)
117 if table.optimizer is not None:
118 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
119 else:
120 slot_vars = {}
121 slot_vars["parameters"] = parameters
122 return slot_vars
124 def _create_variables_and_slots(self):
125 """Create variables and slots variables for TPU embeddings."""
126 raise NotImplementedError
128 def build(self):
129 """Create variables and slots variables for TPU embeddings."""
130 if self._built:
131 return
132 self._variables = self._create_variables_and_slots()
133 self._built = True
135 def __call__(self, features: Any, weights: Optional[Any] = None) -> Any:
136 """Call the mid level api to do embedding lookup."""
137 if not self._built:
138 self.build()
139 return self.embedding_lookup(features, weights)
141 def embedding_lookup(self,
142 features: Any,
143 weights: Optional[Any] = None) -> Any:
144 """Lookup the embedding table using the input features."""
145 raise NotImplementedError