Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/feature_column.py: 27%

220 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# =================================================================== 

15"""TPU Feature Column Library.""" 

16import math 

17 

18from tensorflow.python.feature_column import feature_column as fc 

19from tensorflow.python.feature_column import feature_column_lib as fc_lib 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import init_ops 

23from tensorflow.python.ops import variable_scope 

24from tensorflow.python.tpu import tpu 

25from tensorflow.python.tpu import tpu_function 

26from tensorflow.python.tpu import tpu_replication 

27# pylint: disable=protected-access 

28 

29 

30_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' 

31_SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn, 

32 fc_lib.SequenceCategoricalColumn) 

33 

34 

35# For V2 columns, we support anything that inherits from CategoricalColumn 

36# other than those in the denylist. User-provided columns that inherit from 

37# CategoricalColumn may or may not be compatible; it is up to the user to 

38# manage TPU compatibility for custom columns. 

39_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,) 

40_DENYLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn, 

41 fc_lib.BucketizedColumn, 

42 fc_lib.CrossedColumn) 

43_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, 

44 fc._VocabularyFileCategoricalColumn, 

45 fc._VocabularyListCategoricalColumn, 

46 fc._WeightedCategoricalColumn, 

47 fc._SequenceCategoricalColumn 

48 ) + _SUPPORTED_CATEGORICAL_COLUMNS_V2 

49_SEQUENCE_FEATURE_LENGTH_POSTFIX = '_seq_length_' 

50 

51 

52def embedding_column(categorical_column, 

53 dimension, 

54 combiner='mean', 

55 initializer=None, 

56 max_sequence_length=0, 

57 learning_rate_fn=None, 

58 use_safe_embedding_lookup=True): 

59 """TPU embedding_column for `tf.feature_column.embedding_column`. 

60 

61 Note that the interface for TPU embedding_column is different from the non-TPU 

62 version. The following args available for the non-TPU version are NOT 

63 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. 

64 

65 Args: 

66 categorical_column: A categorical_column returned from 

67 categorical_column_with_identity, weighted_categorical_column, 

68 categorical_column_with_vocabulary_file, 

69 categorical_column_with_vocabulary_list, 

70 sequence_categorical_column_with_identity, 

71 sequence_categorical_column_with_vocabulary_file, 

72 sequence_categorical_column_with_vocabulary_list 

73 dimension: An integer specifying dimension of the embedding, must be > 0. 

74 combiner: A string specifying how to reduce if there are multiple entries 

75 in a single row for a non-sequence column. For more information, see 

76 `tf.feature_column.embedding_column`. 

77 initializer: A variable initializer function to be used in embedding 

78 variable initialization. If not specified, defaults to 

79 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and 

80 standard deviation `1/sqrt(dimension)`. 

81 max_sequence_length: An non-negative integer specifying the max sequence 

82 length. Any sequence shorter then this will be padded with 0 embeddings 

83 and any sequence longer will be truncated. This must be positive for 

84 sequence features and 0 for non-sequence features. 

85 learning_rate_fn: A function that takes global step and returns learning 

86 rate for the embedding table. If you intend to use the same learning rate 

87 for multiple embedding tables, please ensure that you pass the exact same 

88 python function to all calls of embedding_column, otherwise performence 

89 may suffer. 

90 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 

91 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 

92 there are no empty rows and all weights and ids are positive at the 

93 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 

94 input tensors. Defaults to true, consider turning off if the above checks 

95 are not needed. Note that having empty rows will not trigger any error 

96 though the output result might be 0 or omitted. 

97 

98 Returns: 

99 A _TPUEmbeddingColumn. 

100 

101 Raises: 

102 ValueError: if `dimension` not > 0. 

103 ValueError: if `initializer` is specified but not callable. 

104 TypeError: if categorical_column is not a supported type. 

105 """ 

106 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2): 

107 raise TypeError('categorical_column for tpu ' 

108 ' embedding_column was ' 

109 f'denylisted type {type(categorical_column)}') 

110 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 

