Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/feature_column/dense_features_v2.py: 47%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A layer that produces a dense `Tensor` based on given `feature_columns`.""" 

16 

17from __future__ import absolute_import 

18from __future__ import division 

19from __future__ import print_function 

20 

21import tensorflow.compat.v2 as tf 

22 

23from keras.src.feature_column import base_feature_layer as kfc 

24from keras.src.feature_column import dense_features 

25from keras.src.utils import tf_contextlib 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@keras_export("keras.layers.DenseFeatures", v1=[]) 

32class DenseFeatures(dense_features.DenseFeatures): 

33 """A layer that produces a dense `Tensor` based on given `feature_columns`. 

34 

35 Generally a single example in training data is described with 

36 FeatureColumns. At the first layer of the model, this column oriented data 

37 should be converted to a single `Tensor`. 

38 

39 This layer can be called multiple times with different features. 

40 

41 This is the V2 version of this layer that uses name_scopes to create 

42 variables instead of variable_scopes. But this approach currently lacks 

43 support for partitioned variables. In that case, use the V1 version instead. 

44 

45 Example: 

46 

47 ```python 

48 price = tf.feature_column.numeric_column('price') 

49 keywords_embedded = tf.feature_column.embedding_column( 

50 tf.feature_column.categorical_column_with_hash_bucket("keywords", 

51 10000), 

52 dimensions=16) 

53 columns = [price, keywords_embedded, ...] 

54 feature_layer = tf.keras.layers.DenseFeatures(columns) 

55 

56 features = tf.io.parse_example( 

57 ..., features=tf.feature_column.make_parse_example_spec(columns)) 

58 dense_tensor = feature_layer(features) 

59 for units in [128, 64, 32]: 

60 dense_tensor = tf.keras.layers.Dense(units, activation='relu')( 

61 dense_tensor) 

62 prediction = tf.keras.layers.Dense(1)(dense_tensor) 

63 ``` 

64 """ 

65 

66 def __init__(self, feature_columns, trainable=True, name=None, **kwargs): 

67 """Creates a DenseFeatures object. 

68 

69 Args: 

70 feature_columns: An iterable containing the FeatureColumns to use as 

71 inputs to your model. All items should be instances of classes 

72 derived from `DenseColumn` such as `numeric_column`, 

73 `embedding_column`, `bucketized_column`, `indicator_column`. If you 

74 have categorical features, you can wrap them with an 

75 `embedding_column` or `indicator_column`. 

76 trainable: Boolean, whether the layer's variables will be updated via 

77 gradient descent during training. 

78 name: Name to give to the DenseFeatures. 

79 **kwargs: Keyword arguments to construct a layer. 

80 

81 Raises: 

82 ValueError: if an item in `feature_columns` is not a `DenseColumn`. 

83 """ 

84 super().__init__( 

85 feature_columns=feature_columns, 

86 trainable=trainable, 

87 name=name, 

88 **kwargs 

89 ) 

90 self._state_manager = _StateManagerImplV2(self, self.trainable) 

91 

92 def build(self, _): 

93 for column in self._feature_columns: 

94 with tf.name_scope(column.name): 

95 column.create_state(self._state_manager) 

96 # We would like to call Layer.build and not _DenseFeaturesHelper.build. 

97 

98 super(kfc._BaseFeaturesLayer, self).build(None) 

99 

100 

101class _StateManagerImplV2(tf.__internal__.feature_column.StateManager): 

102 """Manages the state of DenseFeatures.""" 

103 

104 def create_variable( 

105 self, 

106 feature_column, 

107 name, 

108 shape, 

109 dtype=None, 

110 trainable=True, 

111 use_resource=True, 

112 initializer=None, 

113 ): 

114 if name in self._cols_to_vars_map[feature_column]: 

115 raise ValueError("Variable already exists.") 

116 

117 # We explicitly track these variables since `name` is not guaranteed to 

118 # be unique and disable manual tracking that the add_weight call does. 

119 with no_manual_dependency_tracking_scope(self._layer): 

120 var = self._layer.add_weight( 

121 name=name, 

122 shape=shape, 

123 dtype=dtype, 

124 initializer=initializer, 

125 trainable=self._trainable and trainable, 

126 use_resource=use_resource, 

127 ) 

128 if isinstance(var, tf.__internal__.tracking.Trackable): 

129 self._layer._track_trackable(var, feature_column.name + "/" + name) 

130 self._cols_to_vars_map[feature_column][name] = var 

131 return var 

132 

133 

134@tf_contextlib.contextmanager 

135def no_manual_dependency_tracking_scope(obj): 

136 """A context that disables manual dependency tracking for the given `obj`. 

137 

138 Sometimes library methods might track objects on their own and we might want 

139 to disable that and do the tracking on our own. One can then use this 

140 context manager to disable the tracking the library method does and do your 

141 own tracking. 

142 

143 For example: 

144 

145 class TestLayer(tf.keras.Layer): 

146 def build(): 

147 with no_manual_dependency_tracking_scope(self): 

148 var = self.add_weight("name1") # Creates a var and doesn't track it 

149 # We track variable with name `name2` 

150 self._track_trackable("name2", var) 

151 

152 Args: 

153 obj: A trackable object. 

154 

155 Yields: 

156 a scope in which the object doesn't track dependencies manually. 

157 """ 

158 

159 previous_value = getattr(obj, "_manual_tracking", True) 

160 obj._manual_tracking = False 

161 try: 

162 yield 

163 finally: 

164 obj._manual_tracking = previous_value 

165