Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/feature_column/feature_column_v2.py: 37%
1431 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
241. Feature type:
26 * Continuous features can be represented by `numeric_column`.
27 * Categorical features can be represented by any `categorical_column_with_*`
28 column:
29 - `categorical_column_with_vocabulary_list`
30 - `categorical_column_with_vocabulary_file`
31 - `categorical_column_with_hash_bucket`
32 - `categorical_column_with_identity`
33 - `weighted_categorical_column`
352. Model type:
37 * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
39 Continuous features can be directly fed into deep neural network models.
41 age_column = numeric_column("age")
43 To feed sparse features into DNN models, wrap the column with
44 `embedding_column` or `indicator_column`. `indicator_column` is recommended
45 for features with only a few possible values. For features with many
46 possible values, to reduce the size of your model, `embedding_column` is
47 recommended.
49 embedded_dept_column = embedding_column(
50 categorical_column_with_vocabulary_list(
51 "department", ["math", "philosophy", ...]), dimension=10)
53 * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
55 Sparse features can be fed directly into linear models. They behave like an
56 indicator column but with an efficient implementation.
58 dept_column = categorical_column_with_vocabulary_list("department",
59 ["math", "philosophy", "english"])
61 It is recommended that continuous features be bucketized before being
62 fed into linear models.
64 bucketized_age_column = bucketized_column(
65 source_column=age_column,
66 boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
68 Sparse features can be crossed (also known as conjuncted or combined) in
69 order to form non-linearities, and then fed into linear models.
71 cross_dept_age_column = crossed_column(
72 columns=["department", bucketized_age_column],
73 hash_bucket_size=1000)
75Example of building canned `Estimator`s using FeatureColumns:
77 ```python
78 # Define features and transformations
79 deep_feature_columns = [age_column, embedded_dept_column]
80 wide_feature_columns = [dept_column, bucketized_age_column,
81 cross_dept_age_column]
83 # Build deep model
84 estimator = DNNClassifier(
85 feature_columns=deep_feature_columns,
86 hidden_units=[500, 250, 50])
87 estimator.train(...)
89 # Or build a wide model
90 estimator = LinearClassifier(
91 feature_columns=wide_feature_columns)
92 estimator.train(...)
94 # Or build a wide and deep model!
95 estimator = DNNLinearCombinedClassifier(
96 linear_feature_columns=wide_feature_columns,
97 dnn_feature_columns=deep_feature_columns,
98 dnn_hidden_units=[500, 250, 50])
99 estimator.train(...)
100 ```
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
109 ```python
110 # Building model via layers
112 deep_feature_columns = [age_column, embedded_dept_column]
113 columns_to_tensor = parse_feature_columns_from_examples(
114 serialized=my_data,
115 feature_columns=deep_feature_columns)
116 first_layer = input_layer(
117 features=columns_to_tensor,
118 feature_columns=deep_feature_columns)
119 second_layer = fully_connected(first_layer, ...)
120 ```
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
126import abc
127import collections
128import math
129import re
131import numpy as np
132import six
134from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
135from tensorflow.python.data.ops import readers
136from tensorflow.python.eager import context
137from tensorflow.python.feature_column import feature_column as fc_old
138from tensorflow.python.feature_column import utils as fc_utils
139from tensorflow.python.framework import dtypes
140from tensorflow.python.framework import ops
141from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
142from tensorflow.python.framework import tensor_shape
143from tensorflow.python.ops import array_ops
144from tensorflow.python.ops import array_ops_stack
145from tensorflow.python.ops import check_ops
146from tensorflow.python.ops import cond
147from tensorflow.python.ops import embedding_ops
148from tensorflow.python.ops import init_ops
149from tensorflow.python.ops import lookup_ops
150from tensorflow.python.ops import math_ops
151from tensorflow.python.ops import parsing_ops
152from tensorflow.python.ops import sparse_ops
153from tensorflow.python.ops import string_ops
154from tensorflow.python.ops import variable_scope
155from tensorflow.python.ops import variables
156from tensorflow.python.platform import gfile
157from tensorflow.python.platform import tf_logging as logging
158from tensorflow.python.trackable import autotrackable
159from tensorflow.python.trackable import base as trackable
160from tensorflow.python.trackable import data_structures
161from tensorflow.python.training import checkpoint_utils
162from tensorflow.python.util import deprecation
163from tensorflow.python.util import nest
164from tensorflow.python.util import tf_inspect
165from tensorflow.python.util.compat import collections_abc
166from tensorflow.python.util.tf_export import tf_export
167from tensorflow.tools.docs import doc_controls
169_FEATURE_COLUMN_DEPRECATION_DATE = None
170_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
171 'deprecated. Please use the new FeatureColumn '
172 'APIs instead.')
173_FEATURE_COLUMN_DEPRECATION_WARNING = """\
174 Warning: tf.feature_column is not recommended for new code. Instead,
175 feature preprocessing can be done directly using either [Keras preprocessing
176 layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns)
177 or through the one-stop utility [`tf.keras.utils.FeatureSpace`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/FeatureSpace)
178 built on top of them. See the [migration guide](https://tensorflow.org/guide/migrate)
179 for details.
180 """
181_FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING = (
182 'Use Keras preprocessing layers instead, either directly or via the '
183 '`tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has '
184 'a functional equivalent in `tf.keras.layers` for feature preprocessing '
185 'when training a Keras model.')
188class StateManager(object):
189 """Manages the state associated with FeatureColumns.
191 Some `FeatureColumn`s create variables or resources to assist their
192 computation. The `StateManager` is responsible for creating and storing these
193 objects since `FeatureColumn`s are supposed to be stateless configuration
194 only.
195 """
197 def create_variable(self,
198 feature_column,
199 name,
200 shape,
201 dtype=None,
202 trainable=True,
203 use_resource=True,
204 initializer=None):
205 """Creates a new variable.
207 Args:
208 feature_column: A `FeatureColumn` object this variable corresponds to.
209 name: variable name.
210 shape: variable shape.
211 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
212 trainable: Whether this variable is trainable or not.
213 use_resource: If true, we use resource variables. Otherwise we use
214 RefVariable.
215 initializer: initializer instance (callable).
217 Returns:
218 The created variable.
219 """
220 del feature_column, name, shape, dtype, trainable, use_resource, initializer
221 raise NotImplementedError('StateManager.create_variable')
223 def add_variable(self, feature_column, var):
224 """Adds an existing variable to the state.
226 Args:
227 feature_column: A `FeatureColumn` object to associate this variable with.
228 var: The variable.
229 """
230 del feature_column, var
231 raise NotImplementedError('StateManager.add_variable')
233 def get_variable(self, feature_column, name):
234 """Returns an existing variable.
236 Args:
237 feature_column: A `FeatureColumn` object this variable corresponds to.
238 name: variable name.
239 """
240 del feature_column, name
241 raise NotImplementedError('StateManager.get_var')
243 def add_resource(self, feature_column, name, resource):
244 """Creates a new resource.
246 Resources can be things such as tables, variables, trackables, etc.
248 Args:
249 feature_column: A `FeatureColumn` object this resource corresponds to.
250 name: Name of the resource.
251 resource: The resource.
253 Returns:
254 The created resource.
255 """
256 del feature_column, name, resource
257 raise NotImplementedError('StateManager.add_resource')
259 def has_resource(self, feature_column, name):
260 """Returns true iff a resource with same name exists.
262 Resources can be things such as tables, variables, trackables, etc.
264 Args:
265 feature_column: A `FeatureColumn` object this variable corresponds to.
266 name: Name of the resource.
267 """
268 del feature_column, name
269 raise NotImplementedError('StateManager.has_resource')
271 def get_resource(self, feature_column, name):
272 """Returns an already created resource.
274 Resources can be things such as tables, variables, trackables, etc.
276 Args:
277 feature_column: A `FeatureColumn` object this variable corresponds to.
278 name: Name of the resource.
279 """
280 del feature_column, name
281 raise NotImplementedError('StateManager.get_resource')
284@tf_export('__internal__.feature_column.StateManager', v1=[])
285class _StateManagerImpl(StateManager):
286 """Manages the state of DenseFeatures and LinearLayer.
288 Some `FeatureColumn`s create variables or resources to assist their
289 computation. The `StateManager` is responsible for creating and storing these
290 objects since `FeatureColumn`s are supposed to be stateless configuration
291 only.
292 """
294 def __init__(self, layer, trainable):
295 """Creates an _StateManagerImpl object.
297 Args:
298 layer: The input layer this state manager is associated with.
299 trainable: Whether by default, variables created are trainable or not.
300 """
301 self._trainable = trainable
302 self._layer = layer
303 if self._layer is not None and not hasattr(self._layer, '_resources'):
304 self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access
305 self._cols_to_vars_map = collections.defaultdict(lambda: {})
306 self._cols_to_resources_map = collections.defaultdict(lambda: {})
308 def create_variable(self,
309 feature_column,
310 name,
311 shape,
312 dtype=None,
313 trainable=True,
314 use_resource=True,
315 initializer=None):
316 """Creates a new variable.
318 Args:
319 feature_column: A `FeatureColumn` object this variable corresponds to.
320 name: variable name.
321 shape: variable shape.
322 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
323 trainable: Whether this variable is trainable or not.
324 use_resource: If true, we use resource variables. Otherwise we use
325 RefVariable.
326 initializer: initializer instance (callable).
328 Returns:
329 The created variable.
330 """
331 if name in self._cols_to_vars_map[feature_column]:
332 raise ValueError('Variable already exists.')
334 # We explicitly track these variables since `name` is not guaranteed to be
335 # unique and disable manual tracking that the add_weight call does.
336 with trackable.no_manual_dependency_tracking_scope(self._layer):
337 var = self._layer.add_weight(
338 name=name,
339 shape=shape,
340 dtype=dtype,
341 initializer=initializer,
342 trainable=self._trainable and trainable,
343 use_resource=use_resource,
344 # TODO(rohanj): Get rid of this hack once we have a mechanism for
345 # specifying a default partitioner for an entire layer. In that case,
346 # the default getter for Layers should work.
347 getter=variable_scope.get_variable)
348 if isinstance(var, variables.PartitionedVariable):
349 for v in var:
350 part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access
351 self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access
352 else:
353 if isinstance(var, trackable.Trackable):
354 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access
356 self._cols_to_vars_map[feature_column][name] = var
357 return var
359 def get_variable(self, feature_column, name):
360 """Returns an existing variable.
362 Args:
363 feature_column: A `FeatureColumn` object this variable corresponds to.
364 name: variable name.
365 """
366 if name in self._cols_to_vars_map[feature_column]:
367 return self._cols_to_vars_map[feature_column][name]
368 raise ValueError('Variable does not exist.')
370 def add_resource(self, feature_column, resource_name, resource):
371 """Creates a new resource.
373 Resources can be things such as tables, variables, trackables, etc.
375 Args:
376 feature_column: A `FeatureColumn` object this resource corresponds to.
377 resource_name: Name of the resource.
378 resource: The resource.
380 Returns:
381 The created resource.
382 """
383 self._cols_to_resources_map[feature_column][resource_name] = resource
384 # pylint: disable=protected-access
385 if self._layer is not None and isinstance(resource, trackable.Trackable):
386 # Add trackable resources to the layer for serialization.
387 if feature_column.name not in self._layer._resources:
388 self._layer._resources[feature_column.name] = data_structures.Mapping()
389 if resource_name not in self._layer._resources[feature_column.name]:
390 self._layer._resources[feature_column.name][resource_name] = resource
391 # pylint: enable=protected-access
393 def has_resource(self, feature_column, resource_name):
394 """Returns true iff a resource with same name exists.
396 Resources can be things such as tables, variables, trackables, etc.
398 Args:
399 feature_column: A `FeatureColumn` object this variable corresponds to.
400 resource_name: Name of the resource.
401 """
402 return resource_name in self._cols_to_resources_map[feature_column]
404 def get_resource(self, feature_column, resource_name):
405 """Returns an already created resource.
407 Resources can be things such as tables, variables, trackables, etc.
409 Args:
410 feature_column: A `FeatureColumn` object this variable corresponds to.
411 resource_name: Name of the resource.
412 """
413 if (feature_column not in self._cols_to_resources_map or
414 resource_name not in self._cols_to_resources_map[feature_column]):
415 raise ValueError('Resource does not exist.')
416 return self._cols_to_resources_map[feature_column][resource_name]
419def _transform_features_v2(features, feature_columns, state_manager):
420 """Returns transformed features based on features columns passed in.
422 Please note that most probably you would not need to use this function. Please
423 check `input_layer` and `linear_model` to see whether they will
424 satisfy your use case or not.
426 Example:
428 ```python
429 # Define features and transformations
430 crosses_a_x_b = crossed_column(
431 columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
432 price_buckets = bucketized_column(
433 source_column=numeric_column("price"), boundaries=[...])
435 columns = [crosses_a_x_b, price_buckets]
436 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
437 transformed = transform_features(features=features, feature_columns=columns)
439 assertCountEqual(columns, transformed.keys())
440 ```
442 Args:
443 features: A mapping from key to tensors. `FeatureColumn`s look up via these
444 keys. For example `numeric_column('price')` will look at 'price' key in
445 this dict. Values can be a `SparseTensor` or a `Tensor` depends on
446 corresponding `FeatureColumn`.
447 feature_columns: An iterable containing all the `FeatureColumn`s.
448 state_manager: A StateManager object that holds the FeatureColumn state.
450 Returns:
451 A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
452 """
453 feature_columns = _normalize_feature_columns(feature_columns)
454 outputs = {}
455 with ops.name_scope(
456 None, default_name='transform_features', values=features.values()):
457 transformation_cache = FeatureTransformationCache(features)
458 for column in feature_columns:
459 with ops.name_scope(
460 None,
461 default_name=_sanitize_column_name_for_variable_scope(column.name)):
462 outputs[column] = transformation_cache.get(column, state_manager)
463 return outputs
466@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
467@tf_export(
468 'feature_column.make_parse_example_spec',
469 v1=[])
470@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
471def make_parse_example_spec_v2(feature_columns):
472 """Creates parsing spec dictionary from input feature_columns.
474 The returned dictionary can be used as arg 'features' in
475 `tf.io.parse_example`.
477 Typical usage example:
479 ```python
480 # Define features and transformations
481 feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...)
482 feature_b = tf.feature_column.numeric_column(...)
483 feature_c_bucketized = tf.feature_column.bucketized_column(
484 tf.feature_column.numeric_column("feature_c"), ...)
485 feature_a_x_feature_c = tf.feature_column.crossed_column(
486 columns=["feature_a", feature_c_bucketized], ...)
488 feature_columns = set(
489 [feature_b, feature_c_bucketized, feature_a_x_feature_c])
490 features = tf.io.parse_example(
491 serialized=serialized_examples,
492 features=tf.feature_column.make_parse_example_spec(feature_columns))
493 ```
495 For the above example, make_parse_example_spec would return the dict:
497 ```python
498 {
499 "feature_a": parsing_ops.VarLenFeature(tf.string),
500 "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
501 "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
502 }
503 ```
505 Args:
506 feature_columns: An iterable containing all feature columns. All items
507 should be instances of classes derived from `FeatureColumn`.
509 Returns:
510 A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
511 value.
513 Raises:
514 ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
515 instance.
516 """
517 result = {}
518 for column in feature_columns:
519 if not isinstance(column, FeatureColumn):
520 raise ValueError('All feature_columns must be FeatureColumn instances. '
521 'Given: {}'.format(column))
522 config = column.parse_example_spec
523 for key, value in six.iteritems(config):
524 if key in result and value != result[key]:
525 raise ValueError('feature_columns contain different parse_spec for key '
526 '{}. Given {} and {}'.format(key, value, result[key]))
527 result.update(config)
528 return result
531@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
532@tf_export('feature_column.embedding_column')
533@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
534def embedding_column(categorical_column,
535 dimension,
536 combiner='mean',
537 initializer=None,
538 ckpt_to_load_from=None,
539 tensor_name_in_ckpt=None,
540 max_norm=None,
541 trainable=True,
542 use_safe_embedding_lookup=True):
543 """`DenseColumn` that converts from sparse, categorical input.
545 Use this when your inputs are sparse, but you want to convert them to a dense
546 representation (e.g., to feed to a DNN).
548 Inputs must be a `CategoricalColumn` created by any of the
549 `categorical_column_*` function. Here is an example of using
550 `embedding_column` with `DNNClassifier`:
552 ```python
553 video_id = categorical_column_with_identity(
554 key='video_id', num_buckets=1000000, default_value=0)
555 columns = [embedding_column(video_id, 9),...]
557 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
559 label_column = ...
560 def input_fn():
561 features = tf.io.parse_example(
562 ..., features=make_parse_example_spec(columns + [label_column]))
563 labels = features.pop(label_column.name)
564 return features, labels
566 estimator.train(input_fn=input_fn, steps=100)
567 ```
569 Here is an example using `embedding_column` with model_fn:
571 ```python
572 def model_fn(features, ...):
573 video_id = categorical_column_with_identity(
574 key='video_id', num_buckets=1000000, default_value=0)
575 columns = [embedding_column(video_id, 9),...]
576 dense_tensor = input_layer(features, columns)
577 # Form DNN layers, calculate loss, and return EstimatorSpec.
578 ...
579 ```
581 Args:
582 categorical_column: A `CategoricalColumn` created by a
583 `categorical_column_with_*` function. This column produces the sparse IDs
584 that are inputs to the embedding lookup.
585 dimension: An integer specifying dimension of the embedding, must be > 0.
586 combiner: A string specifying how to reduce if there are multiple entries in
587 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
588 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
589 with bag-of-words columns. Each of this can be thought as example level
590 normalizations on the column. For more information, see
591 `tf.embedding_lookup_sparse`.
592 initializer: A variable initializer function to be used in embedding
593 variable initialization. If not specified, defaults to
594 `truncated_normal_initializer` with mean `0.0` and standard deviation
595 `1/sqrt(dimension)`.
596 ckpt_to_load_from: String representing checkpoint name/pattern from which to
597 restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
598 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
599 to restore the column weights. Required if `ckpt_to_load_from` is not
600 `None`.
601 max_norm: If not `None`, embedding values are l2-normalized to this value.
602 trainable: Whether or not the embedding is trainable. Default is True.
603 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
604 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
605 there are no empty rows and all weights and ids are positive at the
606 expense of extra compute cost. This only applies to rank 2 (NxM) shaped
607 input tensors. Defaults to true, consider turning off if the above checks
608 are not needed. Note that having empty rows will not trigger any error
609 though the output result might be 0 or omitted.
611 Returns:
612 `DenseColumn` that converts from sparse input.
614 Raises:
615 ValueError: if `dimension` not > 0.
616 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
617 is specified.
618 ValueError: if `initializer` is specified and is not callable.
619 RuntimeError: If eager execution is enabled.
620 """
621 if (dimension is None) or (dimension < 1):
622 raise ValueError('Invalid dimension {}.'.format(dimension))
623 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
624 raise ValueError('Must specify both `ckpt_to_load_from` and '
625 '`tensor_name_in_ckpt` or none of them.')
627 if (initializer is not None) and (not callable(initializer)):
628 raise ValueError('initializer must be callable if specified. '
629 'Embedding of column_name: {}'.format(
630 categorical_column.name))
631 if initializer is None:
632 initializer = init_ops.truncated_normal_initializer(
633 mean=0.0, stddev=1 / math.sqrt(dimension))
635 return EmbeddingColumn(
636 categorical_column=categorical_column,
637 dimension=dimension,
638 combiner=combiner,
639 initializer=initializer,
640 ckpt_to_load_from=ckpt_to_load_from,
641 tensor_name_in_ckpt=tensor_name_in_ckpt,
642 max_norm=max_norm,
643 trainable=trainable,
644 use_safe_embedding_lookup=use_safe_embedding_lookup)
647@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
648@tf_export(v1=['feature_column.shared_embedding_columns'])
649@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
650def shared_embedding_columns(categorical_columns,
651 dimension,
652 combiner='mean',
653 initializer=None,
654 shared_embedding_collection_name=None,
655 ckpt_to_load_from=None,
656 tensor_name_in_ckpt=None,
657 max_norm=None,
658 trainable=True,
659 use_safe_embedding_lookup=True):
660 """List of dense columns that convert from sparse, categorical input.
662 This is similar to `embedding_column`, except that it produces a list of
663 embedding columns that share the same embedding weights.
665 Use this when your inputs are sparse and of the same type (e.g. watched and
666 impression video IDs that share the same vocabulary), and you want to convert
667 them to a dense representation (e.g., to feed to a DNN).
669 Inputs must be a list of categorical columns created by any of the
670 `categorical_column_*` function. They must all be of the same type and have
671 the same arguments except `key`. E.g. they can be
672 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
673 all columns could also be weighted_categorical_column.
675 Here is an example embedding of two features for a DNNClassifier model:
677 ```python
678 watched_video_id = categorical_column_with_vocabulary_file(
679 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
680 impression_video_id = categorical_column_with_vocabulary_file(
681 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
682 columns = shared_embedding_columns(
683 [watched_video_id, impression_video_id], dimension=10)
685 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
687 label_column = ...
688 def input_fn():
689 features = tf.io.parse_example(
690 ..., features=make_parse_example_spec(columns + [label_column]))
691 labels = features.pop(label_column.name)
692 return features, labels
694 estimator.train(input_fn=input_fn, steps=100)
695 ```
697 Here is an example using `shared_embedding_columns` with model_fn:
699 ```python
700 def model_fn(features, ...):
701 watched_video_id = categorical_column_with_vocabulary_file(
702 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
703 impression_video_id = categorical_column_with_vocabulary_file(
704 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
705 columns = shared_embedding_columns(
706 [watched_video_id, impression_video_id], dimension=10)
707 dense_tensor = input_layer(features, columns)
708 # Form DNN layers, calculate loss, and return EstimatorSpec.
709 ...
710 ```
712 Args:
713 categorical_columns: List of categorical columns created by a
714 `categorical_column_with_*` function. These columns produce the sparse IDs
715 that are inputs to the embedding lookup. All columns must be of the same
716 type and have the same arguments except `key`. E.g. they can be
717 categorical_column_with_vocabulary_file with the same vocabulary_file.
718 Some or all columns could also be weighted_categorical_column.
719 dimension: An integer specifying dimension of the embedding, must be > 0.
720 combiner: A string specifying how to reduce if there are multiple entries in
721 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
722 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
723 with bag-of-words columns. Each of this can be thought as example level
724 normalizations on the column. For more information, see
725 `tf.embedding_lookup_sparse`.
726 initializer: A variable initializer function to be used in embedding
727 variable initialization. If not specified, defaults to
728 `truncated_normal_initializer` with mean `0.0` and standard deviation
729 `1/sqrt(dimension)`.
730 shared_embedding_collection_name: Optional name of the collection where
731 shared embedding weights are added. If not given, a reasonable name will
732 be chosen based on the names of `categorical_columns`. This is also used
733 in `variable_scope` when creating shared embedding weights.
734 ckpt_to_load_from: String representing checkpoint name/pattern from which to
735 restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
736 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
737 to restore the column weights. Required if `ckpt_to_load_from` is not
738 `None`.
739 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
740 than this value, before combining.
741 trainable: Whether or not the embedding is trainable. Default is True.
742 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
743 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
744 there are no empty rows and all weights and ids are positive at the
745 expense of extra compute cost. This only applies to rank 2 (NxM) shaped
746 input tensors. Defaults to true, consider turning off if the above checks
747 are not needed. Note that having empty rows will not trigger any error
748 though the output result might be 0 or omitted.
750 Returns:
751 A list of dense columns that converts from sparse input. The order of
752 results follows the ordering of `categorical_columns`.
754 Raises:
755 ValueError: if `dimension` not > 0.
756 ValueError: if any of the given `categorical_columns` is of different type
757 or has different arguments than the others.
758 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
759 is specified.
760 ValueError: if `initializer` is specified and is not callable.
761 RuntimeError: if eager execution is enabled.
762 """
763 if context.executing_eagerly():
764 raise RuntimeError('shared_embedding_columns are not supported when eager '
765 'execution is enabled.')
767 if (dimension is None) or (dimension < 1):
768 raise ValueError('Invalid dimension {}.'.format(dimension))
769 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
770 raise ValueError('Must specify both `ckpt_to_load_from` and '
771 '`tensor_name_in_ckpt` or none of them.')
773 if (initializer is not None) and (not callable(initializer)):
774 raise ValueError('initializer must be callable if specified.')
775 if initializer is None:
776 initializer = init_ops.truncated_normal_initializer(
777 mean=0.0, stddev=1. / math.sqrt(dimension))
779 # Sort the columns so the default collection name is deterministic even if the
780 # user passes columns from an unsorted collection, such as dict.values().
781 sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
783 c0 = sorted_columns[0]
784 num_buckets = c0._num_buckets # pylint: disable=protected-access
785 if not isinstance(c0, fc_old._CategoricalColumn): # pylint: disable=protected-access
786 raise ValueError(
787 'All categorical_columns must be subclasses of _CategoricalColumn. '
788 'Given: {}, of type: {}'.format(c0, type(c0)))
789 while isinstance(
790 c0,
791 (
792 fc_old._WeightedCategoricalColumn, # pylint: disable=protected-access
793 WeightedCategoricalColumn,
794 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
795 SequenceCategoricalColumn)):
796 c0 = c0.categorical_column
797 for c in sorted_columns[1:]:
798 while isinstance(
799 c,
800 (
801 fc_old._WeightedCategoricalColumn, # pylint: disable=protected-access
802 WeightedCategoricalColumn,
803 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
804 SequenceCategoricalColumn)):
805 c = c.categorical_column
806 if not isinstance(c, type(c0)):
807 raise ValueError(
808 'To use shared_embedding_column, all categorical_columns must have '
809 'the same type, or be weighted_categorical_column or sequence column '
810 'of the same type. Given column: {} of type: {} does not match given '
811 'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
812 if num_buckets != c._num_buckets: # pylint: disable=protected-access
813 raise ValueError(
814 'To use shared_embedding_column, all categorical_columns must have '
815 'the same number of buckets. ven column: {} with buckets: {} does '
816 'not match column: {} with buckets: {}'.format(
817 c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access
819 if not shared_embedding_collection_name:
820 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
821 shared_embedding_collection_name += '_shared_embedding'
823 result = []
824 for column in categorical_columns:
825 result.append(
826 fc_old._SharedEmbeddingColumn( # pylint: disable=protected-access
827 categorical_column=column,
828 initializer=initializer,
829 dimension=dimension,
830 combiner=combiner,
831 shared_embedding_collection_name=shared_embedding_collection_name,
832 ckpt_to_load_from=ckpt_to_load_from,
833 tensor_name_in_ckpt=tensor_name_in_ckpt,
834 max_norm=max_norm,
835 trainable=trainable,
836 use_safe_embedding_lookup=use_safe_embedding_lookup))
838 return result
841@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
842@tf_export(
843 'feature_column.shared_embeddings',
844 v1=[])
845@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
846def shared_embedding_columns_v2(categorical_columns,
847 dimension,
848 combiner='mean',
849 initializer=None,
850 shared_embedding_collection_name=None,
851 ckpt_to_load_from=None,
852 tensor_name_in_ckpt=None,
853 max_norm=None,
854 trainable=True,
855 use_safe_embedding_lookup=True):
856 """List of dense columns that convert from sparse, categorical input.
858 This is similar to `embedding_column`, except that it produces a list of
859 embedding columns that share the same embedding weights.
861 Use this when your inputs are sparse and of the same type (e.g. watched and
862 impression video IDs that share the same vocabulary), and you want to convert
863 them to a dense representation (e.g., to feed to a DNN).
865 Inputs must be a list of categorical columns created by any of the
866 `categorical_column_*` function. They must all be of the same type and have
867 the same arguments except `key`. E.g. they can be
868 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
869 all columns could also be weighted_categorical_column.
871 Here is an example embedding of two features for a DNNClassifier model:
873 ```python
874 watched_video_id = categorical_column_with_vocabulary_file(
875 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
876 impression_video_id = categorical_column_with_vocabulary_file(
877 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
878 columns = shared_embedding_columns(
879 [watched_video_id, impression_video_id], dimension=10)
881 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
883 label_column = ...
884 def input_fn():
885 features = tf.io.parse_example(
886 ..., features=make_parse_example_spec(columns + [label_column]))
887 labels = features.pop(label_column.name)
888 return features, labels
890 estimator.train(input_fn=input_fn, steps=100)
891 ```
893 Here is an example using `shared_embedding_columns` with model_fn:
895 ```python
896 def model_fn(features, ...):
897 watched_video_id = categorical_column_with_vocabulary_file(
898 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
899 impression_video_id = categorical_column_with_vocabulary_file(
900 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
901 columns = shared_embedding_columns(
902 [watched_video_id, impression_video_id], dimension=10)
903 dense_tensor = input_layer(features, columns)
904 # Form DNN layers, calculate loss, and return EstimatorSpec.
905 ...
906 ```
908 Args:
909 categorical_columns: List of categorical columns created by a
910 `categorical_column_with_*` function. These columns produce the sparse IDs
911 that are inputs to the embedding lookup. All columns must be of the same
912 type and have the same arguments except `key`. E.g. they can be
913 categorical_column_with_vocabulary_file with the same vocabulary_file.
914 Some or all columns could also be weighted_categorical_column.
915 dimension: An integer specifying dimension of the embedding, must be > 0.
916 combiner: A string specifying how to reduce if there are multiple entries in
917 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
918 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
919 with bag-of-words columns. Each of this can be thought as example level
920 normalizations on the column. For more information, see
921 `tf.embedding_lookup_sparse`.
922 initializer: A variable initializer function to be used in embedding
923 variable initialization. If not specified, defaults to
924 `truncated_normal_initializer` with mean `0.0` and standard deviation
925 `1/sqrt(dimension)`.
926 shared_embedding_collection_name: Optional collective name of these columns.
927 If not given, a reasonable name will be chosen based on the names of
928 `categorical_columns`.
929 ckpt_to_load_from: String representing checkpoint name/pattern from which to
930 restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
931 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
932 to restore the column weights. Required if `ckpt_to_load_from` is not
933 `None`.
934 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
935 than this value, before combining.
936 trainable: Whether or not the embedding is trainable. Default is True.
937 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
938 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
939 there are no empty rows and all weights and ids are positive at the
940 expense of extra compute cost. This only applies to rank 2 (NxM) shaped
941 input tensors. Defaults to true, consider turning off if the above checks
942 are not needed. Note that having empty rows will not trigger any error
943 though the output result might be 0 or omitted.
945 Returns:
946 A list of dense columns that converts from sparse input. The order of
947 results follows the ordering of `categorical_columns`.
949 Raises:
950 ValueError: if `dimension` not > 0.
951 ValueError: if any of the given `categorical_columns` is of different type
952 or has different arguments than the others.
953 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
954 is specified.
955 ValueError: if `initializer` is specified and is not callable.
956 RuntimeError: if eager execution is enabled.
957 """
958 if context.executing_eagerly():
959 raise RuntimeError('shared_embedding_columns are not supported when eager '
960 'execution is enabled.')
962 if (dimension is None) or (dimension < 1):
963 raise ValueError('Invalid dimension {}.'.format(dimension))
964 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
965 raise ValueError('Must specify both `ckpt_to_load_from` and '
966 '`tensor_name_in_ckpt` or none of them.')
968 if (initializer is not None) and (not callable(initializer)):
969 raise ValueError('initializer must be callable if specified.')
970 if initializer is None:
971 initializer = init_ops.truncated_normal_initializer(
972 mean=0.0, stddev=1. / math.sqrt(dimension))
974 # Sort the columns so the default collection name is deterministic even if the
975 # user passes columns from an unsorted collection, such as dict.values().
976 sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
978 c0 = sorted_columns[0]
979 num_buckets = c0.num_buckets
980 if not isinstance(c0, CategoricalColumn):
981 raise ValueError(
982 'All categorical_columns must be subclasses of CategoricalColumn. '
983 'Given: {}, of type: {}'.format(c0, type(c0)))
984 while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
985 c0 = c0.categorical_column
986 for c in sorted_columns[1:]:
987 while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
988 c = c.categorical_column
989 if not isinstance(c, type(c0)):
990 raise ValueError(
991 'To use shared_embedding_column, all categorical_columns must have '
992 'the same type, or be weighted_categorical_column or sequence column '
993 'of the same type. Given column: {} of type: {} does not match given '
994 'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
995 if num_buckets != c.num_buckets:
996 raise ValueError(
997 'To use shared_embedding_column, all categorical_columns must have '
998 'the same number of buckets. Given column: {} with buckets: {} does '
999 'not match column: {} with buckets: {}'.format(
1000 c0, num_buckets, c, c.num_buckets))
1002 if not shared_embedding_collection_name:
1003 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
1004 shared_embedding_collection_name += '_shared_embedding'
1006 column_creator = SharedEmbeddingColumnCreator(
1007 dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
1008 num_buckets, trainable, shared_embedding_collection_name,
1009 use_safe_embedding_lookup)
1011 result = []
1012 for column in categorical_columns:
1013 result.append(
1014 column_creator(
1015 categorical_column=column, combiner=combiner, max_norm=max_norm))
1017 return result
1020@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1021@tf_export('feature_column.numeric_column')
1022@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1023def numeric_column(key,
1024 shape=(1,),
1025 default_value=None,
1026 dtype=dtypes.float32,
1027 normalizer_fn=None):
1028 """Represents real valued or numerical features.
1030 Example:
1032 Assume we have data with two features `a` and `b`.
1034 >>> data = {'a': [15, 9, 17, 19, 21, 18, 25, 30],
1035 ... 'b': [5.0, 6.4, 10.5, 13.6, 15.7, 19.9, 20.3 , 0.0]}
1037 Let us represent the features `a` and `b` as numerical features.
1039 >>> a = tf.feature_column.numeric_column('a')
1040 >>> b = tf.feature_column.numeric_column('b')
1042 Feature column describe a set of transformations to the inputs.
1044 For example, to "bucketize" feature `a`, wrap the `a` column in a
1045 `feature_column.bucketized_column`.
1046 Providing `5` bucket boundaries, the bucketized_column api
1047 will bucket this feature in total of `6` buckets.
1049 >>> a_buckets = tf.feature_column.bucketized_column(a,
1050 ... boundaries=[10, 15, 20, 25, 30])
1052 Create a `DenseFeatures` layer which will apply the transformations
1053 described by the set of `tf.feature_column` objects:
1055 >>> feature_layer = tf.keras.layers.DenseFeatures([a_buckets, b])
1056 >>> print(feature_layer(data))
1057 tf.Tensor(
1058 [[ 0. 0. 1. 0. 0. 0. 5. ]
1059 [ 1. 0. 0. 0. 0. 0. 6.4]
1060 [ 0. 0. 1. 0. 0. 0. 10.5]
1061 [ 0. 0. 1. 0. 0. 0. 13.6]
1062 [ 0. 0. 0. 1. 0. 0. 15.7]
1063 [ 0. 0. 1. 0. 0. 0. 19.9]
1064 [ 0. 0. 0. 0. 1. 0. 20.3]
1065 [ 0. 0. 0. 0. 0. 1. 0. ]], shape=(8, 7), dtype=float32)
1067 Args:
1068 key: A unique string identifying the input feature. It is used as the column
1069 name and the dictionary key for feature parsing configs, feature `Tensor`
1070 objects, and feature columns.
1071 shape: An iterable of integers specifies the shape of the `Tensor`. An
1072 integer can be given which means a single dimension `Tensor` with given
1073 width. The `Tensor` representing the column will have the shape of
1074 [batch_size] + `shape`.
1075 default_value: A single value compatible with `dtype` or an iterable of
1076 values compatible with `dtype` which the column takes on during
1077 `tf.Example` parsing if data is missing. A default value of `None` will
1078 cause `tf.io.parse_example` to fail if an example does not contain this
1079 column. If a single value is provided, the same value will be applied as
1080 the default value for every item. If an iterable of values is provided,
1081 the shape of the `default_value` should be equal to the given `shape`.
1082 dtype: defines the type of values. Default value is `tf.float32`. Must be a
1083 non-quantized, real integer or floating point type.
1084 normalizer_fn: If not `None`, a function that can be used to normalize the
1085 value of the tensor after `default_value` is applied for parsing.
1086 Normalizer function takes the input `Tensor` as its argument, and returns
1087 the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1088 even though the most common use case of this function is normalization, it
1089 can be used for any kind of Tensorflow transformations.
1091 Returns:
1092 A `NumericColumn`.
1094 Raises:
1095 TypeError: if any dimension in shape is not an int
1096 ValueError: if any dimension in shape is not a positive integer
1097 TypeError: if `default_value` is an iterable but not compatible with `shape`
1098 TypeError: if `default_value` is not compatible with `dtype`.
1099 ValueError: if `dtype` is not convertible to `tf.float32`.
1100 """
1101 shape = _check_shape(shape, key)
1102 if not (dtype.is_integer or dtype.is_floating):
1103 raise ValueError('dtype must be convertible to float. '
1104 'dtype: {}, key: {}'.format(dtype, key))
1105 default_value = fc_utils.check_default_value(shape, default_value, dtype, key)
1107 if normalizer_fn is not None and not callable(normalizer_fn):
1108 raise TypeError(
1109 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1111 fc_utils.assert_key_is_string(key)
1112 return NumericColumn(
1113 key,
1114 shape=shape,
1115 default_value=default_value,
1116 dtype=dtype,
1117 normalizer_fn=normalizer_fn)
1120@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1121@tf_export('feature_column.bucketized_column')
1122@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1123def bucketized_column(source_column, boundaries):
1124 """Represents discretized dense input bucketed by `boundaries`.
1126 Buckets include the left boundary, and exclude the right boundary. Namely,
1127 `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1128 `[1., 2.)`, and `[2., +inf)`.
1130 For example, if the inputs are
1132 ```python
1133 boundaries = [0, 10, 100]
1134 input tensor = [[-5, 10000]
1135 [150, 10]
1136 [5, 100]]
1137 ```
1139 then the output will be
1141 ```python
1142 output = [[0, 3]
1143 [3, 2]
1144 [1, 3]]
1145 ```
1147 Example:
1149 ```python
1150 price = tf.feature_column.numeric_column('price')
1151 bucketized_price = tf.feature_column.bucketized_column(
1152 price, boundaries=[...])
1153 columns = [bucketized_price, ...]
1154 features = tf.io.parse_example(
1155 ..., features=tf.feature_column.make_parse_example_spec(columns))
1156 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1157 ```
1159 A `bucketized_column` can also be crossed with another categorical column
1160 using `crossed_column`:
1162 ```python
1163 price = tf.feature_column.numeric_column('price')
1164 # bucketized_column converts numerical feature to a categorical one.
1165 bucketized_price = tf.feature_column.bucketized_column(
1166 price, boundaries=[...])
1167 # 'keywords' is a string feature.
1168 price_x_keywords = tf.feature_column.crossed_column(
1169 [bucketized_price, 'keywords'], 50K)
1170 columns = [price_x_keywords, ...]
1171 features = tf.io.parse_example(
1172 ..., features=tf.feature_column.make_parse_example_spec(columns))
1173 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1174 linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor)
1175 ```
1177 Args:
1178 source_column: A one-dimensional dense column which is generated with
1179 `numeric_column`.
1180 boundaries: A sorted list or tuple of floats specifying the boundaries.
1182 Returns:
1183 A `BucketizedColumn`.
1185 Raises:
1186 ValueError: If `source_column` is not a numeric column, or if it is not
1187 one-dimensional.
1188 ValueError: If `boundaries` is not a sorted list or tuple.
1189 """
1190 if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access
1191 raise ValueError(
1192 'source_column must be a column generated with numeric_column(). '
1193 'Given: {}'.format(source_column))
1194 if len(source_column.shape) > 1:
1195 raise ValueError('source_column must be one-dimensional column. '
1196 'Given: {}'.format(source_column))
1197 if not boundaries:
1198 raise ValueError('boundaries must not be empty.')
1199 if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1200 raise ValueError('boundaries must be a sorted list.')
1201 for i in range(len(boundaries) - 1):
1202 if boundaries[i] >= boundaries[i + 1]:
1203 raise ValueError('boundaries must be a sorted list.')
1204 return BucketizedColumn(source_column, tuple(boundaries))
1207@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1208@tf_export('feature_column.categorical_column_with_hash_bucket')
1209@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1210def categorical_column_with_hash_bucket(key,
1211 hash_bucket_size,
1212 dtype=dtypes.string):
1213 """Represents sparse feature where ids are set by hashing.
1215 Use this when your sparse features are in string or integer format, and you
1216 want to distribute your inputs into a finite number of buckets by hashing.
1217 output_id = Hash(input_feature_string) % bucket_size for string type input.
1218 For int type input, the value is converted to its string representation first
1219 and then hashed by the same formula.
1221 For input dictionary `features`, `features[key]` is either `Tensor` or
1222 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1223 and `''` for string, which will be dropped by this feature column.
1225 Example:
1227 ```python
1228 import tensorflow as tf
1229 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1230 10000)
1231 columns = [keywords]
1232 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1233 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1234 'LSTM', 'Keras', 'RNN']])}
1235 linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features,
1236 columns)
1238 # or
1239 import tensorflow as tf
1240 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1241 10000)
1242 keywords_embedded = tf.feature_column.embedding_column(keywords, 16)
1243 columns = [keywords_embedded]
1244 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1245 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1246 'LSTM', 'Keras', 'RNN']])}
1247 input_layer = tf.keras.layers.DenseFeatures(columns)
1248 dense_tensor = input_layer(features)
1249 ```
1251 Args:
1252 key: A unique string identifying the input feature. It is used as the column
1253 name and the dictionary key for feature parsing configs, feature `Tensor`
1254 objects, and feature columns.
1255 hash_bucket_size: An int > 1. The number of buckets.
1256 dtype: The type of features. Only string and integer types are supported.
1258 Returns:
1259 A `HashedCategoricalColumn`.
1261 Raises:
1262 ValueError: `hash_bucket_size` is not greater than 1.
1263 ValueError: `dtype` is neither string nor integer.
1264 """
1265 if hash_bucket_size is None:
1266 raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1268 if hash_bucket_size < 1:
1269 raise ValueError('hash_bucket_size must be at least 1. '
1270 'hash_bucket_size: {}, key: {}'.format(
1271 hash_bucket_size, key))
1273 fc_utils.assert_key_is_string(key)
1274 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1276 return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1279@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1280@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1281@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1282def categorical_column_with_vocabulary_file(key,
1283 vocabulary_file,
1284 vocabulary_size=None,
1285 num_oov_buckets=0,
1286 default_value=None,
1287 dtype=dtypes.string):
1288 """A `CategoricalColumn` with a vocabulary file.
1290 Use this when your inputs are in string or integer format, and you have a
1291 vocabulary file that maps each value to an integer ID. By default,
1292 out-of-vocabulary values are ignored. Use either (but not both) of
1293 `num_oov_buckets` and `default_value` to specify how to include
1294 out-of-vocabulary values.
1296 For input dictionary `features`, `features[key]` is either `Tensor` or
1297 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1298 and `''` for string, which will be dropped by this feature column.
1300 Example with `num_oov_buckets`:
1301 File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1302 abbreviation. All inputs with values in that file are assigned an ID 0-49,
1303 corresponding to its line number. All other values are hashed and assigned an
1304 ID 50-54.
1306 ```python
1307 import tensorflow as tf
1308 states = tf.feature_column.categorical_column_with_vocabulary_file(
1309 key='states', vocabulary_file='states.txt', vocabulary_size=5,
1310 num_oov_buckets=1)
1311 columns = [states]
1312 features = {'states':tf.constant([['california', 'georgia', 'michigan',
1313 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1314 'texas']])}
1315 linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1316 columns)
1317 ```
1319 Example with `default_value`:
1320 File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1321 other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1322 in input, and other values missing from the file, will be assigned ID 0. All
1323 others are assigned the corresponding line number 1-50.
1325 ```python
1326 import tensorflow as tf
1327 states = tf.feature_column.categorical_column_with_vocabulary_file(
1328 key='states', vocabulary_file='states.txt', vocabulary_size=6,
1329 default_value=0)
1330 columns = [states]
1331 features = {'states':tf.constant([['california', 'georgia', 'michigan',
1332 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1333 'texas']])}
1334 linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1335 columns)
1336 ```
1338 And to make an embedding with either:
1340 ```python
1341 import tensorflow as tf
1342 states = tf.feature_column.categorical_column_with_vocabulary_file(
1343 key='states', vocabulary_file='states.txt', vocabulary_size=5,
1344 num_oov_buckets=1)
1345 columns = [tf.feature_column.embedding_column(states, 3)]
1346 features = {'states':tf.constant([['california', 'georgia', 'michigan',
1347 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1348 'texas']])}
1349 input_layer = tf.keras.layers.DenseFeatures(columns)
1350 dense_tensor = input_layer(features)
1351 ```
1353 Args:
1354 key: A unique string identifying the input feature. It is used as the column
1355 name and the dictionary key for feature parsing configs, feature `Tensor`
1356 objects, and feature columns.
1357 vocabulary_file: The vocabulary file name.
1358 vocabulary_size: Number of the elements in the vocabulary. This must be no
1359 greater than length of `vocabulary_file`, if less than length, later
1360 values are ignored. If None, it is set to the length of `vocabulary_file`.
1361 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1362 buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1363 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1364 the input value. A positive `num_oov_buckets` can not be specified with
1365 `default_value`.
1366 default_value: The integer ID value to return for out-of-vocabulary feature
1367 values, defaults to `-1`. This can not be specified with a positive
1368 `num_oov_buckets`.
1369 dtype: The type of features. Only string and integer types are supported.
1371 Returns:
1372 A `CategoricalColumn` with a vocabulary file.
1374 Raises:
1375 ValueError: `vocabulary_file` is missing or cannot be opened.
1376 ValueError: `vocabulary_size` is missing or < 1.
1377 ValueError: `num_oov_buckets` is a negative integer.
1378 ValueError: `num_oov_buckets` and `default_value` are both specified.
1379 ValueError: `dtype` is neither string nor integer.
1380 """
1381 return categorical_column_with_vocabulary_file_v2(key, vocabulary_file,
1382 vocabulary_size, dtype,
1383 default_value,
1384 num_oov_buckets)
1387@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1388@tf_export(
1389 'feature_column.categorical_column_with_vocabulary_file',
1390 v1=[])
1391@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1392def categorical_column_with_vocabulary_file_v2(key,
1393 vocabulary_file,
1394 vocabulary_size=None,
1395 dtype=dtypes.string,
1396 default_value=None,
1397 num_oov_buckets=0,
1398 file_format=None):
1399 """A `CategoricalColumn` with a vocabulary file.
1401 Use this when your inputs are in string or integer format, and you have a
1402 vocabulary file that maps each value to an integer ID. By default,
1403 out-of-vocabulary values are ignored. Use either (but not both) of
1404 `num_oov_buckets` and `default_value` to specify how to include
1405 out-of-vocabulary values.
1407 For input dictionary `features`, `features[key]` is either `Tensor` or
1408 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1409 and `''` for string, which will be dropped by this feature column.
1411 Example with `num_oov_buckets`:
1412 File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state
1413 abbreviation. All inputs with values in that file are assigned an ID 0-49,
1414 corresponding to its line number. All other values are hashed and assigned an
1415 ID 50-54.
1417 ```python
1418 states = categorical_column_with_vocabulary_file(
1419 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1420 num_oov_buckets=5)
1421 columns = [states, ...]
1422 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1423 linear_prediction = linear_model(features, columns)
1424 ```
1426 Example with `default_value`:
1427 File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the
1428 other 50 each have a 2-character U.S. state abbreviation. Both a literal
1429 `'XX'` in input, and other values missing from the file, will be assigned
1430 ID 0. All others are assigned the corresponding line number 1-50.
1432 ```python
1433 states = categorical_column_with_vocabulary_file(
1434 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1435 default_value=0)
1436 columns = [states, ...]
1437 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1438 linear_prediction, _, _ = linear_model(features, columns)
1439 ```
1441 And to make an embedding with either:
1443 ```python
1444 columns = [embedding_column(states, 3),...]
1445 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1446 dense_tensor = input_layer(features, columns)
1447 ```
1449 Args:
1450 key: A unique string identifying the input feature. It is used as the column
1451 name and the dictionary key for feature parsing configs, feature `Tensor`
1452 objects, and feature columns.
1453 vocabulary_file: The vocabulary file name.
1454 vocabulary_size: Number of the elements in the vocabulary. This must be no
1455 greater than length of `vocabulary_file`, if less than length, later
1456 values are ignored. If None, it is set to the length of `vocabulary_file`.
1457 dtype: The type of features. Only string and integer types are supported.
1458 default_value: The integer ID value to return for out-of-vocabulary feature
1459 values, defaults to `-1`. This can not be specified with a positive
1460 `num_oov_buckets`.
1461 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1462 buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1463 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1464 the input value. A positive `num_oov_buckets` can not be specified with
1465 `default_value`.
1466 file_format: The format of the vocabulary file. The format is 'text' by
1467 default unless `vocabulary_file` is a string which ends in 'tfrecord.gz'.
1468 Accepted alternative value for `file_format` is 'tfrecord_gzip'.
1470 Returns:
1471 A `CategoricalColumn` with a vocabulary file.
1473 Raises:
1474 ValueError: `vocabulary_file` is missing or cannot be opened.
1475 ValueError: `vocabulary_size` is missing or < 1.
1476 ValueError: `num_oov_buckets` is a negative integer.
1477 ValueError: `num_oov_buckets` and `default_value` are both specified.
1478 ValueError: `dtype` is neither string nor integer.
1479 """
1480 if not vocabulary_file:
1481 raise ValueError('Missing vocabulary_file in {}.'.format(key))
1483 if file_format is None and vocabulary_file.endswith('tfrecord.gz'):
1484 file_format = 'tfrecord_gzip'
1486 if vocabulary_size is None:
1487 if not gfile.Exists(vocabulary_file):
1488 raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1490 if file_format == 'tfrecord_gzip':
1491 ds = readers.TFRecordDataset(vocabulary_file, 'GZIP')
1492 vocabulary_size = ds.reduce(0, lambda x, _: x + 1)
1493 if context.executing_eagerly():
1494 vocabulary_size = vocabulary_size.numpy()
1495 else:
1496 with gfile.GFile(vocabulary_file, mode='rb') as f:
1497 vocabulary_size = sum(1 for _ in f)
1498 logging.info(
1499 'vocabulary_size = %d in %s is inferred from the number of elements '
1500 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1502 # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1503 if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1:
1504 raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1505 if num_oov_buckets:
1506 if default_value is not None:
1507 raise ValueError(
1508 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1509 key))
1510 if num_oov_buckets < 0:
1511 raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1512 num_oov_buckets, key))
1513 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1514 fc_utils.assert_key_is_string(key)
1515 return VocabularyFileCategoricalColumn(
1516 key=key,
1517 vocabulary_file=vocabulary_file,
1518 vocabulary_size=vocabulary_size,
1519 num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1520 default_value=-1 if default_value is None else default_value,
1521 dtype=dtype,
1522 file_format=file_format)
1525@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1526@tf_export('feature_column.categorical_column_with_vocabulary_list')
1527@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1528def categorical_column_with_vocabulary_list(key,
1529 vocabulary_list,
1530 dtype=None,
1531 default_value=-1,
1532 num_oov_buckets=0):
1533 """A `CategoricalColumn` with in-memory vocabulary.
1535 Use this when your inputs are in string or integer format, and you have an
1536 in-memory vocabulary mapping each value to an integer ID. By default,
1537 out-of-vocabulary values are ignored. Use either (but not both) of
1538 `num_oov_buckets` and `default_value` to specify how to include
1539 out-of-vocabulary values.
1541 For input dictionary `features`, `features[key]` is either `Tensor` or
1542 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1543 and `''` for string, which will be dropped by this feature column.
1545 Example with `num_oov_buckets`:
1546 In the following example, each input in `vocabulary_list` is assigned an ID
1547 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1548 inputs are hashed and assigned an ID 4-5.
1550 ```python
1551 colors = categorical_column_with_vocabulary_list(
1552 key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1553 num_oov_buckets=2)
1554 columns = [colors, ...]
1555 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1556 linear_prediction, _, _ = linear_model(features, columns)
1557 ```
1559 Example with `default_value`:
1560 In the following example, each input in `vocabulary_list` is assigned an ID
1561 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1562 inputs are assigned `default_value` 0.
1565 ```python
1566 colors = categorical_column_with_vocabulary_list(
1567 key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1568 columns = [colors, ...]
1569 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1570 linear_prediction, _, _ = linear_model(features, columns)
1571 ```
1573 And to make an embedding with either:
1575 ```python
1576 columns = [embedding_column(colors, 3),...]
1577 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1578 dense_tensor = input_layer(features, columns)
1579 ```
1581 Args:
1582 key: A unique string identifying the input feature. It is used as the column
1583 name and the dictionary key for feature parsing configs, feature `Tensor`
1584 objects, and feature columns.
1585 vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1586 is mapped to the index of its value (if present) in `vocabulary_list`.
1587 Must be castable to `dtype`.
1588 dtype: The type of features. Only string and integer types are supported. If
1589 `None`, it will be inferred from `vocabulary_list`.
1590 default_value: The integer ID value to return for out-of-vocabulary feature
1591 values, defaults to `-1`. This can not be specified with a positive
1592 `num_oov_buckets`.
1593 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1594 buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1595 `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1596 hash of the input value. A positive `num_oov_buckets` can not be specified
1597 with `default_value`.
1599 Returns:
1600 A `CategoricalColumn` with in-memory vocabulary.
1602 Raises:
1603 ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1604 ValueError: `num_oov_buckets` is a negative integer.
1605 ValueError: `num_oov_buckets` and `default_value` are both specified.
1606 ValueError: if `dtype` is not integer or string.
1607 """
1608 if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1609 raise ValueError(
1610 'vocabulary_list {} must be non-empty, column_name: {}'.format(
1611 vocabulary_list, key))
1612 if len(set(vocabulary_list)) != len(vocabulary_list):
1613 raise ValueError(
1614 'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1615 vocabulary_list, key))
1616 vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1617 if num_oov_buckets:
1618 if default_value != -1:
1619 raise ValueError(
1620 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1621 key))
1622 if num_oov_buckets < 0:
1623 raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1624 num_oov_buckets, key))
1625 fc_utils.assert_string_or_int(
1626 vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1627 if dtype is None:
1628 dtype = vocabulary_dtype
1629 elif dtype.is_integer != vocabulary_dtype.is_integer:
1630 raise ValueError(
1631 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1632 dtype, vocabulary_dtype, key))
1633 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1634 fc_utils.assert_key_is_string(key)
1636 return VocabularyListCategoricalColumn(
1637 key=key,
1638 vocabulary_list=tuple(vocabulary_list),
1639 dtype=dtype,
1640 default_value=default_value,
1641 num_oov_buckets=num_oov_buckets)
1644@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1645@tf_export('feature_column.categorical_column_with_identity')
1646@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1647def categorical_column_with_identity(key, num_buckets, default_value=None):
1648 """A `CategoricalColumn` that returns identity values.
1650 Use this when your inputs are integers in the range `[0, num_buckets)`, and
1651 you want to use the input value itself as the categorical ID. Values outside
1652 this range will result in `default_value` if specified, otherwise it will
1653 fail.
1655 Typically, this is used for contiguous ranges of integer indexes, but
1656 it doesn't have to be. This might be inefficient, however, if many of IDs
1657 are unused. Consider `categorical_column_with_hash_bucket` in that case.
1659 For input dictionary `features`, `features[key]` is either `Tensor` or
1660 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1661 and `''` for string, which will be dropped by this feature column.
1663 In the following examples, each input in the range `[0, 1000000)` is assigned
1664 the same value. All other inputs are assigned `default_value` 0. Note that a
1665 literal 0 in inputs will result in the same default ID.
1667 Linear model:
1669 ```python
1670 import tensorflow as tf
1671 video_id = tf.feature_column.categorical_column_with_identity(
1672 key='video_id', num_buckets=1000000, default_value=0)
1673 columns = [video_id]
1674 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1675 [33,78, 2, 73, 1]])}
1676 linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1677 columns)
1678 ```
1680 Embedding for a DNN model:
1682 ```python
1683 import tensorflow as tf
1684 video_id = tf.feature_column.categorical_column_with_identity(
1685 key='video_id', num_buckets=1000000, default_value=0)
1686 columns = [tf.feature_column.embedding_column(video_id, 9)]
1687 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1688 [33,78, 2, 73, 1]])}
1689 input_layer = tf.keras.layers.DenseFeatures(columns)
1690 dense_tensor = input_layer(features)
1691 ```
1693 Args:
1694 key: A unique string identifying the input feature. It is used as the column
1695 name and the dictionary key for feature parsing configs, feature `Tensor`
1696 objects, and feature columns.
1697 num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1698 default_value: If set, values outside of range `[0, num_buckets)` will be
1699 replaced with this value. If not set, values >= num_buckets will cause a
1700 failure while values < 0 will be dropped.
1702 Returns:
1703 A `CategoricalColumn` that returns identity values.
1705 Raises:
1706 ValueError: if `num_buckets` is less than one.
1707 ValueError: if `default_value` is not in range `[0, num_buckets)`.
1708 """
1709 if num_buckets < 1:
1710 raise ValueError('num_buckets {} < 1, column_name {}'.format(
1711 num_buckets, key))
1712 if (default_value is not None) and ((default_value < 0) or
1713 (default_value >= num_buckets)):
1714 raise ValueError(
1715 'default_value {} not in range [0, {}), column_name {}'.format(
1716 default_value, num_buckets, key))
1717 fc_utils.assert_key_is_string(key)
1718 return IdentityCategoricalColumn(
1719 key=key, number_buckets=num_buckets, default_value=default_value)
1722@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1723@tf_export('feature_column.indicator_column')
1724@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1725def indicator_column(categorical_column):
1726 """Represents multi-hot representation of given categorical column.
1728 - For DNN model, `indicator_column` can be used to wrap any
1729 `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1730 `embedding_column` if the number of buckets/unique(values) are large.
1732 - For Wide (aka linear) model, `indicator_column` is the internal
1733 representation for categorical column when passing categorical column
1734 directly (as any element in feature_columns) to `linear_model`. See
1735 `linear_model` for details.
1737 ```python
1738 name = indicator_column(categorical_column_with_vocabulary_list(
1739 'name', ['bob', 'george', 'wanda']))
1740 columns = [name, ...]
1741 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1742 dense_tensor = input_layer(features, columns)
1744 dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"]
1745 dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"]
1746 dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"]
1747 ```
1749 Args:
1750 categorical_column: A `CategoricalColumn` which is created by
1751 `categorical_column_with_*` or `crossed_column` functions.
1753 Returns:
1754 An `IndicatorColumn`.
1756 Raises:
1757 ValueError: If `categorical_column` is not CategoricalColumn type.
1758 """
1759 if not isinstance(categorical_column,
1760 (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access
1761 raise ValueError(
1762 'Unsupported input type. Input must be a CategoricalColumn. '
1763 'Given: {}'.format(categorical_column))
1764 return IndicatorColumn(categorical_column)
1767@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1768@tf_export('feature_column.weighted_categorical_column')
1769@deprecation.deprecated(None, _FEATURE_COLUMN_DEPRECATION_RUNTIME_WARNING)
1770def weighted_categorical_column(categorical_column,
1771 weight_feature_key,
1772 dtype=dtypes.float32):
1773 """Applies weight values to a `CategoricalColumn`.
1775 Use this when each of your sparse inputs has both an ID and a value. For
1776 example, if you're representing text documents as a collection of word
1777 frequencies, you can provide 2 parallel sparse input features ('terms' and
1778 'frequencies' below).
1780 Example:
1782 Input `tf.Example` objects:
1784 ```proto
1785 [
1786 features {
1787 feature {
1788 key: "terms"
1789 value {bytes_list {value: "very" value: "model"}}
1790 }
1791 feature {
1792 key: "frequencies"
1793 value {float_list {value: 0.3 value: 0.1}}
1794 }
1795 },
1796 features {
1797 feature {
1798 key: "terms"
1799 value {bytes_list {value: "when" value: "course" value: "human"}}
1800 }
1801 feature {
1802 key: "frequencies"
1803 value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1804 }
1805 }
1806 ]
1807 ```
1809 ```python
1810 categorical_column = categorical_column_with_hash_bucket(
1811 column_name='terms', hash_bucket_size=1000)
1812 weighted_column = weighted_categorical_column(
1813 categorical_column=categorical_column, weight_feature_key='frequencies')
1814 columns = [weighted_column, ...]
1815 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1816 linear_prediction, _, _ = linear_model(features, columns)
1817 ```
1819 This assumes the input dictionary contains a `SparseTensor` for key
1820 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1821 the same indices and dense shape.
1823 Args:
1824 categorical_column: A `CategoricalColumn` created by
1825 `categorical_column_with_*` functions.
1826 weight_feature_key: String key for weight values.
1827 dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1828 are supported.
1830 Returns:
1831 A `CategoricalColumn` composed of two sparse features: one represents id,
1832 the other represents weight (value) of the id feature in that example.
1834 Raises:
1835 ValueError: if `dtype` is not convertible to float.
1836 """
1837 if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1838 raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1839 return WeightedCategoricalColumn(
1840 categorical_column=categorical_column,
1841 weight_feature_key=weight_feature_key,
1842 dtype=dtype)
1845@doc_controls.header(_FEATURE_COLUMN_DEPRECATION_WARNING)
1846@tf_export('feature_column.crossed_column')
1847@deprecation.deprecated(
1848 None,
1849 'Use `tf.keras.layers.experimental.preprocessing.HashedCrossing` '
1850 'instead for feature crossing when preprocessing data to train a '
1851 'Keras model.')
1852def crossed_column(keys, hash_bucket_size, hash_key=None):
1853 """Returns a column for performing crosses of categorical features.
1855 Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1856 the transformation can be thought of as:
1857 Hash(cartesian product of features) % `hash_bucket_size`
1859 For example, if the input features are:
1861 * SparseTensor referred by first key:
1863 ```python
1864 shape = [2, 2]
1865 {
1866 [0, 0]: "a"
1867 [1, 0]: "b"
1868 [1, 1]: "c"
1869 }
1870 ```
1872 * SparseTensor referred by second key:
1874 ```python
1875 shape = [2, 1]
1876 {
1877 [0, 0]: "d"
1878 [1, 0]: "e"
1879 }
1880 ```
1882 then crossed feature will look like:
1884 ```python
1885 shape = [2, 2]
1886 {
1887 [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
1888 [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
1889 [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
1890 }
1891 ```
1893 Here is an example to create a linear model with crosses of string features:
1895 ```python
1896 keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
1897 columns = [keywords_x_doc_terms, ...]
1898 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1899 linear_prediction = linear_model(features, columns)
1900 ```
1902 You could also use vocabulary lookup before crossing:
1904 ```python
1905 keywords = categorical_column_with_vocabulary_file(
1906 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
1907 keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
1908 columns = [keywords_x_doc_terms, ...]
1909 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1910 linear_prediction = linear_model(features, columns)
1911 ```
1913 If an input feature is of numeric type, you can use
1914 `categorical_column_with_identity`, or `bucketized_column`, as in the example:
1916 ```python
1917 # vertical_id is an integer categorical feature.
1918 vertical_id = categorical_column_with_identity('vertical_id', 10K)
1919 price = numeric_column('price')
1920 # bucketized_column converts numerical feature to a categorical one.
1921 bucketized_price = bucketized_column(price, boundaries=[...])
1922 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1923 columns = [vertical_id_x_price, ...]
1924 features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1925 linear_prediction = linear_model(features, columns)
1926 ```
1928 To use crossed column in DNN model, you need to add it in an embedding column
1929 as in this example:
1931 ```python
1932 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1933 vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
1934 dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
1935 ```
1937 Args:
1938 keys: An iterable identifying the features to be crossed. Each element can
1939 be either:
1940 * string: Will use the corresponding feature which must be of string type.
1941 * `CategoricalColumn`: Will use the transformed tensor produced by this
1942 column. Does not support hashed categorical column.
1943 hash_bucket_size: An int > 1. The number of buckets.
1944 hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
1945 function to combine the crosses fingerprints on SparseCrossOp (optional).
1947 Returns:
1948 A `CrossedColumn`.
1950 Raises:
1951 ValueError: If `len(keys) < 2`.
1952 ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
1953 ValueError: If any of the keys is `HashedCategoricalColumn`.
1954 ValueError: If `hash_bucket_size < 1`.
1955 """
1956 if not hash_bucket_size or hash_bucket_size < 1:
1957 raise ValueError('hash_bucket_size must be > 1. '
1958 'hash_bucket_size: {}'.format(hash_bucket_size))
1959 if not keys or len(keys) < 2:
1960 raise ValueError(
1961 'keys must be a list with length > 1. Given: {}'.format(keys))
1962 for key in keys:
1963 if (not isinstance(key, six.string_types) and
1964 not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access
1965 raise ValueError(
1966 'Unsupported key type. All keys must be either string, or '
1967 'categorical column except HashedCategoricalColumn. '
1968 'Given: {}'.format(key))
1969 if isinstance(key,
1970 (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access
1971 raise ValueError(
1972 'categorical_column_with_hash_bucket is not supported for crossing. '
1973 'Hashing before crossing will increase probability of collision. '
1974 'Instead, use the feature name as a string. Given: {}'.format(key))
1975 return CrossedColumn(
1976 keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
1981# TODO(b/181853833): Add a tf.type for instance type checking.
1982@tf_export('__internal__.feature_column.FeatureColumn', v1=[])
1983@six.add_metaclass(abc.ABCMeta)
1984class FeatureColumn(object):
1985 """Represents a feature column abstraction.
1987 WARNING: Do not subclass this layer unless you know what you are doing:
1988 the API is subject to future changes.
1990 To distinguish between the concept of a feature family and a specific binary
1991 feature within a family, we refer to a feature family like "country" as a
1992 feature column. For example, we can have a feature in a `tf.Example` format:
1993 {key: "country", value: [ "US" ]}
1994 In this example the value of feature is "US" and "country" refers to the
1995 column of the feature.
1997 This class is an abstract class. Users should not create instances of this.
1998 """
2000 @abc.abstractproperty
2001 def name(self):
2002 """Returns string. Used for naming."""
2003 pass
2005 def __lt__(self, other):
2006 """Allows feature columns to be sorted in Python 3 as they are in Python 2.
2008 Feature columns need to occasionally be sortable, for example when used as
2009 keys in a features dictionary passed to a layer.
2011 In CPython, `__lt__` must be defined for all objects in the
2012 sequence being sorted.
2014 If any objects in the sequence being sorted do not have an `__lt__` method
2015 compatible with feature column objects (such as strings), then CPython will
2016 fall back to using the `__gt__` method below.
2017 https://docs.python.org/3/library/stdtypes.html#list.sort
2019 Args:
2020 other: The other object to compare to.
2022 Returns:
2023 True if the string representation of this object is lexicographically less
2024 than the string representation of `other`. For FeatureColumn objects,
2025 this looks like "<__main__.FeatureColumn object at 0xa>".
2026 """
2027 return str(self) < str(other)
2029 def __gt__(self, other):
2030 """Allows feature columns to be sorted in Python 3 as they are in Python 2.
2032 Feature columns need to occasionally be sortable, for example when used as
2033 keys in a features dictionary passed to a layer.
2035 `__gt__` is called when the "other" object being compared during the sort
2036 does not have `__lt__` defined.
2037 Example:
2038 ```
2039 # __lt__ only class
2040 class A():
2041 def __lt__(self, other): return str(self) < str(other)
2043 a = A()
2044 a < "b" # True
2045 "0" < a # Error
2047 # __lt__ and __gt__ class
2048 class B():
2049 def __lt__(self, other): return str(self) < str(other)
2050 def __gt__(self, other): return str(self) > str(other)
2052 b = B()
2053 b < "c" # True
2054 "0" < b # True
2055 ```
2057 Args:
2058 other: The other object to compare to.
2060 Returns:
2061 True if the string representation of this object is lexicographically
2062 greater than the string representation of `other`. For FeatureColumn
2063 objects, this looks like "<__main__.FeatureColumn object at 0xa>".
2064 """
2065 return str(self) > str(other)
2067 @abc.abstractmethod
2068 def transform_feature(self, transformation_cache, state_manager):
2069 """Returns intermediate representation (usually a `Tensor`).
2071 Uses `transformation_cache` to create an intermediate representation
2072 (usually a `Tensor`) that other feature columns can use.
2074 Example usage of `transformation_cache`:
2075 Let's say a Feature column depends on raw feature ('raw') and another
2076 `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
2077 transformation_cache will be used as follows:
2079 ```python
2080 raw_tensor = transformation_cache.get('raw', state_manager)
2081 fc_tensor = transformation_cache.get(input_fc, state_manager)
2082 ```
2084 Args:
2085 transformation_cache: A `FeatureTransformationCache` object to access
2086 features.
2087 state_manager: A `StateManager` to create / access resources such as
2088 lookup tables.
2090 Returns:
2091 Transformed feature `Tensor`.
2092 """
2093 pass
2095 @abc.abstractproperty
2096 def parse_example_spec(self):
2097 """Returns a `tf.Example` parsing spec as dict.
2099 It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
2100 a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
2101 supported objects. Please check documentation of `tf.io.parse_example` for
2102 all supported spec objects.
2104 Let's say a Feature column depends on raw feature ('raw') and another
2105 `FeatureColumn` (input_fc). One possible implementation of
2106 parse_example_spec is as follows:
2108 ```python
2109 spec = {'raw': tf.io.FixedLenFeature(...)}
2110 spec.update(input_fc.parse_example_spec)
2111 return spec
2112 ```
2113 """
2114 pass
2116 def create_state(self, state_manager):
2117 """Uses the `state_manager` to create state for the FeatureColumn.
2119 Args:
2120 state_manager: A `StateManager` to create / access resources such as
2121 lookup tables and variables.
2122 """
2123 pass
2125 @abc.abstractproperty
2126 def _is_v2_column(self):
2127 """Returns whether this FeatureColumn is fully conformant to the new API.
2129 This is needed for composition type cases where an EmbeddingColumn etc.
2130 might take in old categorical columns as input and then we want to use the
2131 old API.
2132 """
2133 pass
2135 @abc.abstractproperty
2136 def parents(self):
2137 """Returns a list of immediate raw feature and FeatureColumn dependencies.
2139 For example:
2140 # For the following feature columns
2141 a = numeric_column('f1')
2142 c = crossed_column(a, 'f2')
2143 # The expected parents are:
2144 a.parents = ['f1']
2145 c.parents = [a, 'f2']
2146 """
2147 pass
2149 def get_config(self):
2150 """Returns the config of the feature column.
2152 A FeatureColumn config is a Python dictionary (serializable) containing the
2153 configuration of a FeatureColumn. The same FeatureColumn can be
2154 reinstantiated later from this configuration.
2156 The config of a feature column does not include information about feature
2157 columns depending on it nor the FeatureColumn class name.
2159 Example with (de)serialization practices followed in this file:
2160 ```python
2161 class SerializationExampleFeatureColumn(
2162 FeatureColumn, collections.namedtuple(
2163 'SerializationExampleFeatureColumn',
2164 ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2166 def get_config(self):
2167 # Create a dict from the namedtuple.
2168 # Python attribute literals can be directly copied from / to the config.
2169 # For example 'dimension', assuming it is an integer literal.
2170 config = dict(zip(self._fields, self))
2172 # (De)serialization of parent FeatureColumns should use the provided
2173 # (de)serialize_feature_column() methods that take care of de-duping.
2174 config['parent'] = serialize_feature_column(self.parent)
2176 # Many objects provide custom (de)serialization e.g: for tf.DType
2177 # tf.DType.name, tf.as_dtype() can be used.
2178 config['dtype'] = self.dtype.name
2180 # Non-trivial dependencies should be Keras-(de)serializable.
2181 config['normalizer_fn'] = generic_utils.serialize_keras_object(
2182 self.normalizer_fn)
2184 return config
2186 @classmethod
2187 def from_config(cls, config, custom_objects=None, columns_by_name=None):
2188 # This should do the inverse transform from `get_config` and construct
2189 # the namedtuple.
2190 kwargs = config.copy()
2191 kwargs['parent'] = deserialize_feature_column(
2192 config['parent'], custom_objects, columns_by_name)
2193 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2194 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2195 config['normalizer_fn'], custom_objects=custom_objects)
2196 return cls(**kwargs)
2198 ```
2199 Returns:
2200 A serializable Dict that can be used to deserialize the object with
2201 from_config.
2202 """
2203 return self._get_config()
2205 def _get_config(self):
2206 raise NotImplementedError('Must be implemented in subclasses.')
2208 @classmethod
2209 def from_config(cls, config, custom_objects=None, columns_by_name=None):
2210 """Creates a FeatureColumn from its config.
2212 This method should be the reverse of `get_config`, capable of instantiating
2213 the same FeatureColumn from the config dictionary. See `get_config` for an
2214 example of common (de)serialization practices followed in this file.
2216 TODO(b/118939620): This is a private method until consensus is reached on
2217 supporting object deserialization deduping within Keras.
2219 Args:
2220 config: A Dict config acquired with `get_config`.
2221 custom_objects: Optional dictionary mapping names (strings) to custom
2222 classes or functions to be considered during deserialization.
2223 columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2224 order to avoid duplication. Should be passed to any calls to
2225 deserialize_feature_column().
2227 Returns:
2228 A FeatureColumn for the input config.
2229 """
2230 return cls._from_config(config, custom_objects, columns_by_name)
2232 @classmethod
2233 def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2234 raise NotImplementedError('Must be implemented in subclasses.')
2237# TODO(b/181853833): Add a tf.type for instance type checking.
2238@tf_export('__internal__.feature_column.DenseColumn', v1=[])
2239class DenseColumn(FeatureColumn):
2240 """Represents a column which can be represented as `Tensor`.
2242 Some examples of this type are: numeric_column, embedding_column,
2243 indicator_column.
2244 """
2246 @abc.abstractproperty
2247 def variable_shape(self):
2248 """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2249 pass
2251 @abc.abstractmethod
2252 def get_dense_tensor(self, transformation_cache, state_manager):
2253 """Returns a `Tensor`.
2255 The output of this function will be used by model-builder-functions. For
2256 example the pseudo code of `input_layer` will be like:
2258 ```python
2259 def input_layer(features, feature_columns, ...):
2260 outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2261 return tf.concat(outputs)
2262 ```
2264 Args:
2265 transformation_cache: A `FeatureTransformationCache` object to access
2266 features.
2267 state_manager: A `StateManager` to create / access resources such as
2268 lookup tables.
2270 Returns:
2271 `Tensor` of shape [batch_size] + `variable_shape`.
2272 """
2273 pass
2276def is_feature_column_v2(feature_columns):
2277 """Returns True if all feature columns are V2."""
2278 for feature_column in feature_columns:
2279 if not isinstance(feature_column, FeatureColumn):
2280 return False
2281 if not feature_column._is_v2_column: # pylint: disable=protected-access
2282 return False
2283 return True
2286def _create_weighted_sum(column, transformation_cache, state_manager,
2287 sparse_combiner, weight_var):
2288 """Creates a weighted sum for a dense/categorical column for linear_model."""
2289 if isinstance(column, CategoricalColumn):
2290 return _create_categorical_column_weighted_sum(
2291 column=column,
2292 transformation_cache=transformation_cache,
2293 state_manager=state_manager,
2294 sparse_combiner=sparse_combiner,
2295 weight_var=weight_var)
2296 else:
2297 return _create_dense_column_weighted_sum(
2298 column=column,
2299 transformation_cache=transformation_cache,
2300 state_manager=state_manager,
2301 weight_var=weight_var)
2304def _create_dense_column_weighted_sum(column, transformation_cache,
2305 state_manager, weight_var):
2306 """Create a weighted sum of a dense column for linear_model."""
2307 tensor = column.get_dense_tensor(transformation_cache, state_manager)
2308 num_elements = column.variable_shape.num_elements()
2309 batch_size = array_ops.shape(tensor)[0]
2310 tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2311 return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2314class CategoricalColumn(FeatureColumn):
2315 """Represents a categorical feature.
2317 A categorical feature typically handled with a `tf.sparse.SparseTensor` of
2318 IDs.
2319 """
2321 IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name
2322 'IdWeightPair', ('id_tensor', 'weight_tensor'))
2324 @abc.abstractproperty
2325 def num_buckets(self):
2326 """Returns number of buckets in this sparse feature."""
2327 pass
2329 @abc.abstractmethod
2330 def get_sparse_tensors(self, transformation_cache, state_manager):
2331 """Returns an IdWeightPair.
2333 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2334 weights.
2336 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2337 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2338 `SparseTensor` of `float` or `None` to indicate all weights should be
2339 taken to be 1. If specified, `weight_tensor` must have exactly the same
2340 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2341 output of a `VarLenFeature` which is a ragged matrix.
2343 Args:
2344 transformation_cache: A `FeatureTransformationCache` object to access
2345 features.
2346 state_manager: A `StateManager` to create / access resources such as
2347 lookup tables.
2348 """
2349 pass
2352def _create_categorical_column_weighted_sum(column, transformation_cache,
2353 state_manager, sparse_combiner,
2354 weight_var):
2355 # pylint: disable=g-doc-return-or-yield,g-doc-args
2356 """Create a weighted sum of a categorical column for linear_model.
2358 Note to maintainer: As implementation details, the weighted sum is
2359 implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2360 they are the same.
2362 To be specific, conceptually, categorical column can be treated as multi-hot
2363 vector. Say:
2365 ```python
2366 x = [0 0 1] # categorical column input
2367 w = [a b c] # weights
2368 ```
2369 The weighted sum is `c` in this case, which is same as `w[2]`.
2371 Another example is
2373 ```python
2374 x = [0 1 1] # categorical column input
2375 w = [a b c] # weights
2376 ```
2377 The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2379 For both cases, we can implement weighted sum via embedding_lookup with
2380 sparse_combiner = "sum".
2381 """
2383 sparse_tensors = column.get_sparse_tensors(transformation_cache,
2384 state_manager)
2385 id_tensor = sparse_ops.sparse_reshape(
2386 sparse_tensors.id_tensor,
2387 [array_ops.shape(sparse_tensors.id_tensor)[0], -1])
2388 weight_tensor = sparse_tensors.weight_tensor
2389 if weight_tensor is not None:
2390 weight_tensor = sparse_ops.sparse_reshape(
2391 weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2393 return embedding_ops.safe_embedding_lookup_sparse(
2394 weight_var,
2395 id_tensor,
2396 sparse_weights=weight_tensor,
2397 combiner=sparse_combiner,
2398 name='weighted_sum')
2401# TODO(b/181853833): Add a tf.type for instance type checking.
2402@tf_export('__internal__.feature_column.SequenceDenseColumn', v1=[])
2403class SequenceDenseColumn(FeatureColumn):
2404 """Represents dense sequence data."""
2406 TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
2407 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2409 @abc.abstractmethod
2410 def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2411 """Returns a `TensorSequenceLengthPair`.
2413 Args:
2414 transformation_cache: A `FeatureTransformationCache` object to access
2415 features.
2416 state_manager: A `StateManager` to create / access resources such as
2417 lookup tables.
2418 """
2419 pass
2422@tf_export('__internal__.feature_column.FeatureTransformationCache', v1=[])
2423class FeatureTransformationCache(object):
2424 """Handles caching of transformations while building the model.
2426 `FeatureColumn` specifies how to digest an input column to the network. Some
2427 feature columns require data transformations. This class caches those
2428 transformations.
2430 Some features may be used in more than one place. For example, one can use a
2431 bucketized feature by itself and a cross with it. In that case we
2432 should create only one bucketization op instead of creating ops for each
2433 feature column separately. To handle re-use of transformed columns,
2434 `FeatureTransformationCache` caches all previously transformed columns.
2436 Example:
2437 We're trying to use the following `FeatureColumn`s:
2439 ```python
2440 bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2441 keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2442 age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2443 ... = linear_model(features,
2444 [bucketized_age, keywords, age_X_keywords]
2445 ```
2447 If we transform each column independently, then we'll get duplication of
2448 bucketization (one for cross, one for bucketization itself).
2449 The `FeatureTransformationCache` eliminates this duplication.
2450 """
2452 def __init__(self, features):
2453 """Creates a `FeatureTransformationCache`.
2455 Args:
2456 features: A mapping from feature column to objects that are `Tensor` or
2457 `SparseTensor`, or can be converted to same via
2458 `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2459 signifies a base feature (not-transformed). A `FeatureColumn` key means
2460 that this `Tensor` is the output of an existing `FeatureColumn` which
2461 can be reused.
2462 """
2463 self._features = features.copy()
2464 self._feature_tensors = {}
2466 def get(self, key, state_manager, training=None):
2467 """Returns a `Tensor` for the given key.
2469 A `str` key is used to access a base feature (not-transformed). When a
2470 `FeatureColumn` is passed, the transformed feature is returned if it
2471 already exists, otherwise the given `FeatureColumn` is asked to provide its
2472 transformed output, which is then cached.
2474 Args:
2475 key: a `str` or a `FeatureColumn`.
2476 state_manager: A StateManager object that holds the FeatureColumn state.
2477 training: Boolean indicating whether to the column is being used in
2478 training mode. This argument is passed to the transform_feature method
2479 of any `FeatureColumn` that takes a `training` argument. For example, if
2480 a `FeatureColumn` performed dropout, it could expose a `training`
2481 argument to control whether the dropout should be applied.
2483 Returns:
2484 The transformed `Tensor` corresponding to the `key`.
2486 Raises:
2487 ValueError: if key is not found or a transformed `Tensor` cannot be
2488 computed.
2489 """
2490 if key in self._feature_tensors:
2491 # FeatureColumn is already transformed or converted.
2492 return self._feature_tensors[key]
2494 if key in self._features:
2495 feature_tensor = self._get_raw_feature_as_tensor(key)
2496 self._feature_tensors[key] = feature_tensor
2497 return feature_tensor
2499 if isinstance(key, six.string_types):
2500 raise ValueError('Feature {} is not in features dictionary.'.format(key))
2502 if not isinstance(key, FeatureColumn):
2503 raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2504 'Provided: {}'.format(key))
2506 column = key
2507 logging.debug('Transforming feature_column %s.', column)
2509 # Some columns may need information about whether the transformation is
2510 # happening in training or prediction mode, but not all columns expose this
2511 # argument.
2512 try:
2513 transformed = column.transform_feature(
2514 self, state_manager, training=training)
2515 except TypeError:
2516 transformed = column.transform_feature(self, state_manager)
2517 if transformed is None:
2518 raise ValueError('Column {} is not supported.'.format(column.name))
2519 self._feature_tensors[column] = transformed
2520 return transformed
2522 def _get_raw_feature_as_tensor(self, key):
2523 """Gets the raw_feature (keyed by `key`) as `tensor`.
2525 The raw feature is converted to (sparse) tensor and maybe expand dim.
2527 For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2528 the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2529 error out as it is not supported.
2531 Args:
2532 key: A `str` key to access the raw feature.
2534 Returns:
2535 A `Tensor` or `SparseTensor`.
2537 Raises:
2538 ValueError: if the raw feature has rank 0.
2539 """
2540 raw_feature = self._features[key]
2541 feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2542 raw_feature)
2544 def expand_dims(input_tensor):
2545 # Input_tensor must have rank 1.
2546 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2547 return sparse_ops.sparse_reshape(input_tensor,
2548 [array_ops.shape(input_tensor)[0], 1])
2549 else:
2550 return array_ops.expand_dims(input_tensor, -1)
2552 rank = feature_tensor.get_shape().ndims
2553 if rank is not None:
2554 if rank == 0:
2555 raise ValueError(
2556 'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2557 key, feature_tensor))
2558 return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2560 # Handle dynamic rank.
2561 with ops.control_dependencies([
2562 check_ops.assert_positive(
2563 array_ops.rank(feature_tensor),
2564 message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2565 key, feature_tensor))
2566 ]):
2567 return cond.cond(
2568 math_ops.equal(1, array_ops.rank(feature_tensor)),
2569 lambda: expand_dims(feature_tensor), lambda: feature_tensor)
2572# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2573def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2574 """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2576 If `input_tensor` is already a `SparseTensor`, just return it.
2578 Args:
2579 input_tensor: A string or integer `Tensor`.
2580 ignore_value: Entries in `dense_tensor` equal to this value will be absent
2581 from the resulting `SparseTensor`. If `None`, default value of
2582 `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2584 Returns:
2585 A `SparseTensor` with the same shape as `input_tensor`.
2587 Raises:
2588 ValueError: when `input_tensor`'s rank is `None`.
2589 """
2590 input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2591 input_tensor)
2592 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2593 return input_tensor
2594 with ops.name_scope(None, 'to_sparse_input', (
2595 input_tensor,
2596 ignore_value,
2597 )):
2598 if ignore_value is None:
2599 if input_tensor.dtype == dtypes.string:
2600 # Exception due to TF strings are converted to numpy objects by default.
2601 ignore_value = ''
2602 elif input_tensor.dtype.is_integer:
2603 ignore_value = -1 # -1 has a special meaning of missing feature
2604 else:
2605 # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2606 # constructing a new numpy object of the given type, which yields the
2607 # default value for that type.
2608 ignore_value = input_tensor.dtype.as_numpy_dtype()
2609 ignore_value = math_ops.cast(
2610 ignore_value, input_tensor.dtype, name='ignore_value')
2611 indices = array_ops.where_v2(
2612 math_ops.not_equal(input_tensor, ignore_value), name='indices')
2613 return sparse_tensor_lib.SparseTensor(
2614 indices=indices,
2615 values=array_ops.gather_nd(input_tensor, indices, name='values'),
2616 dense_shape=array_ops.shape(
2617 input_tensor, out_type=dtypes.int64, name='dense_shape'))
2620def _normalize_feature_columns(feature_columns):
2621 """Normalizes the `feature_columns` input.
2623 This method converts the `feature_columns` to list type as best as it can. In
2624 addition, verifies the type and other parts of feature_columns, required by
2625 downstream library.
2627 Args:
2628 feature_columns: The raw feature columns, usually passed by users.
2630 Returns:
2631 The normalized feature column list.
2633 Raises:
2634 ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2635 """
2636 if isinstance(feature_columns, FeatureColumn):
2637 feature_columns = [feature_columns]
2639 if isinstance(feature_columns, collections_abc.Iterator):
2640 feature_columns = list(feature_columns)
2642 if isinstance(feature_columns, dict):
2643 raise ValueError('Expected feature_columns to be iterable, found dict.')
2645 for column in feature_columns:
2646 if not isinstance(column, FeatureColumn):
2647 raise ValueError('Items of feature_columns must be a FeatureColumn. '
2648 'Given (type {}): {}.'.format(type(column), column))
2649 if not feature_columns:
2650 raise ValueError('feature_columns must not be empty.')
2651 name_to_column = {}
2652 for column in feature_columns:
2653 if column.name in name_to_column:
2654 raise ValueError('Duplicate feature column name found for columns: {} '
2655 'and {}. This usually means that these columns refer to '
2656 'same base feature. Either one must be discarded or a '
2657 'duplicated but renamed item must be inserted in '
2658 'features dict.'.format(column,
2659 name_to_column[column.name]))
2660 name_to_column[column.name] = column
2662 return sorted(feature_columns, key=lambda x: x.name)
2665class NumericColumn(
2666 DenseColumn,
2667 fc_old._DenseColumn, # pylint: disable=protected-access
2668 collections.namedtuple(
2669 'NumericColumn',
2670 ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2671 """see `numeric_column`."""
2673 @property
2674 def _is_v2_column(self):
2675 return True
2677 @property
2678 def name(self):
2679 """See `FeatureColumn` base class."""
2680 return self.key
2682 @property
2683 def parse_example_spec(self):
2684 """See `FeatureColumn` base class."""
2685 return {
2686 self.key:
2687 parsing_ops.FixedLenFeature(self.shape, self.dtype,
2688 self.default_value)
2689 }
2691 @property
2692 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2693 _FEATURE_COLUMN_DEPRECATION)
2694 def _parse_example_spec(self):
2695 return self.parse_example_spec
2697 def _transform_input_tensor(self, input_tensor):
2698 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2699 raise ValueError(
2700 'The corresponding Tensor of numerical column must be a Tensor. '
2701 'SparseTensor is not supported. key: {}'.format(self.key))
2702 if self.normalizer_fn is not None:
2703 input_tensor = self.normalizer_fn(input_tensor)
2704 return math_ops.cast(input_tensor, dtypes.float32)
2706 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2707 _FEATURE_COLUMN_DEPRECATION)
2708 def _transform_feature(self, inputs):
2709 input_tensor = inputs.get(self.key)
2710 return self._transform_input_tensor(input_tensor)
2712 def transform_feature(self, transformation_cache, state_manager):
2713 """See `FeatureColumn` base class.
2715 In this case, we apply the `normalizer_fn` to the input tensor.
2717 Args:
2718 transformation_cache: A `FeatureTransformationCache` object to access
2719 features.
2720 state_manager: A `StateManager` to create / access resources such as
2721 lookup tables.
2723 Returns:
2724 Normalized input tensor.
2725 Raises:
2726 ValueError: If a SparseTensor is passed in.
2727 """
2728 input_tensor = transformation_cache.get(self.key, state_manager)
2729 return self._transform_input_tensor(input_tensor)
2731 @property
2732 def variable_shape(self):
2733 """See `DenseColumn` base class."""
2734 return tensor_shape.TensorShape(self.shape)
2736 @property
2737 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2738 _FEATURE_COLUMN_DEPRECATION)
2739 def _variable_shape(self):
2740 return self.variable_shape
2742 def get_dense_tensor(self, transformation_cache, state_manager):
2743 """Returns dense `Tensor` representing numeric feature.
2745 Args:
2746 transformation_cache: A `FeatureTransformationCache` object to access
2747 features.
2748 state_manager: A `StateManager` to create / access resources such as
2749 lookup tables.
2751 Returns:
2752 Dense `Tensor` created within `transform_feature`.
2753 """
2754 # Feature has been already transformed. Return the intermediate
2755 # representation created by _transform_feature.
2756 return transformation_cache.get(self, state_manager)
2758 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2759 _FEATURE_COLUMN_DEPRECATION)
2760 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2761 del weight_collections
2762 del trainable
2763 return inputs.get(self)
2765 @property
2766 def parents(self):
2767 """See 'FeatureColumn` base class."""
2768 return [self.key]
2770 def get_config(self):
2771 """See 'FeatureColumn` base class."""
2772 config = dict(zip(self._fields, self))
2773 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
2774 config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access
2775 self.normalizer_fn)
2776 config['dtype'] = self.dtype.name
2777 return config
2779 @classmethod
2780 def from_config(cls, config, custom_objects=None, columns_by_name=None):
2781 """See 'FeatureColumn` base class."""
2782 _check_config_keys(config, cls._fields)
2783 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
2784 kwargs = _standardize_and_copy_config(config)
2785 kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
2786 config['normalizer_fn'],
2787 custom_objects=custom_objects)
2788 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2790 return cls(**kwargs)
2793class BucketizedColumn(
2794 DenseColumn,
2795 CategoricalColumn,
2796 fc_old._DenseColumn, # pylint: disable=protected-access
2797 fc_old._CategoricalColumn, # pylint: disable=protected-access
2798 collections.namedtuple('BucketizedColumn',
2799 ('source_column', 'boundaries'))):
2800 """See `bucketized_column`."""
2802 @property
2803 def _is_v2_column(self):
2804 return (isinstance(self.source_column, FeatureColumn) and
2805 self.source_column._is_v2_column) # pylint: disable=protected-access
2807 @property
2808 def name(self):
2809 """See `FeatureColumn` base class."""
2810 return '{}_bucketized'.format(self.source_column.name)
2812 @property
2813 def parse_example_spec(self):
2814 """See `FeatureColumn` base class."""
2815 return self.source_column.parse_example_spec
2817 @property
2818 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2819 _FEATURE_COLUMN_DEPRECATION)
2820 def _parse_example_spec(self):
2821 return self.source_column._parse_example_spec # pylint: disable=protected-access
2823 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2824 _FEATURE_COLUMN_DEPRECATION)
2825 def _transform_feature(self, inputs):
2826 """Returns bucketized categorical `source_column` tensor."""
2827 source_tensor = inputs.get(self.source_column)
2828 return math_ops._bucketize( # pylint: disable=protected-access
2829 source_tensor,
2830 boundaries=self.boundaries)
2832 def transform_feature(self, transformation_cache, state_manager):
2833 """Returns bucketized categorical `source_column` tensor."""
2834 source_tensor = transformation_cache.get(self.source_column, state_manager)
2835 return math_ops._bucketize( # pylint: disable=protected-access
2836 source_tensor,
2837 boundaries=self.boundaries)
2839 @property
2840 def variable_shape(self):
2841 """See `DenseColumn` base class."""
2842 return tensor_shape.TensorShape(
2843 tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2845 @property
2846 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2847 _FEATURE_COLUMN_DEPRECATION)
2848 def _variable_shape(self):
2849 return self.variable_shape
2851 def _get_dense_tensor_for_input_tensor(self, input_tensor):
2852 return array_ops.one_hot(
2853 indices=math_ops.cast(input_tensor, dtypes.int64),
2854 depth=len(self.boundaries) + 1,
2855 on_value=1.,
2856 off_value=0.)
2858 def get_dense_tensor(self, transformation_cache, state_manager):
2859 """Returns one hot encoded dense `Tensor`."""
2860 input_tensor = transformation_cache.get(self, state_manager)
2861 return self._get_dense_tensor_for_input_tensor(input_tensor)
2863 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2864 _FEATURE_COLUMN_DEPRECATION)
2865 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2866 del weight_collections
2867 del trainable
2868 input_tensor = inputs.get(self)
2869 return self._get_dense_tensor_for_input_tensor(input_tensor)
2871 @property
2872 def num_buckets(self):
2873 """See `CategoricalColumn` base class."""
2874 # By construction, source_column is always one-dimensional.
2875 return (len(self.boundaries) + 1) * self.source_column.shape[0]
2877 @property
2878 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2879 _FEATURE_COLUMN_DEPRECATION)
2880 def _num_buckets(self):
2881 return self.num_buckets
2883 def _get_sparse_tensors_for_input_tensor(self, input_tensor):
2884 batch_size = array_ops.shape(input_tensor)[0]
2885 # By construction, source_column is always one-dimensional.
2886 source_dimension = self.source_column.shape[0]
2888 i1 = array_ops.reshape(
2889 array_ops.tile(
2890 array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2891 [1, source_dimension]), (-1,))
2892 i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2893 # Flatten the bucket indices and unique them across dimensions
2894 # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2895 bucket_indices = (
2896 array_ops.reshape(input_tensor,
2897 (-1,)) + (len(self.boundaries) + 1) * i2)
2899 indices = math_ops.cast(
2900 array_ops.transpose(array_ops_stack.stack((i1, i2))), dtypes.int64)
2901 dense_shape = math_ops.cast(
2902 array_ops_stack.stack([batch_size, source_dimension]), dtypes.int64)
2903 sparse_tensor = sparse_tensor_lib.SparseTensor(
2904 indices=indices, values=bucket_indices, dense_shape=dense_shape)
2905 return CategoricalColumn.IdWeightPair(sparse_tensor, None)
2907 def get_sparse_tensors(self, transformation_cache, state_manager):
2908 """Converts dense inputs to SparseTensor so downstream code can use it."""
2909 input_tensor = transformation_cache.get(self, state_manager)
2910 return self._get_sparse_tensors_for_input_tensor(input_tensor)
2912 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2913 _FEATURE_COLUMN_DEPRECATION)
2914 def _get_sparse_tensors(self,
2915 inputs,
2916 weight_collections=None,
2917 trainable=None):
2918 """Converts dense inputs to SparseTensor so downstream code can use it."""
2919 del weight_collections
2920 del trainable
2921 input_tensor = inputs.get(self)
2922 return self._get_sparse_tensors_for_input_tensor(input_tensor)
2924 @property
2925 def parents(self):
2926 """See 'FeatureColumn` base class."""
2927 return [self.source_column]
2929 def get_config(self):
2930 """See 'FeatureColumn` base class."""
2931 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
2932 config = dict(zip(self._fields, self))
2933 config['source_column'] = serialize_feature_column(self.source_column)
2934 return config
2936 @classmethod
2937 def from_config(cls, config, custom_objects=None, columns_by_name=None):
2938 """See 'FeatureColumn` base class."""
2939 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
2940 _check_config_keys(config, cls._fields)
2941 kwargs = _standardize_and_copy_config(config)
2942 kwargs['source_column'] = deserialize_feature_column(
2943 config['source_column'], custom_objects, columns_by_name)
2944 return cls(**kwargs)
2947class EmbeddingColumn(
2948 DenseColumn,
2949 SequenceDenseColumn,
2950 fc_old._DenseColumn, # pylint: disable=protected-access
2951 fc_old._SequenceDenseColumn, # pylint: disable=protected-access
2952 collections.namedtuple(
2953 'EmbeddingColumn',
2954 ('categorical_column', 'dimension', 'combiner', 'initializer',
2955 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
2956 'use_safe_embedding_lookup'))):
2957 """See `embedding_column`."""
2959 def __new__(cls,
2960 categorical_column,
2961 dimension,
2962 combiner,
2963 initializer,
2964 ckpt_to_load_from,
2965 tensor_name_in_ckpt,
2966 max_norm,
2967 trainable,
2968 use_safe_embedding_lookup=True):
2969 return super(EmbeddingColumn, cls).__new__(
2970 cls,
2971 categorical_column=categorical_column,
2972 dimension=dimension,
2973 combiner=combiner,
2974 initializer=initializer,
2975 ckpt_to_load_from=ckpt_to_load_from,
2976 tensor_name_in_ckpt=tensor_name_in_ckpt,
2977 max_norm=max_norm,
2978 trainable=trainable,
2979 use_safe_embedding_lookup=use_safe_embedding_lookup)
2981 @property
2982 def _is_v2_column(self):
2983 return (isinstance(self.categorical_column, FeatureColumn) and
2984 self.categorical_column._is_v2_column) # pylint: disable=protected-access
2986 @property
2987 def name(self):
2988 """See `FeatureColumn` base class."""
2989 return '{}_embedding'.format(self.categorical_column.name)
2991 @property
2992 def parse_example_spec(self):
2993 """See `FeatureColumn` base class."""
2994 return self.categorical_column.parse_example_spec
2996 @property
2997 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2998 _FEATURE_COLUMN_DEPRECATION)
2999 def _parse_example_spec(self):
3000 return self.categorical_column._parse_example_spec # pylint: disable=protected-access
3002 def transform_feature(self, transformation_cache, state_manager):
3003 """Transforms underlying `categorical_column`."""
3004 return transformation_cache.get(self.categorical_column, state_manager)
3006 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3007 _FEATURE_COLUMN_DEPRECATION)
3008 def _transform_feature(self, inputs):
3009 return inputs.get(self.categorical_column)
3011 @property
3012 def variable_shape(self):
3013 """See `DenseColumn` base class."""
3014 return tensor_shape.TensorShape([self.dimension])
3016 @property
3017 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3018 _FEATURE_COLUMN_DEPRECATION)
3019 def _variable_shape(self):
3020 return self.variable_shape
3022 def create_state(self, state_manager):
3023 """Creates the embedding lookup variable."""
3024 default_num_buckets = (
3025 self.categorical_column.num_buckets
3026 if self._is_v2_column else self.categorical_column._num_buckets) # pylint: disable=protected-access
3027 num_buckets = getattr(self.categorical_column, 'num_buckets',
3028 default_num_buckets)
3029 embedding_shape = (num_buckets, self.dimension)
3030 state_manager.create_variable(
3031 self,
3032 name='embedding_weights',
3033 shape=embedding_shape,
3034 dtype=dtypes.float32,
3035 trainable=self.trainable,
3036 use_resource=True,
3037 initializer=self.initializer)
3039 def _get_dense_tensor_internal_helper(self, sparse_tensors,
3040 embedding_weights):
3041 sparse_ids = sparse_tensors.id_tensor
3042 sparse_weights = sparse_tensors.weight_tensor
3044 if self.ckpt_to_load_from is not None:
3045 to_restore = embedding_weights
3046 if isinstance(to_restore, variables.PartitionedVariable):
3047 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
3048 checkpoint_utils.init_from_checkpoint(
3049 self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})
3051 sparse_id_rank = tensor_shape.dimension_value(
3052 sparse_ids.dense_shape.get_shape()[0])
3053 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3054 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3055 sparse_id_rank <= 2):
3056 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
3057 # Return embedding lookup result.
3058 return embedding_lookup_sparse(
3059 embedding_weights,
3060 sparse_ids,
3061 sparse_weights,
3062 combiner=self.combiner,
3063 name='%s_weights' % self.name,
3064 max_norm=self.max_norm)
3066 def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
3067 """Private method that follows the signature of get_dense_tensor."""
3068 embedding_weights = state_manager.get_variable(
3069 self, name='embedding_weights')
3070 return self._get_dense_tensor_internal_helper(sparse_tensors,
3071 embedding_weights)
3073 def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
3074 trainable):
3075 """Private method that follows the signature of _get_dense_tensor."""
3076 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
3077 if (weight_collections and
3078 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
3079 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
3080 embedding_weights = variable_scope.get_variable(
3081 name='embedding_weights',
3082 shape=embedding_shape,
3083 dtype=dtypes.float32,
3084 initializer=self.initializer,
3085 trainable=self.trainable and trainable,
3086 collections=weight_collections)
3087 return self._get_dense_tensor_internal_helper(sparse_tensors,
3088 embedding_weights)
3090 def get_dense_tensor(self, transformation_cache, state_manager):
3091 """Returns tensor after doing the embedding lookup.
3093 Args:
3094 transformation_cache: A `FeatureTransformationCache` object to access
3095 features.
3096 state_manager: A `StateManager` to create / access resources such as
3097 lookup tables.
3099 Returns:
3100 Embedding lookup tensor.
3102 Raises:
3103 ValueError: `categorical_column` is SequenceCategoricalColumn.
3104 """
3105 if isinstance(self.categorical_column, SequenceCategoricalColumn):
3106 raise ValueError(
3107 'In embedding_column: {}. '
3108 'categorical_column must not be of type SequenceCategoricalColumn. '
3109 'Suggested fix A: If you wish to use DenseFeatures, use a '
3110 'non-sequence categorical_column_with_*. '
3111 'Suggested fix B: If you wish to create sequence input, use '
3112 'SequenceFeatures instead of DenseFeatures. '
3113 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3114 self.categorical_column))
3115 # Get sparse IDs and weights.
3116 sparse_tensors = self.categorical_column.get_sparse_tensors(
3117 transformation_cache, state_manager)
3118 return self._get_dense_tensor_internal(sparse_tensors, state_manager)
3120 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3121 _FEATURE_COLUMN_DEPRECATION)
3122 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3123 if isinstance(
3124 self.categorical_column,
3125 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
3126 raise ValueError(
3127 'In embedding_column: {}. '
3128 'categorical_column must not be of type _SequenceCategoricalColumn. '
3129 'Suggested fix A: If you wish to use DenseFeatures, use a '
3130 'non-sequence categorical_column_with_*. '
3131 'Suggested fix B: If you wish to create sequence input, use '
3132 'SequenceFeatures instead of DenseFeatures. '
3133 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3134 self.categorical_column))
3135 sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
3136 inputs, weight_collections, trainable)
3137 return self._old_get_dense_tensor_internal(sparse_tensors,
3138 weight_collections, trainable)
3140 def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3141 """See `SequenceDenseColumn` base class."""
3142 if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3143 raise ValueError(
3144 'In embedding_column: {}. '
3145 'categorical_column must be of type SequenceCategoricalColumn '
3146 'to use SequenceFeatures. '
3147 'Suggested fix: Use one of sequence_categorical_column_with_*. '
3148 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3149 self.categorical_column))
3150 sparse_tensors = self.categorical_column.get_sparse_tensors(
3151 transformation_cache, state_manager)
3152 dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3153 state_manager)
3154 sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3155 sparse_tensors.id_tensor)
3156 return SequenceDenseColumn.TensorSequenceLengthPair(
3157 dense_tensor=dense_tensor, sequence_length=sequence_length)
3159 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3160 _FEATURE_COLUMN_DEPRECATION)
3161 def _get_sequence_dense_tensor(self,
3162 inputs,
3163 weight_collections=None,
3164 trainable=None):
3165 if not isinstance(
3166 self.categorical_column,
3167 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
3168 raise ValueError(
3169 'In embedding_column: {}. '
3170 'categorical_column must be of type SequenceCategoricalColumn '
3171 'to use SequenceFeatures. '
3172 'Suggested fix: Use one of sequence_categorical_column_with_*. '
3173 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3174 self.categorical_column))
3175 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
3176 dense_tensor = self._old_get_dense_tensor_internal(
3177 sparse_tensors,
3178 weight_collections=weight_collections,
3179 trainable=trainable)
3180 sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3181 sparse_tensors.id_tensor)
3182 return SequenceDenseColumn.TensorSequenceLengthPair(
3183 dense_tensor=dense_tensor, sequence_length=sequence_length)
3185 @property
3186 def parents(self):
3187 """See 'FeatureColumn` base class."""
3188 return [self.categorical_column]
3190 def get_config(self):
3191 """See 'FeatureColumn` base class."""
3192 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
3193 config = dict(zip(self._fields, self))
3194 config['categorical_column'] = serialization.serialize_feature_column(
3195 self.categorical_column)
3196 config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access
3197 self.initializer)
3198 return config
3200 @classmethod
3201 def from_config(cls, config, custom_objects=None, columns_by_name=None):
3202 """See 'FeatureColumn` base class."""
3203 if 'use_safe_embedding_lookup' not in config:
3204 config['use_safe_embedding_lookup'] = True
3205 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
3206 _check_config_keys(config, cls._fields)
3207 kwargs = _standardize_and_copy_config(config)
3208 kwargs['categorical_column'] = serialization.deserialize_feature_column(
3209 config['categorical_column'], custom_objects, columns_by_name)
3210 all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
3211 kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
3212 config['initializer'],
3213 module_objects=all_initializers,
3214 custom_objects=custom_objects)
3215 return cls(**kwargs)
3218def _raise_shared_embedding_column_error():
3219 raise ValueError('SharedEmbeddingColumns are not supported in '
3220 '`linear_model` or `input_layer`. Please use '
3221 '`DenseFeatures` or `LinearModel` instead.')
3224class SharedEmbeddingColumnCreator(autotrackable.AutoTrackable):
3225 """Class that creates a `SharedEmbeddingColumn`."""
3227 def __init__(self,
3228 dimension,
3229 initializer,
3230 ckpt_to_load_from,
3231 tensor_name_in_ckpt,
3232 num_buckets,
3233 trainable,
3234 name='shared_embedding_column_creator',
3235 use_safe_embedding_lookup=True):
3236 self._dimension = dimension
3237 self._initializer = initializer
3238 self._ckpt_to_load_from = ckpt_to_load_from
3239 self._tensor_name_in_ckpt = tensor_name_in_ckpt
3240 self._num_buckets = num_buckets
3241 self._trainable = trainable
3242 self._name = name
3243 self._use_safe_embedding_lookup = use_safe_embedding_lookup
3244 # Map from graph keys to embedding_weight variables.
3245 self._embedding_weights = {}
3247 def __call__(self, categorical_column, combiner, max_norm):
3248 return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm,
3249 self._use_safe_embedding_lookup)
3251 @property
3252 def embedding_weights(self):
3253 key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
3254 if key not in self._embedding_weights:
3255 embedding_shape = (self._num_buckets, self._dimension)
3256 var = variable_scope.get_variable(
3257 name=self._name,
3258 shape=embedding_shape,
3259 dtype=dtypes.float32,
3260 initializer=self._initializer,
3261 trainable=self._trainable)
3263 if self._ckpt_to_load_from is not None:
3264 to_restore = var
3265 if isinstance(to_restore, variables.PartitionedVariable):
3266 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
3267 checkpoint_utils.init_from_checkpoint(
3268 self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3269 self._embedding_weights[key] = var
3270 return self._embedding_weights[key]
3272 @property
3273 def dimension(self):
3274 return self._dimension
3277class SharedEmbeddingColumn(
3278 DenseColumn,
3279 SequenceDenseColumn,
3280 fc_old._DenseColumn, # pylint: disable=protected-access
3281 fc_old._SequenceDenseColumn, # pylint: disable=protected-access
3282 collections.namedtuple(
3283 'SharedEmbeddingColumn',
3284 ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3285 'max_norm', 'use_safe_embedding_lookup'))):
3286 """See `embedding_column`."""
3288 def __new__(cls,
3289 categorical_column,
3290 shared_embedding_column_creator,
3291 combiner,
3292 max_norm,
3293 use_safe_embedding_lookup=True):
3294 return super(SharedEmbeddingColumn, cls).__new__(
3295 cls,
3296 categorical_column=categorical_column,
3297 shared_embedding_column_creator=shared_embedding_column_creator,
3298 combiner=combiner,
3299 max_norm=max_norm,
3300 use_safe_embedding_lookup=use_safe_embedding_lookup)
3302 @property
3303 def _is_v2_column(self):
3304 return True
3306 @property
3307 def name(self):
3308 """See `FeatureColumn` base class."""
3309 return '{}_shared_embedding'.format(self.categorical_column.name)
3311 @property
3312 def parse_example_spec(self):
3313 """See `FeatureColumn` base class."""
3314 return self.categorical_column.parse_example_spec
3316 @property
3317 def _parse_example_spec(self):
3318 return _raise_shared_embedding_column_error()
3320 def transform_feature(self, transformation_cache, state_manager):
3321 """See `FeatureColumn` base class."""
3322 return transformation_cache.get(self.categorical_column, state_manager)
3324 def _transform_feature(self, inputs):
3325 return _raise_shared_embedding_column_error()
3327 @property
3328 def variable_shape(self):
3329 """See `DenseColumn` base class."""
3330 return tensor_shape.TensorShape(
3331 [self.shared_embedding_column_creator.dimension])
3333 @property
3334 def _variable_shape(self):
3335 return _raise_shared_embedding_column_error()
3337 def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3338 """Private method that follows the signature of _get_dense_tensor."""
3339 # This method is called from a variable_scope with name _var_scope_name,
3340 # which is shared among all shared embeddings. Open a name_scope here, so
3341 # that the ops for different columns have distinct names.
3342 with ops.name_scope(None, default_name=self.name):
3343 # Get sparse IDs and weights.
3344 sparse_tensors = self.categorical_column.get_sparse_tensors(
3345 transformation_cache, state_manager)
3346 sparse_ids = sparse_tensors.id_tensor
3347 sparse_weights = sparse_tensors.weight_tensor
3349 embedding_weights = self.shared_embedding_column_creator.embedding_weights
3351 sparse_id_rank = tensor_shape.dimension_value(
3352 sparse_ids.dense_shape.get_shape()[0])
3353 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3354 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3355 sparse_id_rank <= 2):
3356 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
3357 # Return embedding lookup result.
3358 return embedding_lookup_sparse(
3359 embedding_weights,
3360 sparse_ids,
3361 sparse_weights,
3362 combiner=self.combiner,
3363 name='%s_weights' % self.name,
3364 max_norm=self.max_norm)
3366 def get_dense_tensor(self, transformation_cache, state_manager):
3367 """Returns the embedding lookup result."""
3368 if isinstance(self.categorical_column, SequenceCategoricalColumn):
3369 raise ValueError(
3370 'In embedding_column: {}. '
3371 'categorical_column must not be of type SequenceCategoricalColumn. '
3372 'Suggested fix A: If you wish to use DenseFeatures, use a '
3373 'non-sequence categorical_column_with_*. '
3374 'Suggested fix B: If you wish to create sequence input, use '
3375 'SequenceFeatures instead of DenseFeatures. '
3376 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3377 self.categorical_column))
3378 return self._get_dense_tensor_internal(transformation_cache, state_manager)
3380 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3381 return _raise_shared_embedding_column_error()
3383 def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3384 """See `SequenceDenseColumn` base class."""
3385 if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3386 raise ValueError(
3387 'In embedding_column: {}. '
3388 'categorical_column must be of type SequenceCategoricalColumn '
3389 'to use SequenceFeatures. '
3390 'Suggested fix: Use one of sequence_categorical_column_with_*. '
3391 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3392 self.categorical_column))
3393 dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3394 state_manager)
3395 sparse_tensors = self.categorical_column.get_sparse_tensors(
3396 transformation_cache, state_manager)
3397 sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3398 sparse_tensors.id_tensor)
3399 return SequenceDenseColumn.TensorSequenceLengthPair(
3400 dense_tensor=dense_tensor, sequence_length=sequence_length)
3402 def _get_sequence_dense_tensor(self,
3403 inputs,
3404 weight_collections=None,
3405 trainable=None):
3406 return _raise_shared_embedding_column_error()
3408 @property
3409 def parents(self):
3410 """See 'FeatureColumn` base class."""
3411 return [self.categorical_column]
3414def _check_shape(shape, key):
3415 """Returns shape if it's valid, raises error otherwise."""
3416 assert shape is not None
3417 if not nest.is_nested(shape):
3418 shape = [shape]
3419 shape = tuple(shape)
3420 for dimension in shape:
3421 if not isinstance(dimension, int):
3422 raise TypeError('shape dimensions must be integer. '
3423 'shape: {}, key: {}'.format(shape, key))
3424 if dimension < 1:
3425 raise ValueError('shape dimensions must be greater than 0. '
3426 'shape: {}, key: {}'.format(shape, key))
3427 return shape
3430class HashedCategoricalColumn(
3431 CategoricalColumn,
3432 fc_old._CategoricalColumn, # pylint: disable=protected-access
3433 collections.namedtuple('HashedCategoricalColumn',
3434 ('key', 'hash_bucket_size', 'dtype'))):
3435 """see `categorical_column_with_hash_bucket`."""
3437 @property
3438 def _is_v2_column(self):
3439 return True
3441 @property
3442 def name(self):
3443 """See `FeatureColumn` base class."""
3444 return self.key
3446 @property
3447 def parse_example_spec(self):
3448 """See `FeatureColumn` base class."""
3449 return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3451 @property
3452 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3453 _FEATURE_COLUMN_DEPRECATION)
3454 def _parse_example_spec(self):
3455 return self.parse_example_spec
3457 def _transform_input_tensor(self, input_tensor):
3458 """Hashes the values in the feature_column."""
3459 if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3460 raise ValueError('SparseColumn input must be a SparseTensor.')
3462 fc_utils.assert_string_or_int(
3463 input_tensor.dtype,
3464 prefix='column_name: {} input_tensor'.format(self.key))
3466 if self.dtype.is_integer != input_tensor.dtype.is_integer:
3467 raise ValueError(
3468 'Column dtype and SparseTensors dtype must be compatible. '
3469 'key: {}, column dtype: {}, tensor dtype: {}'.format(
3470 self.key, self.dtype, input_tensor.dtype))
3472 if self.dtype == dtypes.string:
3473 sparse_values = input_tensor.values
3474 else:
3475 sparse_values = string_ops.as_string(input_tensor.values)
3477 sparse_id_values = string_ops.string_to_hash_bucket_fast(
3478 sparse_values, self.hash_bucket_size, name='lookup')
3479 return sparse_tensor_lib.SparseTensor(input_tensor.indices,
3480 sparse_id_values,
3481 input_tensor.dense_shape)
3483 def transform_feature(self, transformation_cache, state_manager):
3484 """Hashes the values in the feature_column."""
3485 input_tensor = _to_sparse_input_and_drop_ignore_values(
3486 transformation_cache.get(self.key, state_manager))
3487 return self._transform_input_tensor(input_tensor)
3489 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3490 _FEATURE_COLUMN_DEPRECATION)
3491 def _transform_feature(self, inputs):
3492 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3493 return self._transform_input_tensor(input_tensor)
3495 @property
3496 def num_buckets(self):
3497 """Returns number of buckets in this sparse feature."""
3498 return self.hash_bucket_size
3500 @property
3501 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3502 _FEATURE_COLUMN_DEPRECATION)
3503 def _num_buckets(self):
3504 return self.num_buckets
3506 def get_sparse_tensors(self, transformation_cache, state_manager):
3507 """See `CategoricalColumn` base class."""
3508 return CategoricalColumn.IdWeightPair(
3509 transformation_cache.get(self, state_manager), None)
3511 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3512 _FEATURE_COLUMN_DEPRECATION)
3513 def _get_sparse_tensors(self,
3514 inputs,
3515 weight_collections=None,
3516 trainable=None):
3517 del weight_collections
3518 del trainable
3519 return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3521 @property
3522 def parents(self):
3523 """See 'FeatureColumn` base class."""
3524 return [self.key]
3526 def get_config(self):
3527 """See 'FeatureColumn` base class."""
3528 config = dict(zip(self._fields, self))
3529 config['dtype'] = self.dtype.name
3530 return config
3532 @classmethod
3533 def from_config(cls, config, custom_objects=None, columns_by_name=None):
3534 """See 'FeatureColumn` base class."""
3535 _check_config_keys(config, cls._fields)
3536 kwargs = _standardize_and_copy_config(config)
3537 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3538 return cls(**kwargs)
3541class VocabularyFileCategoricalColumn(
3542 CategoricalColumn,
3543 fc_old._CategoricalColumn, # pylint: disable=protected-access
3544 collections.namedtuple(
3545 'VocabularyFileCategoricalColumn',
3546 ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets',
3547 'dtype', 'default_value', 'file_format'))):
3548 """See `categorical_column_with_vocabulary_file`."""
3550 def __new__(cls,
3551 key,
3552 vocabulary_file,
3553 vocabulary_size,
3554 num_oov_buckets,
3555 dtype,
3556 default_value,
3557 file_format=None):
3558 return super(VocabularyFileCategoricalColumn, cls).__new__(
3559 cls,
3560 key=key,
3561 vocabulary_file=vocabulary_file,
3562 vocabulary_size=vocabulary_size,
3563 num_oov_buckets=num_oov_buckets,
3564 dtype=dtype,
3565 default_value=default_value,
3566 file_format=file_format)
3568 @property
3569 def _is_v2_column(self):
3570 return True
3572 @property
3573 def name(self):
3574 """See `FeatureColumn` base class."""
3575 return self.key
3577 @property
3578 def parse_example_spec(self):
3579 """See `FeatureColumn` base class."""
3580 return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3582 @property
3583 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3584 _FEATURE_COLUMN_DEPRECATION)
3585 def _parse_example_spec(self):
3586 return self.parse_example_spec
3588 def _make_table_from_tfrecord_gzip_file(self, key_dtype, name):
3589 dataset = readers.TFRecordDataset(
3590 self.vocabulary_file, compression_type='GZIP')
3592 def key_dtype_fn(key):
3593 return key if key_dtype is dtypes.string else string_ops.string_to_number(
3594 key, out_type=key_dtype)
3596 return data_lookup_ops.index_table_from_dataset(
3597 dataset.map(key_dtype_fn),
3598 num_oov_buckets=self.num_oov_buckets,
3599 vocab_size=self.vocabulary_size,
3600 default_value=self.default_value,
3601 key_dtype=key_dtype,
3602 name=name)
3604 def _make_table(self, key_dtype, state_manager):
3605 name = '{}_lookup'.format(self.key)
3606 if state_manager is None or not state_manager.has_resource(self, name):
3607 with ops.init_scope():
3608 if self.file_format == 'tfrecord_gzip':
3609 table = self._make_table_from_tfrecord_gzip_file(key_dtype, name)
3610 else:
3611 table = lookup_ops.index_table_from_file(
3612 vocabulary_file=self.vocabulary_file,
3613 num_oov_buckets=self.num_oov_buckets,
3614 vocab_size=self.vocabulary_size,
3615 default_value=self.default_value,
3616 key_dtype=key_dtype,
3617 name=name)
3618 if state_manager is not None:
3619 state_manager.add_resource(self, name, table)
3620 else:
3621 # Reuse the table from the previous run.
3622 table = state_manager.get_resource(self, name)
3623 return table
3625 def _transform_input_tensor(self, input_tensor, state_manager=None):
3626 """Creates a lookup table for the vocabulary."""
3627 if self.dtype.is_integer != input_tensor.dtype.is_integer:
3628 raise ValueError(
3629 'Column dtype and SparseTensors dtype must be compatible. '
3630 'key: {}, column dtype: {}, tensor dtype: {}'.format(
3631 self.key, self.dtype, input_tensor.dtype))
3633 fc_utils.assert_string_or_int(
3634 input_tensor.dtype,
3635 prefix='column_name: {} input_tensor'.format(self.key))
3637 key_dtype = self.dtype
3638 if input_tensor.dtype.is_integer:
3639 # `index_table_from_file` requires 64-bit integer keys.
3640 key_dtype = dtypes.int64
3641 input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3642 return self._make_table(key_dtype, state_manager).lookup(input_tensor)
3644 def transform_feature(self, transformation_cache, state_manager):
3645 """Creates a lookup table for the vocabulary."""
3646 input_tensor = _to_sparse_input_and_drop_ignore_values(
3647 transformation_cache.get(self.key, state_manager))
3648 return self._transform_input_tensor(input_tensor, state_manager)
3650 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3651 _FEATURE_COLUMN_DEPRECATION)
3652 def _transform_feature(self, inputs):
3653 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3654 return self._transform_input_tensor(input_tensor)
3656 @property
3657 def num_buckets(self):
3658 """Returns number of buckets in this sparse feature."""
3659 return self.vocabulary_size + self.num_oov_buckets
3661 @property
3662 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3663 _FEATURE_COLUMN_DEPRECATION)
3664 def _num_buckets(self):
3665 return self.num_buckets
3667 def get_sparse_tensors(self, transformation_cache, state_manager):
3668 """See `CategoricalColumn` base class."""
3669 return CategoricalColumn.IdWeightPair(
3670 transformation_cache.get(self, state_manager), None)
3672 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3673 _FEATURE_COLUMN_DEPRECATION)
3674 def _get_sparse_tensors(self,
3675 inputs,
3676 weight_collections=None,
3677 trainable=None):
3678 del weight_collections
3679 del trainable
3680 return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3682 @property
3683 def parents(self):
3684 """See 'FeatureColumn` base class."""
3685 return [self.key]
3687 def get_config(self):
3688 """See 'FeatureColumn` base class."""
3689 config = dict(zip(self._fields, self))
3690 config['dtype'] = self.dtype.name
3691 return config
3693 @classmethod
3694 def from_config(cls, config, custom_objects=None, columns_by_name=None):
3695 """See 'FeatureColumn` base class."""
3696 _check_config_keys(config, cls._fields)
3697 kwargs = _standardize_and_copy_config(config)
3698 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3699 return cls(**kwargs)
3702class VocabularyListCategoricalColumn(
3703 CategoricalColumn,
3704 fc_old._CategoricalColumn, # pylint: disable=protected-access
3705 collections.namedtuple(
3706 'VocabularyListCategoricalColumn',
3707 ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3708):
3709 """See `categorical_column_with_vocabulary_list`."""
3711 @property
3712 def _is_v2_column(self):
3713 return True
3715 @property
3716 def name(self):
3717 """See `FeatureColumn` base class."""
3718 return self.key
3720 @property
3721 def parse_example_spec(self):
3722 """See `FeatureColumn` base class."""
3723 return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3725 @property
3726 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3727 _FEATURE_COLUMN_DEPRECATION)
3728 def _parse_example_spec(self):
3729 return self.parse_example_spec
3731 def _transform_input_tensor(self, input_tensor, state_manager=None):
3732 """Creates a lookup table for the vocabulary list."""
3733 if self.dtype.is_integer != input_tensor.dtype.is_integer:
3734 raise ValueError(
3735 'Column dtype and SparseTensors dtype must be compatible. '
3736 'key: {}, column dtype: {}, tensor dtype: {}'.format(
3737 self.key, self.dtype, input_tensor.dtype))
3739 fc_utils.assert_string_or_int(
3740 input_tensor.dtype,
3741 prefix='column_name: {} input_tensor'.format(self.key))
3743 key_dtype = self.dtype
3744 if input_tensor.dtype.is_integer:
3745 # `index_table_from_tensor` requires 64-bit integer keys.
3746 key_dtype = dtypes.int64
3747 input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3749 name = '{}_lookup'.format(self.key)
3750 if state_manager is None or not state_manager.has_resource(self, name):
3751 with ops.init_scope():
3752 table = lookup_ops.index_table_from_tensor(
3753 vocabulary_list=tuple(self.vocabulary_list),
3754 default_value=self.default_value,
3755 num_oov_buckets=self.num_oov_buckets,
3756 dtype=key_dtype,
3757 name=name)
3758 if state_manager is not None:
3759 state_manager.add_resource(self, name, table)
3760 else:
3761 # Reuse the table from the previous run.
3762 table = state_manager.get_resource(self, name)
3763 return table.lookup(input_tensor)
3765 def transform_feature(self, transformation_cache, state_manager):
3766 """Creates a lookup table for the vocabulary list."""
3767 input_tensor = _to_sparse_input_and_drop_ignore_values(
3768 transformation_cache.get(self.key, state_manager))
3769 return self._transform_input_tensor(input_tensor, state_manager)
3771 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3772 _FEATURE_COLUMN_DEPRECATION)
3773 def _transform_feature(self, inputs):
3774 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3775 return self._transform_input_tensor(input_tensor)
3777 @property
3778 def num_buckets(self):
3779 """Returns number of buckets in this sparse feature."""
3780 return len(self.vocabulary_list) + self.num_oov_buckets
3782 @property
3783 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3784 _FEATURE_COLUMN_DEPRECATION)
3785 def _num_buckets(self):
3786 return self.num_buckets
3788 def get_sparse_tensors(self, transformation_cache, state_manager):
3789 """See `CategoricalColumn` base class."""
3790 return CategoricalColumn.IdWeightPair(
3791 transformation_cache.get(self, state_manager), None)
3793 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3794 _FEATURE_COLUMN_DEPRECATION)
3795 def _get_sparse_tensors(self,
3796 inputs,
3797 weight_collections=None,
3798 trainable=None):
3799 del weight_collections
3800 del trainable
3801 return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3803 @property
3804 def parents(self):
3805 """See 'FeatureColumn` base class."""
3806 return [self.key]
3808 def get_config(self):
3809 """See 'FeatureColumn` base class."""
3810 config = dict(zip(self._fields, self))
3811 config['dtype'] = self.dtype.name
3812 return config
3814 @classmethod
3815 def from_config(cls, config, custom_objects=None, columns_by_name=None):
3816 """See 'FeatureColumn` base class."""
3817 _check_config_keys(config, cls._fields)
3818 kwargs = _standardize_and_copy_config(config)
3819 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3820 return cls(**kwargs)
3823class IdentityCategoricalColumn(
3824 CategoricalColumn,
3825 fc_old._CategoricalColumn, # pylint: disable=protected-access
3826 collections.namedtuple('IdentityCategoricalColumn',
3827 ('key', 'number_buckets', 'default_value'))):
3828 """See `categorical_column_with_identity`."""
3830 @property
3831 def _is_v2_column(self):
3832 return True
3834 @property
3835 def name(self):
3836 """See `FeatureColumn` base class."""
3837 return self.key
3839 @property
3840 def parse_example_spec(self):
3841 """See `FeatureColumn` base class."""
3842 return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3844 @property
3845 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3846 _FEATURE_COLUMN_DEPRECATION)
3847 def _parse_example_spec(self):
3848 return self.parse_example_spec
3850 def _transform_input_tensor(self, input_tensor):
3851 """Returns a SparseTensor with identity values."""
3852 if not input_tensor.dtype.is_integer:
3853 raise ValueError('Invalid input, not integer. key: {} dtype: {}'.format(
3854 self.key, input_tensor.dtype))
3855 values = input_tensor.values
3856 if input_tensor.values.dtype != dtypes.int64:
3857 values = math_ops.cast(values, dtypes.int64, name='values')
3858 if self.default_value is not None:
3859 values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3860 num_buckets = math_ops.cast(
3861 self.num_buckets, dtypes.int64, name='num_buckets')
3862 zero = math_ops.cast(0, dtypes.int64, name='zero')
3863 # Assign default for out-of-range values.
3864 values = array_ops.where_v2(
3865 math_ops.logical_or(
3866 values < zero, values >= num_buckets, name='out_of_range'),
3867 array_ops.fill(
3868 dims=array_ops.shape(values),
3869 value=math_ops.cast(self.default_value, dtypes.int64),
3870 name='default_values'), values)
3872 return sparse_tensor_lib.SparseTensor(
3873 indices=input_tensor.indices,
3874 values=values,
3875 dense_shape=input_tensor.dense_shape)
3877 def transform_feature(self, transformation_cache, state_manager):
3878 """Returns a SparseTensor with identity values."""
3879 input_tensor = _to_sparse_input_and_drop_ignore_values(
3880 transformation_cache.get(self.key, state_manager))
3881 return self._transform_input_tensor(input_tensor)
3883 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3884 _FEATURE_COLUMN_DEPRECATION)
3885 def _transform_feature(self, inputs):
3886 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3887 return self._transform_input_tensor(input_tensor)
3889 @property
3890 def num_buckets(self):
3891 """Returns number of buckets in this sparse feature."""
3892 return self.number_buckets
3894 @property
3895 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3896 _FEATURE_COLUMN_DEPRECATION)
3897 def _num_buckets(self):
3898 return self.num_buckets
3900 def get_sparse_tensors(self, transformation_cache, state_manager):
3901 """See `CategoricalColumn` base class."""
3902 return CategoricalColumn.IdWeightPair(
3903 transformation_cache.get(self, state_manager), None)
3905 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3906 _FEATURE_COLUMN_DEPRECATION)
3907 def _get_sparse_tensors(self,
3908 inputs,
3909 weight_collections=None,
3910 trainable=None):
3911 del weight_collections
3912 del trainable
3913 return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3915 @property
3916 def parents(self):
3917 """See 'FeatureColumn` base class."""
3918 return [self.key]
3920 def get_config(self):
3921 """See 'FeatureColumn` base class."""
3922 return dict(zip(self._fields, self))
3924 @classmethod
3925 def from_config(cls, config, custom_objects=None, columns_by_name=None):
3926 """See 'FeatureColumn` base class."""
3927 _check_config_keys(config, cls._fields)
3928 kwargs = _standardize_and_copy_config(config)
3929 return cls(**kwargs)
3932class WeightedCategoricalColumn(
3933 CategoricalColumn,
3934 fc_old._CategoricalColumn, # pylint: disable=protected-access
3935 collections.namedtuple(
3936 'WeightedCategoricalColumn',
3937 ('categorical_column', 'weight_feature_key', 'dtype'))):
3938 """See `weighted_categorical_column`."""
3940 @property
3941 def _is_v2_column(self):
3942 return (isinstance(self.categorical_column, FeatureColumn) and
3943 self.categorical_column._is_v2_column) # pylint: disable=protected-access
3945 @property
3946 def name(self):
3947 """See `FeatureColumn` base class."""
3948 return '{}_weighted_by_{}'.format(self.categorical_column.name,
3949 self.weight_feature_key)
3951 @property
3952 def parse_example_spec(self):
3953 """See `FeatureColumn` base class."""
3954 config = self.categorical_column.parse_example_spec
3955 if self.weight_feature_key in config:
3956 raise ValueError('Parse config {} already exists for {}.'.format(
3957 config[self.weight_feature_key], self.weight_feature_key))
3958 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3959 return config
3961 @property
3962 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3963 _FEATURE_COLUMN_DEPRECATION)
3964 def _parse_example_spec(self):
3965 config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
3966 if self.weight_feature_key in config:
3967 raise ValueError('Parse config {} already exists for {}.'.format(
3968 config[self.weight_feature_key], self.weight_feature_key))
3969 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3970 return config
3972 @property
3973 def num_buckets(self):
3974 """See `DenseColumn` base class."""
3975 return self.categorical_column.num_buckets
3977 @property
3978 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3979 _FEATURE_COLUMN_DEPRECATION)
3980 def _num_buckets(self):
3981 return self.categorical_column._num_buckets # pylint: disable=protected-access
3983 def _transform_weight_tensor(self, weight_tensor):
3984 if weight_tensor is None:
3985 raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3986 weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3987 weight_tensor)
3988 if self.dtype != weight_tensor.dtype.base_dtype:
3989 raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3990 self.dtype, weight_tensor.dtype))
3991 if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3992 # The weight tensor can be a regular Tensor. In this case, sparsify it.
3993 weight_tensor = _to_sparse_input_and_drop_ignore_values(
3994 weight_tensor, ignore_value=0.0)
3995 if not weight_tensor.dtype.is_floating:
3996 weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3997 return weight_tensor
3999 def transform_feature(self, transformation_cache, state_manager):
4000 """Applies weights to tensor generated from `categorical_column`'."""
4001 weight_tensor = transformation_cache.get(self.weight_feature_key,
4002 state_manager)
4003 sparse_weight_tensor = self._transform_weight_tensor(weight_tensor)
4004 sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values(
4005 transformation_cache.get(self.categorical_column, state_manager))
4006 return (sparse_categorical_tensor, sparse_weight_tensor)
4008 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4009 _FEATURE_COLUMN_DEPRECATION)
4010 def _transform_feature(self, inputs):
4011 """Applies weights to tensor generated from `categorical_column`'."""
4012 weight_tensor = inputs.get(self.weight_feature_key)
4013 weight_tensor = self._transform_weight_tensor(weight_tensor)
4014 return (inputs.get(self.categorical_column), weight_tensor)
4016 def get_sparse_tensors(self, transformation_cache, state_manager):
4017 """See `CategoricalColumn` base class."""
4018 tensors = transformation_cache.get(self, state_manager)
4019 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
4021 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4022 _FEATURE_COLUMN_DEPRECATION)
4023 def _get_sparse_tensors(self,
4024 inputs,
4025 weight_collections=None,
4026 trainable=None):
4027 del weight_collections
4028 del trainable
4029 tensors = inputs.get(self)
4030 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
4032 @property
4033 def parents(self):
4034 """See 'FeatureColumn` base class."""
4035 return [self.categorical_column, self.weight_feature_key]
4037 def get_config(self):
4038 """See 'FeatureColumn` base class."""
4039 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
4040 config = dict(zip(self._fields, self))
4041 config['categorical_column'] = serialize_feature_column(
4042 self.categorical_column)
4043 config['dtype'] = self.dtype.name
4044 return config
4046 @classmethod
4047 def from_config(cls, config, custom_objects=None, columns_by_name=None):
4048 """See 'FeatureColumn` base class."""
4049 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
4050 _check_config_keys(config, cls._fields)
4051 kwargs = _standardize_and_copy_config(config)
4052 kwargs['categorical_column'] = deserialize_feature_column(
4053 config['categorical_column'], custom_objects, columns_by_name)
4054 kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
4055 return cls(**kwargs)
4058class CrossedColumn(
4059 CategoricalColumn,
4060 fc_old._CategoricalColumn, # pylint: disable=protected-access
4061 collections.namedtuple('CrossedColumn',
4062 ('keys', 'hash_bucket_size', 'hash_key'))):
4063 """See `crossed_column`."""
4065 @property
4066 def _is_v2_column(self):
4067 for key in _collect_leaf_level_keys(self):
4068 if isinstance(key, six.string_types):
4069 continue
4070 if not isinstance(key, FeatureColumn):
4071 return False
4072 if not key._is_v2_column: # pylint: disable=protected-access
4073 return False
4074 return True
4076 @property
4077 def name(self):
4078 """See `FeatureColumn` base class."""
4079 feature_names = []
4080 for key in _collect_leaf_level_keys(self):
4081 if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access
4082 feature_names.append(key.name)
4083 else: # key must be a string
4084 feature_names.append(key)
4085 return '_X_'.join(sorted(feature_names))
4087 @property
4088 def parse_example_spec(self):
4089 """See `FeatureColumn` base class."""
4090 config = {}
4091 for key in self.keys:
4092 if isinstance(key, FeatureColumn):
4093 config.update(key.parse_example_spec)
4094 elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access
4095 config.update(key._parse_example_spec) # pylint: disable=protected-access
4096 else: # key must be a string
4097 config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
4098 return config
4100 @property
4101 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4102 _FEATURE_COLUMN_DEPRECATION)
4103 def _parse_example_spec(self):
4104 return self.parse_example_spec
4106 def transform_feature(self, transformation_cache, state_manager):
4107 """Generates a hashed sparse cross from the input tensors."""
4108 feature_tensors = []
4109 for key in _collect_leaf_level_keys(self):
4110 if isinstance(key, six.string_types):
4111 feature_tensors.append(transformation_cache.get(key, state_manager))
4112 elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access
4113 ids_and_weights = key.get_sparse_tensors(transformation_cache,
4114 state_manager)
4115 if ids_and_weights.weight_tensor is not None:
4116 raise ValueError(
4117 'crossed_column does not support weight_tensor, but the given '
4118 'column populates weight_tensor. '
4119 'Given column: {}'.format(key.name))
4120 feature_tensors.append(ids_and_weights.id_tensor)
4121 else:
4122 raise ValueError('Unsupported column type. Given: {}'.format(key))
4123 return sparse_ops.sparse_cross_hashed(
4124 inputs=feature_tensors,
4125 num_buckets=self.hash_bucket_size,
4126 hash_key=self.hash_key)
4128 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4129 _FEATURE_COLUMN_DEPRECATION)
4130 def _transform_feature(self, inputs):
4131 """Generates a hashed sparse cross from the input tensors."""
4132 feature_tensors = []
4133 for key in _collect_leaf_level_keys(self):
4134 if isinstance(key, six.string_types):
4135 feature_tensors.append(inputs.get(key))
4136 elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access
4137 ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
4138 if ids_and_weights.weight_tensor is not None:
4139 raise ValueError(
4140 'crossed_column does not support weight_tensor, but the given '
4141 'column populates weight_tensor. '
4142 'Given column: {}'.format(key.name))
4143 feature_tensors.append(ids_and_weights.id_tensor)
4144 else:
4145 raise ValueError('Unsupported column type. Given: {}'.format(key))
4146 return sparse_ops.sparse_cross_hashed(
4147 inputs=feature_tensors,
4148 num_buckets=self.hash_bucket_size,
4149 hash_key=self.hash_key)
4151 @property
4152 def num_buckets(self):
4153 """Returns number of buckets in this sparse feature."""
4154 return self.hash_bucket_size
4156 @property
4157 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4158 _FEATURE_COLUMN_DEPRECATION)
4159 def _num_buckets(self):
4160 return self.num_buckets
4162 def get_sparse_tensors(self, transformation_cache, state_manager):
4163 """See `CategoricalColumn` base class."""
4164 return CategoricalColumn.IdWeightPair(
4165 transformation_cache.get(self, state_manager), None)
4167 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4168 _FEATURE_COLUMN_DEPRECATION)
4169 def _get_sparse_tensors(self,
4170 inputs,
4171 weight_collections=None,
4172 trainable=None):
4173 """See `CategoricalColumn` base class."""
4174 del weight_collections
4175 del trainable
4176 return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4178 @property
4179 def parents(self):
4180 """See 'FeatureColumn` base class."""
4181 return list(self.keys)
4183 def get_config(self):
4184 """See 'FeatureColumn` base class."""
4185 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
4186 config = dict(zip(self._fields, self))
4187 config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4188 return config
4190 @classmethod
4191 def from_config(cls, config, custom_objects=None, columns_by_name=None):
4192 """See 'FeatureColumn` base class."""
4193 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
4194 _check_config_keys(config, cls._fields)
4195 kwargs = _standardize_and_copy_config(config)
4196 kwargs['keys'] = tuple([
4197 deserialize_feature_column(c, custom_objects, columns_by_name)
4198 for c in config['keys']
4199 ])
4200 return cls(**kwargs)
4203def _collect_leaf_level_keys(cross):
4204 """Collects base keys by expanding all nested crosses.
4206 Args:
4207 cross: A `CrossedColumn`.
4209 Returns:
4210 A list of strings or `CategoricalColumn` instances.
4211 """
4212 leaf_level_keys = []
4213 for k in cross.keys:
4214 if isinstance(k, CrossedColumn):
4215 leaf_level_keys.extend(_collect_leaf_level_keys(k))
4216 else:
4217 leaf_level_keys.append(k)
4218 return leaf_level_keys
4221def _prune_invalid_ids(sparse_ids, sparse_weights):
4222 """Prune invalid IDs (< 0) from the input ids and weights."""
4223 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4224 if sparse_weights is not None:
4225 is_id_valid = math_ops.logical_and(
4226 is_id_valid,
4227 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4228 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4229 if sparse_weights is not None:
4230 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4231 return sparse_ids, sparse_weights
4234def _prune_invalid_weights(sparse_ids, sparse_weights):
4235 """Prune invalid weights (< 0) from the input ids and weights."""
4236 if sparse_weights is not None:
4237 is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4238 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4239 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4240 return sparse_ids, sparse_weights
4243class IndicatorColumn(
4244 DenseColumn,
4245 SequenceDenseColumn,
4246 fc_old._DenseColumn, # pylint: disable=protected-access
4247 fc_old._SequenceDenseColumn, # pylint: disable=protected-access
4248 collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4249 """Represents a one-hot column for use in deep networks.
4251 Args:
4252 categorical_column: A `CategoricalColumn` which is created by
4253 `categorical_column_with_*` function.
4254 """
4256 @property
4257 def _is_v2_column(self):
4258 return (isinstance(self.categorical_column, FeatureColumn) and
4259 self.categorical_column._is_v2_column) # pylint: disable=protected-access
4261 @property
4262 def name(self):
4263 """See `FeatureColumn` base class."""
4264 return '{}_indicator'.format(self.categorical_column.name)
4266 def _transform_id_weight_pair(self, id_weight_pair, size):
4267 id_tensor = id_weight_pair.id_tensor
4268 weight_tensor = id_weight_pair.weight_tensor
4270 # If the underlying column is weighted, return the input as a dense tensor.
4271 if weight_tensor is not None:
4272 weighted_column = sparse_ops.sparse_merge(
4273 sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size))
4274 # Remove (?, -1) index.
4275 weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4276 weighted_column.dense_shape)
4277 # Use scatter_nd to merge duplicated indices if existed,
4278 # instead of sparse_tensor_to_dense.
4279 return array_ops.scatter_nd(weighted_column.indices,
4280 weighted_column.values,
4281 weighted_column.dense_shape)
4283 dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4284 id_tensor, default_value=-1)
4286 # One hot must be float for tf.concat reasons since all other inputs to
4287 # input_layer are float32.
4288 one_hot_id_tensor = array_ops.one_hot(
4289 dense_id_tensor, depth=size, on_value=1.0, off_value=0.0)
4291 # Reduce to get a multi-hot per example.
4292 return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4294 def transform_feature(self, transformation_cache, state_manager):
4295 """Returns dense `Tensor` representing feature.
4297 Args:
4298 transformation_cache: A `FeatureTransformationCache` object to access
4299 features.
4300 state_manager: A `StateManager` to create / access resources such as
4301 lookup tables.
4303 Returns:
4304 Transformed feature `Tensor`.
4306 Raises:
4307 ValueError: if input rank is not known at graph building time.
4308 """
4309 id_weight_pair = self.categorical_column.get_sparse_tensors(
4310 transformation_cache, state_manager)
4311 return self._transform_id_weight_pair(id_weight_pair,
4312 self.variable_shape[-1])
4314 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4315 _FEATURE_COLUMN_DEPRECATION)
4316 def _transform_feature(self, inputs):
4317 id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
4318 return self._transform_id_weight_pair(id_weight_pair,
4319 self._variable_shape[-1])
4321 @property
4322 def parse_example_spec(self):
4323 """See `FeatureColumn` base class."""
4324 return self.categorical_column.parse_example_spec
4326 @property
4327 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4328 _FEATURE_COLUMN_DEPRECATION)
4329 def _parse_example_spec(self):
4330 return self.categorical_column._parse_example_spec # pylint: disable=protected-access
4332 @property
4333 def variable_shape(self):
4334 """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4335 if isinstance(self.categorical_column, FeatureColumn):
4336 return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4337 else:
4338 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
4340 @property
4341 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4342 _FEATURE_COLUMN_DEPRECATION)
4343 def _variable_shape(self):
4344 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
4346 def get_dense_tensor(self, transformation_cache, state_manager):
4347 """Returns dense `Tensor` representing feature.
4349 Args:
4350 transformation_cache: A `FeatureTransformationCache` object to access
4351 features.
4352 state_manager: A `StateManager` to create / access resources such as
4353 lookup tables.
4355 Returns:
4356 Dense `Tensor` created within `transform_feature`.
4358 Raises:
4359 ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4360 """
4361 if isinstance(self.categorical_column, SequenceCategoricalColumn):
4362 raise ValueError(
4363 'In indicator_column: {}. '
4364 'categorical_column must not be of type SequenceCategoricalColumn. '
4365 'Suggested fix A: If you wish to use DenseFeatures, use a '
4366 'non-sequence categorical_column_with_*. '
4367 'Suggested fix B: If you wish to create sequence input, use '
4368 'SequenceFeatures instead of DenseFeatures. '
4369 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4370 self.categorical_column))
4371 # Feature has been already transformed. Return the intermediate
4372 # representation created by transform_feature.
4373 return transformation_cache.get(self, state_manager)
4375 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4376 _FEATURE_COLUMN_DEPRECATION)
4377 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4378 del weight_collections
4379 del trainable
4380 if isinstance(
4381 self.categorical_column,
4382 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
4383 raise ValueError(
4384 'In indicator_column: {}. '
4385 'categorical_column must not be of type _SequenceCategoricalColumn. '
4386 'Suggested fix A: If you wish to use DenseFeatures, use a '
4387 'non-sequence categorical_column_with_*. '
4388 'Suggested fix B: If you wish to create sequence input, use '
4389 'SequenceFeatures instead of DenseFeatures. '
4390 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4391 self.categorical_column))
4392 # Feature has been already transformed. Return the intermediate
4393 # representation created by transform_feature.
4394 return inputs.get(self)
4396 def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4397 """See `SequenceDenseColumn` base class."""
4398 if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4399 raise ValueError(
4400 'In indicator_column: {}. '
4401 'categorical_column must be of type SequenceCategoricalColumn '
4402 'to use SequenceFeatures. '
4403 'Suggested fix: Use one of sequence_categorical_column_with_*. '
4404 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4405 self.categorical_column))
4406 # Feature has been already transformed. Return the intermediate
4407 # representation created by transform_feature.
4408 dense_tensor = transformation_cache.get(self, state_manager)
4409 sparse_tensors = self.categorical_column.get_sparse_tensors(
4410 transformation_cache, state_manager)
4411 sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4412 sparse_tensors.id_tensor)
4413 return SequenceDenseColumn.TensorSequenceLengthPair(
4414 dense_tensor=dense_tensor, sequence_length=sequence_length)
4416 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4417 _FEATURE_COLUMN_DEPRECATION)
4418 def _get_sequence_dense_tensor(self,
4419 inputs,
4420 weight_collections=None,
4421 trainable=None):
4422 # Do nothing with weight_collections and trainable since no variables are
4423 # created in this function.
4424 del weight_collections
4425 del trainable
4426 if not isinstance(
4427 self.categorical_column,
4428 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
4429 raise ValueError(
4430 'In indicator_column: {}. '
4431 'categorical_column must be of type _SequenceCategoricalColumn '
4432 'to use SequenceFeatures. '
4433 'Suggested fix: Use one of sequence_categorical_column_with_*. '
4434 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4435 self.categorical_column))
4436 # Feature has been already transformed. Return the intermediate
4437 # representation created by _transform_feature.
4438 dense_tensor = inputs.get(self)
4439 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
4440 sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4441 sparse_tensors.id_tensor)
4442 return SequenceDenseColumn.TensorSequenceLengthPair(
4443 dense_tensor=dense_tensor, sequence_length=sequence_length)
4445 @property
4446 def parents(self):
4447 """See 'FeatureColumn` base class."""
4448 return [self.categorical_column]
4450 def get_config(self):
4451 """See 'FeatureColumn` base class."""
4452 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
4453 config = dict(zip(self._fields, self))
4454 config['categorical_column'] = serialize_feature_column(
4455 self.categorical_column)
4456 return config
4458 @classmethod
4459 def from_config(cls, config, custom_objects=None, columns_by_name=None):
4460 """See 'FeatureColumn` base class."""
4461 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
4462 _check_config_keys(config, cls._fields)
4463 kwargs = _standardize_and_copy_config(config)
4464 kwargs['categorical_column'] = deserialize_feature_column(
4465 config['categorical_column'], custom_objects, columns_by_name)
4466 return cls(**kwargs)
4469def _verify_static_batch_size_equality(tensors, columns):
4470 """Verify equality between static batch sizes.
4472 Args:
4473 tensors: iterable of input tensors.
4474 columns: Corresponding feature columns.
4476 Raises:
4477 ValueError: in case of mismatched batch sizes.
4478 """
4479 # bath_size is a Dimension object.
4480 expected_batch_size = None
4481 for i in range(0, len(tensors)):
4482 batch_size = tensor_shape.Dimension(
4483 tensor_shape.dimension_value(tensors[i].shape[0]))
4484 if batch_size.value is not None:
4485 if expected_batch_size is None:
4486 bath_size_column_index = i
4487 expected_batch_size = batch_size
4488 elif not expected_batch_size.is_compatible_with(batch_size):
4489 raise ValueError(
4490 'Batch size (first dimension) of each feature must be same. '
4491 'Batch size of columns ({}, {}): ({}, {})'.format(
4492 columns[bath_size_column_index].name, columns[i].name,
4493 expected_batch_size, batch_size))
4496class SequenceCategoricalColumn(
4497 CategoricalColumn,
4498 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
4499 collections.namedtuple('SequenceCategoricalColumn',
4500 ('categorical_column'))):
4501 """Represents sequences of categorical data."""
4503 @property
4504 def _is_v2_column(self):
4505 return (isinstance(self.categorical_column, FeatureColumn) and
4506 self.categorical_column._is_v2_column) # pylint: disable=protected-access
4508 @property
4509 def name(self):
4510 """See `FeatureColumn` base class."""
4511 return self.categorical_column.name
4513 @property
4514 def parse_example_spec(self):
4515 """See `FeatureColumn` base class."""
4516 return self.categorical_column.parse_example_spec
4518 @property
4519 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4520 _FEATURE_COLUMN_DEPRECATION)
4521 def _parse_example_spec(self):
4522 return self.categorical_column._parse_example_spec # pylint: disable=protected-access
4524 def transform_feature(self, transformation_cache, state_manager):
4525 """See `FeatureColumn` base class."""
4526 return self.categorical_column.transform_feature(transformation_cache,
4527 state_manager)
4529 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4530 _FEATURE_COLUMN_DEPRECATION)
4531 def _transform_feature(self, inputs):
4532 return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
4534 @property
4535 def num_buckets(self):
4536 """Returns number of buckets in this sparse feature."""
4537 return self.categorical_column.num_buckets
4539 @property
4540 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4541 _FEATURE_COLUMN_DEPRECATION)
4542 def _num_buckets(self):
4543 return self.categorical_column._num_buckets # pylint: disable=protected-access
4545 def _get_sparse_tensors_helper(self, sparse_tensors):
4546 id_tensor = sparse_tensors.id_tensor
4547 weight_tensor = sparse_tensors.weight_tensor
4548 # Expands third dimension, if necessary so that embeddings are not
4549 # combined during embedding lookup. If the tensor is already 3D, leave
4550 # as-is.
4551 shape = array_ops.shape(id_tensor)
4552 # Compute the third dimension explicitly instead of setting it to -1, as
4553 # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4554 # This happens for empty sequences.
4555 target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4556 id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4557 if weight_tensor is not None:
4558 weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4559 return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4561 def get_sparse_tensors(self, transformation_cache, state_manager):
4562 """Returns an IdWeightPair.
4564 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4565 weights.
4567 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4568 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4569 `SparseTensor` of `float` or `None` to indicate all weights should be
4570 taken to be 1. If specified, `weight_tensor` must have exactly the same
4571 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4572 output of a `VarLenFeature` which is a ragged matrix.
4574 Args:
4575 transformation_cache: A `FeatureTransformationCache` object to access
4576 features.
4577 state_manager: A `StateManager` to create / access resources such as
4578 lookup tables.
4579 """
4580 sparse_tensors = self.categorical_column.get_sparse_tensors(
4581 transformation_cache, state_manager)
4582 return self._get_sparse_tensors_helper(sparse_tensors)
4584 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4585 _FEATURE_COLUMN_DEPRECATION)
4586 def _get_sparse_tensors(self,
4587 inputs,
4588 weight_collections=None,
4589 trainable=None):
4590 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
4591 return self._get_sparse_tensors_helper(sparse_tensors)
4593 @property
4594 def parents(self):
4595 """See 'FeatureColumn` base class."""
4596 return [self.categorical_column]
4598 def get_config(self):
4599 """See 'FeatureColumn` base class."""
4600 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
4601 config = dict(zip(self._fields, self))
4602 config['categorical_column'] = serialize_feature_column(
4603 self.categorical_column)
4604 return config
4606 @classmethod
4607 def from_config(cls, config, custom_objects=None, columns_by_name=None):
4608 """See 'FeatureColumn` base class."""
4609 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
4610 _check_config_keys(config, cls._fields)
4611 kwargs = _standardize_and_copy_config(config)
4612 kwargs['categorical_column'] = deserialize_feature_column(
4613 config['categorical_column'], custom_objects, columns_by_name)
4614 return cls(**kwargs)
4617def _check_config_keys(config, expected_keys):
4618 """Checks that a config has all expected_keys."""
4619 if set(config.keys()) != set(expected_keys):
4620 raise ValueError('Invalid config: {}, expected keys: {}'.format(
4621 config, expected_keys))
4624def _standardize_and_copy_config(config):
4625 """Returns a shallow copy of config with lists turned to tuples.
4627 Keras serialization uses nest to listify everything.
4628 This causes problems with the NumericColumn shape, which becomes
4629 unhashable. We could try to solve this on the Keras side, but that
4630 would require lots of tracking to avoid changing existing behavior.
4631 Instead, we ensure here that we revive correctly.
4633 Args:
4634 config: dict that will be used to revive a Feature Column
4636 Returns:
4637 Shallow copy of config with lists turned to tuples.
4638 """
4639 kwargs = config.copy()
4640 for k, v in kwargs.items():
4641 if isinstance(v, list):
4642 kwargs[k] = tuple(v)
4644 return kwargs
4647def _sanitize_column_name_for_variable_scope(name):
4648 """Sanitizes user-provided feature names for use as variable scopes."""
4649 invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
4650 return invalid_char.sub('_', name)