111 raise TypeError( 

112 'categorical_column for tpu ' 

113 ' embedding_column must be type {}, got {}.'.format(' or '.join([ 

114 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS 

115 ]), type(categorical_column))) 

116 if (dimension is None) or (dimension < 1): 

117 raise ValueError('Invalid dimension {}.'.format(dimension)) 

118 

119 if (initializer is not None) and (not callable(initializer)): 

120 raise ValueError('initializer must be callable if specified. ' 

121 'Embedding of column_name: {}'.format( 

122 categorical_column.name)) 

123 if initializer is None: 

124 initializer = init_ops.truncated_normal_initializer( 

125 mean=0.0, stddev=1 / math.sqrt(dimension)) 

126 

127 embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access 

128 

129 def _creator(weight_collections, scope): 

130 embedding_column_layer = fc._EmbeddingColumnLayer( 

131 embedding_shape=embedding_shape, 

132 initializer=initializer, 

133 weight_collections=weight_collections, 

134 trainable=True, 

135 name='embedding_column_layer') 

136 return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable 

137 

138 column = _TPUEmbeddingColumn( 

139 categorical_column=categorical_column, 

140 dimension=dimension, 

141 combiner=combiner, 

142 layer_creator=_creator, 

143 ckpt_to_load_from=None, 

144 tensor_name_in_ckpt=None, 

145 max_norm=None, 

146 trainable=True, 

147 max_sequence_length=max_sequence_length, 

148 learning_rate_fn=learning_rate_fn, 

149 use_safe_embedding_lookup=use_safe_embedding_lookup) 

150 # For Embedding column, the initializer is hidden inside the creator Fn, which 

151 # is not accessible later. So, we attach it to a special field. Also note 

152 # that non-TPU Embedding column and non-TPU shared Embedding column handle the 

153 # initializer differently. See shared_embedding_columns for details. 

154 column._tpu_initializer = initializer 

155 return column 

156 

157 

158def shared_embedding_columns(categorical_columns, 

159 dimension, 

160 combiner='mean', 

161 initializer=None, 

162 shared_embedding_collection_name=None, 

163 max_sequence_lengths=None, 

164 learning_rate_fn=None, 

165 use_safe_embedding_lookup=True): 

166 """List of dense columns that convert from sparse, categorical input. 

167 

168 Note that the interface for TPU embedding_column is different from the non-TPU 

169 version. The following args available for the non-TPU version are NOT 

170 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. 

171 

172 Args: 

173 categorical_columns: A list of categorical_columns returned from 

174 categorical_column_with_identity, weighted_categorical_column, 

175 categorical_column_with_vocabulary_file, 

176 categorical_column_with_vocabulary_list, 

177 sequence_categorical_column_with_identity, 

178 sequence_categorical_column_with_vocabulary_file, 

179 sequence_categorical_column_with_vocabulary_list 

180 dimension: An integer specifying dimension of the embedding, must be > 0. 

181 combiner: A string specifying how to reduce if there are multiple entries 

182 in a single row for a non-sequence column. For more information, see 

183 `tf.feature_column.embedding_column`. 

184 initializer: A variable initializer function to be used in embedding 

185 variable initialization. If not specified, defaults to 

186 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 

187 `1/sqrt(dimension)`. 

188 shared_embedding_collection_name: Optional name of the collection where 

189 shared embedding weights are added. If not given, a reasonable name will 

190 be chosen based on the names of `categorical_columns`. This is also used 

191 in `variable_scope` when creating shared embedding weights. 

192 max_sequence_lengths: An list of non-negative integers, either None or 

193 empty or the same length as the argument categorical_columns. Entries 

194 corresponding to non-sequence columns must be 0 and entries corresponding 

195 to sequence columns specify the max sequence length for the column. Any 

196 sequence shorter then this will be padded with 0 embeddings and any 

197 sequence longer will be truncated. 

198 learning_rate_fn: A function that takes global step and returns learning 

199 rate for the embedding table. If you intend to use the same learning rate 

200 for multiple embedding tables, please ensure that you pass the exact same 

201 python function to all calls of shared_embedding_columns, otherwise 

202 performence may suffer. 

203 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 

204 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 

205 there are no empty rows and all weights and ids are positive at the 

206 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 

207 input tensors. Defaults to true, consider turning off if the above checks 

208 are not needed. Note that having empty rows will not trigger any error 

209 though the output result might be 0 or omitted. 

210 

211 Returns: 

212 A _TPUEmbeddingColumn. 

213 

214 Raises: 

215 ValueError: if `dimension` not > 0. 

216 ValueError: if `initializer` is specified but not callable. 

217 ValueError: if `max_sequence_lengths` is specified and not the same length 

218 as `categorical_columns`. 

219 ValueError: if `max_sequence_lengths` is positive for a non sequence column 

220 or 0 for a sequence column. 

221 """ 

