Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/feature_column_v2.py: 21%

373 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# =================================================================== 

15"""TPU Feature Column Library.""" 

16import copy 

17import enum 

18import math 

19from tensorflow.python.feature_column import feature_column as fc 

20from tensorflow.python.feature_column import feature_column_lib as fc_lib 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import embedding_ops 

25from tensorflow.python.ops import init_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import sparse_ops 

28from tensorflow.python.ops import variable_scope 

29from tensorflow.python.tpu import tpu 

30from tensorflow.python.tpu import tpu_replication 

31from tensorflow.python.tpu.feature_column import _is_running_on_cpu 

32from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name 

33from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2 

34from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS 

35from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn 

36from tensorflow.python.util.tf_export import tf_export 

37# pylint: disable=protected-access 

38 

39_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core'] 

40_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK' 

41 

42 

43class EmbeddingDevice(enum.Enum): 

44 CPU = 1 

45 TPU_TENSOR_CORE = 2 

46 TPU_EMBEDDING_CORE = 3 

47 

48 

49@tf_export(v1=['tpu.experimental.embedding_column']) 

50def embedding_column_v2(categorical_column, 

51 dimension, 

52 combiner='mean', 

53 initializer=None, 

54 max_sequence_length=0, 

55 learning_rate_fn=None, 

56 embedding_lookup_device=None, 

57 tensor_core_shape=None, 

58 use_safe_embedding_lookup=True): 

59 """TPU version of `tf.compat.v1.feature_column.embedding_column`. 

60 

61 Note that the interface for `tf.tpu.experimental.embedding_column` is 

62 different from that of `tf.compat.v1.feature_column.embedding_column`: The 

63 following arguments are NOT supported: `ckpt_to_load_from`, 

64 `tensor_name_in_ckpt`, `max_norm` and `trainable`. 

65 

66 Use this function in place of `tf.compat.v1.feature_column.embedding_column` 

67 when you want to use the TPU to accelerate your embedding lookups via TPU 

68 embeddings. 

69 

70 ``` 

71 column = tf.feature_column.categorical_column_with_identity(...) 

72 tpu_column = tf.tpu.experimental.embedding_column(column, 10) 

73 ... 

74 def model_fn(features): 

75 dense_feature = tf.keras.layers.DenseFeature(tpu_column) 

76 embedded_feature = dense_feature(features) 

77 ... 

78 

79 estimator = tf.estimator.tpu.TPUEstimator( 

80 model_fn=model_fn, 

81 ... 

82 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

83 column=[tpu_column], 

84 ...)) 

85 ``` 

86 

87 Args: 

88 categorical_column: A categorical column returned from 

89 `categorical_column_with_identity`, `weighted_categorical_column`, 

90 `categorical_column_with_vocabulary_file`, 

91 `categorical_column_with_vocabulary_list`, 

92 `sequence_categorical_column_with_identity`, 

93 `sequence_categorical_column_with_vocabulary_file`, 

94 `sequence_categorical_column_with_vocabulary_list` 

95 dimension: An integer specifying dimension of the embedding, must be > 0. 

96 combiner: A string specifying how to reduce if there are multiple entries 

97 in a single row for a non-sequence column. For more information, see 

98 `tf.feature_column.embedding_column`. 

99 initializer: A variable initializer function to be used in embedding 

100 variable initialization. If not specified, defaults to 

101 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and 

102 standard deviation `1/sqrt(dimension)`. 

103 max_sequence_length: An non-negative integer specifying the max sequence 

104 length. Any sequence shorter then this will be padded with 0 embeddings 

105 and any sequence longer will be truncated. This must be positive for 

106 sequence features and 0 for non-sequence features. 

107 learning_rate_fn: A function that takes global step and returns learning 

108 rate for the embedding table. If you intend to use the same learning rate 

109 for multiple embedding tables, please ensure that you pass the exact same 

110 python function to all calls of embedding_column, otherwise performence 

111 may suffer. 

112 embedding_lookup_device: The device on which to run the embedding lookup. 

113 Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". 

114 If specifying "tpu_tensor_core", a tensor_core_shape must be supplied. 

115 If not specified, the default behavior is embedding lookup on 

116 "tpu_embedding_core" for training and "cpu" for inference. 

117 Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"] 

118 Valid options for serving : ["cpu", "tpu_tensor_core"] 

119 For training, tpu_embedding_core is good for large embedding vocab (>1M), 

120 otherwise, tpu_tensor_core is often sufficient. 

121 For serving, doing embedding lookup on tpu_tensor_core during serving is 

122 a way to reduce host cpu usage in cases where that is a bottleneck. 

123 tensor_core_shape: If supplied, a list of integers which specifies 

124 the intended dense shape to run embedding lookup for this feature on 

125 TensorCore. The batch dimension can be left None or -1 to indicate 

126 a dynamic shape. Only rank 2 shapes currently supported. 

127 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 

128 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 

129 there are no empty rows and all weights and ids are positive at the 

130 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 

131 input tensors. Defaults to true, consider turning off if the above checks 

132 are not needed. Note that having empty rows will not trigger any error 

133 though the output result might be 0 or omitted. 

134 

135 Returns: 

136 A `_TPUEmbeddingColumnV2`. 

137 

138 Raises: 

139 ValueError: if `dimension` not > 0. 

140 ValueError: if `initializer` is specified but not callable. 

141 """ 

142 

143 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2): 

144 raise TypeError( 

145 'categorical_column for tpu ' 

146 'embedding_column must be type {}, got {}.'.format(' or '.join([ 

147 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2 

148 ]), type(categorical_column))) 

