Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/preprocessing/hashed_crossing.py: 27%
71 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras hashed crossing preprocessing layer."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine import base_layer
22from keras.src.engine import base_preprocessing_layer
23from keras.src.layers.preprocessing import preprocessing_utils as utils
24from keras.src.utils import layer_utils
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
29INT = utils.INT
30ONE_HOT = utils.ONE_HOT
33@keras_export(
34 "keras.layers.HashedCrossing",
35 "keras.layers.experimental.preprocessing.HashedCrossing",
36 v1=[],
37)
38class HashedCrossing(base_layer.Layer):
39 """A preprocessing layer which crosses features using the "hashing trick".
41 This layer performs crosses of categorical features using the "hasing
42 trick". Conceptually, the transformation can be thought of as:
43 hash(concatenation of features) % `num_bins`.
45 This layer currently only performs crosses of scalar inputs and batches of
46 scalar inputs. Valid input shapes are `(batch_size, 1)`, `(batch_size,)` and
47 `()`.
49 For an overview and full list of preprocessing layers, see the preprocessing
50 [guide](https://www.tensorflow.org/guide/keras/preprocessing_layers).
52 Args:
53 num_bins: Number of hash bins.
54 output_mode: Specification for the output of the layer. Values can be
55 `"int"`, or `"one_hot"` configuring the layer as follows:
56 - `"int"`: Return the integer bin indices directly.
57 - `"one_hot"`: Encodes each individual element in the input into an
58 array the same size as `num_bins`, containing a 1 at the input's bin
59 index.
60 Defaults to `"int"`.
61 sparse: Boolean. Only applicable to `"one_hot"` mode. If True, returns a
62 `SparseTensor` instead of a dense `Tensor`. Defaults to `False`.
63 **kwargs: Keyword arguments to construct a layer.
65 Examples:
67 **Crossing two scalar features.**
69 >>> layer = tf.keras.layers.HashedCrossing(
70 ... num_bins=5)
71 >>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A'])
72 >>> feat2 = tf.constant([101, 101, 101, 102, 102])
73 >>> layer((feat1, feat2))
74 <tf.Tensor: shape=(5,), dtype=int64, numpy=array([1, 4, 1, 1, 3])>
76 **Crossing and one-hotting two scalar features.**
78 >>> layer = tf.keras.layers.HashedCrossing(
79 ... num_bins=5, output_mode='one_hot')
80 >>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A'])
81 >>> feat2 = tf.constant([101, 101, 101, 102, 102])
82 >>> layer((feat1, feat2))
83 <tf.Tensor: shape=(5, 5), dtype=float32, numpy=
84 array([[0., 1., 0., 0., 0.],
85 [0., 0., 0., 0., 1.],
86 [0., 1., 0., 0., 0.],
87 [0., 1., 0., 0., 0.],
88 [0., 0., 0., 1., 0.]], dtype=float32)>
89 """
91 def __init__(self, num_bins, output_mode="int", sparse=False, **kwargs):
92 # By default, output int64 when output_mode="int" and floats otherwise.
93 if "dtype" not in kwargs or kwargs["dtype"] is None:
94 kwargs["dtype"] = (
95 tf.int64 if output_mode == INT else backend.floatx()
96 )
98 super().__init__(**kwargs)
99 base_preprocessing_layer.keras_kpl_gauge.get_cell("HashedCrossing").set(
100 True
101 )
103 # Check dtype only after base layer parses it; dtype parsing is complex.
104 if (
105 output_mode == INT
106 and not tf.as_dtype(self.compute_dtype).is_integer
107 ):
108 input_dtype = kwargs["dtype"]
109 raise ValueError(
110 "When `output_mode='int'`, `dtype` should be an integer "
111 f"type. Received: dtype={input_dtype}"
112 )
114 # "output_mode" must be one of (INT, ONE_HOT)
115 layer_utils.validate_string_arg(
116 output_mode,
117 allowable_strings=(INT, ONE_HOT),
118 layer_name=self.__class__.__name__,
119 arg_name="output_mode",
120 )
122 self.num_bins = num_bins
123 self.output_mode = output_mode
124 self.sparse = sparse
126 def call(self, inputs):
127 # Convert all inputs to tensors and check shape. This layer only
128 # supports sclars and batches of scalars for the initial version.
129 self._check_at_least_two_inputs(inputs)
130 inputs = [utils.ensure_tensor(x) for x in inputs]
131 self._check_input_shape_and_type(inputs)
133 # Uprank to rank 2 for the cross_hashed op.
134 rank = inputs[0].shape.rank
135 if rank < 2:
136 inputs = [utils.expand_dims(x, -1) for x in inputs]
137 if rank < 1:
138 inputs = [utils.expand_dims(x, -1) for x in inputs]
140 # Perform the cross and convert to dense
141 outputs = tf.sparse.cross_hashed(inputs, self.num_bins)
142 outputs = tf.sparse.to_dense(outputs)
144 # Fix output shape and downrank to match input rank.
145 if rank == 2:
146 # tf.sparse.cross_hashed output shape will always be None on the
147 # last dimension. Given our input shape restrictions, we want to
148 # force shape 1 instead.
149 outputs = tf.reshape(outputs, [-1, 1])
150 elif rank == 1:
151 outputs = tf.reshape(outputs, [-1])
152 elif rank == 0:
153 outputs = tf.reshape(outputs, [])
155 # Encode outputs.
156 return utils.encode_categorical_inputs(
157 outputs,
158 output_mode=self.output_mode,
159 depth=self.num_bins,
160 sparse=self.sparse,
161 dtype=self.compute_dtype,
162 )
164 def compute_output_shape(self, input_shapes):
165 self._check_at_least_two_inputs(input_shapes)
166 return utils.compute_shape_for_encode_categorical(input_shapes[0])
168 def compute_output_signature(self, input_specs):
169 input_shapes = [x.shape.as_list() for x in input_specs]
170 output_shape = self.compute_output_shape(input_shapes)
171 if self.sparse or any(
172 isinstance(x, tf.SparseTensorSpec) for x in input_specs
173 ):
174 return tf.SparseTensorSpec(
175 shape=output_shape, dtype=self.compute_dtype
176 )
177 return tf.TensorSpec(shape=output_shape, dtype=self.compute_dtype)
179 def get_config(self):
180 config = super().get_config()
181 config.update(
182 {
183 "num_bins": self.num_bins,
184 "output_mode": self.output_mode,
185 "sparse": self.sparse,
186 }
187 )
188 return config
190 def _check_at_least_two_inputs(self, inputs):
191 if not isinstance(inputs, (list, tuple)):
192 raise ValueError(
193 "`HashedCrossing` should be called on a list or tuple of "
194 f"inputs. Received: inputs={inputs}"
195 )
196 if len(inputs) < 2:
197 raise ValueError(
198 "`HashedCrossing` should be called on at least two inputs. "
199 f"Received: inputs={inputs}"
200 )
202 def _check_input_shape_and_type(self, inputs):
203 first_shape = inputs[0].shape.as_list()
204 rank = len(first_shape)
205 if rank > 2 or (rank == 2 and first_shape[-1] != 1):
206 raise ValueError(
207 "All `HashedCrossing` inputs should have shape `[]`, "
208 "`[batch_size]` or `[batch_size, 1]`. "
209 f"Received: inputs={inputs}"
210 )
211 if not all(x.shape.as_list() == first_shape for x in inputs[1:]):
212 raise ValueError(
213 "All `HashedCrossing` inputs should have equal shape. "
214 f"Received: inputs={inputs}"
215 )
216 if any(
217 isinstance(x, (tf.RaggedTensor, tf.SparseTensor)) for x in inputs
218 ):
219 raise ValueError(
220 "All `HashedCrossing` inputs should be dense tensors. "
221 f"Received: inputs={inputs}"
222 )
223 if not all(x.dtype.is_integer or x.dtype == tf.string for x in inputs):
224 raise ValueError(
225 "All `HashedCrossing` inputs should have an integer or "
226 f"string dtype. Received: inputs={inputs}"
227 )