222 for categorical_column in categorical_columns: 

223 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2): 

224 raise TypeError('categorical_column for tpu ' 

225 ' embedding_column was denylisted type ' 

226 f'{type(categorical_column)}') 

227 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 

228 raise TypeError( 

229 'categorical_column for tpu ' 

230 ' shared_embedding_columns must be type {}, got {}.'.format( 

231 ' or '.join( 

232 [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS]), 

233 type(categorical_column))) 

234 

235 if not max_sequence_lengths: 

236 max_sequence_lengths = [0] * len(categorical_columns) 

237 if len(max_sequence_lengths) != len(categorical_columns): 

238 raise ValueError('max_sequence_lengths and categorical_columns must be of ' 

239 'the same length. len(max_sequence_lengths)={} ' 

240 'len(categorical_columns)={}.'.format( 

241 len(max_sequence_lengths), len(categorical_columns))) 

242 

243 if (dimension is None) or (dimension < 1): 

244 raise ValueError('Invalid dimension {}.'.format(dimension)) 

245 

246 if (initializer is not None) and (not callable(initializer)): 

247 raise ValueError('initializer must be callable if specified. ') 

248 if initializer is None: 

249 initializer = init_ops.truncated_normal_initializer( 

250 mean=0.0, stddev=1 / math.sqrt(dimension)) 

251 

252 # Sort the columns so the default collection name is deterministic even if the 

253 # user passes columns from an unsorted collection, such as dict.values(). 

254 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 

255 num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access 

256 

257 for c in sorted_columns[1:]: 

258 if num_buckets != c._num_buckets: # pylint: disable=protected-access 