149 if (dimension is None) or (dimension < 1): 

150 raise ValueError('Invalid dimension {}.'.format(dimension)) 

151 if tensor_core_shape and len(tensor_core_shape) != 2: 

152 raise ValueError( 

153 'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape)) 

154 

155 if (initializer is not None) and (not callable(initializer)): 

156 raise ValueError('initializer must be callable if specified. ' 

157 'Embedding of column_name: {}'.format( 

158 categorical_column.name)) 

159 if initializer is None: 

160 initializer = init_ops.truncated_normal_initializer( 

161 mean=0.0, stddev=1 / math.sqrt(dimension)) 

162 

163 if (embedding_lookup_device and 

164 embedding_lookup_device not in _ALLOWED_DEVICES): 

165 raise ValueError( 

166 f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}') 

167 

168 if embedding_lookup_device == 'cpu': 

169 embedding_lookup_device = EmbeddingDevice.CPU 

170 elif embedding_lookup_device == 'tpu_tensor_core': 

171 embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE 

172 elif embedding_lookup_device == 'tpu_embedding_core': 

173 embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE 

174 

175 if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: 

176 if not tensor_core_shape: 

177 raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' 

178 'tensor_core_shape to be set.') 

179 if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS): 

180 raise ValueError('embedding_lookup_device=tpu_tensor_core currently does ' 

181 'not support sequence columns.') 

182 

183 if not embedding_lookup_device: 

184 return _TPUEmbeddingColumnV2( 

185 categorical_column=categorical_column, 

186 dimension=dimension, 

187 combiner=combiner, 

188 initializer=initializer, 

189 max_sequence_length=max_sequence_length, 

190 learning_rate_fn=learning_rate_fn, 

191 use_safe_embedding_lookup=use_safe_embedding_lookup) 

192 else: 

193 return _TPUDeviceSpecificEmbeddingColumnV2( 

194 categorical_column=categorical_column, 

195 dimension=dimension, 

196 combiner=combiner, 

197 initializer=initializer, 

198 max_sequence_length=max_sequence_length, 

199 learning_rate_fn=learning_rate_fn, 

200 embedding_lookup_device=embedding_lookup_device, 

201 tensor_core_shape=tensor_core_shape, 

202 use_safe_embedding_lookup=use_safe_embedding_lookup) 

203 

204 

205@tf_export(v1=['tpu.experimental.shared_embedding_columns']) 

206def shared_embedding_columns_v2(categorical_columns, 

207 dimension, 

208 combiner='mean', 

209 initializer=None, 

210 shared_embedding_collection_name=None, 

211 max_sequence_lengths=None, 

212 learning_rate_fn=None, 

213 embedding_lookup_device=None, 

214 tensor_core_shape=None, 

215 use_safe_embedding_lookup=True): 

216 """TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`. 

217 

218 Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is 

219 different from that of `tf.compat.v1.feature_column.shared_embedding_columns`: 

220 The following arguments are NOT supported: `ckpt_to_load_from`, 

221 `tensor_name_in_ckpt`, `max_norm` and `trainable`. 

222 

223 Use this function in place of 

224 tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the 

225 TPU to accelerate your embedding lookups via TPU embeddings. 

226 

227 ``` 

228 column_a = tf.feature_column.categorical_column_with_identity(...) 

229 column_b = tf.feature_column.categorical_column_with_identity(...) 

230 tpu_columns = tf.tpu.experimental.shared_embedding_columns( 

231 [column_a, column_b], 10) 

232 ... 

233 def model_fn(features): 

234 dense_feature = tf.keras.layers.DenseFeature(tpu_columns) 

235 embedded_feature = dense_feature(features) 

236 ... 

237 

238 estimator = tf.estimator.tpu.TPUEstimator( 

239 model_fn=model_fn, 

240 ... 

241 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 

242 column=tpu_columns, 

243 ...)) 

244 ``` 

245 

246 Args: 

247 categorical_columns: A list of categorical columns returned from 

248 `categorical_column_with_identity`, `weighted_categorical_column`, 

249 `categorical_column_with_vocabulary_file`, 

250 `categorical_column_with_vocabulary_list`, 

251 `sequence_categorical_column_with_identity`, 

252 `sequence_categorical_column_with_vocabulary_file`, 

253 `sequence_categorical_column_with_vocabulary_list` 

254 dimension: An integer specifying dimension of the embedding, must be > 0. 

255 combiner: A string specifying how to reduce if there are multiple entries in 

256 a single row for a non-sequence column. For more information, see 

257 `tf.feature_column.embedding_column`. 

258 initializer: A variable initializer function to be used in embedding 

259 variable initialization. If not specified, defaults to 

260 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 

261 `1/sqrt(dimension)`. 

262 shared_embedding_collection_name: Optional name of the collection where 

263 shared embedding weights are added. If not given, a reasonable name will 

264 be chosen based on the names of `categorical_columns`. This is also used 

265 in `variable_scope` when creating shared embedding weights. 

266 max_sequence_lengths: An list of non-negative integers, either None or empty 

267 or the same length as the argument categorical_columns. Entries 

268 corresponding to non-sequence columns must be 0 and entries corresponding 

269 to sequence columns specify the max sequence length for the column. Any 

270 sequence shorter then this will be padded with 0 embeddings and any 

271 sequence longer will be truncated. 

272 learning_rate_fn: A function that takes global step and returns learning 

273 rate for the embedding table. If you intend to use the same learning rate 

274 for multiple embedding tables, please ensure that you pass the exact same 

275 python function to all calls of shared_embedding_columns, otherwise 

276 performence may suffer. 

277 embedding_lookup_device: The device on which to run the embedding lookup. 

278 Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If 

279 specifying "tpu_tensor_core", a tensor_core_shape must be supplied. 

280 Defaults to "cpu". If not specified, the default behavior is embedding 

281 lookup on "tpu_embedding_core" for training and "cpu" for inference. 

282 Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"] 

283 Valid options for serving : ["cpu", "tpu_tensor_core"] 

284 For training, tpu_embedding_core is good for large embedding vocab (>1M), 

285 otherwise, tpu_tensor_core is often sufficient. 

286 For serving, doing embedding lookup on tpu_tensor_core during serving is 

287 a way to reduce host cpu usage in cases where that is a bottleneck. 

288 tensor_core_shape: If supplied, a list of integers which specifies the 

289 intended dense shape to run embedding lookup for this feature on 

290 TensorCore. The batch dimension can be left None or -1 to indicate a 

291 dynamic shape. Only rank 2 shapes currently supported. 

292 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 

293 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 

294 there are no empty rows and all weights and ids are positive at the 

295 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 

296 input tensors. Defaults to true, consider turning off if the above checks 

297 are not needed. Note that having empty rows will not trigger any error 

298 though the output result might be 0 or omitted. 

299 

300 Returns: 

301 A list of `_TPUSharedEmbeddingColumnV2`. 

302 

303 Raises: 

304 ValueError: if `dimension` not > 0. 

305 ValueError: if `initializer` is specified but not callable. 

306 ValueError: if `max_sequence_lengths` is specified and not the same length 

307 as `categorical_columns`. 

308 ValueError: if `max_sequence_lengths` is positive for a non sequence column 

309 or 0 for a sequence column. 

310 """ 

