Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/preprocessing/hashed_crossing.py: 27%

71 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras hashed crossing preprocessing layer.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine import base_layer 

22from keras.src.engine import base_preprocessing_layer 

23from keras.src.layers.preprocessing import preprocessing_utils as utils 

24from keras.src.utils import layer_utils 

25 

26# isort: off 

27from tensorflow.python.util.tf_export import keras_export 

28 

29INT = utils.INT 

30ONE_HOT = utils.ONE_HOT 

31 

32 

33@keras_export( 

34 "keras.layers.HashedCrossing", 

35 "keras.layers.experimental.preprocessing.HashedCrossing", 

36 v1=[], 

37) 

38class HashedCrossing(base_layer.Layer): 

39 """A preprocessing layer which crosses features using the "hashing trick". 

40 

41 This layer performs crosses of categorical features using the "hasing 

42 trick". Conceptually, the transformation can be thought of as: 

43 hash(concatenation of features) % `num_bins`. 

44 

45 This layer currently only performs crosses of scalar inputs and batches of 

46 scalar inputs. Valid input shapes are `(batch_size, 1)`, `(batch_size,)` and 

47 `()`. 

48 

49 For an overview and full list of preprocessing layers, see the preprocessing 

50 [guide](https://www.tensorflow.org/guide/keras/preprocessing_layers). 

51 

52 Args: 

53 num_bins: Number of hash bins. 

54 output_mode: Specification for the output of the layer. Values can be 

55 `"int"`, or `"one_hot"` configuring the layer as follows: 

56 - `"int"`: Return the integer bin indices directly. 

57 - `"one_hot"`: Encodes each individual element in the input into an 

58 array the same size as `num_bins`, containing a 1 at the input's bin 

59 index. 

60 Defaults to `"int"`. 

61 sparse: Boolean. Only applicable to `"one_hot"` mode. If True, returns a 

62 `SparseTensor` instead of a dense `Tensor`. Defaults to `False`. 

63 **kwargs: Keyword arguments to construct a layer. 

64 

65 Examples: 

66 

67 **Crossing two scalar features.** 

68 

69 >>> layer = tf.keras.layers.HashedCrossing( 

70 ... num_bins=5) 

71 >>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A']) 

72 >>> feat2 = tf.constant([101, 101, 101, 102, 102]) 

73 >>> layer((feat1, feat2)) 

74 <tf.Tensor: shape=(5,), dtype=int64, numpy=array([1, 4, 1, 1, 3])> 

75 

76 **Crossing and one-hotting two scalar features.** 

77 

78 >>> layer = tf.keras.layers.HashedCrossing( 

79 ... num_bins=5, output_mode='one_hot') 

80 >>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A']) 

81 >>> feat2 = tf.constant([101, 101, 101, 102, 102]) 

82 >>> layer((feat1, feat2)) 

83 <tf.Tensor: shape=(5, 5), dtype=float32, numpy= 

84 array([[0., 1., 0., 0., 0.], 

85 [0., 0., 0., 0., 1.], 

86 [0., 1., 0., 0., 0.], 

87 [0., 1., 0., 0., 0.], 

88 [0., 0., 0., 1., 0.]], dtype=float32)> 

89 """ 

90 

91 def __init__(self, num_bins, output_mode="int", sparse=False, **kwargs): 

92 # By default, output int64 when output_mode="int" and floats otherwise. 

93 if "dtype" not in kwargs or kwargs["dtype"] is None: 

94 kwargs["dtype"] = ( 

95 tf.int64 if output_mode == INT else backend.floatx() 

96 ) 

97 

98 super().__init__(**kwargs) 

99 base_preprocessing_layer.keras_kpl_gauge.get_cell("HashedCrossing").set( 

100 True 

101 ) 

102 

103 # Check dtype only after base layer parses it; dtype parsing is complex. 

104 if ( 

105 output_mode == INT 

106 and not tf.as_dtype(self.compute_dtype).is_integer 

107 ): 

108 input_dtype = kwargs["dtype"] 

109 raise ValueError( 

110 "When `output_mode='int'`, `dtype` should be an integer " 

111 f"type. Received: dtype={input_dtype}" 

112 ) 

113 

114 # "output_mode" must be one of (INT, ONE_HOT) 

115 layer_utils.validate_string_arg( 

116 output_mode, 

117 allowable_strings=(INT, ONE_HOT), 

118 layer_name=self.__class__.__name__, 

119 arg_name="output_mode", 

120 ) 