259 raise ValueError( 

260 'To use shared_embedding_column, all categorical_columns must have ' 

261 'the same number of buckets. Given column: {} with buckets: {} does ' 

262 'not match column: {} with buckets: {}'.format( 

263 sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 

264 

265 if not shared_embedding_collection_name: 

266 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 

267 shared_embedding_collection_name += '_shared_embedding' 

268 

269 tpu_columns = [] 

270 

271 # Create the state (_SharedEmbeddingColumnLayer) here. 

272 for categorical_column, max_sequence_length in zip( 

273 categorical_columns, max_sequence_lengths): 

274 column = _TPUSharedEmbeddingColumn( 

275 categorical_column=categorical_column, 

276 dimension=dimension, 

277 combiner=combiner, 

278 initializer=initializer, 

279 shared_embedding_collection_name=shared_embedding_collection_name, 

280 ckpt_to_load_from=None, 

281 tensor_name_in_ckpt=None, 

282 max_norm=None, 

283 trainable=True, 

284 max_sequence_length=max_sequence_length, 

285 learning_rate_fn=learning_rate_fn, 

286 use_safe_embedding_lookup=use_safe_embedding_lookup) 

287 tpu_columns.append(column) 

288 

289 return tpu_columns 

290 

291 

292class _TPUBaseEmbeddingColumn(object): 

293 """Base class for TPU Embedding Column.""" 

294 

295 def __init__(self, 

296 categorical_column, 

297 max_sequence_length=0, 

298 learning_rate_fn=None): 

299 self._tpu_categorical_column = categorical_column 

300 self._max_sequence_length = max_sequence_length 

301 self._learning_rate_fn = learning_rate_fn 

302 if (self.is_sequence_column() and max_sequence_length < 1): 

303 raise ValueError('max_sequence_length must be greater than 0 for ' 

304 'sequence columns. Got max_sequence_length={} for ' 

305 'sequence column {}.'.format(max_sequence_length, 

306 categorical_column.name)) 

307 if (not self.is_sequence_column() and max_sequence_length != 0): 

308 raise ValueError('Non zero max_seq_length={} specified for non ' 

309 'sequence column {}.'.format(max_sequence_length, 

310 categorical_column.name)) 

311 

312 def get_combiner(self): 

313 """Returns the embedding combiner.""" 

314 raise NotImplementedError('not implemented') 

315 

316 def get_embedding_table_size(self): 

317 """Returns the embedding table size, tuple of vocab size and dimension.""" 

318 raise NotImplementedError('not implemented') 

319 

320 def get_feature_key_name(self): 

321 """Returns the feature key name in the features dict.""" 

322 raise NotImplementedError('not impl') 

323 

324 def get_weight_key_name(self): 

325 """Return the key name for weights.""" 

326 raise NotImplementedError('not impl') 

327 

328 def get_embedding_var_name(self): 

329 """Returns the embedding variable name. 

330 

331 Feature key name and embedding variable name are usually one-to-one mapping. 

332 But for shared embedding columns, it is many-to-one mapping. 

333 """ 

334 raise NotImplementedError('not impl') 

335 

336 def get_initializer(self): 

337 """Returns the initializer.""" 

338 raise NotImplementedError('not impl') 

339 

340 def is_categorical_column_weighted(self): 

341 """Check if the categorical column of the embedding column is weighted.""" 

342 raise NotImplementedError('not impl') 

343 

344 def is_sequence_column(self): 

345 return isinstance(self._tpu_categorical_column, _SUPPORTED_SEQUENCE_COLUMNS) 

346 

347 def get_max_sequence_length(self): 

348 return self._max_sequence_length 

349 

350 def get_learning_rate_fn(self): 

351 return self._learning_rate_fn 

352 

353 def get_sequence_length_feature_key_name(self): 

354 """Get the key for the associated sequence length feature.""" 

355 return get_sequence_length_feature_key_name_from_feature_key_name( 

356 self.get_feature_key_name()) 

357 

358 

359class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): 

360 """Core Embedding Column.""" 

361 

362 def __new__(cls, 

363 categorical_column, 

364 dimension, 

365 combiner='mean', 

366 layer_creator=None, 

367 ckpt_to_load_from=None, 

368 tensor_name_in_ckpt=None, 

369 max_norm=None, 

370 trainable=True, 

371 max_sequence_length=0, 

372 learning_rate_fn=None, 

373 use_safe_embedding_lookup=True, 

374 bypass_scope_validation=False): 

375 # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable 

376 # are not supported on TPU. They are solely for matching the signature of 

377 # __new__ of parent class fc._EmbeddingColumn. 

378 del bypass_scope_validation 

379 # pylint: disable=redundant-keyword-arg 

380 return fc._EmbeddingColumn.__new__( 

381 cls, 

382 categorical_column, 

383 dimension, 

384 combiner=combiner, 

385 layer_creator=layer_creator, 

386 ckpt_to_load_from=ckpt_to_load_from, 

387 tensor_name_in_ckpt=tensor_name_in_ckpt, 

388 max_norm=max_norm, 

389 trainable=trainable, 

390 use_safe_embedding_lookup=use_safe_embedding_lookup) 

391 

392 def __init__(self, 

393 categorical_column, 

394 dimension, 

395 combiner='mean', 

396 layer_creator=None, 

397 ckpt_to_load_from=None, 

398 tensor_name_in_ckpt=None, 

399 max_norm=None, 

400 trainable=True, 

401 max_sequence_length=0, 

402 learning_rate_fn=None, 

403 use_safe_embedding_lookup=True, 

404 bypass_scope_validation=False): 

405 _TPUBaseEmbeddingColumn.__init__( 

406 self, 

407 categorical_column, 

408 max_sequence_length=max_sequence_length, 

409 learning_rate_fn=learning_rate_fn) 

410 self._key = None 

411 # If true, scope validation is skipped to allow the same column to be used 

412 # in multiple variable scopes. By default, this is False, and we expect a 

413 # 1:1 mapping between feature columns and scopes. 

414 self._bypass_scope_validation = bypass_scope_validation 

415 

416 def get_combiner(self): 

417 return self.combiner 

418 

419 def get_embedding_table_size(self): 

420 """Returns num_ids and width.""" 

421 return (self.categorical_column._num_buckets, self.dimension) 

422 

423 def get_feature_key_name(self): 

424 """get_feature_key_name.""" 

425 if self.is_categorical_column_weighted(): 

426 return self.categorical_column.categorical_column.name 

427 return self.categorical_column.name 

428 

429 def get_weight_key_name(self): 

430 """get_weight_key_name.""" 

431 if self.is_categorical_column_weighted(): 

432 return self.categorical_column.weight_feature_key 

433 return None 

434 

435 def get_embedding_var_name(self): 

436 """get_embedding_var_name.""" 

437 return self.categorical_column.name 

438 

439 def get_initializer(self): 

440 return self._tpu_initializer 

441 

442 def is_categorical_column_weighted(self): 

443 """Check if the categorical column of the embedding column is weighted.""" 

444 if isinstance( 

445 self.categorical_column, 

446 ( 

447 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 

448 fc_lib.WeightedCategoricalColumn)): 

449 return True 

450 return False 

451 

452 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 

453 if tpu.under_tpu_inference_context(): 

454 def host_computation(): 

455 return fc._EmbeddingColumn._get_dense_tensor( 

456 self, inputs, weight_collections, trainable) 

457 

458 return tpu_replication.outside_compilation(host_computation) 

459 

460 if _is_running_on_cpu(): 

461 return fc._EmbeddingColumn._get_dense_tensor( 

462 self, inputs, weight_collections, trainable) 

463 

464 # TPU mode 

465 # Get the embeddings from the LazyBuilder. 

466 tensor = inputs.get(self.get_feature_key_name()) 

467 

468 # Add to collection for _create_tpu_embedding_variables_and_ops 

469 _record_variable_scope_and_name( 

470 self.get_embedding_var_name(), 

471 'embedding_weights', 

472 bypass_scope_validation=self._bypass_scope_validation) 

473 

474 return tensor 

475 

476 def _get_sequence_dense_tensor( 

477 self, inputs, weight_collections=None, trainable=None): 

478 if tpu.under_tpu_inference_context(): 

479 def host_computation(): 

480 return fc._EmbeddingColumn._get_sequence_dense_tensor( 

481 self, inputs, weight_collections, trainable) 

482 

483 return tpu_replication.outside_compilation(host_computation) 

484 

485 if _is_running_on_cpu(): 

486 return fc._EmbeddingColumn._get_sequence_dense_tensor( 

487 self, inputs, weight_collections, trainable) 

488 

489 tensor = inputs.get(self.get_feature_key_name()) 

490 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 

491 

492 # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1). 

493 # We need to undo this to match the standard CPU sequence embedding. 

494 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 

495 

496 # Add to collection for _create_tpu_embedding_variables_and_ops 

497 _record_variable_scope_and_name( 

498 self.get_embedding_var_name(), 

499 'embedding_weights', 

500 bypass_scope_validation=self._bypass_scope_validation) 

501 

502 return fc._SequenceDenseColumn.TensorSequenceLengthPair( 

503 dense_tensor=tensor, sequence_length=tensor_lengths) 

504 

505 

506class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, 

507 fc._SharedEmbeddingColumn): 