311 

312 for categorical_column in categorical_columns: 

313 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2): 

314 raise TypeError( 

315 'categorical_column for tpu ' 

316 ' shared_embedding_columns must be type {}, got {}.'.format( 

317 ' or '.join( 

318 [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2]), 

319 type(categorical_column))) 

320 

321 if not max_sequence_lengths: 

322 max_sequence_lengths = [0] * len(categorical_columns) 

323 if len(max_sequence_lengths) != len(categorical_columns): 

324 raise ValueError('max_sequence_lengths and categorical_columns must be of ' 

325 'the same length. len(max_sequence_lengths)={} ' 

326 'len(categorical_columns)={}.'.format( 

327 len(max_sequence_lengths), len(categorical_columns))) 

328 

329 if (dimension is None) or (dimension < 1): 

330 raise ValueError('Invalid dimension {}.'.format(dimension)) 

331 if tensor_core_shape and len(tensor_core_shape) != 2: 

332 raise ValueError( 

333 'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape)) 

334 

335 if (initializer is not None) and (not callable(initializer)): 

336 raise ValueError('initializer must be callable if specified. ') 

337 if initializer is None: 

338 initializer = init_ops.truncated_normal_initializer( 

339 mean=0.0, stddev=1 / math.sqrt(dimension)) 

340 

341 # Sort the columns so the default collection name is deterministic even if the 

342 # user passes columns from an unsorted collection, such as dict.values(). 

343 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 

344 num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access 

345 

346 for c in sorted_columns[1:]: 

347 if num_buckets != c._num_buckets: # pylint: disable=protected-access 

348 raise ValueError( 

349 'To use shared_embedding_column, all categorical_columns must have ' 

350 'the same number of buckets. Given column: {} with buckets: {} does ' 

351 'not match column: {} with buckets: {}'.format( 

352 sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 

353 

354 if not shared_embedding_collection_name: 

355 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 

356 shared_embedding_collection_name += '_shared_embedding' 

357 

358 tpu_columns = [] 

359 

360 column_creator = fc_lib.SharedEmbeddingColumnCreator( 

361 dimension=dimension, initializer=initializer, ckpt_to_load_from=None, 

362 tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None, 

363 name=shared_embedding_collection_name) 

364 

365 if (embedding_lookup_device and 

366 embedding_lookup_device not in _ALLOWED_DEVICES): 

367 raise ValueError( 

368 f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}') 

369 

370 if embedding_lookup_device == 'cpu': 

371 embedding_lookup_device = EmbeddingDevice.CPU 

372 elif embedding_lookup_device == 'tpu_tensor_core': 

373 embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE 

374 elif embedding_lookup_device == 'tpu_embedding_core': 

375 embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE 

376 

377 if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: 

378 if not tensor_core_shape: 

379 raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' 

380 'tensor_core_shape to be set.') 

381 for c in sorted_columns: 

382 if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS): 

383 raise ValueError('embedding_lookup_device=tpu_tensor_core currently ' 

384 'does not support sequence columns.') 

385 

386 # Create the state (_SharedEmbeddingColumnLayer) here. 

387 for categorical_column, max_sequence_length in zip( 

388 categorical_columns, max_sequence_lengths): 

389 if not embedding_lookup_device: 

390 column = _TPUSharedEmbeddingColumnV2( 

391 categorical_column=categorical_column, 

392 shared_embedding_column_creator=column_creator, 

393 combiner=combiner, 

394 initializer=initializer, 

395 shared_embedding_collection_name=shared_embedding_collection_name, 

396 max_sequence_length=max_sequence_length, 

397 learning_rate_fn=learning_rate_fn, 

398 use_safe_embedding_lookup=use_safe_embedding_lookup) 

399 else: 

400 column = _TPUSharedDeviceSpecificEmbeddingColumnV2( 

401 categorical_column=categorical_column, 

402 shared_embedding_column_creator=column_creator, 

403 combiner=combiner, 

404 initializer=initializer, 

405 shared_embedding_collection_name=shared_embedding_collection_name, 

406 max_sequence_length=max_sequence_length, 

407 learning_rate_fn=learning_rate_fn, 

408 embedding_lookup_device=embedding_lookup_device, 

409 tensor_core_shape=tensor_core_shape, 

410 use_safe_embedding_lookup=use_safe_embedding_lookup) 

411 tpu_columns.append(column) 

412 

413 return tpu_columns 

414 

415 

416class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn): 