121 

122 self.num_bins = num_bins 

123 self.output_mode = output_mode 

124 self.sparse = sparse 

125 

126 def call(self, inputs): 

127 # Convert all inputs to tensors and check shape. This layer only 

128 # supports sclars and batches of scalars for the initial version. 

129 self._check_at_least_two_inputs(inputs) 

130 inputs = [utils.ensure_tensor(x) for x in inputs] 

131 self._check_input_shape_and_type(inputs) 

132 

133 # Uprank to rank 2 for the cross_hashed op. 

134 rank = inputs[0].shape.rank 

135 if rank < 2: 

136 inputs = [utils.expand_dims(x, -1) for x in inputs] 

137 if rank < 1: 

138 inputs = [utils.expand_dims(x, -1) for x in inputs] 

139 

140 # Perform the cross and convert to dense 

141 outputs = tf.sparse.cross_hashed(inputs, self.num_bins) 

142 outputs = tf.sparse.to_dense(outputs) 

143 

144 # Fix output shape and downrank to match input rank. 

145 if rank == 2: 

146 # tf.sparse.cross_hashed output shape will always be None on the 

147 # last dimension. Given our input shape restrictions, we want to 

148 # force shape 1 instead. 

149 outputs = tf.reshape(outputs, [-1, 1]) 

150 elif rank == 1: 

151 outputs = tf.reshape(outputs, [-1]) 

152 elif rank == 0: 

153 outputs = tf.reshape(outputs, []) 

154 

155 # Encode outputs. 

156 return utils.encode_categorical_inputs( 

157 outputs, 

158 output_mode=self.output_mode, 

159 depth=self.num_bins, 

160 sparse=self.sparse, 

161 dtype=self.compute_dtype, 

162 ) 

163 

164 def compute_output_shape(self, input_shapes): 

165 self._check_at_least_two_inputs(input_shapes) 

166 return utils.compute_shape_for_encode_categorical(input_shapes[0]) 

167 

168 def compute_output_signature(self, input_specs): 

169 input_shapes = [x.shape.as_list() for x in input_specs] 

170 output_shape = self.compute_output_shape(input_shapes) 

171 if self.sparse or any( 

172 isinstance(x, tf.SparseTensorSpec) for x in input_specs 

173 ): 

174 return tf.SparseTensorSpec( 

175 shape=output_shape, dtype=self.compute_dtype 

176 ) 

177 return tf.TensorSpec(shape=output_shape, dtype=self.compute_dtype) 

178 

179 def get_config(self): 

180 config = super().get_config() 

181 config.update( 

182 { 

183 "num_bins": self.num_bins, 

184 "output_mode": self.output_mode, 

185 "sparse": self.sparse, 

186 } 

187 ) 

188 return config 

189 

190 def _check_at_least_two_inputs(self, inputs): 

191 if not isinstance(inputs, (list, tuple)): 

192 raise ValueError( 

193 "`HashedCrossing` should be called on a list or tuple of " 

194 f"inputs. Received: inputs={inputs}" 

195 ) 

196 if len(inputs) < 2: 

197 raise ValueError( 

198 "`HashedCrossing` should be called on at least two inputs. " 

199 f"Received: inputs={inputs}" 

200 ) 

201 

202 def _check_input_shape_and_type(self, inputs): 

203 first_shape = inputs[0].shape.as_list() 

204 rank = len(first_shape) 

205 if rank > 2 or (rank == 2 and first_shape[-1] != 1): 

206 raise ValueError( 

207 "All `HashedCrossing` inputs should have shape `[]`, " 

208 "`[batch_size]` or `[batch_size, 1]`. " 

209 f"Received: inputs={inputs}" 

210 ) 

211 if not all(x.shape.as_list() == first_shape for x in inputs[1:]): 

212 raise ValueError( 

213 "All `HashedCrossing` inputs should have equal shape. " 

214 f"Received: inputs={inputs}" 

215 ) 

216 if any( 

217 isinstance(x, (tf.RaggedTensor, tf.SparseTensor)) for x in inputs 

218 ): 

219 raise ValueError( 

220 "All `HashedCrossing` inputs should be dense tensors. " 

221 f"Received: inputs={inputs}" 

222 ) 

223 if not all(x.dtype.is_integer or x.dtype == tf.string for x in inputs): 

224 raise ValueError( 

225 "All `HashedCrossing` inputs should have an integer or " 

226 f"string dtype. Received: inputs={inputs}" 

227 ) 

228