508 """Core Shared Embedding Column.""" 

509 

510 def __new__(cls, 

511 categorical_column, 

512 dimension, 

513 combiner='mean', 

514 initializer=None, 

515 shared_embedding_collection_name=None, 

516 ckpt_to_load_from=None, 

517 tensor_name_in_ckpt=None, 

518 max_norm=None, 

519 trainable=True, 

520 max_sequence_length=0, 

521 learning_rate_fn=None, 

522 use_safe_embedding_lookup=True): 

523 return fc._SharedEmbeddingColumn.__new__( 

524 cls, 

525 categorical_column, 

526 dimension, 

527 combiner=combiner, 

528 initializer=initializer, 

529 shared_embedding_collection_name=shared_embedding_collection_name, 

530 ckpt_to_load_from=ckpt_to_load_from, 

531 tensor_name_in_ckpt=tensor_name_in_ckpt, 

532 max_norm=max_norm, 

533 trainable=trainable, 

534 use_safe_embedding_lookup=use_safe_embedding_lookup) 

535 

536 def __init__(self, 

537 categorical_column, 

538 dimension, 

539 combiner='mean', 

540 initializer=None, 

541 shared_embedding_collection_name=None, 

542 ckpt_to_load_from=None, 

543 tensor_name_in_ckpt=None, 

544 max_norm=None, 

545 trainable=True, 

546 max_sequence_length=0, 

547 learning_rate_fn=None, 

548 use_safe_embedding_lookup=True): 