417 """Core Embedding Column.""" 

418 

419 def __new__(cls, 

420 categorical_column, 

421 dimension, 

422 combiner='mean', 

423 initializer=None, 

424 max_sequence_length=0, 

425 learning_rate_fn=None, 

426 use_safe_embedding_lookup=True, 

427 bypass_scope_validation=False): 

428 del bypass_scope_validation 

429 # pylint: disable=redundant-keyword-arg 

430 return fc_lib.EmbeddingColumn.__new__( 

431 cls, 

432 categorical_column, 

433 dimension, 

434 combiner=combiner, 

435 initializer=initializer, 

436 ckpt_to_load_from=None, 

437 tensor_name_in_ckpt=None, 

438 max_norm=None, 

439 trainable=True, 

440 use_safe_embedding_lookup=use_safe_embedding_lookup) 

441 

442 def __getnewargs__(self): 

443 return (self._tpu_categorical_column, self.dimension, self.combiner, 

444 self.initializer, self._max_sequence_length, self._learning_rate_fn, 

445 self.use_safe_embedding_lookup, self._bypass_scope_validation) 

446 

447 def __deepcopy__(self, memo): 

448 return _TPUEmbeddingColumnV2( 

449 *(copy.deepcopy(a, memo) for a in self.__getnewargs__())) 

450 

451 def __init__(self, 

452 categorical_column, 

453 dimension, 

454 combiner='mean', 

455 initializer=None, 

456 max_sequence_length=0, 

457 learning_rate_fn=None, 

458 use_safe_embedding_lookup=True, 

459 bypass_scope_validation=False): 

460 _TPUBaseEmbeddingColumn.__init__( 

461 self, 

462 categorical_column, 

463 max_sequence_length=max_sequence_length, 

464 learning_rate_fn=learning_rate_fn) 

465 self._key = None 

466 # If true, scope validation is skipped to allow the same column to be used 

467 # in multiple variable scopes. By default, this is False, and we expect a 

468 # 1:1 mapping between feature columns and scopes. 

469 self._bypass_scope_validation = bypass_scope_validation 

470 

471 def get_combiner(self): 

472 return self.combiner 

473 

474 def get_embedding_table_size(self): 

475 """Returns num_ids and width.""" 

476 return (self.categorical_column._num_buckets, self.dimension) 

477 

478 def get_feature_key_name(self): 

479 """get_feature_key_name.""" 

480 if self.is_categorical_column_weighted(): 

481 return self.categorical_column.categorical_column.name 

482 return self.categorical_column.name 

483 

484 def get_weight_key_name(self): 

485 """get_weight_key_name.""" 

486 if self.is_categorical_column_weighted(): 

487 return self.categorical_column.weight_feature_key 

488 return None 

489 

490 def get_embedding_var_name(self): 

491 """get_embedding_var_name.""" 

492 return self.categorical_column.name 

493 

494 def get_initializer(self): 

495 return self.initializer 

496 

497 def is_categorical_column_weighted(self): 

498 """Check if the categorical column of the embedding column is weighted.""" 

499 if isinstance( 

500 self.categorical_column, 

501 ( 

502 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 

503 fc_lib.WeightedCategoricalColumn)): 

504 return True 

505 return False 

506 

507 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 

508 if tpu.under_tpu_inference_context(): 

509 def host_computation(): 

510 return fc_lib.EmbeddingColumn._get_dense_tensor( 

511 self, inputs, weight_collections, trainable) 

512 

513 return tpu_replication.outside_compilation(host_computation) 

514 

515 if _is_running_on_cpu(): 

516 return fc_lib.EmbeddingColumn._get_dense_tensor( 

517 self, inputs, weight_collections, trainable) 

518 

519 # TPU mode 

520 # Get the embeddings from the LazyBuilder. 

521 tensor = inputs.get(self.get_feature_key_name()) 

522 

523 # Add to collection for _create_tpu_embedding_variables_and_ops 

524 _record_variable_scope_and_name( 

525 self.get_embedding_var_name(), 

526 'embedding_weights', 

527 bypass_scope_validation=self._bypass_scope_validation) 

528 

529 return tensor 

530 

531 def create_state(self, state_manager): 

532 if _is_running_on_cpu(): 

533 return fc_lib.EmbeddingColumn.create_state( 

534 self, state_manager) 

535 

536 # Create state is called for the EmbeddingColumn to create its embedding 

537 # variables under feature column V2, if we are on TPU so record the scope 

538 # here. 

539 _record_variable_scope_and_name( 

540 self.get_embedding_var_name(), 

541 'embedding_weights', 

542 bypass_scope_validation=self._bypass_scope_validation) 

543 

544 def get_dense_tensor(self, transformation_cache, state_manager): 

545 if tpu.under_tpu_inference_context(): 

546 def host_computation(): 

547 return fc_lib.EmbeddingColumn.get_dense_tensor( 

548 self, transformation_cache, state_manager) 

549 

550 return tpu_replication.outside_compilation(host_computation) 

551 

552 if _is_running_on_cpu(): 

