Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/core/lambda_layer.py: 20%
147 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Lambda layer."""
17import sys
18import textwrap
19import types as python_types
20import warnings
22import numpy as np
23import tensorflow.compat.v2 as tf
25from keras.src.engine.base_layer import Layer
26from keras.src.saving import serialization_lib
27from keras.src.utils import generic_utils
28from keras.src.utils import tf_inspect
29from keras.src.utils import tf_utils
31# isort: off
32from tensorflow.python.platform import tf_logging
33from tensorflow.python.util.tf_export import keras_export
36@keras_export("keras.layers.Lambda")
37class Lambda(Layer):
38 """Wraps arbitrary expressions as a `Layer` object.
40 The `Lambda` layer exists so that arbitrary expressions can be used
41 as a `Layer` when constructing Sequential
42 and Functional API models. `Lambda` layers are best suited for simple
43 operations or quick experimentation. For more advanced use cases, follow
44 [this guide](
45 https://www.tensorflow.org/guide/keras/custom_layers_and_models)
46 for subclassing `tf.keras.layers.Layer`.
48 WARNING: `tf.keras.layers.Lambda` layers have (de)serialization limitations!
50 The main reason to subclass `tf.keras.layers.Layer` instead of using a
51 `Lambda` layer is saving and inspecting a Model. `Lambda` layers
52 are saved by serializing the Python bytecode, which is fundamentally
53 non-portable. They should only be loaded in the same environment where
54 they were saved. Subclassed layers can be saved in a more portable way
55 by overriding their `get_config()` method. Models that rely on
56 subclassed Layers are also often easier to visualize and reason about.
58 Examples:
60 ```python
61 # add a x -> x^2 layer
62 model.add(Lambda(lambda x: x ** 2))
63 ```
65 ```python
66 # add a layer that returns the concatenation
67 # of the positive part of the input and
68 # the opposite of the negative part
70 def antirectifier(x):
71 x -= K.mean(x, axis=1, keepdims=True)
72 x = K.l2_normalize(x, axis=1)
73 pos = K.relu(x)
74 neg = K.relu(-x)
75 return K.concatenate([pos, neg], axis=1)
77 model.add(Lambda(antirectifier))
78 ```
80 **Note on Variables:**
82 While it is possible to use Variables with Lambda layers,
83 this practice is discouraged as it can easily lead to bugs.
84 For instance, consider the following layer:
86 ```python
87 scale = tf.Variable(1.)
88 scale_layer = tf.keras.layers.Lambda(lambda x: x * scale)
89 ```
91 Because `scale_layer` does not directly track the `scale` variable, it will
92 not appear in `scale_layer.trainable_weights` and will therefore not be
93 trained if `scale_layer` is used in a Model.
95 A better pattern is to write a subclassed Layer:
97 ```python
98 class ScaleLayer(tf.keras.layers.Layer):
99 def __init__(self, **kwargs):
100 super().__init__(**kwargs)
101 self.scale = tf.Variable(1.)
103 def call(self, inputs):
104 return inputs * self.scale
105 ```
107 In general, `Lambda` layers can be convenient for simple stateless
108 computation, but anything more complex should use a subclass Layer instead.
110 Args:
111 function: The function to be evaluated. Takes input tensor as first
112 argument.
113 output_shape: Expected output shape from function. This argument can be
114 inferred if not explicitly provided. Can be a tuple or function. If a
115 tuple, it only specifies the first dimension onward;
116 sample dimension is assumed either the same as the input:
117 `output_shape = (input_shape[0], ) + output_shape` or, the input is
118 `None` and the sample dimension is also `None`:
119 `output_shape = (None, ) + output_shape` If a function, it specifies the
120 entire shape as a function of the input shape:
121 `output_shape = f(input_shape)`
122 mask: Either None (indicating no masking) or a callable with the same
123 signature as the `compute_mask` layer method, or a tensor that will be
124 returned as output mask regardless of what the input is.
125 arguments: Optional dictionary of keyword arguments to be passed to the
126 function.
128 Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
129 integers, does not include the samples axis) when using this layer as the
130 first layer in a model.
132 Output shape: Specified by `output_shape` argument
133 """
135 @tf.__internal__.tracking.no_automatic_dependency_tracking
136 def __init__(
137 self, function, output_shape=None, mask=None, arguments=None, **kwargs
138 ):
139 super().__init__(**kwargs)
141 self.arguments = arguments or {}
142 self.function = function
144 if mask is not None:
145 self.supports_masking = True
146 self.mask = mask
147 self._output_shape = output_shape
149 # Warning on every invocation will be quite irksome in Eager mode.
150 self._already_warned = False
152 function_args = tf_inspect.getfullargspec(function).args
153 self._fn_expects_training_arg = "training" in function_args
154 self._fn_expects_mask_arg = "mask" in function_args
156 @tf_utils.shape_type_conversion
157 def compute_output_shape(self, input_shape):
158 if self._output_shape is None:
159 # Make use of existing autocomputation but provide Lambda-specific
160 # error message. This is always safe to run even when the outer
161 # context is Graph mode because Lambda layers don't have side
162 # effects such as `add_loss`.
163 with tf.__internal__.eager_context.eager_mode():
164 try:
165 return super().compute_output_shape(input_shape)
166 except NotImplementedError:
167 raise NotImplementedError(
168 "We could not automatically infer the shape of "
169 "the Lambda's output. Please specify `output_shape` "
170 "for this Lambda."
171 )
173 if callable(self._output_shape):
174 output_shapes = self._output_shape(input_shape)
175 return tf_utils.convert_shapes(output_shapes, to_tuples=False)
177 # Output shapes are passed directly and don't include batch dimension.
178 input_tensor_shape = tf_utils.convert_shapes(
179 input_shape, to_tuples=False
180 )
181 batch_size = (
182 tf.nest.flatten(input_tensor_shape)[0][0] if input_shape else None
183 )
185 def _add_batch(shape):
186 return tf.TensorShape([batch_size] + shape.as_list())
188 output_shapes = tf_utils.convert_shapes(
189 self._output_shape, to_tuples=False
190 )
191 return tf.nest.map_structure(_add_batch, output_shapes)
193 def call(self, inputs, mask=None, training=None):
194 # We must copy for thread safety, but it only needs to be a shallow
195 # copy.
196 kwargs = {k: v for k, v in self.arguments.items()}
197 if self._fn_expects_mask_arg:
198 kwargs["mask"] = mask
199 if self._fn_expects_training_arg:
200 kwargs["training"] = training
202 created_variables = []
204 def _variable_creator(next_creator, **kwargs):
205 var = next_creator(**kwargs)
206 created_variables.append(var)
207 return var
209 with tf.GradientTape(
210 watch_accessed_variables=True
211 ) as tape, tf.variable_creator_scope(_variable_creator):
212 result = self.function(inputs, **kwargs)
213 self._check_variables(created_variables, tape.watched_variables())
214 return result
216 def _check_variables(self, created_variables, accessed_variables):
217 if not created_variables and not accessed_variables:
218 # In the common case that a Lambda layer does not touch a Variable,
219 # we don't want to incur the runtime cost of assembling any state
220 # used for checking only to immediately discard it.
221 return
223 # Filter out the state variable in the tf.random.Generator, which is
224 # commonly used for initializer or droput. The variable is intentionally
225 # not tracked and it is not a trainable variable.
226 created_variables = [
227 v for v in created_variables if "StateVar" not in v.name
228 ]
230 tracked_weights = set(v.ref() for v in self.weights)
231 untracked_new_vars = [
232 v for v in created_variables if v.ref() not in tracked_weights
233 ]
234 if untracked_new_vars:
235 variable_str = "\n".join(f" {i}" for i in untracked_new_vars)
236 error_str = textwrap.dedent(
237 """
238 The following Variables were created within a Lambda layer ({name})
239 but are not tracked by said layer:
240 {variable_str}
241 The layer cannot safely ensure proper Variable reuse across multiple
242 calls, and consequently this behavior is disallowed for safety. Lambda
243 layers are not well suited to stateful computation; instead, writing a
244 subclassed Layer is the recommend way to define layers with
245 Variables."""
246 ).format(name=self.name, variable_str=variable_str)
247 raise ValueError(error_str)
249 untracked_used_vars = [
250 v for v in accessed_variables if v.ref() not in tracked_weights
251 ]
252 if untracked_used_vars and not self._already_warned:
253 variable_str = "\n".join(f" {i}" for i in untracked_used_vars)
254 self._warn(
255 textwrap.dedent(
256 """
257 The following Variables were used a Lambda layer's call ({name}), but
258 are not present in its tracked objects:
259 {variable_str}
260 It is possible that this is intended behavior, but it is more likely
261 an omission. This is a strong indication that this layer should be
262 formulated as a subclassed Layer rather than a Lambda layer."""
263 ).format(name=self.name, variable_str=variable_str)
264 )
265 self._already_warned = True
267 def _warn(self, msg):
268 # This method will be overridden in a unit test to raise an error,
269 # because self.assertWarns is not universally implemented.
270 return tf_logging.warning(msg)
272 def compute_mask(self, inputs, mask=None):
273 if callable(self.mask):
274 return self.mask(inputs, mask)
275 return self.mask
277 def get_config(self):
278 function_config = self._serialize_function_to_config(self.function)
279 output_shape_config = self._serialize_function_to_config(
280 self._output_shape, allow_raw=True
281 )
282 config = {
283 "function": function_config[0],
284 "function_type": function_config[1],
285 "module": function_config[2],
286 "output_shape": output_shape_config[0],
287 "output_shape_type": output_shape_config[1],
288 "output_shape_module": output_shape_config[2],
289 }
290 if self.mask is not None:
291 mask_config = self._serialize_function_to_config(self.mask)
292 config.update(
293 {
294 "mask": mask_config[0],
295 "mask_type": mask_config[1],
296 "mask_module": mask_config[2],
297 }
298 )
299 config["arguments"] = self.arguments
301 base_config = super().get_config()
302 return dict(list(base_config.items()) + list(config.items()))
304 def _serialize_function_to_config(self, inputs, allow_raw=False):
305 if isinstance(inputs, python_types.LambdaType):
306 output = generic_utils.func_dump(inputs)
307 output_type = "lambda"
308 module = inputs.__module__
309 elif callable(inputs):
310 output = inputs.__name__
311 output_type = "function"
312 module = inputs.__module__
313 elif allow_raw:
314 output = inputs
315 output_type = "raw"
316 module = None
317 else:
318 raise ValueError(
319 f"Invalid input for serialization, type: {type(inputs)} "
320 )
322 return output, output_type, module
324 @classmethod
325 def from_config(cls, config, custom_objects=None):
326 config = config.copy()
327 function = cls._parse_function_from_config(
328 config, custom_objects, "function", "module", "function_type"
329 )
331 output_shape = cls._parse_function_from_config(
332 config,
333 custom_objects,
334 "output_shape",
335 "output_shape_module",
336 "output_shape_type",
337 )
338 if "mask" in config:
339 mask = cls._parse_function_from_config(
340 config, custom_objects, "mask", "mask_module", "mask_type"
341 )
342 else:
343 mask = None
345 config["function"] = function
346 config["output_shape"] = output_shape
347 config["mask"] = mask
349 # If arguments were numpy array, they have been saved as
350 # list. We need to recover the ndarray
351 if "arguments" in config:
352 for key in config["arguments"]:
353 if isinstance(config["arguments"][key], dict):
354 arg_dict = config["arguments"][key]
355 if "type" in arg_dict and arg_dict["type"] == "ndarray":
356 # Overwrite the argument with its numpy translation
357 config["arguments"][key] = np.array(arg_dict["value"])
359 return cls(**config)
361 @classmethod
362 def _parse_function_from_config(
363 cls,
364 config,
365 custom_objects,
366 func_attr_name,
367 module_attr_name,
368 func_type_attr_name,
369 ):
370 globs = globals().copy()
371 module = config.pop(module_attr_name, None)
372 if module in sys.modules:
373 globs.update(sys.modules[module].__dict__)
374 elif module is not None:
375 # Note: we don't know the name of the function if it's a lambda.
376 warnings.warn(
377 "{} is not loaded, but a Lambda layer uses it. "
378 "It may cause errors.".format(module),
379 UserWarning,
380 stacklevel=2,
381 )
382 if custom_objects:
383 globs.update(custom_objects)
384 function_type = config.pop(func_type_attr_name)
385 if function_type == "function":
386 # Simple lookup in custom objects
387 function = serialization_lib.deserialize_keras_object(
388 config[func_attr_name],
389 custom_objects=custom_objects,
390 printable_module_name="function in Lambda layer",
391 )
392 elif function_type == "lambda":
393 if serialization_lib.in_safe_mode():
394 raise ValueError(
395 "Requested the deserialization of a Lambda layer with a "
396 "Python `lambda` inside it. "
397 "This carries a potential risk of arbitrary code execution "
398 "and thus it is disallowed by default. If you trust the "
399 "source of the saved model, you can pass `safe_mode=False` "
400 "to the loading function in order to allow "
401 "Lambda layer loading."
402 )
403 # /!\ Unsafe deserialization from bytecode! Danger! /!\
404 function = generic_utils.func_load(
405 config[func_attr_name], globs=globs
406 )
407 elif function_type == "raw":
408 function = config[func_attr_name]
409 else:
410 supported_types = ["function", "lambda", "raw"]
411 raise TypeError(
412 "Unsupported value for `function_type` argument. Received: "
413 f"function_type={function_type}. "
414 f"Expected one of {supported_types}"
415 )
416 return function