Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/core/tf_op_layer.py: 34%
199 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the TFOpLambda layer."""
16import tensorflow.compat.v2 as tf
18from keras.src import backend
19from keras.src.engine import keras_tensor
20from keras.src.engine.base_layer import Layer
22# isort: off
23from tensorflow.python.platform import tf_logging
24from tensorflow.python.util.tf_export import (
25 get_canonical_name_for_symbol,
26)
27from tensorflow.python.util.tf_export import (
28 get_symbol_from_name,
29)
32class ClassMethod(Layer):
33 """Wraps a TF API Class's class method in a `Layer` object.
35 It is inserted by the Functional API construction whenever users call
36 a supported TF Class's class method on KerasTensors.
38 This is useful in the case where users do something like:
39 x = keras.Input(...)
40 y = keras.Input(...)
41 out = tf.RaggedTensor.from_row_splits(x, y)
42 """
44 @tf.__internal__.tracking.no_automatic_dependency_tracking
45 def __init__(self, cls_ref, method_name, **kwargs):
46 self.cls_ref = cls_ref
47 self.method_name = method_name
48 self.cls_symbol = get_canonical_name_for_symbol(
49 self.cls_ref, add_prefix_to_v1_names=True
50 ) or get_canonical_name_for_symbol(
51 self.cls_ref, api_name="keras", add_prefix_to_v1_names=True
52 )
53 if "name" not in kwargs:
54 kwargs["name"] = backend.unique_object_name(
55 "tf." + self.cls_symbol + "." + self.method_name,
56 zero_based=True,
57 avoid_observed_names=True,
58 )
59 kwargs["autocast"] = False
61 # Do not individually trace op layers in the SavedModel.
62 self._must_restore_from_config = True
64 super().__init__(**kwargs)
66 # Preserve all argument data structures when saving/loading a config
67 # (e.g., don't unnest lists that contain one element)
68 self._preserve_input_structure_in_config = True
70 self._call_spec.expects_training_arg = False
71 self._call_spec.expects_mask_arg = False
73 def call(self, args, kwargs):
74 return getattr(self.cls_ref, self.method_name)(*args, **kwargs)
76 def get_config(self):
77 if not self.cls_symbol:
78 raise ValueError(
79 "This Keras class method conversion tried to convert "
80 f"a method belonging to class {self.cls_symbol}, a class "
81 "that is not publicly exposed in the TensorFlow API. "
82 "To ensure cross-version compatibility of Keras models "
83 "that use op layers, only op layers produced from "
84 "public TensorFlow API symbols can be serialized."
85 )
87 config = {
88 "cls_symbol": self.cls_symbol,
89 "method_name": self.method_name,
90 }
91 base_config = super().get_config()
92 return dict(list(base_config.items()) + list(config.items()))
94 @classmethod
95 def from_config(cls, config, custom_objects=None):
96 config = config.copy()
97 symbol_name = config.pop("cls_symbol")
98 cls_ref = get_symbol_from_name(symbol_name)
99 if not cls_ref:
100 raise ValueError(
101 f"TensorFlow symbol `{symbol_name}` could not be found."
102 )
104 config["cls_ref"] = cls_ref
106 return cls(**config)
109class KerasOpDispatcher(tf.__internal__.dispatch.GlobalOpDispatcher):
110 """A global dispatcher that allows building a functional model with TF
111 Ops."""
113 def handle(self, op, args, kwargs):
114 """Handle the specified operation with the specified arguments."""
115 if any(
116 isinstance(x, keras_tensor.KerasTensor)
117 for x in tf.nest.flatten([args, kwargs])
118 ):
119 return TFOpLambda(op)(*args, **kwargs)
120 else:
121 return self.NOT_SUPPORTED
124KerasOpDispatcher().register()
127class InstanceProperty(Layer):
128 """Wraps an instance property access (e.g.
130 `x.foo`) in a Keras Layer.
132 This layer takes an attribute name `attr_name` in the constructor and,
133 when called on input tensor `obj` returns `obj.attr_name`.
135 KerasTensors specialized for specific extension types use it to
136 represent instance property accesses on the represented object in the
137 case where the property needs to be dynamically accessed as opposed to
138 being statically computed from the typespec, e.g.
140 x = keras.Input(..., ragged=True)
141 out = x.flat_values
142 """
144 @tf.__internal__.tracking.no_automatic_dependency_tracking
145 def __init__(self, attr_name, **kwargs):
146 self.attr_name = attr_name
148 if "name" not in kwargs:
149 kwargs["name"] = backend.unique_object_name(
150 "input." + self.attr_name,
151 zero_based=True,
152 avoid_observed_names=True,
153 )
154 kwargs["autocast"] = False
156 # Do not individually trace op layers in the SavedModel.
157 self._must_restore_from_config = True
159 super().__init__(**kwargs)
161 # Preserve all argument data structures when saving/loading a config
162 # (e.g., don't unnest lists that contain one element)
163 self._preserve_input_structure_in_config = True
165 def call(self, obj):
166 return getattr(obj, self.attr_name)
168 def get_config(self):
169 config = {"attr_name": self.attr_name}
170 base_config = super().get_config()
171 return dict(list(base_config.items()) + list(config.items()))
173 @classmethod
174 def from_config(cls, config, custom_objects=None):
175 return cls(**config)
178class InstanceMethod(InstanceProperty):
179 """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer.
181 This layer takes an attribute name `attr_name` in the constructor and,
182 when called on input tensor `obj` with additional arguments `args` and
183 `kwargs` returns `obj.attr_name(*args, **kwargs)`.
185 KerasTensors specialized for specific extension types use it to
186 represent dynamic instance method calls on the represented object, e.g.
188 x = keras.Input(..., ragged=True)
189 new_values = keras.Input(...)
190 out = x.with_values(new_values)
191 """
193 def call(self, obj, args, kwargs):
194 method = getattr(obj, self.attr_name)
195 return method(*args, **kwargs)
198class TFOpLambda(Layer):
199 """Wraps TF API symbols in a `Layer` object.
201 It is inserted by the Functional API construction whenever users call
202 a supported TF symbol on KerasTensors.
204 Like Lambda layers, this layer tries to raise warnings when it detects users
205 explicitly use variables in the call. (To let them know
206 that the layer will not capture the variables).
208 This is useful in the case where users do something like:
209 x = keras.Input(...)
210 y = tf.Variable(...)
211 out = x * tf_variable
212 """
214 @tf.__internal__.tracking.no_automatic_dependency_tracking
215 def __init__(self, function, **kwargs):
216 self.function = function
217 self.symbol = get_canonical_name_for_symbol(
218 self.function, add_prefix_to_v1_names=True
219 ) or get_canonical_name_for_symbol(
220 self.function, api_name="keras", add_prefix_to_v1_names=True
221 )
222 if "name" not in kwargs:
223 # Generate a name.
224 # TFOpLambda layers avoid already-observed names,
225 # because users cannot easily control the generated names.
226 # Without this avoidance, users would be more likely to run
227 # into unavoidable duplicate layer name collisions.
228 # (For standard layers users could just set `name` when creating the
229 # layer to work around a collision, but they can't do that for
230 # auto-generated layers)
231 if self.symbol:
232 name = "tf." + self.symbol
233 else:
234 name = self.function.__name__
235 kwargs["name"] = backend.unique_object_name(
236 name, zero_based=True, avoid_observed_names=True
237 )
238 kwargs["autocast"] = False
240 # Decorate the function to produce this layer's call method
241 def _call_wrapper(*args, **kwargs):
242 return self._call_wrapper(*args, **kwargs)
244 self.call = tf.__internal__.decorator.make_decorator(
245 function, _call_wrapper
246 )
248 # Do not individually trace op layers in the SavedModel.
249 self._must_restore_from_config = True
251 super().__init__(**kwargs)
253 # Preserve all argument data structures when saving/loading a config
254 # (e.g., don't unnest lists that contain one element)
255 self._preserve_input_structure_in_config = True
257 # Warning on every invocation will be quite irksome in Eager mode.
258 self._already_warned = False
260 self._call_spec.expects_training_arg = False
261 self._call_spec.expects_mask_arg = False
263 def _call_wrapper(self, *args, **kwargs):
264 created_variables = []
266 def _variable_creator(next_creator, **creator_kwargs):
267 var = next_creator(**creator_kwargs)
268 created_variables.append(var)
269 return var
271 with tf.GradientTape(
272 watch_accessed_variables=True
273 ) as tape, tf.variable_creator_scope(_variable_creator):
274 # We explicitly drop `name` arguments here,
275 # to guard against the case where an op explicitly has a
276 # `name` passed (which is susceptible to producing
277 # multiple ops w/ the same name when the layer is reused)
278 kwargs.pop("name", None)
279 result = self.function(*args, **kwargs)
280 self._check_variables(created_variables, tape.watched_variables())
281 return result
283 def _check_variables(self, created_variables, accessed_variables):
284 if not created_variables and not accessed_variables:
285 # In the common case that a Lambda layer does not touch a Variable,
286 # we don't want to incur the runtime cost of assembling any state
287 # used for checking only to immediately discard it.
288 return
290 tracked_weights = set(v.ref() for v in self.weights)
291 untracked_new_vars = [
292 v for v in created_variables if v.ref() not in tracked_weights
293 ]
294 if untracked_new_vars:
295 variable_str = "\n".join(f" {i}" for i in untracked_new_vars)
296 raise ValueError(
297 "The following Variables were created within a Lambda layer "
298 f"({self.name}) but are not tracked by said layer: "
299 f"{variable_str}\n"
300 "The layer cannot safely ensure proper Variable reuse "
301 "across multiple calls, and consequently this behavior "
302 "is disallowed for safety reasons. Lambda layers are "
303 "not well suited for stateful computation; instead, "
304 "writing a subclassed Layer is the recommend "
305 "way to define layers with Variables."
306 )
308 untracked_used_vars = [
309 v for v in accessed_variables if v.ref() not in tracked_weights
310 ]
311 if untracked_used_vars and not self._already_warned:
312 variable_str = "\n".join(f" {i}" for i in untracked_used_vars)
313 self._warn(
314 "The following Variables were used in a Lambda layer's call "
315 f"({self.name}), but are not present in its tracked objects: "
316 f"{variable_str}. This is a strong indication that the Lambda "
317 "layer should be rewritten as a subclassed Layer."
318 )
319 self._already_warned = True
321 def _warn(self, msg):
322 # This method will be overridden in a unit test to raise an error,
323 # because self.assertWarns is not universally implemented.
324 return tf_logging.warning(msg)
326 def get_config(self):
327 if not self.symbol:
328 raise ValueError(
329 f"This Keras op layer was generated from {self.function}, a "
330 "method that is not publicly exposed in the TensorFlow API. "
331 "This may have happened if the method was explicitly "
332 "decorated to add dispatching support, and it was used "
333 "during Functional model construction. "
334 "To ensure cross-version compatibility of Keras models "
335 "that use op layers, only op layers produced from "
336 "public TensorFlow API symbols can be serialized."
337 )
338 config = {"function": self.symbol}
340 base_config = super().get_config()
341 return dict(list(base_config.items()) + list(config.items()))
343 @classmethod
344 def from_config(cls, config, custom_objects=None):
345 config = config.copy()
346 symbol_name = config["function"]
347 function = get_symbol_from_name(symbol_name)
348 if not function:
349 raise ValueError(f"TF symbol `{symbol_name}` could not be found.")
351 config["function"] = function
353 return cls(**config)
356def _delegate_property(keras_tensor_cls, property_name):
357 """Register property on a KerasTensor class.
359 Calling this multiple times with the same arguments should be a no-op.
361 This method exposes a property on the KerasTensor class that will use an
362 `InstanceProperty` layer to access the property on the represented
363 intermediate values in the model.
365 Args:
366 keras_tensor_cls: The KerasTensor subclass that should expose the
367 property.
368 property_name: The name of the property to expose and delegate to the
369 represented (Composite)Tensor.
370 """
371 # We use a lambda because we can't create a Keras layer at import time
372 # due to dynamic layer class versioning.
373 property_access = property(
374 lambda self: InstanceProperty(property_name)(self)
375 )
376 setattr(keras_tensor_cls, property_name, property_access)
379def _delegate_method(keras_tensor_cls, method_name):
380 """Register method on a KerasTensor class.
382 Calling this function times with the same arguments should be a no-op.
384 This method exposes an instance method on the KerasTensor class that will
385 use an `InstanceMethod` layer to run the desired method on the represented
386 intermediate values in the model.
388 Args:
389 keras_tensor_cls: The KerasTensor subclass that should expose the
390 property.
391 method_name: The name of the method to expose and delegate to the
392 represented (Composite)Tensor.
393 """
395 def delegate(self, *args, **kwargs):
396 return InstanceMethod(method_name)(self, args, kwargs)
398 setattr(keras_tensor_cls, method_name, delegate)
401# We do not support the `uniform_row_length` property because it
402# returns either `None` or an int tensor, and code that relies on it tends
403# to check `is None` directly. Delegating it here would always return a
404# `KerasTensor`, regardless of what can be statically inferred. This would
405# never equal `None`, breaking code that expects it to be partially-static
406# in unpredictable ways.
407for ragged_property in [
408 "values",
409 "flat_values",
410 "row_splits",
411 "nested_row_splits",
412]:
413 _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property)
415for ragged_method_name in [
416 "value_rowids",
417 "nested_value_rowids",
418 "nrows",
419 "row_starts",
420 "row_limits",
421 "row_lengths",
422 "nested_row_lengths",
423 "bounding_shape",
424 "with_values",
425 "with_flat_values",
426 "with_row_splits_dtype",
427 "merge_dims",
428 "to_tensor",
429 "to_sparse",
430]:
431 _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name)
433for sparse_property in [
434 "indices",
435 "values",
436 "dense_shape",
437]:
438 _delegate_property(keras_tensor.SparseKerasTensor, sparse_property)
440for sparse_method in [
441 "with_values",
442]:
443 _delegate_method(keras_tensor.SparseKerasTensor, sparse_method)
446class TFClassMethodDispatcher(tf.__internal__.dispatch.OpDispatcher):
447 """A class method dispatcher that allows building a functional model with TF
448 class methods."""
450 def __init__(self, cls, method_name):
451 self.cls = cls
452 self.method_name = method_name
454 def handle(self, args, kwargs):
455 """Handle the specified operation with the specified arguments."""
456 if any(
457 isinstance(x, keras_tensor.KerasTensor)
458 for x in tf.nest.flatten([args, kwargs])
459 ):
460 return ClassMethod(self.cls, self.method_name)(args[1:], kwargs)
461 else:
462 return self.NOT_SUPPORTED
465for ragged_class_method in [
466 "from_value_rowids",
467 "from_row_splits",
468 "from_row_lengths",
469 "from_row_starts",
470 "from_row_limits",
471 "from_uniform_row_length",
472 "from_nested_value_rowids",
473 "from_nested_row_splits",
474 "from_nested_row_lengths",
475 "from_tensor",
476 "from_sparse",
477]:
478 TFClassMethodDispatcher(tf.RaggedTensor, ragged_class_method).register(
479 getattr(tf.RaggedTensor, ragged_class_method)
480 )
483class SlicingOpLambda(TFOpLambda):
484 """Wraps TF API symbols in a `Layer` object.
486 It is inserted by the Functional API construction whenever users call
487 a supported TF symbol on KerasTensors.
489 Like Lambda layers, this layer tries to raise warnings when it detects users
490 explicitly use variables in the call. (To let them know
491 that the layer will not capture the variables).
493 This is useful in the case where users do something like:
494 x = keras.Input(...)
495 y = tf.Variable(...)
496 out = x * tf_variable
497 """
499 @tf.__internal__.tracking.no_automatic_dependency_tracking
500 def __init__(self, function, **kwargs):
501 super().__init__(function, **kwargs)
503 original_call = self.call
505 # Decorate the function to produce this layer's call method
506 def _call_wrapper(*args, **kwargs):
507 # Turn any slice dicts in the args back into `slice` objects.
508 # This conversion cannot use nest.flatten/map_structure,
509 # because dicts are flattened by nest while slices aren't.
510 # So, map_structure would only see the individual elements in the
511 # dict.
512 # This can't use map_structure_up_to either because the
513 # 'shallowness' of the shallow tree would have to vary depending on
514 # if only one dim or multiple are being sliced.
515 new_args = []
516 for arg in args:
517 arg = _dict_to_slice(arg)
518 if isinstance(arg, (list, tuple)):
519 new_arg = []
520 for sub_arg in arg:
521 new_arg.append(_dict_to_slice(sub_arg))
522 arg = new_arg
523 new_args.append(arg)
525 # Handle the kwargs too.
526 new_kwargs = {}
527 for key, value in kwargs.items():
528 value = _dict_to_slice(value)
529 if isinstance(value, (list, tuple)):
530 new_value = []
531 for v in value:
532 new_value.append(_dict_to_slice(v))
533 value = new_value
534 new_kwargs[key] = value
536 return original_call(*new_args, **new_kwargs)
538 self.call = tf.__internal__.decorator.make_decorator(
539 original_call, _call_wrapper
540 )
543def _slice_to_dict(x):
544 if isinstance(x, slice):
545 return {"start": x.start, "stop": x.stop, "step": x.step}
546 return x
549def _dict_to_slice(x):
550 if isinstance(x, dict):
551 return slice(x["start"], x["stop"], x["step"])
552 return x
555class TFSlicingOpDispatcher(tf.__internal__.dispatch.OpDispatcher):
556 """A global dispatcher that allows building a functional model with TF
557 Ops."""
559 def __init__(self, op):
560 self.op = op
562 def handle(self, args, kwargs):
563 """Handle the specified operation with the specified arguments."""
564 args = tf.nest.map_structure(_slice_to_dict, args)
565 kwargs = tf.nest.map_structure(_slice_to_dict, kwargs)
566 if any(
567 isinstance(x, keras_tensor.KerasTensor)
568 for x in tf.nest.flatten([args, kwargs])
569 ):
570 return SlicingOpLambda(self.op)(*args, **kwargs)
571 else:
572 return self.NOT_SUPPORTED
575for slicing_op in [
576 tf.__operators__.getitem,
577 tf.compat.v1.boolean_mask,
578 tf.boolean_mask,
579 tf.__operators__.ragged_getitem,
580]:
581 TFSlicingOpDispatcher(slicing_op).register(slicing_op)