Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/feature_column.py: 27%
220 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU Feature Column Library."""
16import math
18from tensorflow.python.feature_column import feature_column as fc
19from tensorflow.python.feature_column import feature_column_lib as fc_lib
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import variable_scope
24from tensorflow.python.tpu import tpu
25from tensorflow.python.tpu import tpu_function
26from tensorflow.python.tpu import tpu_replication
27# pylint: disable=protected-access
30_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
31_SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn,
32 fc_lib.SequenceCategoricalColumn)
35# For V2 columns, we support anything that inherits from CategoricalColumn
36# other than those in the denylist. User-provided columns that inherit from
37# CategoricalColumn may or may not be compatible; it is up to the user to
38# manage TPU compatibility for custom columns.
39_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,)
40_DENYLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn,
41 fc_lib.BucketizedColumn,
42 fc_lib.CrossedColumn)
43_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
44 fc._VocabularyFileCategoricalColumn,
45 fc._VocabularyListCategoricalColumn,
46 fc._WeightedCategoricalColumn,
47 fc._SequenceCategoricalColumn
48 ) + _SUPPORTED_CATEGORICAL_COLUMNS_V2
49_SEQUENCE_FEATURE_LENGTH_POSTFIX = '_seq_length_'
52def embedding_column(categorical_column,
53 dimension,
54 combiner='mean',
55 initializer=None,
56 max_sequence_length=0,
57 learning_rate_fn=None,
58 use_safe_embedding_lookup=True):
59 """TPU embedding_column for `tf.feature_column.embedding_column`.
61 Note that the interface for TPU embedding_column is different from the non-TPU
62 version. The following args available for the non-TPU version are NOT
63 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
65 Args:
66 categorical_column: A categorical_column returned from
67 categorical_column_with_identity, weighted_categorical_column,
68 categorical_column_with_vocabulary_file,
69 categorical_column_with_vocabulary_list,
70 sequence_categorical_column_with_identity,
71 sequence_categorical_column_with_vocabulary_file,
72 sequence_categorical_column_with_vocabulary_list
73 dimension: An integer specifying dimension of the embedding, must be > 0.
74 combiner: A string specifying how to reduce if there are multiple entries
75 in a single row for a non-sequence column. For more information, see
76 `tf.feature_column.embedding_column`.
77 initializer: A variable initializer function to be used in embedding
78 variable initialization. If not specified, defaults to
79 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
80 standard deviation `1/sqrt(dimension)`.
81 max_sequence_length: An non-negative integer specifying the max sequence
82 length. Any sequence shorter then this will be padded with 0 embeddings
83 and any sequence longer will be truncated. This must be positive for
84 sequence features and 0 for non-sequence features.
85 learning_rate_fn: A function that takes global step and returns learning
86 rate for the embedding table. If you intend to use the same learning rate
87 for multiple embedding tables, please ensure that you pass the exact same
88 python function to all calls of embedding_column, otherwise performence
89 may suffer.
90 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
91 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
92 there are no empty rows and all weights and ids are positive at the
93 expense of extra compute cost. This only applies to rank 2 (NxM) shaped
94 input tensors. Defaults to true, consider turning off if the above checks
95 are not needed. Note that having empty rows will not trigger any error
96 though the output result might be 0 or omitted.
98 Returns:
99 A _TPUEmbeddingColumn.
101 Raises:
102 ValueError: if `dimension` not > 0.
103 ValueError: if `initializer` is specified but not callable.
104 TypeError: if categorical_column is not a supported type.
105 """
106 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2):
107 raise TypeError('categorical_column for tpu '
108 ' embedding_column was '
109 f'denylisted type {type(categorical_column)}')
110 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
111 raise TypeError(
112 'categorical_column for tpu '
113 ' embedding_column must be type {}, got {}.'.format(' or '.join([
114 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
115 ]), type(categorical_column)))
116 if (dimension is None) or (dimension < 1):
117 raise ValueError('Invalid dimension {}.'.format(dimension))
119 if (initializer is not None) and (not callable(initializer)):
120 raise ValueError('initializer must be callable if specified. '
121 'Embedding of column_name: {}'.format(
122 categorical_column.name))
123 if initializer is None:
124 initializer = init_ops.truncated_normal_initializer(
125 mean=0.0, stddev=1 / math.sqrt(dimension))
127 embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
129 def _creator(weight_collections, scope):
130 embedding_column_layer = fc._EmbeddingColumnLayer(
131 embedding_shape=embedding_shape,
132 initializer=initializer,
133 weight_collections=weight_collections,
134 trainable=True,
135 name='embedding_column_layer')
136 return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
138 column = _TPUEmbeddingColumn(
139 categorical_column=categorical_column,
140 dimension=dimension,
141 combiner=combiner,
142 layer_creator=_creator,
143 ckpt_to_load_from=None,
144 tensor_name_in_ckpt=None,
145 max_norm=None,
146 trainable=True,
147 max_sequence_length=max_sequence_length,
148 learning_rate_fn=learning_rate_fn,
149 use_safe_embedding_lookup=use_safe_embedding_lookup)
150 # For Embedding column, the initializer is hidden inside the creator Fn, which
151 # is not accessible later. So, we attach it to a special field. Also note
152 # that non-TPU Embedding column and non-TPU shared Embedding column handle the
153 # initializer differently. See shared_embedding_columns for details.
154 column._tpu_initializer = initializer
155 return column
158def shared_embedding_columns(categorical_columns,
159 dimension,
160 combiner='mean',
161 initializer=None,
162 shared_embedding_collection_name=None,
163 max_sequence_lengths=None,
164 learning_rate_fn=None,
165 use_safe_embedding_lookup=True):
166 """List of dense columns that convert from sparse, categorical input.
168 Note that the interface for TPU embedding_column is different from the non-TPU
169 version. The following args available for the non-TPU version are NOT
170 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
172 Args:
173 categorical_columns: A list of categorical_columns returned from
174 categorical_column_with_identity, weighted_categorical_column,
175 categorical_column_with_vocabulary_file,
176 categorical_column_with_vocabulary_list,
177 sequence_categorical_column_with_identity,
178 sequence_categorical_column_with_vocabulary_file,
179 sequence_categorical_column_with_vocabulary_list
180 dimension: An integer specifying dimension of the embedding, must be > 0.
181 combiner: A string specifying how to reduce if there are multiple entries
182 in a single row for a non-sequence column. For more information, see
183 `tf.feature_column.embedding_column`.
184 initializer: A variable initializer function to be used in embedding
185 variable initialization. If not specified, defaults to
186 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
187 `1/sqrt(dimension)`.
188 shared_embedding_collection_name: Optional name of the collection where
189 shared embedding weights are added. If not given, a reasonable name will
190 be chosen based on the names of `categorical_columns`. This is also used
191 in `variable_scope` when creating shared embedding weights.
192 max_sequence_lengths: An list of non-negative integers, either None or
193 empty or the same length as the argument categorical_columns. Entries
194 corresponding to non-sequence columns must be 0 and entries corresponding
195 to sequence columns specify the max sequence length for the column. Any
196 sequence shorter then this will be padded with 0 embeddings and any
197 sequence longer will be truncated.
198 learning_rate_fn: A function that takes global step and returns learning
199 rate for the embedding table. If you intend to use the same learning rate
200 for multiple embedding tables, please ensure that you pass the exact same
201 python function to all calls of shared_embedding_columns, otherwise
202 performence may suffer.
203 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
204 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
205 there are no empty rows and all weights and ids are positive at the
206 expense of extra compute cost. This only applies to rank 2 (NxM) shaped
207 input tensors. Defaults to true, consider turning off if the above checks
208 are not needed. Note that having empty rows will not trigger any error
209 though the output result might be 0 or omitted.
211 Returns:
212 A _TPUEmbeddingColumn.
214 Raises:
215 ValueError: if `dimension` not > 0.
216 ValueError: if `initializer` is specified but not callable.
217 ValueError: if `max_sequence_lengths` is specified and not the same length
218 as `categorical_columns`.
219 ValueError: if `max_sequence_lengths` is positive for a non sequence column
220 or 0 for a sequence column.
221 """
222 for categorical_column in categorical_columns:
223 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2):
224 raise TypeError('categorical_column for tpu '
225 ' embedding_column was denylisted type '
226 f'{type(categorical_column)}')
227 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
228 raise TypeError(
229 'categorical_column for tpu '
230 ' shared_embedding_columns must be type {}, got {}.'.format(
231 ' or '.join(
232 [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS]),
233 type(categorical_column)))
235 if not max_sequence_lengths:
236 max_sequence_lengths = [0] * len(categorical_columns)
237 if len(max_sequence_lengths) != len(categorical_columns):
238 raise ValueError('max_sequence_lengths and categorical_columns must be of '
239 'the same length. len(max_sequence_lengths)={} '
240 'len(categorical_columns)={}.'.format(
241 len(max_sequence_lengths), len(categorical_columns)))
243 if (dimension is None) or (dimension < 1):
244 raise ValueError('Invalid dimension {}.'.format(dimension))
246 if (initializer is not None) and (not callable(initializer)):
247 raise ValueError('initializer must be callable if specified. ')
248 if initializer is None:
249 initializer = init_ops.truncated_normal_initializer(
250 mean=0.0, stddev=1 / math.sqrt(dimension))
252 # Sort the columns so the default collection name is deterministic even if the
253 # user passes columns from an unsorted collection, such as dict.values().
254 sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
255 num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access
257 for c in sorted_columns[1:]:
258 if num_buckets != c._num_buckets: # pylint: disable=protected-access
259 raise ValueError(
260 'To use shared_embedding_column, all categorical_columns must have '
261 'the same number of buckets. Given column: {} with buckets: {} does '
262 'not match column: {} with buckets: {}'.format(
263 sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access
265 if not shared_embedding_collection_name:
266 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
267 shared_embedding_collection_name += '_shared_embedding'
269 tpu_columns = []
271 # Create the state (_SharedEmbeddingColumnLayer) here.
272 for categorical_column, max_sequence_length in zip(
273 categorical_columns, max_sequence_lengths):
274 column = _TPUSharedEmbeddingColumn(
275 categorical_column=categorical_column,
276 dimension=dimension,
277 combiner=combiner,
278 initializer=initializer,
279 shared_embedding_collection_name=shared_embedding_collection_name,
280 ckpt_to_load_from=None,
281 tensor_name_in_ckpt=None,
282 max_norm=None,
283 trainable=True,
284 max_sequence_length=max_sequence_length,
285 learning_rate_fn=learning_rate_fn,
286 use_safe_embedding_lookup=use_safe_embedding_lookup)
287 tpu_columns.append(column)
289 return tpu_columns
292class _TPUBaseEmbeddingColumn(object):
293 """Base class for TPU Embedding Column."""
295 def __init__(self,
296 categorical_column,
297 max_sequence_length=0,
298 learning_rate_fn=None):
299 self._tpu_categorical_column = categorical_column
300 self._max_sequence_length = max_sequence_length
301 self._learning_rate_fn = learning_rate_fn
302 if (self.is_sequence_column() and max_sequence_length < 1):
303 raise ValueError('max_sequence_length must be greater than 0 for '
304 'sequence columns. Got max_sequence_length={} for '
305 'sequence column {}.'.format(max_sequence_length,
306 categorical_column.name))
307 if (not self.is_sequence_column() and max_sequence_length != 0):
308 raise ValueError('Non zero max_seq_length={} specified for non '
309 'sequence column {}.'.format(max_sequence_length,
310 categorical_column.name))
312 def get_combiner(self):
313 """Returns the embedding combiner."""
314 raise NotImplementedError('not implemented')
316 def get_embedding_table_size(self):
317 """Returns the embedding table size, tuple of vocab size and dimension."""
318 raise NotImplementedError('not implemented')
320 def get_feature_key_name(self):
321 """Returns the feature key name in the features dict."""
322 raise NotImplementedError('not impl')
324 def get_weight_key_name(self):
325 """Return the key name for weights."""
326 raise NotImplementedError('not impl')
328 def get_embedding_var_name(self):
329 """Returns the embedding variable name.
331 Feature key name and embedding variable name are usually one-to-one mapping.
332 But for shared embedding columns, it is many-to-one mapping.
333 """
334 raise NotImplementedError('not impl')
336 def get_initializer(self):
337 """Returns the initializer."""
338 raise NotImplementedError('not impl')
340 def is_categorical_column_weighted(self):
341 """Check if the categorical column of the embedding column is weighted."""
342 raise NotImplementedError('not impl')
344 def is_sequence_column(self):
345 return isinstance(self._tpu_categorical_column, _SUPPORTED_SEQUENCE_COLUMNS)
347 def get_max_sequence_length(self):
348 return self._max_sequence_length
350 def get_learning_rate_fn(self):
351 return self._learning_rate_fn
353 def get_sequence_length_feature_key_name(self):
354 """Get the key for the associated sequence length feature."""
355 return get_sequence_length_feature_key_name_from_feature_key_name(
356 self.get_feature_key_name())
359class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
360 """Core Embedding Column."""
362 def __new__(cls,
363 categorical_column,
364 dimension,
365 combiner='mean',
366 layer_creator=None,
367 ckpt_to_load_from=None,
368 tensor_name_in_ckpt=None,
369 max_norm=None,
370 trainable=True,
371 max_sequence_length=0,
372 learning_rate_fn=None,
373 use_safe_embedding_lookup=True,
374 bypass_scope_validation=False):
375 # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
376 # are not supported on TPU. They are solely for matching the signature of
377 # __new__ of parent class fc._EmbeddingColumn.
378 del bypass_scope_validation
379 # pylint: disable=redundant-keyword-arg
380 return fc._EmbeddingColumn.__new__(
381 cls,
382 categorical_column,
383 dimension,
384 combiner=combiner,
385 layer_creator=layer_creator,
386 ckpt_to_load_from=ckpt_to_load_from,
387 tensor_name_in_ckpt=tensor_name_in_ckpt,
388 max_norm=max_norm,
389 trainable=trainable,
390 use_safe_embedding_lookup=use_safe_embedding_lookup)
392 def __init__(self,
393 categorical_column,
394 dimension,
395 combiner='mean',
396 layer_creator=None,
397 ckpt_to_load_from=None,
398 tensor_name_in_ckpt=None,
399 max_norm=None,
400 trainable=True,
401 max_sequence_length=0,
402 learning_rate_fn=None,
403 use_safe_embedding_lookup=True,
404 bypass_scope_validation=False):
405 _TPUBaseEmbeddingColumn.__init__(
406 self,
407 categorical_column,
408 max_sequence_length=max_sequence_length,
409 learning_rate_fn=learning_rate_fn)
410 self._key = None
411 # If true, scope validation is skipped to allow the same column to be used
412 # in multiple variable scopes. By default, this is False, and we expect a
413 # 1:1 mapping between feature columns and scopes.
414 self._bypass_scope_validation = bypass_scope_validation
416 def get_combiner(self):
417 return self.combiner
419 def get_embedding_table_size(self):
420 """Returns num_ids and width."""
421 return (self.categorical_column._num_buckets, self.dimension)
423 def get_feature_key_name(self):
424 """get_feature_key_name."""
425 if self.is_categorical_column_weighted():
426 return self.categorical_column.categorical_column.name
427 return self.categorical_column.name
429 def get_weight_key_name(self):
430 """get_weight_key_name."""
431 if self.is_categorical_column_weighted():
432 return self.categorical_column.weight_feature_key
433 return None
435 def get_embedding_var_name(self):
436 """get_embedding_var_name."""
437 return self.categorical_column.name
439 def get_initializer(self):
440 return self._tpu_initializer
442 def is_categorical_column_weighted(self):
443 """Check if the categorical column of the embedding column is weighted."""
444 if isinstance(
445 self.categorical_column,
446 (
447 fc._WeightedCategoricalColumn, # pylint: disable=protected-access
448 fc_lib.WeightedCategoricalColumn)):
449 return True
450 return False
452 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
453 if tpu.under_tpu_inference_context():
454 def host_computation():
455 return fc._EmbeddingColumn._get_dense_tensor(
456 self, inputs, weight_collections, trainable)
458 return tpu_replication.outside_compilation(host_computation)
460 if _is_running_on_cpu():
461 return fc._EmbeddingColumn._get_dense_tensor(
462 self, inputs, weight_collections, trainable)
464 # TPU mode
465 # Get the embeddings from the LazyBuilder.
466 tensor = inputs.get(self.get_feature_key_name())
468 # Add to collection for _create_tpu_embedding_variables_and_ops
469 _record_variable_scope_and_name(
470 self.get_embedding_var_name(),
471 'embedding_weights',
472 bypass_scope_validation=self._bypass_scope_validation)
474 return tensor
476 def _get_sequence_dense_tensor(
477 self, inputs, weight_collections=None, trainable=None):
478 if tpu.under_tpu_inference_context():
479 def host_computation():
480 return fc._EmbeddingColumn._get_sequence_dense_tensor(
481 self, inputs, weight_collections, trainable)
483 return tpu_replication.outside_compilation(host_computation)
485 if _is_running_on_cpu():
486 return fc._EmbeddingColumn._get_sequence_dense_tensor(
487 self, inputs, weight_collections, trainable)
489 tensor = inputs.get(self.get_feature_key_name())
490 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
492 # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1).
493 # We need to undo this to match the standard CPU sequence embedding.
494 tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
496 # Add to collection for _create_tpu_embedding_variables_and_ops
497 _record_variable_scope_and_name(
498 self.get_embedding_var_name(),
499 'embedding_weights',
500 bypass_scope_validation=self._bypass_scope_validation)
502 return fc._SequenceDenseColumn.TensorSequenceLengthPair(
503 dense_tensor=tensor, sequence_length=tensor_lengths)
506class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
507 fc._SharedEmbeddingColumn):
508 """Core Shared Embedding Column."""
510 def __new__(cls,
511 categorical_column,
512 dimension,
513 combiner='mean',
514 initializer=None,
515 shared_embedding_collection_name=None,
516 ckpt_to_load_from=None,
517 tensor_name_in_ckpt=None,
518 max_norm=None,
519 trainable=True,
520 max_sequence_length=0,
521 learning_rate_fn=None,
522 use_safe_embedding_lookup=True):
523 return fc._SharedEmbeddingColumn.__new__(
524 cls,
525 categorical_column,
526 dimension,
527 combiner=combiner,
528 initializer=initializer,
529 shared_embedding_collection_name=shared_embedding_collection_name,
530 ckpt_to_load_from=ckpt_to_load_from,
531 tensor_name_in_ckpt=tensor_name_in_ckpt,
532 max_norm=max_norm,
533 trainable=trainable,
534 use_safe_embedding_lookup=use_safe_embedding_lookup)
536 def __init__(self,
537 categorical_column,
538 dimension,
539 combiner='mean',
540 initializer=None,
541 shared_embedding_collection_name=None,
542 ckpt_to_load_from=None,
543 tensor_name_in_ckpt=None,
544 max_norm=None,
545 trainable=True,
546 max_sequence_length=0,
547 learning_rate_fn=None,
548 use_safe_embedding_lookup=True):
550 _TPUBaseEmbeddingColumn.__init__(
551 self,
552 categorical_column,
553 max_sequence_length=max_sequence_length,
554 learning_rate_fn=learning_rate_fn)
555 self._key = None
557 def get_combiner(self):
558 return self.combiner
560 def get_embedding_table_size(self):
561 """Returns num_ids and width."""
562 return (self.categorical_column._num_buckets, self.dimension)
564 def get_feature_key_name(self):
565 """get_feature_key_name."""
566 if self.is_categorical_column_weighted():
567 return self.categorical_column.categorical_column.name
568 return self.categorical_column.name
570 def get_weight_key_name(self):
571 """get_weight_key_name."""
572 if self.is_categorical_column_weighted():
573 return self.categorical_column.weight_feature_key
574 return None
576 def get_embedding_var_name(self):
577 """get_embedding_var_name."""
578 return self.shared_embedding_collection_name
580 def get_initializer(self):
581 return self.initializer
583 def is_categorical_column_weighted(self):
584 """Check if the categorical column of the embedding column is weighted."""
585 if isinstance(
586 self.categorical_column,
587 (
588 fc._WeightedCategoricalColumn, # pylint: disable=protected-access
589 fc_lib.WeightedCategoricalColumn)):
590 return True
591 return False
593 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
594 if tpu.under_tpu_inference_context():
595 def host_computation():
596 return fc._SharedEmbeddingColumn._get_dense_tensor(
597 self, inputs, weight_collections, trainable)
599 return tpu_replication.outside_compilation(host_computation)
601 if _is_running_on_cpu():
602 return fc._SharedEmbeddingColumn._get_dense_tensor(
603 self, inputs, weight_collections, trainable)
605 # TPU mode
606 # Get the embeddings from the LazyBuilder.
607 tensor = inputs.get(self.get_feature_key_name())
609 # Add to collection for _create_tpu_embedding_variables_and_ops
610 _record_variable_scope_and_name(
611 self.get_embedding_var_name(),
612 'embedding_weights',
613 is_shared_embedding=True)
614 return tensor
616 def _get_sequence_dense_tensor(
617 self, inputs, weight_collections=None, trainable=None):
618 if tpu.under_tpu_inference_context():
619 def host_computation():
620 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
621 self, inputs, weight_collections, trainable)
623 return tpu_replication.outside_compilation(host_computation)
625 if _is_running_on_cpu():
626 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
627 self, inputs, weight_collections, trainable)
629 tensor = inputs.get(self.get_feature_key_name())
630 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
632 # Add to collection for _create_tpu_embedding_variables_and_ops
633 _record_variable_scope_and_name(
634 self.get_embedding_var_name(),
635 'embedding_weights',
636 is_shared_embedding=True)
638 return fc._SequenceDenseColumn.TensorSequenceLengthPair(
639 dense_tensor=tensor, sequence_length=tensor_lengths)
642def _record_variable_scope_and_name(embedding_var_name,
643 embedding_var_name_in_fc,
644 is_shared_embedding=False,
645 bypass_scope_validation=False):
646 """Add embedding variable name and scope to collection."""
647 g = ops.get_default_graph()
648 collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
649 if not collection:
650 collection.append({})
652 var_def_dict = collection[0]
654 captured_scope = variable_scope.get_variable_scope()
655 captured_scope_name = captured_scope.name
657 if embedding_var_name in var_def_dict:
658 if (var_def_dict[embedding_var_name][0] != captured_scope_name and
659 not is_shared_embedding and not bypass_scope_validation):
660 raise ValueError(
661 'For embedding var name {}, the variable scope name is different, '
662 'got {}; expected {}'.format(embedding_var_name,
663 captured_scope_name,
664 var_def_dict[embedding_var_name][0]))
665 if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
666 raise ValueError(
667 'For embedding var name {}, the embedding name is different, '
668 'got {}; expected {}'.format(embedding_var_name,
669 embedding_var_name_in_fc,
670 var_def_dict[embedding_var_name][1]))
671 else:
672 var_def_dict[embedding_var_name] = (captured_scope_name,
673 embedding_var_name_in_fc)
676def _is_running_on_cpu():
677 """Returns True if the current context is CPU model."""
678 return tpu_function.get_tpu_context().number_of_shards is None
681def get_sequence_length_feature_key_name_from_feature_key_name(feature_name):
682 """Gets the name of the sequence length feature from that of the base feature.
684 Args:
685 feature_name: The feature key of a sequence column.
687 Returns:
688 A string which is the feature key for the associated feature length column.
689 """
690 return feature_name + _SEQUENCE_FEATURE_LENGTH_POSTFIX
693def split_sequence_columns(feature_columns):
694 """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
696 For use in a TPUEstimator model_fn function. E.g.
698 def model_fn(features):
699 sequence_columns, feature_columns = (
700 tf.tpu.feature_column.split_sequence_columns(feature_columns))
701 input = tf.feature_column.input_layer(
702 features=features, feature_columns=feature_columns)
703 sequence_features, sequence_lengths = (
704 tf.contrib.feature_column.sequence_input_layer(
705 features=features, feature_columns=sequence_columns))
707 Args:
708 feature_columns: A list of _TPUEmbeddingColumns to split.
710 Returns:
711 Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
712 second is the non-sequence columns.
713 """
714 sequence_columns = []
715 non_sequence_columns = []
716 for column in feature_columns:
717 if not isinstance(column, (_TPUEmbeddingColumn, _TPUSharedEmbeddingColumn)):
718 raise TypeError(
719 'column must be a _TPUEmbeddingColumn or _TPUSharedEmbeddingColumn '
720 f'but got {type(column)} instead.')
721 if column.is_sequence_column():
722 sequence_columns.append(column)
723 else:
724 non_sequence_columns.append(column)
725 return sequence_columns, non_sequence_columns