553 return fc_lib.EmbeddingColumn.get_dense_tensor( 

554 self, transformation_cache, state_manager) 

555 

556 # TPU mode 

557 # Get the embeddings from the FeatureTransformationCache. 

558 tensor = transformation_cache.get(self.get_feature_key_name(), 

559 state_manager) 

560 

561 return tensor 

562 

563 def _get_sequence_dense_tensor( 

564 self, inputs, weight_collections=None, trainable=None): 

565 if tpu.under_tpu_inference_context(): 

566 def host_computation(): 

567 return fc_lib.EmbeddingColumn._get_sequence_dense_tensor( 

568 self, inputs, weight_collections, trainable) 

569 

570 return tpu_replication.outside_compilation(host_computation) 

571 

572 if _is_running_on_cpu(): 

573 return fc_lib.EmbeddingColumn._get_sequence_dense_tensor( 

574 self, inputs, weight_collections, trainable) 

575 

576 tensor = inputs.get(self.get_feature_key_name()) 

577 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 

578 

579 # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1). 

580 # We need to undo this to match the standard CPU sequence embedding. 

581 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 

582 

583 # Add to collection for _create_tpu_embedding_variables_and_ops 

584 _record_variable_scope_and_name( 

585 self.get_embedding_var_name(), 

586 'embedding_weights', 

587 bypass_scope_validation=self._bypass_scope_validation) 

588 

589 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 

590 dense_tensor=tensor, sequence_length=tensor_lengths) 

591 

592 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 

593 if tpu.under_tpu_inference_context(): 

594 def host_computation(): 

595 return fc_lib.EmbeddingColumn.get_sequence_dense_tensor( 

596 self, transformation_cache, state_manager) 

597 

598 return tpu_replication.outside_compilation(host_computation) 

599 

600 if _is_running_on_cpu(): 

601 return fc_lib.EmbeddingColumn.get_sequence_dense_tensor( 

602 self, transformation_cache, state_manager) 

603 

604 tensor = transformation_cache.get(self.get_feature_key_name(), 

605 state_manager) 

606 tensor_lengths = transformation_cache.get( 

607 self.get_sequence_length_feature_key_name(), 

608 state_manager) 

609 

610 # FeatureTransformationCache expands rank 1 tensors (like sequence length) 

611 # to rank 2. We need to undo this to match the standard CPU sequence 

612 # embedding. 

613 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 

614 

615 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 

616 dense_tensor=tensor, sequence_length=tensor_lengths) 

617 

618 

619class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn, 

620 fc_lib.SharedEmbeddingColumn): 

621 """Core Shared Embedding Column.""" 

622 

623 def __new__(cls, 

624 categorical_column, 

625 shared_embedding_column_creator, 

626 combiner='mean', 

627 initializer=None, 

628 shared_embedding_collection_name=None, 

629 max_sequence_length=0, 

630 learning_rate_fn=None, 

631 use_safe_embedding_lookup=True): 

632 # pylint: disable=redundant-keyword-arg 

633 return fc_lib.SharedEmbeddingColumn.__new__( 

634 cls, 

635 categorical_column, 

636 combiner=combiner, 

637 shared_embedding_column_creator=shared_embedding_column_creator, 

638 max_norm=None, 

639 use_safe_embedding_lookup=use_safe_embedding_lookup) 

640 

641 def __getnewargs__(self): 

642 return (self._tpu_categorical_column, self.shared_embedding_column_creator, 

643 self.combiner, self._initializer, 

644 self._shared_embedding_collection_name, self._max_sequence_length, 

645 self._learning_rate_fn) 

646 

647 def __deepcopy__(self, memo): 

648 return _TPUSharedEmbeddingColumnV2( 

649 *(copy.deepcopy(a, memo) for a in self.__getnewargs__())) 

650 

651 def __init__(self, 

652 categorical_column, 

653 shared_embedding_column_creator, 

654 combiner='mean', 

655 initializer=None, 

656 shared_embedding_collection_name=None, 

657 max_sequence_length=0, 

658 learning_rate_fn=None, 

659 use_safe_embedding_lookup=True): 

660 

661 _TPUBaseEmbeddingColumn.__init__( 

662 self, 

663 categorical_column, 

664 max_sequence_length=max_sequence_length, 

665 learning_rate_fn=learning_rate_fn) 

666 self._initializer = initializer 

667 self._shared_embedding_collection_name = shared_embedding_collection_name 

668 

669 def get_combiner(self): 

670 return self.combiner 

671 

672 def get_embedding_table_size(self): 

673 """Returns num_ids and width.""" 

674 return (self.categorical_column._num_buckets, 

675 self.shared_embedding_column_creator.dimension) 

676 

677 def get_feature_key_name(self): 

678 """get_feature_key_name.""" 

679 if self.is_categorical_column_weighted(): 

680 return self.categorical_column.categorical_column.name 

681 return self.categorical_column.name 

682 

683 def get_weight_key_name(self): 

684 """get_weight_key_name.""" 

685 if self.is_categorical_column_weighted(): 

686 return self.categorical_column.weight_feature_key 

687 return None 

688 

689 def get_embedding_var_name(self): 

690 """get_embedding_var_name.""" 

691 return self._shared_embedding_collection_name 

692 

693 def get_initializer(self): 

694 return self._initializer 

695 

696 def is_categorical_column_weighted(self): 

697 """Check if the categorical column of the embedding column is weighted.""" 

698 if isinstance( 

699 self.categorical_column, 

700 ( 

701 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 

702 fc_lib.WeightedCategoricalColumn)): 

703 return True 

704 return False 

705 

