Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/base_preprocessing_layer.py: 32%
98 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
17import abc
19import tensorflow.compat.v2 as tf
21from keras.src.engine import data_adapter
22from keras.src.engine.base_layer import Layer
23from keras.src.utils import version_utils
25# isort: off
26from tensorflow.python.eager import context
27from tensorflow.python.util.tf_export import keras_export
28from tensorflow.tools.docs import doc_controls
30keras_kpl_gauge = tf.__internal__.monitoring.BoolGauge(
31 "/tensorflow/api/keras/layers/preprocessing",
32 "keras preprocessing layers usage",
33 "method",
34)
37@keras_export("keras.layers.experimental.preprocessing.PreprocessingLayer")
38class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
39 """Base class for Preprocessing Layers.
41 **Don't use this class directly: it's an abstract base class!** You may
42 be looking for one of the many built-in
43 [preprocessing layers](https://keras.io/guides/preprocessing_layers/)
44 instead.
46 Preprocessing layers are layers whose state gets computed before model
47 training starts. They do not get updated during training. Most
48 preprocessing layers implement an `adapt()` method for state computation.
50 The `PreprocessingLayer` class is the base class you would subclass to
51 implement your own preprocessing layers.
52 """
54 _must_restore_from_config = True
56 def __init__(self, **kwargs):
57 super().__init__(**kwargs)
58 self._is_compiled = False
59 self._is_adapted = False
61 # Sets `is_adapted=False` when `reset_state` is called.
62 self._reset_state_impl = self.reset_state
63 self.reset_state = self._reset_state_wrapper
65 self._adapt_function = None
67 @property
68 def is_adapted(self):
69 """Whether the layer has been fit to data already."""
70 return self._is_adapted
72 @doc_controls.do_not_generate_docs
73 def update_state(self, data):
74 """Accumulates statistics for the preprocessing layer.
76 Arguments:
77 data: A mini-batch of inputs to the layer.
78 """
79 raise NotImplementedError
81 @doc_controls.do_not_generate_docs
82 def reset_state(self):
83 """Resets the statistics of the preprocessing layer."""
84 raise NotImplementedError
86 @doc_controls.do_not_generate_docs
87 def finalize_state(self):
88 """Finalize the statistics for the preprocessing layer.
90 This method is called at the end of `adapt` or after restoring a
91 serialized preprocessing layer's state. This method handles any one-time
92 operations that should occur on the layer's state before
93 `Layer.__call__`.
94 """
95 pass
97 @doc_controls.do_not_generate_docs
98 def make_adapt_function(self):
99 """Creates a function to execute one step of `adapt`.
101 This method can be overridden to support custom adapt logic.
102 This method is called by `PreprocessingLayer.adapt`.
104 Typically, this method directly controls `tf.function` settings,
105 and delegates the actual state update logic to
106 `PreprocessingLayer.update_state`.
108 This function is cached the first time `PreprocessingLayer.adapt`
109 is called. The cache is cleared whenever `PreprocessingLayer.compile`
110 is called.
112 Returns:
113 Function. The function created by this method should accept a
114 `tf.data.Iterator`, retrieve a batch, and update the state of the
115 layer.
116 """
117 if self._adapt_function is not None:
118 return self._adapt_function
120 def adapt_step(iterator):
121 data = next(iterator)
122 self._adapt_maybe_build(data)
123 self.update_state(data)
125 if self._steps_per_execution.numpy().item() == 1:
126 adapt_fn = adapt_step
127 else:
129 def adapt_fn(iterator):
130 for _ in tf.range(self._steps_per_execution):
131 adapt_step(iterator)
133 if not self._run_eagerly:
134 adapt_fn = tf.function(adapt_fn)
136 self._adapt_function = adapt_fn
137 return self._adapt_function
139 def compile(self, run_eagerly=None, steps_per_execution=None):
140 """Configures the layer for `adapt`.
142 Arguments:
143 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s
144 logic will not be wrapped in a `tf.function`. Recommended to leave
145 this as `None` unless your `Model` cannot be run inside a
146 `tf.function`.
147 steps_per_execution: Int. Defaults to 1. The number of batches to run
148 during each `tf.function` call. Running multiple batches inside a
149 single `tf.function` call can greatly improve performance on TPUs or
150 small models with a large Python overhead.
151 """
152 if steps_per_execution is None:
153 steps_per_execution = 1
154 self._configure_steps_per_execution(steps_per_execution)
156 if run_eagerly is None:
157 run_eagerly = self.dynamic
158 self._run_eagerly = run_eagerly
160 self._is_compiled = True
162 def adapt(self, data, batch_size=None, steps=None):
163 """Fits the state of the preprocessing layer to the data being passed.
165 After calling `adapt` on a layer, a preprocessing layer's state will not
166 update during training. In order to make preprocessing layers efficient
167 in any distribution context, they are kept constant with respect to any
168 compiled `tf.Graph`s that call the layer. This does not affect the layer
169 use when adapting each layer only once, but if you adapt a layer
170 multiple times you will need to take care to re-compile any compiled
171 functions as follows:
173 * If you are adding a preprocessing layer to a `keras.Model`, you need
174 to call `model.compile` after each subsequent call to `adapt`.
175 * If you are calling a preprocessing layer inside
176 `tf.data.Dataset.map`, you should call `map` again on the input
177 `tf.data.Dataset` after each `adapt`.
178 * If you are using a `tf.function` directly which calls a preprocessing
179 layer, you need to call `tf.function` again on your callable after
180 each subsequent call to `adapt`.
182 `tf.keras.Model` example with multiple adapts:
184 >>> layer = tf.keras.layers.Normalization(
185 ... axis=None)
186 >>> layer.adapt([0, 2])
187 >>> model = tf.keras.Sequential(layer)
188 >>> model.predict([0, 1, 2])
189 array([-1., 0., 1.], dtype=float32)
190 >>> layer.adapt([-1, 1])
191 >>> model.compile() # This is needed to re-compile model.predict!
192 >>> model.predict([0, 1, 2])
193 array([0., 1., 2.], dtype=float32)
195 `tf.data.Dataset` example with multiple adapts:
197 >>> layer = tf.keras.layers.Normalization(
198 ... axis=None)
199 >>> layer.adapt([0, 2])
200 >>> input_ds = tf.data.Dataset.range(3)
201 >>> normalized_ds = input_ds.map(layer)
202 >>> list(normalized_ds.as_numpy_iterator())
203 [array([-1.], dtype=float32),
204 array([0.], dtype=float32),
205 array([1.], dtype=float32)]
206 >>> layer.adapt([-1, 1])
207 >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
208 >>> list(normalized_ds.as_numpy_iterator())
209 [array([0.], dtype=float32),
210 array([1.], dtype=float32),
211 array([2.], dtype=float32)]
213 `adapt()` is meant only as a single machine utility to compute layer
214 state. To analyze a dataset that cannot fit on a single machine, see
215 [Tensorflow Transform](
216 https://www.tensorflow.org/tfx/transform/get_started)
217 for a multi-machine, map-reduce solution.
219 Arguments:
220 data: The data to train on. It can be passed either as a tf.data
221 Dataset, or as a numpy array.
222 batch_size: Integer or `None`.
223 Number of samples per state update. If unspecified,
224 `batch_size` will default to 32. Do not specify the
225 `batch_size` if your data is in the form of datasets,
226 generators, or `keras.utils.Sequence` instances (since they
227 generate batches).
228 steps: Integer or `None`.
229 Total number of steps (batches of samples)
230 When training with input tensors such as
231 TensorFlow data tensors, the default `None` is equal to
232 the number of samples in your dataset divided by
233 the batch size, or 1 if that cannot be determined. If x is a
234 `tf.data` dataset, and 'steps' is None, the epoch will run until
235 the input dataset is exhausted. When passing an infinitely
236 repeating dataset, you must specify the `steps` argument. This
237 argument is not supported with array inputs.
238 """
239 _disallow_inside_tf_function("adapt")
240 if not version_utils.should_use_v2():
241 raise RuntimeError("`adapt` is only supported in tensorflow v2.")
242 if not self._is_compiled:
243 self.compile() # Compile with defaults.
244 if self.built:
245 self.reset_state()
246 data_handler = data_adapter.DataHandler(
247 data,
248 batch_size=batch_size,
249 steps_per_epoch=steps,
250 epochs=1,
251 steps_per_execution=self._steps_per_execution,
252 distribute=False,
253 )
254 self._adapt_function = self.make_adapt_function()
255 for _, iterator in data_handler.enumerate_epochs():
256 with data_handler.catch_stop_iteration():
257 for _ in data_handler.steps():
258 self._adapt_function(iterator)
259 if data_handler.should_sync:
260 context.async_wait()
261 self.finalize_state()
262 self._is_adapted = True
264 def _reset_state_wrapper(self):
265 """Calls `reset_state` and sets `adapted` to `False`."""
266 self._reset_state_impl()
267 self._is_adapted = False
269 @tf.__internal__.tracking.no_automatic_dependency_tracking
270 def _configure_steps_per_execution(self, steps_per_execution):
271 self._steps_per_execution = tf.Variable(
272 steps_per_execution,
273 dtype="int64",
274 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
275 )
277 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
278 def _adapt_maybe_build(self, data):
279 if not self.built:
280 try:
281 # If this is a Numpy array or tensor, we can get shape from
282 # .shape. If not, an attribute error will be thrown.
283 data_shape = data.shape
284 data_shape_nones = tuple([None] * len(data.shape))
285 except AttributeError:
286 # The input has an unknown number of dimensions.
287 data_shape = None
288 data_shape_nones = None
290 # TODO (b/159261555): move this to base layer build.
291 batch_input_shape = getattr(self, "_batch_input_shape", None)
292 if batch_input_shape is None:
293 # Set the number of dimensions.
294 self._batch_input_shape = data_shape_nones
295 self.build(data_shape)
296 self.built = True
299def _disallow_inside_tf_function(method_name):
300 """Disallow calling a method inside a `tf.function`."""
301 if tf.inside_function():
302 error_msg = (
303 "Detected a call to `PreprocessingLayer.{method_name}` inside a "
304 "`tf.function`. `PreprocessingLayer.{method_name} is a high-level "
305 "endpoint that manages its own `tf.function`. Please move the call "
306 "to `PreprocessingLayer.{method_name}` outside of all enclosing "
307 "`tf.function`s. Note that you can call a `PreprocessingLayer` "
308 "directly on `Tensor`s inside a `tf.function` like: `layer(x)`, "
309 "or update its state like: `layer.update_state(x)`."
310 ).format(method_name=method_name)
311 raise RuntimeError(error_msg)