Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_embedding_v1.py: 21%

142 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Mid level API for TPU Embeddings without Embedding Accelerator.""" 

16 

17from typing import Any, Dict, Iterable, Optional, Text, Union 

18 

19from tensorflow.python.distribute import distribute_lib 

20from tensorflow.python.distribute import tpu_strategy 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import sparse_tensor 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import embedding_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import sparse_ops 

28from tensorflow.python.ops import variables as tf_variables 

29from tensorflow.python.ops.ragged import ragged_tensor 

30from tensorflow.python.tpu import tpu_embedding_base 

31from tensorflow.python.tpu import tpu_embedding_v2_utils 

32from tensorflow.python.tpu import tpu_replication 

33from tensorflow.python.util import nest 

34from tensorflow.python.util.tf_export import tf_export 

35 

36 

37@tf_export("tpu.experimental.embedding.TPUEmbeddingV0") 

38class TPUEmbeddingV0(tpu_embedding_base.TPUEmbeddingBase): 

39 """The TPUEmbedding mid level API running on TPU without Embedding accelerator. 

40 

41 NOTE: This mid level API is not intended for large embedding table lookup. 

42 Embedding tables will be replicated across devices rather than sharding 

43 across them. To do large embedding table lookup, please use the 

44 `tpu.experimental.embedding.TPUEmbedding` class. This class is an alternative 

45 way to do embedding lookups when the TPU doesn't support any version of 

46 embedding feature. See 

47 `tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed 

48 explanation. 

49 

50 This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError 

51 will be raised. 

52 ```python 

53 strategy = tf.distribute.TPUStrategy(...) 

54 with strategy.scope(): 