549 

550 _TPUBaseEmbeddingColumn.__init__( 

551 self, 

552 categorical_column, 

553 max_sequence_length=max_sequence_length, 

554 learning_rate_fn=learning_rate_fn) 

555 self._key = None 

556 

557 def get_combiner(self): 

558 return self.combiner 

559 

560 def get_embedding_table_size(self): 

561 """Returns num_ids and width.""" 

562 return (self.categorical_column._num_buckets, self.dimension) 

563 

564 def get_feature_key_name(self): 

565 """get_feature_key_name.""" 

566 if self.is_categorical_column_weighted(): 

567 return self.categorical_column.categorical_column.name 

568 return self.categorical_column.name 

569 

570 def get_weight_key_name(self): 

571 """get_weight_key_name.""" 

572 if self.is_categorical_column_weighted(): 

573 return self.categorical_column.weight_feature_key 

574 return None 

575 

576 def get_embedding_var_name(self): 

577 """get_embedding_var_name.""" 

578 return self.shared_embedding_collection_name 

579 

580 def get_initializer(self): 

581 return self.initializer 

582 

583 def is_categorical_column_weighted(self): 

584 """Check if the categorical column of the embedding column is weighted.""" 

585 if isinstance( 

586 self.categorical_column, 

587 ( 

588 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 

589 fc_lib.WeightedCategoricalColumn)): 

590 return True 

591 return False 

592 

593 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 

594 if tpu.under_tpu_inference_context(): 

595 def host_computation(): 

596 return fc._SharedEmbeddingColumn._get_dense_tensor( 

597 self, inputs, weight_collections, trainable) 

598 

599 return tpu_replication.outside_compilation(host_computation) 

600 

601 if _is_running_on_cpu(): 

602 return fc._SharedEmbeddingColumn._get_dense_tensor( 

603 self, inputs, weight_collections, trainable) 

604 

605 # TPU mode 

606 # Get the embeddings from the LazyBuilder. 

607 tensor = inputs.get(self.get_feature_key_name()) 

608 

609 # Add to collection for _create_tpu_embedding_variables_and_ops 

610 _record_variable_scope_and_name( 

611 self.get_embedding_var_name(), 

612 'embedding_weights', 

613 is_shared_embedding=True) 

614 return tensor 

615 

616 def _get_sequence_dense_tensor( 

617 self, inputs, weight_collections=None, trainable=None): 

618 if tpu.under_tpu_inference_context(): 

619 def host_computation(): 

620 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( 

621 self, inputs, weight_collections, trainable) 

622 

