Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/version_utils.py: 35%
57 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities for Keras classes with v1 and v2 versions."""
18from tensorflow.python.eager import context
19from tensorflow.python.framework import ops
20from tensorflow.python.keras.utils.generic_utils import LazyLoader
22# TODO(b/134426265): Switch back to single-quotes once the issue
23# with copybara is fixed.
24# pylint: disable=g-inconsistent-quotes
25training = LazyLoader(
26 "training", globals(),
27 "tensorflow.python.keras.engine.training")
28training_v1 = LazyLoader(
29 "training_v1", globals(),
30 "tensorflow.python.keras.engine.training_v1")
31base_layer = LazyLoader(
32 "base_layer", globals(),
33 "tensorflow.python.keras.engine.base_layer")
34base_layer_v1 = LazyLoader(
35 "base_layer_v1", globals(),
36 "tensorflow.python.keras.engine.base_layer_v1")
37callbacks = LazyLoader(
38 "callbacks", globals(),
39 "tensorflow.python.keras.callbacks")
40callbacks_v1 = LazyLoader(
41 "callbacks_v1", globals(),
42 "tensorflow.python.keras.callbacks_v1")
45# pylint: enable=g-inconsistent-quotes
48class ModelVersionSelector(object):
49 """Chooses between Keras v1 and v2 Model class."""
51 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
52 use_v2 = should_use_v2()
53 cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment
54 return super(ModelVersionSelector, cls).__new__(cls)
57class LayerVersionSelector(object):
58 """Chooses between Keras v1 and v2 Layer class."""
60 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
61 use_v2 = should_use_v2()
62 cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment
63 return super(LayerVersionSelector, cls).__new__(cls)
66class TensorBoardVersionSelector(object):
67 """Chooses between Keras v1 and v2 TensorBoard callback class."""
69 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
70 use_v2 = should_use_v2()
71 start_cls = cls
72 cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
73 use_v2)
74 if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
75 # Since the v2 class is not a subclass of the v1 class, __init__ has to
76 # be called manually.
77 return cls(*args, **kwargs)
78 return super(TensorBoardVersionSelector, cls).__new__(cls)
81def should_use_v2():
82 """Determine if v1 or v2 version should be used."""
83 if context.executing_eagerly():
84 return True
85 elif ops.executing_eagerly_outside_functions():
86 # Check for a v1 `wrap_function` FuncGraph.
87 # Code inside a `wrap_function` is treated like v1 code.
88 graph = ops.get_default_graph()
89 if (getattr(graph, "name", False) and
90 graph.name.startswith("wrapped_function")):
91 return False
92 return True
93 else:
94 return False
97def swap_class(cls, v2_cls, v1_cls, use_v2):
98 """Swaps in v2_cls or v1_cls depending on graph mode."""
99 if cls == object:
100 return cls
101 if cls in (v2_cls, v1_cls):
102 return v2_cls if use_v2 else v1_cls
104 # Recursively search superclasses to swap in the right Keras class.
105 new_bases = []
106 for base in cls.__bases__:
107 if ((use_v2 and issubclass(base, v1_cls)
108 # `v1_cls` often extends `v2_cls`, so it may still call `swap_class`
109 # even if it doesn't need to. That being said, it may be the safest
110 # not to over optimize this logic for the sake of correctness,
111 # especially if we swap v1 & v2 classes that don't extend each other,
112 # or when the inheritance order is different.
113 or (not use_v2 and issubclass(base, v2_cls)))):
114 new_base = swap_class(base, v2_cls, v1_cls, use_v2)
115 else:
116 new_base = base
117 new_bases.append(new_base)
118 cls.__bases__ = tuple(new_bases)
119 return cls
122def disallow_legacy_graph(cls_name, method_name):
123 if not ops.executing_eagerly_outside_functions():
124 error_msg = (
125 "Calling `{cls_name}.{method_name}` in graph mode is not supported "
126 "when the `{cls_name}` instance was constructed with eager mode "
127 "enabled. Please construct your `{cls_name}` instance in graph mode or"
128 " call `{cls_name}.{method_name}` with eager mode enabled.")
129 error_msg = error_msg.format(cls_name=cls_name, method_name=method_name)
130 raise ValueError(error_msg)
133def is_v1_layer_or_model(obj):
134 return isinstance(obj, (base_layer_v1.Layer, training_v1.Model))