Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_utils.py: 16%
82 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
17import numpy as np
18import tensorflow.compat.v2 as tf
20from keras.src.utils import generic_utils
23def slice_arrays(arrays, indices, contiguous=True):
24 """Slices batches out of provided arrays (workaround for eager tensors).
26 Unfortunately eager tensors don't have the same slicing behavior as
27 Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
28 hence we cannot use `generic_utils.slice_arrays` directly
29 and we have to implement this workaround based on `concat`. This has a
30 performance cost.
32 Args:
33 arrays: Single array or list of arrays.
34 indices: List of indices in the array that should be included in the
35 output batch.
36 contiguous: Boolean flag indicating whether the indices are contiguous.
38 Returns:
39 Slice of data (either single array or list of arrays).
40 """
41 converted_to_list = False
42 if not isinstance(arrays, list):
43 converted_to_list = True
44 arrays = [arrays]
45 if any(tf.is_tensor(x) for x in arrays):
46 if not contiguous:
47 entries = [[x[i : i + 1] for i in indices] for x in arrays]
48 slices = [tf.concat(x, axis=0) for x in entries]
49 else:
50 slices = [x[indices[0] : indices[-1] + 1] for x in arrays]
51 else:
52 slices = generic_utils.slice_arrays(arrays, indices)
54 if converted_to_list:
55 slices = slices[0]
56 return slices
59def handle_partial_sample_weights(
60 outputs, sample_weights, sample_weight_modes, check_all_flat=False
61):
62 """Adds 1.0 as sample weights for the outputs for which there is no weight.
64 Args:
65 outputs: List of model outputs.
66 sample_weights: List of sample weight inputs.
67 sample_weight_modes: List of sample weight modes or None.
68 check_all_flat: Ensure that inputs are not nested structures. This is not
69 a free check, so we may not want to run it eagerly every iteration.
71 Returns:
72 Tuple of sample weights, one sample weight for every output, and booleans
73 describing the raw sample weights.
74 """
75 if not isinstance(sample_weights, (list, tuple)):
76 any_sample_weight = sample_weights is not None
77 partial_sample_weight = any_sample_weight and sample_weights is None
78 else:
79 any_sample_weight = sample_weights is not None and any(
80 w is not None for w in sample_weights
81 )
82 partial_sample_weight = any_sample_weight and any(
83 w is None for w in sample_weights
84 )
86 if not any_sample_weight:
87 return None, any_sample_weight, partial_sample_weight
89 if not partial_sample_weight:
90 return sample_weights, any_sample_weight, partial_sample_weight
92 if check_all_flat:
93 tf.nest.assert_same_structure(
94 list_to_tuple(sample_weights),
95 list_to_tuple(tf.nest.flatten(sample_weights)),
96 )
97 tf.nest.assert_same_structure(
98 list_to_tuple(outputs), list_to_tuple(tf.nest.flatten(outputs))
99 )
100 if sample_weight_modes is not None:
101 tf.nest.assert_same_structure(
102 sample_weight_modes, tf.nest.flatten(sample_weight_modes)
103 )
105 new_sample_weights = []
106 for i, sw in enumerate(sample_weights):
107 if sw is None:
108 as_numpy = isinstance(outputs[i], np.ndarray)
109 output = outputs[i]
110 output_shape = output.shape if as_numpy else tf.shape(output)
112 is_temporal = (
113 sample_weight_modes is not None
114 and sample_weight_modes[i] == "temporal"
115 )
116 sw_shape = (
117 (output_shape[0], output_shape[1])
118 if is_temporal
119 else (output_shape[0],)
120 )
122 new_sample_weights.append(
123 np.ones(sw_shape) if as_numpy else tf.ones(sw_shape)
124 )
126 else:
127 new_sample_weights.append(sw)
128 return (
129 list_to_tuple(new_sample_weights),
130 any_sample_weight,
131 partial_sample_weight,
132 )
135class RespectCompiledTrainableState:
136 """Set and restore trainable state if it has changed since compile.
138 The keras API guarantees that the value of each Layer's `trainable` property
139 at `Model.compile` time will be used when training that model. In order to
140 respect this requirement, it may be necessary to set the trainable value of
141 layers to their compile time values before beginning a training endpoint and
142 restore the values before returning from said endpoint. This scope checks if
143 any layer's trainable state has changed since Model compile, and performs
144 this set and un-set bookkeeping.
146 However, the trainable state of a layer changes quite infrequently, if ever,
147 for many kinds of workflows. Moreover, updating every layer in a model is an
148 expensive operation. As a result, we will only explicitly set and unset the
149 trainable state of a model if a trainable value has changed since compile.
150 """
152 def __init__(self, model):
153 self._model = model
154 self._current_trainable_state = None
155 self._compiled_trainable_state = None
156 self._should_set_trainable = False
158 def __enter__(self):
159 self._current_trainable_state = self._model._get_trainable_state()
160 self._compiled_trainable_state = self._model._compiled_trainable_state
162 # Check to see if any layer's trainable state has changed since
163 # `compile`.
164 for layer, trainable in self._compiled_trainable_state.items():
165 if (
166 layer in self._current_trainable_state
167 and trainable != self._current_trainable_state[layer]
168 ):
169 self._should_set_trainable = True
170 break
172 # If so, restore the model to its compiled state.
173 if self._should_set_trainable:
174 self._model._set_trainable_state(self._compiled_trainable_state)
176 def __exit__(self, type_arg, value_arg, traceback_arg):
177 # If we set the values to their compiled state in __enter__, we need to
178 # restore the original values before leaving the scope.
179 if self._should_set_trainable:
180 self._model._set_trainable_state(self._current_trainable_state)
181 return False # False values do not suppress exceptions
184# Allow use of methods not exposed to the user.
187def get_input_shape_and_dtype(layer):
188 """Retrieves input shape and input dtype of layer if applicable.
190 Args:
191 layer: Layer (or model) instance.
193 Returns:
194 Tuple (input_shape, input_dtype). Both could be None if the layer
195 does not have a defined input shape.
197 Raises:
198 ValueError: in case an empty Sequential or Functional model is passed.
199 """
201 def _is_graph_model(layer):
202 return (
203 hasattr(layer, "_is_graph_network") and layer._is_graph_network
204 ) or layer.__class__.__name__ == "Sequential"
206 # In case of nested models: recover the first layer
207 # of the deepest model to infer input shape and dtype.
208 # Subclassed Models may not have been built so can't be checked.
209 while _is_graph_model(layer):
210 if not layer.layers:
211 raise ValueError("An empty Model cannot be used as a Layer.")
212 layer = layer.layers[0]
214 if getattr(layer, "_batch_input_shape", None):
215 return layer._batch_input_shape, layer.dtype
216 return None, None
219def get_static_batch_size(layer):
220 """Gets the static batch size of a Layer.
222 Args:
223 layer: a `Layer` instance.
225 Returns:
226 The static batch size of a Layer.
227 """
228 batch_input_shape, _ = get_input_shape_and_dtype(layer)
229 if batch_input_shape is not None:
230 return tf.compat.v1.Dimension(batch_input_shape[0]).value
231 return None
234def list_to_tuple(maybe_list):
235 """Datasets will stack the list of tensor, so switch them to tuples."""
236 if isinstance(maybe_list, list):
237 return tuple(maybe_list)
238 return maybe_list