623 return tpu_replication.outside_compilation(host_computation) 

624 

625 if _is_running_on_cpu(): 

626 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( 

627 self, inputs, weight_collections, trainable) 

628 

629 tensor = inputs.get(self.get_feature_key_name()) 

630 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 

631 

632 # Add to collection for _create_tpu_embedding_variables_and_ops 

633 _record_variable_scope_and_name( 

634 self.get_embedding_var_name(), 

635 'embedding_weights', 

636 is_shared_embedding=True) 

637 

638 return fc._SequenceDenseColumn.TensorSequenceLengthPair( 

639 dense_tensor=tensor, sequence_length=tensor_lengths) 

640 

641 

642def _record_variable_scope_and_name(embedding_var_name, 

643 embedding_var_name_in_fc, 

644 is_shared_embedding=False, 

645 bypass_scope_validation=False): 

646 """Add embedding variable name and scope to collection.""" 

647 g = ops.get_default_graph() 

648 collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) 

649 if not collection: 

650 collection.append({}) 

651 

652 var_def_dict = collection[0] 

653 

654 captured_scope = variable_scope.get_variable_scope() 

655 captured_scope_name = captured_scope.name 

656 

657 if embedding_var_name in var_def_dict: 

658 if (var_def_dict[embedding_var_name][0] != captured_scope_name and 

659 not is_shared_embedding and not bypass_scope_validation): 

660 raise ValueError( 

661 'For embedding var name {}, the variable scope name is different, ' 

662 'got {}; expected {}'.format(embedding_var_name, 

663 captured_scope_name, 

664 var_def_dict[embedding_var_name][0])) 

665 if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: 

666 raise ValueError( 

667 'For embedding var name {}, the embedding name is different, ' 

668 'got {}; expected {}'.format(embedding_var_name, 

669 embedding_var_name_in_fc, 

670 var_def_dict[embedding_var_name][1])) 

671 else: 

672 var_def_dict[embedding_var_name] = (captured_scope_name, 

673 embedding_var_name_in_fc) 

674 

675 

676def _is_running_on_cpu(): 

677 """Returns True if the current context is CPU model.""" 

678 return tpu_function.get_tpu_context().number_of_shards is None 

679 

680 

681def get_sequence_length_feature_key_name_from_feature_key_name(feature_name): 

682 """Gets the name of the sequence length feature from that of the base feature. 

683 

684 Args: 

685 feature_name: The feature key of a sequence column. 

686 

687 Returns: 

688 A string which is the feature key for the associated feature length column. 

689 """ 

690 return feature_name + _SEQUENCE_FEATURE_LENGTH_POSTFIX 

691 

692 

693def split_sequence_columns(feature_columns): 

694 """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns. 

695 

696 For use in a TPUEstimator model_fn function. E.g. 

697 

698 def model_fn(features): 

699 sequence_columns, feature_columns = ( 

700 tf.tpu.feature_column.split_sequence_columns(feature_columns)) 

701 input = tf.feature_column.input_layer( 

702 features=features, feature_columns=feature_columns) 

703 sequence_features, sequence_lengths = ( 

704 tf.contrib.feature_column.sequence_input_layer( 

705 features=features, feature_columns=sequence_columns)) 

706 

707 Args: 

708 feature_columns: A list of _TPUEmbeddingColumns to split. 

709 

710 Returns: 

711 Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the 

712 second is the non-sequence columns. 

713 """ 

714 sequence_columns = [] 

715 non_sequence_columns = [] 

716 for column in feature_columns: 

717 if not isinstance(column, (_TPUEmbeddingColumn, _TPUSharedEmbeddingColumn)): 

718 raise TypeError( 

719 'column must be a _TPUEmbeddingColumn or _TPUSharedEmbeddingColumn ' 

720 f'but got {type(column)} instead.') 

721 if column.is_sequence_column(): 

722 sequence_columns.append(column) 

723 else: 

724 non_sequence_columns.append(column) 

725 return sequence_columns, non_sequence_columns