706 def _get_dense_tensor_internal( 

707 self, transformation_cache, state_manager): 

708 if tpu.under_tpu_inference_context(): 

709 def host_computation(): 

710 return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal( 

711 self, transformation_cache, state_manager) 

712 

713 return tpu_replication.outside_compilation(host_computation) 

714 

715 if _is_running_on_cpu(): 

716 return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal( 

717 self, transformation_cache, state_manager) 

718 

719 # TPU mode 

720 # Get the embeddings from the FeatureTransformationCache. 

721 tensor = transformation_cache.get(self.get_feature_key_name(), 

722 state_manager) 

723 

724 # Add to collection for _create_tpu_embedding_variables_and_ops 

725 # Note that in Feature Column V2, shared embeddings have no scope. 

726 _record_variable_scope_and_name( 

727 self.get_embedding_var_name(), 

728 self.shared_embedding_column_creator._name, 

729 is_shared_embedding=True) 

730 return tensor 

731 

732 def get_sequence_dense_tensor( 

733 self, transformation_cache, state_manager): 

734 if tpu.under_tpu_inference_context(): 

735 def host_computation(): 

736 return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor( 

737 self, transformation_cache, state_manager) 

738 

739 return tpu_replication.outside_compilation(host_computation) 

740 

741 if _is_running_on_cpu(): 

742 return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor( 

743 self, transformation_cache, state_manager) 

744 

745 tensor = self._get_dense_tensor_internal( 

746 transformation_cache, state_manager) 

747 tensor_lengths = transformation_cache.get( 

748 self.get_sequence_length_feature_key_name(), 

749 state_manager) 

750 

751 # FeatureTransformationCache expands rank 1 tensors (like sequence length) 

752 # to rank 2. We need to undo this to match the standard CPU sequence 

753 # embedding. 

754 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 

755 

756 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 

757 dense_tensor=tensor, sequence_length=tensor_lengths) 

758 

759 

760def split_sequence_columns_v2(feature_columns): 

761 """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns. 

762 

763 For use in a TPUEstimator model_fn function. E.g. 

764 

765 def model_fn(features): 

766 sequence_columns, feature_columns = ( 

767 tf.tpu.feature_column.split_sequence_columns(feature_columns)) 

768 input = tf.feature_column.input_layer( 

769 features=features, feature_columns=feature_columns) 

770 sequence_features, sequence_lengths = ( 

771 tf.contrib.feature_column.sequence_input_layer( 

772 features=features, feature_columns=sequence_columns)) 

773 

774 Args: 

775 feature_columns: A list of _TPUEmbeddingColumns to split. 

776 

777 Returns: 

778 Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the 

779 second is the non-sequence columns. 

780 """ 

781 sequence_columns = [] 

782 non_sequence_columns = [] 

783 for column in feature_columns: 

784 if not isinstance(column, (_TPUEmbeddingColumnV2, 

785 _TPUSharedEmbeddingColumnV2)): 

786 raise TypeError( 

787 'column must be a _TPUEmbeddingColumnV2 or ' 

788 f'_TPUSharedEmbeddingColumnV2 but got {type(column)} instead.') 

789 if column.is_sequence_column(): 

790 sequence_columns.append(column) 

791 else: 

792 non_sequence_columns.append(column) 

793 return sequence_columns, non_sequence_columns 

794 

795 

796def sparse_embedding_aggregate_slice(params, 

797 values_and_values_mask, 

798 combiner='mean', 

799 name='sparse_embedding_aggregate_slice'): 

800 """Uses XLA's dynamic slice operations to perform embedding lookups. 

801 

802 From third_party/cloud_tpu/models/movielens/tpu_embedding.py 

803 

804 Args: 

805 params: Tensor of embedding table. Rank 2 (table_size x embedding dim) 

806 values_and_values_mask: is a two-tuple that contains: values - Tensor of 

807 embedding indices. Rank 2 (batch x n_indices) values_mask - Tensor of mask 

808 / weights. Rank 2 (batch x n_indices) 

809 combiner: The combiner to use for the embedding lookup. Currently supports 

810 'sum' and 'mean'. 

811 name: Optional name scope for created ops 

812 

813 Returns: 

814 Rank 2 tensor of aggregated (per batch element) embedding vectors. 

815 

816 Raises: 

817 ValueError: Combiner is not supported. 

818 """ 

819 values, values_mask = values_and_values_mask # unpack the two-tuple 

820 with ops.name_scope(name): 

821 _, embedding_dimension = params.get_shape().as_list() 

822 n_batch, n_indices_padded = values.get_shape().as_list() 

823 if not n_batch: 

824 n_batch = -1 

825 

826 emb_lookup = array_ops.reshape( 

827 embedding_ops.embedding_lookup( 

828 params, array_ops.reshape(values, [n_batch, n_indices_padded])), 

829 [n_batch, n_indices_padded, embedding_dimension]) 

830 

831 values_mask_broadcast = array_ops.reshape(values_mask, 

832 [n_batch, n_indices_padded, 1]) 

833 aggregate_emb = math_ops.reduce_sum( 

834 emb_lookup * values_mask_broadcast, axis=1) 

835 if combiner == 'sum': 

836 return aggregate_emb 

837 elif combiner == 'mean': 

838 # In the case we have an empty row, both aggregate_emb and 

839 # math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus, 

840 # we can take max it with a non-zero value to prevent NaNs. Note that 

841 # math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer 

842 # values so 1.0 is the smallest value. 

843 return aggregate_emb / math_ops.maximum( 

844 math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0) 

845 else: 

846 raise ValueError('Dense TPU Embedding does not support combiner ' 

847 'other than sum and mean.') 

