Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_preprocessing_layer.py: 32%

98 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the base ProcessingLayer and a subclass that uses Combiners.""" 

16 

17import abc 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.engine import data_adapter 

22from keras.src.engine.base_layer import Layer 

23from keras.src.utils import version_utils 

24 

25# isort: off 

26from tensorflow.python.eager import context 

27from tensorflow.python.util.tf_export import keras_export 

28from tensorflow.tools.docs import doc_controls 

29 

30keras_kpl_gauge = tf.__internal__.monitoring.BoolGauge( 

31 "/tensorflow/api/keras/layers/preprocessing", 

32 "keras preprocessing layers usage", 

33 "method", 

34) 

35 

36 

37@keras_export("keras.layers.experimental.preprocessing.PreprocessingLayer") 

38class PreprocessingLayer(Layer, metaclass=abc.ABCMeta): 

39 """Base class for Preprocessing Layers. 

40 

41 **Don't use this class directly: it's an abstract base class!** You may 

42 be looking for one of the many built-in 

43 [preprocessing layers](https://keras.io/guides/preprocessing_layers/) 

44 instead. 

45 

46 Preprocessing layers are layers whose state gets computed before model 

47 training starts. They do not get updated during training. Most 

48 preprocessing layers implement an `adapt()` method for state computation. 

49 

50 The `PreprocessingLayer` class is the base class you would subclass to 

51 implement your own preprocessing layers. 

52 """ 

53 

54 _must_restore_from_config = True 

55 

56 def __init__(self, **kwargs): 

57 super().__init__(**kwargs) 

58 self._is_compiled = False 

59 self._is_adapted = False 

60 

61 # Sets `is_adapted=False` when `reset_state` is called. 

62 self._reset_state_impl = self.reset_state 

63 self.reset_state = self._reset_state_wrapper 

64 

65 self._adapt_function = None 

66 

67 @property 

68 def is_adapted(self): 

69 """Whether the layer has been fit to data already.""" 

70 return self._is_adapted 

71 

72 @doc_controls.do_not_generate_docs 

73 def update_state(self, data): 

74 """Accumulates statistics for the preprocessing layer. 

75 

76 Arguments: 

77 data: A mini-batch of inputs to the layer. 

78 """ 

79 raise NotImplementedError 

80 

81 @doc_controls.do_not_generate_docs 

82 def reset_state(self): 

83 """Resets the statistics of the preprocessing layer.""" 

84 raise NotImplementedError 

85 

86 @doc_controls.do_not_generate_docs 

87 def finalize_state(self): 

88 """Finalize the statistics for the preprocessing layer. 

89 

90 This method is called at the end of `adapt` or after restoring a 

91 serialized preprocessing layer's state. This method handles any one-time 

92 operations that should occur on the layer's state before 

93 `Layer.__call__`. 

94 """ 

95 pass 

96 

97 @doc_controls.do_not_generate_docs 

98 def make_adapt_function(self): 

99 """Creates a function to execute one step of `adapt`. 

100 

101 This method can be overridden to support custom adapt logic. 

102 This method is called by `PreprocessingLayer.adapt`. 

103 

104 Typically, this method directly controls `tf.function` settings, 

105 and delegates the actual state update logic to 

106 `PreprocessingLayer.update_state`. 

107 

108 This function is cached the first time `PreprocessingLayer.adapt` 

109 is called. The cache is cleared whenever `PreprocessingLayer.compile` 

110 is called. 

111 

112 Returns: 

113 Function. The function created by this method should accept a 

114 `tf.data.Iterator`, retrieve a batch, and update the state of the 

115 layer. 

116 """ 

117 if self._adapt_function is not None: 

118 return self._adapt_function 

119 

120 def adapt_step(iterator): 

121 data = next(iterator) 

122 self._adapt_maybe_build(data) 

123 self.update_state(data) 

124 

125 if self._steps_per_execution.numpy().item() == 1: 

126 adapt_fn = adapt_step 

127 else: 

128 

129 def adapt_fn(iterator): 

130 for _ in tf.range(self._steps_per_execution): 

131 adapt_step(iterator) 

132 

133 if not self._run_eagerly: 

134 adapt_fn = tf.function(adapt_fn) 

135 

136 self._adapt_function = adapt_fn 

137 return self._adapt_function 

138 

139 def compile(self, run_eagerly=None, steps_per_execution=None): 

140 """Configures the layer for `adapt`. 

141 

142 Arguments: 

143 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s 

144 logic will not be wrapped in a `tf.function`. Recommended to leave 

145 this as `None` unless your `Model` cannot be run inside a 

146 `tf.function`. 

147 steps_per_execution: Int. Defaults to 1. The number of batches to run 

148 during each `tf.function` call. Running multiple batches inside a 

149 single `tf.function` call can greatly improve performance on TPUs or 

150 small models with a large Python overhead. 

151 """ 

152 if steps_per_execution is None: 

153 steps_per_execution = 1 

154 self._configure_steps_per_execution(steps_per_execution) 

155 

156 if run_eagerly is None: 

157 run_eagerly = self.dynamic 

158 self._run_eagerly = run_eagerly 

159 

160 self._is_compiled = True 

161 

162 def adapt(self, data, batch_size=None, steps=None): 

163 """Fits the state of the preprocessing layer to the data being passed. 

164 

165 After calling `adapt` on a layer, a preprocessing layer's state will not 

166 update during training. In order to make preprocessing layers efficient 

167 in any distribution context, they are kept constant with respect to any 

168 compiled `tf.Graph`s that call the layer. This does not affect the layer 

169 use when adapting each layer only once, but if you adapt a layer 

170 multiple times you will need to take care to re-compile any compiled 

171 functions as follows: 

172 

173 * If you are adding a preprocessing layer to a `keras.Model`, you need 

174 to call `model.compile` after each subsequent call to `adapt`. 

175 * If you are calling a preprocessing layer inside 

176 `tf.data.Dataset.map`, you should call `map` again on the input 

177 `tf.data.Dataset` after each `adapt`. 

178 * If you are using a `tf.function` directly which calls a preprocessing 

179 layer, you need to call `tf.function` again on your callable after 

180 each subsequent call to `adapt`. 

181 

182 `tf.keras.Model` example with multiple adapts: 

183 

184 >>> layer = tf.keras.layers.Normalization( 

185 ... axis=None) 

186 >>> layer.adapt([0, 2]) 

187 >>> model = tf.keras.Sequential(layer) 

188 >>> model.predict([0, 1, 2]) 

189 array([-1., 0., 1.], dtype=float32) 

190 >>> layer.adapt([-1, 1]) 

191 >>> model.compile() # This is needed to re-compile model.predict! 

192 >>> model.predict([0, 1, 2]) 

193 array([0., 1., 2.], dtype=float32) 

194 

195 `tf.data.Dataset` example with multiple adapts: 

196 

197 >>> layer = tf.keras.layers.Normalization( 

198 ... axis=None) 

199 >>> layer.adapt([0, 2]) 

200 >>> input_ds = tf.data.Dataset.range(3) 

201 >>> normalized_ds = input_ds.map(layer) 

202 >>> list(normalized_ds.as_numpy_iterator()) 

203 [array([-1.], dtype=float32), 

204 array([0.], dtype=float32), 

205 array([1.], dtype=float32)] 

206 >>> layer.adapt([-1, 1]) 

207 >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset. 

208 >>> list(normalized_ds.as_numpy_iterator()) 

209 [array([0.], dtype=float32), 

210 array([1.], dtype=float32), 

211 array([2.], dtype=float32)] 

212 

213 `adapt()` is meant only as a single machine utility to compute layer 

214 state. To analyze a dataset that cannot fit on a single machine, see 

215 [Tensorflow Transform]( 

216 https://www.tensorflow.org/tfx/transform/get_started) 

217 for a multi-machine, map-reduce solution. 

218 

219 Arguments: 

220 data: The data to train on. It can be passed either as a tf.data 

221 Dataset, or as a numpy array. 

222 batch_size: Integer or `None`. 

223 Number of samples per state update. If unspecified, 

224 `batch_size` will default to 32. Do not specify the 

225 `batch_size` if your data is in the form of datasets, 

226 generators, or `keras.utils.Sequence` instances (since they 

227 generate batches). 

228 steps: Integer or `None`. 

229 Total number of steps (batches of samples) 

230 When training with input tensors such as 

231 TensorFlow data tensors, the default `None` is equal to 

232 the number of samples in your dataset divided by 

233 the batch size, or 1 if that cannot be determined. If x is a 

234 `tf.data` dataset, and 'steps' is None, the epoch will run until 

235 the input dataset is exhausted. When passing an infinitely 

236 repeating dataset, you must specify the `steps` argument. This 

237 argument is not supported with array inputs. 

238 """ 

239 _disallow_inside_tf_function("adapt") 

240 if not version_utils.should_use_v2(): 

241 raise RuntimeError("`adapt` is only supported in tensorflow v2.") 

242 if not self._is_compiled: 

243 self.compile() # Compile with defaults. 

244 if self.built: 

245 self.reset_state() 

246 data_handler = data_adapter.DataHandler( 

247 data, 

248 batch_size=batch_size, 

249 steps_per_epoch=steps, 

250 epochs=1, 

251 steps_per_execution=self._steps_per_execution, 

252 distribute=False, 

253 ) 

254 self._adapt_function = self.make_adapt_function() 

255 for _, iterator in data_handler.enumerate_epochs(): 

256 with data_handler.catch_stop_iteration(): 

257 for _ in data_handler.steps(): 

258 self._adapt_function(iterator) 

259 if data_handler.should_sync: 

260 context.async_wait() 

261 self.finalize_state() 

262 self._is_adapted = True 

263 

264 def _reset_state_wrapper(self): 

265 """Calls `reset_state` and sets `adapted` to `False`.""" 

266 self._reset_state_impl() 

267 self._is_adapted = False 

268 

269 @tf.__internal__.tracking.no_automatic_dependency_tracking 

270 def _configure_steps_per_execution(self, steps_per_execution): 

271 self._steps_per_execution = tf.Variable( 

272 steps_per_execution, 

273 dtype="int64", 

274 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, 

275 ) 

276 

277 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`. 

278 def _adapt_maybe_build(self, data): 

279 if not self.built: 

280 try: 

281 # If this is a Numpy array or tensor, we can get shape from 

282 # .shape. If not, an attribute error will be thrown. 

283 data_shape = data.shape 

284 data_shape_nones = tuple([None] * len(data.shape)) 

285 except AttributeError: 

286 # The input has an unknown number of dimensions. 

287 data_shape = None 

288 data_shape_nones = None 

289 

290 # TODO (b/159261555): move this to base layer build. 

291 batch_input_shape = getattr(self, "_batch_input_shape", None) 

292 if batch_input_shape is None: 

293 # Set the number of dimensions. 

294 self._batch_input_shape = data_shape_nones 

295 self.build(data_shape) 

296 self.built = True 

297 

298 

299def _disallow_inside_tf_function(method_name): 

300 """Disallow calling a method inside a `tf.function`.""" 

301 if tf.inside_function(): 

302 error_msg = ( 

303 "Detected a call to `PreprocessingLayer.{method_name}` inside a " 

304 "`tf.function`. `PreprocessingLayer.{method_name} is a high-level " 

305 "endpoint that manages its own `tf.function`. Please move the call " 

306 "to `PreprocessingLayer.{method_name}` outside of all enclosing " 

307 "`tf.function`s. Note that you can call a `PreprocessingLayer` " 

308 "directly on `Tensor`s inside a `tf.function` like: `layer(x)`, " 

309 "or update its state like: `layer.update_state(x)`." 

310 ).format(method_name=method_name) 

311 raise RuntimeError(error_msg) 

312