55 embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0( 

56 feature_config=feature_config, 

57 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 

58 ``` 

59 When creating a distributed dataset that is to be passed to the lookup 

60 operation a special input option must be specified: 

61 

62 ```python 

63 distributed_dataset = ( 

64 strategy.distribute_datasets_from_function( 

65 dataset_fn=..., 

66 options=tf.distribute.InputOptions( 

67 experimental_fetch_to_device=False)) 

68 dataset_iterator = iter(distributed_dataset) 

69 ``` 

70 

71 Below is an example of a training and evaluation step: 

72 

73 ```python 

74 optimizer = tf.keras.optimizers.SGD(0.1) 

75 

76 @tf.function 

77 def training_step(dataset_iterator, num_steps): 

78 def tpu_step(embedding_features): 

79 with tf.GradientTape() as tape: 

80 tape.watch(embedding.embedding_table.values()) 

81 activation = embedding(embedding_features) 

82 model_output = model(activations) 

83 loss = ... # some function of labels and model_output 

84 

85 embedding_gradients = tape.gradient(loss, 

86 embedding.embedding_table.values()) 

87 optimizer.apply_gradients(list(zip(gradients, 

88 mid_level_api.embedding_tables.values()))) 

89 # Insert your model gradient and optimizer application here 

90 

91 for _ in tf.range(num_steps): 

92 strategy.run(tpu_step, args=(next(dataset_iterator), )) 

93 

94 @tf.function 

95 def evalution_step(dataset_iterator, num_steps): 

96 def tpu_step(embedding_features): 

97 activations = embedding(embedding_features) 

98 model_output = model(activations) 

99 # Insert your evaluation code here. 

100 

101 for _ in tf.range(num_steps): 

102 strategy.run(tpu_step, args=(next(dataset_iterator), )) 

103 ``` 

104 

105 NOTE: The optimizer used here is a Keras optimizer. In order to make the slot 

106 variable creation stay consistent between Keras optimizers and 

107 embedding optimizers, the `slot_variable_creation_fn` argument of the 

108 embedding optimizers has to be passed with the Keras `add_slot` function. Also 

109 note that the slot names might be slightly different between them. 

110 

111 ```python 

112 optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1) 

113 

114 def slot_variable_creation_fn(table, slot_names, slot_initializers): 

115 slots = {} 

116 for slot, initializer in zip(slot_names, slot_initializers): 

117 slots[slot] = optimizer.add_slot(table, slot, initializer) 

118 return slots 

119 

120 embedding_optimizer = tf.experimental.embedding.Adagrad( 

121 learning_rate=0.1, 

122 slot_variable_creation_fn=slot_variable_creation_fn) 

123 

124 # Use the embedding optimizer to create mid level api and keras optimizer to 

125 # apply gradients. 

126 ``` 

127 """ 

128 

129 def __init__( 

130 self, 

131 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 

132 optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access 

133 super(TPUEmbeddingV0, self).__init__(feature_config, optimizer) 

134 self._strategy = distribute_lib.get_strategy() 

135 if not isinstance(self._strategy, 

136 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)): 

137 raise RuntimeError( 

138 "TPUEmbeddingV0 should be created under TPUStrategy but found {}." 

139 .format(self._strategy)) 

140 self._built = False 

141 

142 @property 

143 def embedding_tables( 

144 self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 

145 """Returns a dict of embedding tables, keyed by `TableConfig`.""" 

146 self._maybe_build() 

147 # Only return the tables and not the slot variables. 

148 return { 

149 table: self._variables[table.name]["parameters"] 

150 for table in self._table_config 

151 } 

152 

153 def _create_variables_and_slots( 

154 self) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 

155 """Create variables for TPU embeddings. 

156 

157 Note that this will always ensure that the variable is created under the 

158 TPUStrategy. 

159 

160 Returns: 

161 A dict of dicts. The outer dict is keyed by the table names and the inner 

162 dicts are keyed by 'parameters' and the slot variable names. 

163 """ 

164 variables = {} 

165 for table in self._table_config: 

166 # created TPUDistributedVariable. 

167 variables[table.name] = self._create_variables(table, trainable=True) 

168 return variables 

169 

170 def _maybe_build(self): 

171 if not self._built: 

172 # This can be called while tracing a function, so we wrap the 

173 # initialization code with init_scope so it runs eagerly, this means that 

174 # it will not be included in the function graph generated by tracing so 

175 # that we can be sure that we only initialize the TPU for embeddings 

176 # exactly once. 

177 with ops.init_scope(): 

178 self.build() 

179 

180 def _apply_combiner_to_embeddings( 

181 self, 

182 embeddings: ops.Tensor, 

183 weight: ops.Tensor, 

184 combiner: Optional[Text] = None) -> ops.Tensor: 

185 """Apply the combiner to the embedding look up result on second to last axis. 

186 

187 Args: 

188 embeddings: A Tensor of the embedding lookup result. 

189 weight: A Tensor of weight which has the same shape of the embeddings. 

190 combiner: One of "mean", "sum", "sqrtn". Defaults to "mean". 

191 

192 Raises: 

193 ValueError: If the combiner is not one of 'mean', 'sqrtn' or 'sum'. 

194 Returns: 

195 A Tensor. 

196 """ 

197 if combiner is None: 

198 combiner = "mean" 

199 if combiner == "sum": 

200 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 

201 elif combiner == "mean": 

202 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 

203 weight_sum = math_ops.reduce_sum(weight, axis=-2) 

204 embeddings = math_ops.div_no_nan(embeddings, weight_sum) 

205 elif combiner == "sqrtn": 

206 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 

207 weight_squared = math_ops.pow(weight, 2) 

208 weight_sum = math_ops.reduce_sum(weight_squared, axis=-2) 

209 weight_sum_sqrt = math_ops.sqrt(weight_sum) 

210 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt) 

211 else: 

212 raise ValueError( 

213 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") 

214 return embeddings 

215 

216 def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor, 

217 sequence_length: int) -> ops.Tensor: 

218 """Pad or truncate the embedding lookup result based on the sequence length. 

219 

220 Args: 

221 embeddings: A rank 3 Tensor of the embedding lookup result. 

222 sequence_length: number of the max sequence length set in the feature 

223 config. 

224 

225 Returns: 

226 A Tensor with second last axis padded or truncated. 

227 """ 

228 original_sequence_length = embeddings.shape[1] 

229 if original_sequence_length > sequence_length: 

230 embeddings = array_ops.slice( 

231 embeddings, begin=[0, 0, 0], size=[-1, sequence_length, -1]) 

232 else: 

233 embeddings = array_ops.pad( 

234 embeddings, 

235 paddings=[[0, 0], [0, sequence_length - original_sequence_length], 

236 [0, 0]]) 

237 return embeddings 

238 

239 def embedding_lookup(self, 

240 features: Any, 

241 weights: Optional[Any] = None) -> Any: 

242 """Apply embedding lookup on TPUs using Tensorcore. 

243 

244 Note that all the sparse and ragged tensors will be converted to dense 

245 tensors on CPU and then passed to the TPU to do embedding look up. Large 

246 embedding lookup is not supported by this API, use the TPUEmbedding mid 

247 level api instead. 

248 

249 Args: 

250 features: a nested structure of Tensors, SparseTensors or RaggedTensors. 

251 weights: a nested structure of Tensors, SparseTensors or RaggedTensors or 

252 None for no weights. If not None, structure must match that of inputs, 

253 but entries are allowed to be None. 

254 

255 Returns: 

256 A nested structure of Tensors with the same structure as inputs. 

257 """ 

258 if not self._built: 

259 self.build() 

260 nest.assert_same_structure(features, self._feature_config) 

261 

262 flat_inputs = nest.flatten(features) 

263 flat_weights = [None] * len(flat_inputs) 

264 if weights is not None: 

265 nest.assert_same_structure(features, weights) 

266 flat_weights = nest.flatten(weights) 

267 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 

268 

269 outputs = [] 

270 for inp, weight, (path, feature) in zip(flat_inputs, flat_weights, 

271 flat_features): 

272 table = self.embedding_tables[feature.table] 

273 

274 if weight is not None: 

275 if isinstance(inp, ops.Tensor): 

276 raise ValueError( 

277 "Weight specified for {}, but input is dense.".format(path)) 

278 elif type(weight) is not type(inp): 

279 raise ValueError( 

280 "Weight for {} is of type {} but it does not match type of the " 

281 "input which is {}.".format(path, type(weight), type(inp))) 

282 elif feature.max_sequence_length > 0: 

283 raise ValueError("Weight specified for {}, but this is a sequence " 

284 "feature.".format(path)) 

285 

286 if isinstance(inp, ops.Tensor): 

287 if feature.max_sequence_length > 0: 

288 raise ValueError( 

289 "Feature {} is a sequence feature but a dense tensor " 

290 "was passed.".format(path)) 

291 outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) 

292 

293 elif isinstance(inp, sparse_tensor.SparseTensor): 

294 outputs.append( 

295 self._embedding_lookup_for_sparse_tensor(inp, weight, table, 

296 feature)) 

297 elif isinstance(inp, ragged_tensor.RaggedTensor): 

298 outputs.append( 

299 self._embedding_lookup_for_ragged_tensor(inp, weight, table, 

300 feature)) 

301 else: 

302 raise ValueError("Input {} is type {}. Tensor, SparseTensor or " 

303 "RaggedTensor expected.".format(path, type(inp))) 

304 return nest.pack_sequence_as(self._feature_config, outputs) 

305 

306 def _embedding_lookup_for_sparse_tensor( 

307 self, inp: sparse_tensor.SparseTensor, 

308 weight: Optional[sparse_tensor.SparseTensor], 

309 table: tf_variables.Variable, 

310 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: 

311 """Embedding lookup for sparse tensor based on its feature config. 

312 

313 Args: 

314 inp: a single SparseTensor input. 

315 weight: None or SparseTensor which has the same shape of the input. 

316 table: a table variable. 

317 feature: a feature config. 

318 

319 Returns: 

320 Embedding lookup result. 

321 """ 

322 

323 # This computation needs to placed outside of tpu as the size of the 

324 # indices and values can change for different batch which can cause 

325 # the program to re-compile. 

326 def sparse_to_dense_computation(inp, weight): 

327 if weight is None: 

328 weight = sparse_tensor.SparseTensor( 

329 inp.indices, 

330 array_ops.ones_like(inp.values, dtype=dtypes.float32), 

331 dense_shape=inp.dense_shape) 

332 # Pad the sparse tensor to be dense tensor. 

333 inp = sparse_ops.sparse_tensor_to_dense(inp) 

334 weight = sparse_ops.sparse_tensor_to_dense(weight) 

335 return inp, weight 

336 

337 inp, weight = tpu_replication.outside_compilation( 

338 sparse_to_dense_computation, inp=inp, weight=weight) 

339 

340 embeddings = embedding_ops.embedding_lookup_v2(table, inp) 

341 weight = array_ops.expand_dims(weight, -1) 

342 embeddings *= weight 

343 if not feature.output_shape and feature.max_sequence_length > 0: 

344 embeddings = self._pad_or_truncate_with_sequence_length( 

345 embeddings, feature.max_sequence_length) 

346 else: 

347 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 

348 feature.table.combiner) 

349 return embeddings 

350 

351 def _embedding_lookup_for_ragged_tensor( 

352 self, inp: ragged_tensor.RaggedTensor, 

353 weight: Optional[ragged_tensor.RaggedTensor], 

354 table: tf_variables.Variable, 

355 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: 

356 """Embedding lookup for ragged tensor based on its feature config. 

357 

358 Args: 

359 inp: a single rank 2 RaggedTensor input. 

360 weight: None or RaggedTensor which has the same shape of the input. 

361 table: a table variable. 

362 feature: a feature config. 

363 

364 Returns: 

365 Embedding lookup result. 

366 

367 Raises: 

368 ValueError: if input ragged tensor is not rank 2 or output shape set in 

369 the feature config doesn't match with the first dim size of the input. 

370 """ 

371 if inp.shape.rank != 2: 

372 raise ValueError( 

373 "Only rank 2 ragged tensor is supported, but got rank {}".format( 

374 inp.shape.rank)) 

375 batch_size = inp.shape[0] 

376 

377 # This computation needs to placed outside of tpu as the size of the row 

378 # splits and values can change for different batch which can cause 

379 # the program to re-compile. 

380 def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature): 

381 if weight is None: 

382 weight = ragged_tensor.RaggedTensor.from_row_splits( 

383 array_ops.ones_like(inp.values, dtype=dtypes.float32), 

384 inp.row_splits) 

385 if not feature.output_shape and feature.max_sequence_length > 0: 

386 inp = inp.to_tensor(shape=(batch_size, feature.max_sequence_length)) 

387 # Ignore weight if it is a sequence feature. 

388 weight = array_ops.ones_like(inp, dtype=dtypes.float32) 

389 elif feature.output_shape: 

390 # Eagerly run the following op as the result as to be a number in 

391 # order to use it as part of the output shape. 

392 with ops.init_scope(): 

393 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy() 

394 # If the output batch size matches the data batch size, treat it as 

395 # normal ragged input. 

396 if output_batch_size == batch_size: 

397 inp, weight = inp.to_tensor(), weight.to_tensor() 

398 # If the data batch size is a factor of the output batch size, the 

399 # divide result will be the sequence length. Ignore the weights and 

400 # combiner. 

401 elif output_batch_size > batch_size and output_batch_size % batch_size == 0: 

402 # Pad or truncate in the sequence dimension 

403 seq_length = output_batch_size // batch_size 

404 inp = inp.to_tensor(shape=(batch_size, seq_length)) 

405 # Ignore weight if it is a sequence feature. 

406 weight = array_ops.ones_like(inp, dtype=dtypes.float32) 

407 else: 

408 raise ValueError( 

409 "Output shape set in the FeatureConfig should be the factor of " 

410 "the input data batch size. But instead got output shape {}, " 

411 "input data batch size {}".format(feature.output_shape, 

412 batch_size)) 

413 else: 

414 inp, weight = inp.to_tensor(), weight.to_tensor() 

415 return inp, weight 

416 

417 inp, weight = tpu_replication.outside_compilation( 

418 ragged_to_dense_outside_compilation, 

419 inp=inp, 

420 weight=weight, 

421 batch_size=batch_size, 

422 feature=feature) 

423 

424 embeddings = embedding_ops.embedding_lookup_v2(table, inp) 

425 weight = array_ops.expand_dims(weight, -1) 

426 embeddings *= weight 

427 

428 if feature.output_shape: 

429 with ops.init_scope(): 

430 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy() 

431 if output_batch_size == batch_size: 

432 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 

433 feature.table.combiner) 

434 embeddings = array_ops.reshape( 

435 embeddings, shape=feature.output_shape + [feature.table.dim]) 

436 else: 

437 if feature.max_sequence_length == 0: 

438 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 

439 feature.table.combiner) 

440 return embeddings