848 

849 

850def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size): 

851 """Creates statically-sized Tensors containing indices and weights. 

852 

853 From third_party/cloud_tpu/models/movielens/tpu_embedding.py 

854 

855 Also computes sparse_indices.values % embedding_table_size, for equivalent 

856 functionality to sparse_column_with_integerized_feature. The returned 

857 padded weight Tensor also doubles as a mask indicating which values in 

858 the returned padded indices Tensor are indices versus padded zeros. 

859 

860 Args: 

861 sparse_indices: SparseTensor of embedding lookup indices. 

862 padded_size: Number of columns of the returned Tensors. Indices which fall 

863 out of bounds will be truncated to the padded size. 

864 

865 Returns: 

866 (sparse_indices.values padded to the specified size, 

867 a mask the same size as the returned padded values in which 0s 

868 indicate padded locations and 1s (or values from sparse_weights) 

869 indicate actual values) 

870 """ 

871 batch_size = sparse_indices.dense_shape[0] 

872 sparse_indices = sparse_ops.sparse_slice(sparse_indices, [0, 0], 

873 [batch_size, padded_size]) 

874 indices, values = sparse_indices.indices, sparse_indices.values 

875 

876 padded_values = array_ops.scatter_nd( 

877 indices, 

878 math_ops.cast(values, dtypes.int32), 

879 shape=(batch_size, padded_size)) 

880 

881 weights = array_ops.ones_like(values, dtype=dtypes.float32) 

882 padded_mask = array_ops.scatter_nd( 

883 indices, weights, shape=(batch_size, padded_size)) 

884 

885 return padded_values, padded_mask 

886 

887 

888def _check_invalid_cases(embedding_lookup_device): 

889 """Checks for invalid embedding_lookup_device configurations.""" 

890 if (tpu.under_tpu_inference_context() and 

891 embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE): 

892 raise ValueError( 

893 'Using embedding_lookup_device=tpu_embedding_core during inference ' 

894 'is not supported.') 

895 if embedding_lookup_device == EmbeddingDevice.CPU: 

896 if not tpu.under_tpu_inference_context(): 

897 raise ValueError( 

898 'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" ' 

899 'during training is not supported.') 

900 

901 

902class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2): 

903 """TPUEmbeddingColumn which allows serving on TensorCore.""" 

904 

905 def __new__(cls, *args, **kwargs): 

906 # For __new__, just capture the inference dense shape and call parent. 

907 if 'tensor_core_shape' in kwargs: 

908 cls._tensor_core_shape = kwargs['tensor_core_shape'] 

909 del kwargs['tensor_core_shape'] 

910 if 'embedding_lookup_device' in kwargs: 

911 cls._embedding_lookup_device = kwargs['embedding_lookup_device'] 

912 del kwargs['embedding_lookup_device'] 

913 return _TPUEmbeddingColumnV2.__new__(cls, *args, **kwargs) # pytype: disable=wrong-keyword-args # always-use-return-annotations 

914 

915 def __init__(self, *args, **kwargs): 

916 # For __init__, just capture the inference dense shape and call parent. 

917 if 'tensor_core_shape' in kwargs: 

918 self._tensor_core_shape = kwargs['tensor_core_shape'] 

919 del kwargs['tensor_core_shape'] 

920 if 'embedding_lookup_device' in kwargs: 

921 self._embedding_lookup_device = kwargs['embedding_lookup_device'] 

922 del kwargs['embedding_lookup_device'] 

923 _TPUEmbeddingColumnV2.__init__(self, *args, **kwargs) 

924 

925 def __deepcopy__(self, memo): 

926 return _TPUDeviceSpecificEmbeddingColumnV2( 

927 *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), 

928 tensor_core_shape=self._tensor_core_shape, 

929 embedding_lookup_device=self._embedding_lookup_device) 

930 

931 def create_state(self, state_manager): 

932 _check_invalid_cases(self._embedding_lookup_device) 

933 # CPU case. 

934 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 

935 is_cpu = is_cpu or _is_running_on_cpu() 

936 if is_cpu: 

937 return fc_lib.EmbeddingColumn.create_state(self, state_manager) 

938 # TPU_EMBEDDING_CORE case. 

939 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 

940 return super(_TPUDeviceSpecificEmbeddingColumnV2, 

941 self).create_state(state_manager) 

942 

943 # TPU_EMBEDDING_CORE case. 

944 return fc_lib.EmbeddingColumn.create_state(self, state_manager) 

945 

946 def get_dense_tensor(self, transformation_cache, state_manager): 

947 """Private method that follows get_dense_tensor.""" 

948 _check_invalid_cases(self._embedding_lookup_device) 

949 # CPU Case. 

950 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 

951 is_cpu = is_cpu or _is_running_on_cpu() 

952 if is_cpu: 

953 return super(_TPUDeviceSpecificEmbeddingColumnV2, 

954 self).get_dense_tensor(transformation_cache, state_manager) 

955 # TPU_EMBEDDING_CORE case. 

956 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 

957 return super(_TPUDeviceSpecificEmbeddingColumnV2, 

958 self).get_dense_tensor(transformation_cache, state_manager) 

959 

960 # TPU_EMBEDDING_CORE cases. 

961 if tpu.under_tpu_inference_context(): 

962 # For inference, use outside compile to densify and pad the input tensors. 

963 sparse_tensor = transformation_cache.get(self.categorical_column.name, 

964 state_manager) 

965 

966 def host_computation(): 

967 return pad_sparse_embedding_lookup_indices(sparse_tensor, 

968 self._tensor_core_shape[1]) 

