Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/version_utils.py: 35%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities for Keras classes with v1 and v2 versions."""
18import tensorflow.compat.v2 as tf
20from keras.src.utils.generic_utils import LazyLoader
22# TODO(b/134426265): Switch back to single-quotes once the issue
23# with copybara is fixed.
25training = LazyLoader("training", globals(), "keras.src.engine.training")
26training_v1 = LazyLoader("training_v1", globals(), "keras.src.engine.training_v1")
27base_layer = LazyLoader("base_layer", globals(), "keras.src.engine.base_layer")
28base_layer_v1 = LazyLoader(
29 "base_layer_v1", globals(), "keras.src.engine.base_layer_v1"
30)
31callbacks = LazyLoader("callbacks", globals(), "keras.src.callbacks")
32callbacks_v1 = LazyLoader("callbacks_v1", globals(), "keras.src.callbacks_v1")
35class ModelVersionSelector:
36 """Chooses between Keras v1 and v2 Model class."""
38 def __new__(cls, *args, **kwargs):
39 use_v2 = should_use_v2()
40 cls = swap_class(cls, training.Model, training_v1.Model, use_v2)
41 return super(ModelVersionSelector, cls).__new__(cls)
44class LayerVersionSelector:
45 """Chooses between Keras v1 and v2 Layer class."""
47 def __new__(cls, *args, **kwargs):
48 use_v2 = should_use_v2()
49 cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2)
50 return super(LayerVersionSelector, cls).__new__(cls)
53class TensorBoardVersionSelector:
54 """Chooses between Keras v1 and v2 TensorBoard callback class."""
56 def __new__(cls, *args, **kwargs):
57 use_v2 = should_use_v2()
58 start_cls = cls
59 cls = swap_class(
60 start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, use_v2
61 )
62 if (
63 start_cls == callbacks_v1.TensorBoard
64 and cls == callbacks.TensorBoard
65 ):
66 # Since the v2 class is not a subclass of the v1 class, __init__ has
67 # to be called manually.
68 return cls(*args, **kwargs)
69 return super(TensorBoardVersionSelector, cls).__new__(cls)
72def should_use_v2():
73 """Determine if v1 or v2 version should be used."""
74 if tf.executing_eagerly():
75 return True
76 elif tf.compat.v1.executing_eagerly_outside_functions():
77 # Check for a v1 `wrap_function` FuncGraph.
78 # Code inside a `wrap_function` is treated like v1 code.
79 graph = tf.compat.v1.get_default_graph()
80 if getattr(graph, "name", False) and graph.name.startswith(
81 "wrapped_function"
82 ):
83 return False
84 return True
85 else:
86 return False
89def swap_class(cls, v2_cls, v1_cls, use_v2):
90 """Swaps in v2_cls or v1_cls depending on graph mode."""
91 if cls == object:
92 return cls
93 if cls in (v2_cls, v1_cls):
94 return v2_cls if use_v2 else v1_cls
96 # Recursively search superclasses to swap in the right Keras class.
97 new_bases = []
98 for base in cls.__bases__:
99 if (
100 use_v2
101 and issubclass(base, v1_cls)
102 # `v1_cls` often extends `v2_cls`, so it may still call `swap_class`
103 # even if it doesn't need to. That being said, it may be the safest
104 # not to over optimize this logic for the sake of correctness,
105 # especially if we swap v1 & v2 classes that don't extend each
106 # other, or when the inheritance order is different.
107 or (not use_v2 and issubclass(base, v2_cls))
108 ):
109 new_base = swap_class(base, v2_cls, v1_cls, use_v2)
110 else:
111 new_base = base
112 new_bases.append(new_base)
113 cls.__bases__ = tuple(new_bases)
114 return cls
117def disallow_legacy_graph(cls_name, method_name):
118 if not tf.compat.v1.executing_eagerly_outside_functions():
119 error_msg = (
120 f"Calling `{cls_name}.{method_name}` in graph mode is not "
121 f"supported when the `{cls_name}` instance was constructed with "
122 f"eager mode enabled. Please construct your `{cls_name}` instance "
123 f"in graph mode or call `{cls_name}.{method_name}` with "
124 "eager mode enabled."
125 )
126 raise ValueError(error_msg)
129def is_v1_layer_or_model(obj):
130 return isinstance(obj, (base_layer_v1.Layer, training_v1.Model))