Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_preprocessing_layer.py: 35%
202 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
17import abc
18import collections
20import numpy as np
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.keras import backend
28from tensorflow.python.keras.engine import data_adapter
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.keras.utils import version_utils
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import sparse_ops
34from tensorflow.python.ops import variables
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.trackable import base as trackable
37from tensorflow.python.util.tf_export import keras_export
40@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer')
41class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
42 """Base class for Preprocessing Layers.
44 **Don't use this class directly: it's an abstract base class!** You may
45 be looking for one of the many built-in
46 [preprocessing layers](https://keras.io/guides/preprocessing_layers/)
47 instead.
49 Preprocessing layers are layers whose state gets computed before model
50 training starts. They do not get updated during training.
51 Most preprocessing layers implement an `adapt()` method for state computation.
53 The `PreprocessingLayer` class is the base class you would subclass to
54 implement your own preprocessing layers.
56 Attributes:
57 streaming: Whether a layer can be adapted multiple times without resetting
58 the state of the layer.
59 """
60 _must_restore_from_config = True
62 def __init__(self, streaming=True, **kwargs):
63 super(PreprocessingLayer, self).__init__(**kwargs)
64 self._streaming = streaming
65 self._is_compiled = False
66 self._is_adapted = False
68 # Sets `is_adapted=False` when `reset_state` is called.
69 self._reset_state_impl = self.reset_state
70 self.reset_state = self._reset_state_wrapper
72 self._adapt_function = None
74 @property
75 def streaming(self):
76 """Whether `adapt` can be called twice without resetting the state."""
77 return self._streaming
79 @property
80 def is_adapted(self):
81 """Whether the layer has been fit to data already."""
82 return self._is_adapted
84 def update_state(self, data):
85 """Accumulates statistics for the preprocessing layer.
87 Arguments:
88 data: A mini-batch of inputs to the layer.
89 """
90 raise NotImplementedError
92 def reset_state(self): # pylint: disable=method-hidden
93 """Resets the statistics of the preprocessing layer."""
94 raise NotImplementedError
96 def merge_state(self, layers):
97 """Merge the statistics of multiple preprocessing layers.
99 This layer will contain the merged state.
101 Arguments:
102 layers: Layers whose statistics should be merge with the statistics of
103 this layer.
104 """
105 raise NotImplementedError
107 def finalize_state(self):
108 """Finalize the statistics for the preprocessing layer.
110 This method is called at the end of `adapt` or after restoring a serialized
111 preprocessing layer's state. This method handles any one-time operations
112 that should occur on the layer's state before `Layer.__call__`.
113 """
114 pass
116 def make_adapt_function(self):
117 """Creates a function to execute one step of `adapt`.
119 This method can be overridden to support custom adapt logic.
120 This method is called by `PreprocessingLayer.adapt`.
122 Typically, this method directly controls `tf.function` settings,
123 and delegates the actual state update logic to
124 `PreprocessingLayer.update_state`.
126 This function is cached the first time `PreprocessingLayer.adapt`
127 is called. The cache is cleared whenever `PreprocessingLayer.compile`
128 is called.
130 Returns:
131 Function. The function created by this method should accept a
132 `tf.data.Iterator`, retrieve a batch, and update the state of the
133 layer.
134 """
135 if self._adapt_function is not None:
136 return self._adapt_function
138 def adapt_step(iterator):
139 data = next(iterator)
140 self._adapt_maybe_build(data)
141 self.update_state(data)
143 if self._steps_per_execution.numpy().item() == 1:
144 adapt_fn = adapt_step
145 else:
147 def adapt_fn(iterator):
148 for _ in math_ops.range(self._steps_per_execution):
149 adapt_step(iterator)
151 if not self._run_eagerly:
152 adapt_fn = def_function.function(adapt_fn)
154 self._adapt_function = adapt_fn
155 return self._adapt_function
157 def compile(self, run_eagerly=None, steps_per_execution=None):
158 """Configures the layer for `adapt`.
160 Arguments:
161 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic
162 will not be wrapped in a `tf.function`. Recommended to leave this as
163 `None` unless your `Model` cannot be run inside a `tf.function`.
164 steps_per_execution: Int. Defaults to 1. The number of batches to run
165 during each `tf.function` call. Running multiple batches inside a
166 single `tf.function` call can greatly improve performance on TPUs or
167 small models with a large Python overhead.
168 """
169 if steps_per_execution is None:
170 steps_per_execution = 1
171 self._configure_steps_per_execution(steps_per_execution)
173 if run_eagerly is None:
174 run_eagerly = self.dynamic
175 self._run_eagerly = run_eagerly
177 self._is_compiled = True
179 def adapt(self, data, batch_size=None, steps=None, reset_state=True):
180 """Fits the state of the preprocessing layer to the data being passed.
182 After calling `adapt` on a layer, a preprocessing layer's state will not
183 update during training. In order to make preprocessing layers efficient in
184 any distribution context, they are kept constant with respect to any
185 compiled `tf.Graph`s that call the layer. This does not affect the layer use
186 when adapting each layer only once, but if you adapt a layer multiple times
187 you will need to take care to re-compile any compiled functions as follows:
189 * If you are adding a preprocessing layer to a `keras.Model`, you need to
190 call `model.compile` after each subsequent call to `adapt`.
191 * If you are calling a preprocessing layer inside `tf.data.Dataset.map`,
192 you should call `map` again on the input `tf.data.Dataset` after each
193 `adapt`.
194 * If you are using a `tf.function` directly which calls a preprocessing
195 layer, you need to call `tf.function` again on your callable after
196 each subsequent call to `adapt`.
198 `tf.keras.Model` example with multiple adapts:
200 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization(
201 ... axis=None)
202 >>> layer.adapt([0, 2])
203 >>> model = tf.keras.Sequential(layer)
204 >>> model.predict([0, 1, 2])
205 array([-1., 0., 1.], dtype=float32)
206 >>> layer.adapt([-1, 1])
207 >>> model.compile() # This is needed to re-compile model.predict!
208 >>> model.predict([0, 1, 2])
209 array([0., 1., 2.], dtype=float32)
211 `tf.data.Dataset` example with multiple adapts:
213 >>> layer = tf.keras.layers.experimental.preprocessing.Normalization(
214 ... axis=None)
215 >>> layer.adapt([0, 2])
216 >>> input_ds = tf.data.Dataset.range(3)
217 >>> normalized_ds = input_ds.map(layer)
218 >>> list(normalized_ds.as_numpy_iterator())
219 [array([-1.], dtype=float32),
220 array([0.], dtype=float32),
221 array([1.], dtype=float32)]
222 >>> layer.adapt([-1, 1])
223 >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
224 >>> list(normalized_ds.as_numpy_iterator())
225 [array([0.], dtype=float32),
226 array([1.], dtype=float32),
227 array([2.], dtype=float32)]
229 Arguments:
230 data: The data to train on. It can be passed either as a tf.data
231 Dataset, or as a numpy array.
232 batch_size: Integer or `None`.
233 Number of samples per state update.
234 If unspecified, `batch_size` will default to 32.
235 Do not specify the `batch_size` if your data is in the
236 form of datasets, generators, or `keras.utils.Sequence` instances
237 (since they generate batches).
238 steps: Integer or `None`.
239 Total number of steps (batches of samples)
240 When training with input tensors such as
241 TensorFlow data tensors, the default `None` is equal to
242 the number of samples in your dataset divided by
243 the batch size, or 1 if that cannot be determined. If x is a
244 `tf.data` dataset, and 'steps' is None, the epoch will run until
245 the input dataset is exhausted. When passing an infinitely
246 repeating dataset, you must specify the `steps` argument. This
247 argument is not supported with array inputs.
248 reset_state: Optional argument specifying whether to clear the state of
249 the layer at the start of the call to `adapt`, or whether to start
250 from the existing state. This argument may not be relevant to all
251 preprocessing layers: a subclass of PreprocessingLayer may choose to
252 throw if 'reset_state' is set to False.
253 """
254 _disallow_inside_tf_function('adapt')
255 if not version_utils.should_use_v2():
256 raise RuntimeError('`adapt` is only supported in tensorflow v2.') # pylint: disable=g-doc-exception
257 if not self.streaming and self._is_adapted and not reset_state:
258 raise ValueError('{} does not supporting calling `adapt` twice without '
259 'resetting the state.'.format(self.__class__.__name__))
260 if not self._is_compiled:
261 self.compile() # Compile with defaults.
262 if self.built and reset_state:
263 self.reset_state()
264 data_handler = data_adapter.DataHandler(
265 data,
266 batch_size=batch_size,
267 steps_per_epoch=steps,
268 epochs=1,
269 steps_per_execution=self._steps_per_execution,
270 distribute=False)
271 self._adapt_function = self.make_adapt_function()
272 for _, iterator in data_handler.enumerate_epochs():
273 with data_handler.catch_stop_iteration():
274 for _ in data_handler.steps():
275 self._adapt_function(iterator)
276 if data_handler.should_sync:
277 context.async_wait()
278 self.finalize_state()
279 self._is_adapted = True
281 def _reset_state_wrapper(self):
282 """Calls `reset_state` and sets `adapted` to `False`."""
283 self._reset_state_impl()
284 self._is_adapted = False
286 @trackable.no_automatic_dependency_tracking
287 def _configure_steps_per_execution(self, steps_per_execution):
288 self._steps_per_execution = variables.Variable(
289 steps_per_execution,
290 dtype='int64',
291 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
293 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
294 def _adapt_maybe_build(self, data):
295 if not self.built:
296 try:
297 # If this is a Numpy array or tensor, we can get shape from .shape.
298 # If not, an attribute error will be thrown.
299 data_shape = data.shape
300 data_shape_nones = tuple([None] * len(data.shape))
301 except AttributeError:
302 # The input has an unknown number of dimensions.
303 data_shape = None
304 data_shape_nones = None
306 # TODO (b/159261555): move this to base layer build.
307 batch_input_shape = getattr(self, '_batch_input_shape', None)
308 if batch_input_shape is None:
309 # Set the number of dimensions.
310 self._batch_input_shape = data_shape_nones
311 self.build(data_shape)
312 self.built = True
315# TODO(omalleyt): This class will be gradually replaced.
316class CombinerPreprocessingLayer(PreprocessingLayer):
317 """Base class for PreprocessingLayers that do computation using a Combiner.
319 This class provides several helper methods to make creating a
320 PreprocessingLayer easier. It assumes that the core of your computation will
321 be done via a Combiner object. Subclassing this class to create a
322 PreprocessingLayer allows your layer to be compatible with distributed
323 computation.
325 This class is compatible with Tensorflow 2.0+.
326 """
328 def __init__(self, combiner, **kwargs):
329 super(CombinerPreprocessingLayer, self).__init__(**kwargs)
330 self.state_variables = collections.OrderedDict()
331 self._combiner = combiner
332 self._adapt_accumulator = None
334 def reset_state(self): # pylint: disable=method-hidden
335 self._adapt_accumulator = None
337 @trackable.no_automatic_dependency_tracking
338 def update_state(self, data):
339 if self._adapt_accumulator is None:
340 self._adapt_accumulator = self._get_accumulator()
341 self._adapt_accumulator = self._combiner.compute(data,
342 self._adapt_accumulator)
344 def merge_state(self, layers):
345 accumulators = ([self._get_accumulator()] +
346 [l._get_accumulator() for l in layers]) # pylint: disable=protected-access
347 merged_accumulator = self._combiner.merge(accumulators)
348 self._set_accumulator(merged_accumulator)
350 def finalize_state(self):
351 if self._adapt_accumulator is not None:
352 self._set_accumulator(self._adapt_accumulator)
354 def compile(self, run_eagerly=None, steps_per_execution=None):
355 # TODO(omalleyt): Remove this once sublayers are switched to new APIs.
356 if run_eagerly is None:
357 run_eagerly = True
358 super(CombinerPreprocessingLayer, self).compile(
359 run_eagerly=run_eagerly, steps_per_execution=steps_per_execution)
361 def adapt(self, data, batch_size=None, steps=None, reset_state=True):
362 if not reset_state:
363 self._adapt_accumulator = self._combiner.restore(self._restore_updates())
364 super(CombinerPreprocessingLayer, self).adapt(
365 data, batch_size=batch_size, steps=steps, reset_state=reset_state)
367 def _add_state_variable(self,
368 name,
369 shape,
370 dtype,
371 initializer=None,
372 partitioner=None,
373 use_resource=None,
374 **kwargs):
375 """Add a variable that can hold state which is updated during adapt().
377 Args:
378 name: Variable name.
379 shape: Variable shape. Defaults to scalar if unspecified.
380 dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
381 initializer: initializer instance (callable).
382 partitioner: Partitioner to be passed to the `Trackable` API.
383 use_resource: Whether to use `ResourceVariable`
384 **kwargs: Additional keyword arguments. Accepted values are `getter` and
385 `collections`.
387 Returns:
388 The created variable.
389 """
390 weight = self.add_weight(
391 name=name,
392 shape=shape,
393 dtype=dtype,
394 initializer=initializer,
395 regularizer=None,
396 trainable=False,
397 constraint=None,
398 partitioner=partitioner,
399 use_resource=use_resource,
400 **kwargs)
401 # TODO(momernick): Do not allow collisions here.
402 self.state_variables[name] = weight
403 return weight
405 def _restore_updates(self):
406 """Recreates a dict of updates from the layer's weights."""
407 data_dict = {}
408 for name, var in self.state_variables.items():
409 data_dict[name] = var.numpy()
410 return data_dict
412 def _get_accumulator(self):
413 if self._is_adapted:
414 return self._combiner.restore(self._restore_updates())
415 else:
416 return None
418 def _set_accumulator(self, accumulator):
419 updates = self._combiner.extract(accumulator)
420 self._set_state_variables(updates)
421 self._adapt_accumulator = None # Reset accumulator from adapt.
423 def _set_state_variables(self, updates):
424 """Directly update the internal state of this Layer.
426 This method expects a string-keyed dict of {state_variable_name: state}. The
427 precise nature of the state, and the names associated, are describe by
428 the subclasses of CombinerPreprocessingLayer.
430 Args:
431 updates: A string keyed dict of weights to update.
433 Raises:
434 RuntimeError: if 'build()' was not called before 'set_processing_state'.
435 """
436 # TODO(momernick): Do we need to do any more input sanitization?
437 if not self.built:
438 raise RuntimeError('_set_state_variables() must be called after build().')
440 with ops.init_scope():
441 for var_name, value in updates.items():
442 self.state_variables[var_name].assign(value)
445def convert_to_list(values, sparse_default_value=None):
446 """Convert a TensorLike, CompositeTensor, or ndarray into a Python list."""
447 if tf_utils.is_ragged(values):
448 # There is a corner case when dealing with ragged tensors: if you get an
449 # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode,
450 # you can't call to_list() on it without evaluating it first. However,
451 # because we don't yet fully support composite tensors across Keras,
452 # backend.get_value() won't evaluate the tensor.
453 # TODO(momernick): Get Keras to recognize composite tensors as Tensors
454 # and then replace this with a call to backend.get_value.
455 if (isinstance(values, ragged_tensor.RaggedTensor) and
456 not context.executing_eagerly()):
457 values = backend.get_session(values).run(values)
458 values = values.to_list()
460 if isinstance(values,
461 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
462 if sparse_default_value is None:
463 if dtypes.as_dtype(values.values.dtype) == dtypes.string:
464 sparse_default_value = ''
465 else:
466 sparse_default_value = -1
467 dense_tensor = sparse_ops.sparse_tensor_to_dense(
468 values, default_value=sparse_default_value)
469 values = backend.get_value(dense_tensor)
471 if isinstance(values, ops.Tensor):
472 values = backend.get_value(values)
474 # We may get passed a ndarray or the code above may give us a ndarray.
475 # In either case, we want to force it into a standard python list.
476 if isinstance(values, np.ndarray):
477 values = values.tolist()
479 return values
482# TODO(omalleyt): This class will be gradually replaced.
483class Combiner(object):
484 """Functional object that defines a shardable computation.
486 This object defines functions required to create and manipulate data objects.
487 These data objects, referred to below as 'accumulators', are computation-
488 specific and may be implemented alongside concrete subclasses of Combiner
489 (if necessary - some computations may be simple enough that standard Python
490 types can be used as accumulators).
492 The intent for this class is that by describing computations in this way, we
493 can arbitrarily shard a dataset, perform computations on a subset, and then
494 merge the computation into a final result. This enables distributed
495 computation.
497 The combiner itself does not own any state - all computational state is owned
498 by the accumulator objects. This is so that we can have an arbitrary number of
499 Combiners (thus sharding the computation N ways) without risking any change
500 to the underlying computation. These accumulator objects are uniquely
501 associated with each Combiner; a Combiner defines what the accumulator object
502 should be and will only work with accumulators of that type.
503 """
504 __metaclass__ = abc.ABCMeta
506 def __repr__(self):
507 return '<{}>'.format(self.__class__.__name__)
509 @abc.abstractmethod
510 def compute(self, batch_values, accumulator=None):
511 """Compute a step in this computation, returning a new accumulator.
513 This method computes a step of the computation described by this Combiner.
514 If an accumulator is passed, the data in that accumulator is also used; so
515 compute(batch_values) results in f(batch_values), while
516 compute(batch_values, accumulator) results in
517 merge(f(batch_values), accumulator).
519 Args:
520 batch_values: A list of ndarrays representing the values of the inputs for
521 this step of the computation.
522 accumulator: the current accumulator. Can be None.
524 Returns:
525 An accumulator that includes the passed batch of inputs.
526 """
527 pass
529 @abc.abstractmethod
530 def merge(self, accumulators):
531 """Merge several accumulators to a single accumulator.
533 This method takes the partial values in several accumulators and combines
534 them into a single accumulator. This computation must not be order-specific
535 (that is, merge([a, b]) must return the same result as merge([b, a]).
537 Args:
538 accumulators: the accumulators to merge, as a list.
540 Returns:
541 A merged accumulator.
542 """
543 pass
545 @abc.abstractmethod
546 def extract(self, accumulator):
547 """Convert an accumulator into a dict of output values.
549 Args:
550 accumulator: The accumulator to convert.
552 Returns:
553 A dict of ndarrays representing the data in this accumulator.
554 """
555 pass
557 @abc.abstractmethod
558 def restore(self, output):
559 """Create an accumulator based on 'output'.
561 This method creates a new accumulator with identical internal state to the
562 one used to create the data in 'output'. This means that if you do
564 output_data = combiner.extract(accumulator_1)
565 accumulator_2 = combiner.restore(output_data)
567 then accumulator_1 and accumulator_2 will have identical internal state, and
568 computations using either of them will be equivalent.
570 Args:
571 output: The data output from a previous computation. Should be in the same
572 form as provided by 'extract_output'.
574 Returns:
575 A new accumulator.
576 """
577 pass
579 @abc.abstractmethod
580 def serialize(self, accumulator):
581 """Serialize an accumulator for a remote call.
583 This function serializes an accumulator to be sent to a remote process.
585 Args:
586 accumulator: The accumulator to serialize.
588 Returns:
589 A byte string representing the passed accumulator.
590 """
591 pass
593 @abc.abstractmethod
594 def deserialize(self, encoded_accumulator):
595 """Deserialize an accumulator received from 'serialize()'.
597 This function deserializes an accumulator serialized by 'serialize()'.
599 Args:
600 encoded_accumulator: A byte string representing an accumulator.
602 Returns:
603 The accumulator represented by the passed byte_string.
604 """
605 pass
608def _disallow_inside_tf_function(method_name):
609 """Disallow calling a method inside a `tf.function`."""
610 if ops.inside_function():
611 error_msg = (
612 'Detected a call to `PreprocessingLayer.{method_name}` inside a '
613 '`tf.function`. `PreprocessingLayer.{method_name} is a high-level '
614 'endpoint that manages its own `tf.function`. Please move the call '
615 'to `PreprocessingLayer.{method_name}` outside of all enclosing '
616 '`tf.function`s. Note that you can call a `PreprocessingLayer` '
617 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, '
618 'or update its state like: `layer.update_state(x)`.').format(
619 method_name=method_name)
620 raise RuntimeError(error_msg)