969 

970 values, mask = tpu_replication.outside_compilation(host_computation) 

971 else: 

972 # For training, the inputs should already have been densified and padded. 

973 values = transformation_cache.get(self.categorical_column.name, 

974 state_manager) 

975 mask = transformation_cache.get( 

976 self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, 

977 state_manager) 

978 embedding_weights = state_manager.get_variable( 

979 self, name='embedding_weights') 

980 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 

981 self.get_combiner()) 

982 

983 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 

984 _check_invalid_cases(self._embedding_lookup_device) 

985 # CPU Case. 

986 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 

987 is_cpu = is_cpu or _is_running_on_cpu() 

988 if is_cpu: 

989 return super(_TPUDeviceSpecificEmbeddingColumnV2, 

990 self)._get_dense_tensor(inputs, weight_collections, 

991 trainable) 

992 # TPU_EMBEDDING_CORE case. 

993 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 

994 return super(_TPUDeviceSpecificEmbeddingColumnV2, 

995 self)._get_dense_tensor(inputs, weight_collections, 

996 trainable) 

997 

998 # TPU_EMBEDDING_CORE cases. 

999 if tpu.under_tpu_inference_context(): 

1000 # For inference, use outside compile to densify and pad the input tensors. 

1001 sparse_tensor = inputs.get(self.get_feature_key_name()) 

1002 

1003 def host_computation(): 

1004 return pad_sparse_embedding_lookup_indices(sparse_tensor, 

1005 self._tensor_core_shape[1]) 

1006 

1007 values, mask = tpu_replication.outside_compilation(host_computation) 

1008 else: 

1009 # For training, the inputs should already have been densified and padded. 

1010 values = inputs.get(self.get_feature_key_name()) 

1011 mask = inputs.get(self.get_feature_key_name() + 

1012 _TENSOR_CORE_MASK_KEY_SUFFIX) 

1013 

1014 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 

1015 if (weight_collections and 

1016 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 

1017 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 

1018 embedding_weights = variable_scope.get_variable( 

1019 name='embedding_weights', 

1020 shape=embedding_shape, 

1021 dtype=dtypes.float32, 

1022 initializer=self.initializer, 

1023 trainable=self.trainable and trainable, 

1024 collections=weight_collections) 

1025 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 

1026 self.get_combiner()) 

1027 

1028 

1029class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2): 

1030 """TPUSharedEmbeddingColumnV2 which allows serving on TensorCore.""" 

1031 

1032 def __new__(cls, *args, **kwargs): 

1033 # For __new__, just capture the inference dense shape and call parent. 

1034 if 'tensor_core_shape' in kwargs: 

1035 cls._tensor_core_shape = kwargs['tensor_core_shape'] 

1036 del kwargs['tensor_core_shape'] 

1037 if 'embedding_lookup_device' in kwargs: 

1038 cls._embedding_lookup_device = kwargs['embedding_lookup_device'] 

1039 del kwargs['embedding_lookup_device'] 

1040 

1041 return _TPUSharedEmbeddingColumnV2.__new__(cls, *args, **kwargs) # pytype: disable=wrong-keyword-args # always-use-return-annotations 

1042 

1043 def __init__(self, *args, **kwargs): 

1044 # For __init__, just capture the inference dense shape and call parent. 

1045 if 'tensor_core_shape' in kwargs: 

1046 self._tensor_core_shape = kwargs['tensor_core_shape'] 

1047 del kwargs['tensor_core_shape'] 

1048 if 'embedding_lookup_device' in kwargs: 

1049 self._embedding_lookup_device = kwargs['embedding_lookup_device'] 

1050 del kwargs['embedding_lookup_device'] 

1051 _TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs) 

1052 

1053 def __deepcopy__(self, memo): 

1054 return _TPUSharedDeviceSpecificEmbeddingColumnV2( 

1055 *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), 

1056 tensor_core_shape=self._tensor_core_shape, 

1057 embedding_lookup_device=self._embedding_lookup_device) 

1058 

1059 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 

1060 """Private method that follows _get_dense_tensor_internal.""" 

1061 _check_invalid_cases(self._embedding_lookup_device) 

1062 # CPU Case. 

1063 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 

1064 is_cpu = is_cpu or _is_running_on_cpu() 

1065 if is_cpu: 

1066 return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, 

1067 self)._get_dense_tensor_internal(transformation_cache, 

1068 state_manager) 

1069 # TPU_EMBEDDING_CORE case. 

1070 if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 

1071 return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, 

1072 self)._get_dense_tensor_internal(transformation_cache, 

1073 state_manager) 

1074 

1075 # TPU_EMBEDDING_CORE cases. 

1076 if tpu.under_tpu_inference_context(): 

1077 # For inference, use outside compile to densify and pad the input tensors. 

1078 sparse_tensor = transformation_cache.get(self.categorical_column.name, 

1079 state_manager) 

1080 

1081 def host_computation(): 

1082 return pad_sparse_embedding_lookup_indices(sparse_tensor, 

1083 self._tensor_core_shape[1]) 

1084 

1085 values, mask = tpu_replication.outside_compilation(host_computation) 

1086 else: 

1087 # For training, the inputs should already have been densified and padded. 

1088 values = transformation_cache.get(self.categorical_column.name, 

1089 state_manager) 

1090 mask = transformation_cache.get( 

1091 self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, 

1092 state_manager) 

1093 

1094 # Do a dense embedding lookup on TensorCore. 

1095 embedding_weights = self.shared_embedding_column_creator.embedding_weights 

1096 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 

1097 self.get_combiner())