Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/lazy_variable.py: 28%
78 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Lazily initialized variables, useful for creating a symbolic Keras model."""
17import threading
19# isort: off
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python.eager import context
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import gen_resource_variable_ops
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.ops import variable_scope
26from tensorflow.python.trackable import base as trackable
27from tensorflow.python.util import compat
28from tensorflow.python.util import tf_contextlib
30_DISABLE_LAZY_VARIABLE_INIT = threading.local()
33def _infer_shape_dtype_and_create_handle(initial_value, shape, dtype, name):
34 """Infer shape and dtype from initial_value and create a variable handle."""
35 with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
36 handle_name = ops.name_from_scope_name(name)
37 unique_id = "%s_%d" % (handle_name, ops.uid())
39 # Use attr_scope and device(None) to simulate the behavior of
40 # colocate_with when the variable we want to colocate with doesn't
41 # yet exist.
42 device_context_manager = ops.NullContextmanager
43 attr = attr_value_pb2.AttrValue(
44 list=attr_value_pb2.AttrValue.ListValue(
45 s=[compat.as_bytes(f"loc:@{handle_name}")]
46 )
47 )
48 with ops.get_default_graph()._attr_scope({"_class": attr}):
49 with ops.name_scope("Initializer"), device_context_manager(None):
50 if not callable(initial_value):
51 if isinstance(
52 initial_value, trackable.CheckpointInitialValue
53 ):
54 raise NotImplementedError(
55 "CheckpointInitialValue is not supported to be the "
56 "initial value of a lazy variable."
57 )
58 initial_value = ops.convert_to_tensor(
59 initial_value, name="initial_value", dtype=dtype
60 )
61 assert not callable(initial_value)
63 assert initial_value.shape.is_compatible_with(shape)
64 dtype = dtype or initial_value.dtype.base_dtype
65 shape = shape or initial_value.shape
67 assert dtype
68 assert shape
69 handle = (
70 resource_variable_ops._variable_handle_from_shape_and_dtype(
71 shape=shape,
72 dtype=dtype,
73 shared_name=None, # Never shared
74 name=name,
75 graph_mode=False,
76 initial_value=None,
77 )
78 )
79 # initial_value=initial_value if not callable(initial_value) else
80 # None)
81 return initial_value, shape, dtype, handle, handle_name, unique_id
84class LazyInitVariable(resource_variable_ops.BaseResourceVariable):
85 """Lazily initialized variables.
87 The major use case for this class is to serve as a memory efficient
88 alternative for tf.Variable. The resource handle of this class is point to
89 nothing, which mean it will raise error when its value is fetched in a eager
90 context. Having said that, it will perform like a normal tf.Variable when
91 using with graph tensor, like KerasTensor produced from tf.keras.Input.
92 """
94 def __init__(
95 self,
96 initial_value=None,
97 trainable=None,
98 collections=None,
99 validate_shape=True,
100 caching_device=None,
101 name=None,
102 dtype=None,
103 variable_def=None,
104 import_scope=None,
105 constraint=None,
106 distribute_strategy=None,
107 synchronization=None,
108 aggregation=None,
109 shape=None,
110 **kwargs,
111 ):
112 assert context.executing_eagerly() # To simplify the logic
113 assert variable_def is None # Not supported yet.
114 assert caching_device is None # Not supported yet
116 if initial_value is None:
117 raise ValueError(
118 "The `initial_value` arg to `tf.Variable` must "
119 "be specified except when you are not providing a "
120 "`variable_def`. You provided neither."
121 )
123 if (
124 isinstance(initial_value, ops.Tensor)
125 and hasattr(initial_value, "graph")
126 and initial_value.graph.building_function
127 ):
128 raise ValueError(
129 f"Argument `initial_value` ({initial_value}) could not "
130 "be lifted out of a `tf.function`. "
131 f"(Tried to create variable with name='{name}'). "
132 "To avoid this error, when constructing `tf.Variable`s "
133 "inside of `tf.function` you can create the "
134 "`initial_value` tensor in a "
135 "`tf.init_scope` or pass a callable `initial_value` "
136 "(e.g., `tf.Variable(lambda : "
137 "tf.truncated_normal([10, 40]))`). "
138 "Please file a feature request if this "
139 "restriction inconveniences you."
140 )
142 if constraint is not None and not callable(constraint):
143 raise ValueError(
144 "Argument `constraint` must be None or a callable. "
145 f"a callable. Got a {type(constraint)}: {constraint}"
146 )
148 self._name = name
149 (
150 initial_value,
151 shape,
152 dtype,
153 handle,
154 handle_name,
155 unique_id,
156 ) = _infer_shape_dtype_and_create_handle(
157 initial_value, shape, dtype, name
158 )
160 super().__init__(
161 distribute_strategy=distribute_strategy,
162 initial_value=initial_value,
163 shape=shape,
164 dtype=dtype,
165 name=name,
166 unique_id=unique_id,
167 handle_name=handle_name,
168 constraint=constraint,
169 handle=handle,
170 graph_element=None,
171 trainable=trainable,
172 synchronization=synchronization,
173 aggregation=aggregation,
174 in_graph_mode=False,
175 )
177 # TODO(scottzhu): This method and create_and_initialize might be removed if
178 # we decide to just use the tf.Variable to replace this class.
179 def initialize(self):
180 with ops.name_scope(self._name, "Variable", skip_on_eager=False):
181 with ops.colocate_with(self._handle), ops.name_scope("Initializer"):
182 if callable(self._initial_value):
183 initial_value = self._initial_value()
184 else:
185 initial_value = self._initial_value
187 if not initial_value.shape.is_compatible_with(self._shape):
188 raise ValueError(
189 "In this `tf.Variable` creation, the initial value's "
190 f"shape ({initial_value.shape}) is not compatible with "
191 "the explicitly supplied `shape` "
192 f"argument ({self._shape})."
193 )
194 assert self._dtype is initial_value.dtype.base_dtype
195 gen_resource_variable_ops.assign_variable_op(
196 self._handle, initial_value
197 )
199 def create_and_initialize(self):
200 if callable(self._initial_value):
201 initial_value = self._initial_value()
203 with ops.device(initial_value.device):
204 (
205 initial_value,
206 shape,
207 dtype,
208 handle,
209 handle_name,
210 unique_id,
211 ) = _infer_shape_dtype_and_create_handle(
212 initial_value, self._shape, self._dtype, self._name
213 )
214 self.initialize()
216 super().__init__(
217 trainable=self._trainable,
218 shape=shape,
219 dtype=dtype,
220 handle=handle,
221 synchronization=self._synchronization,
222 constraint=self._constraint,
223 aggregation=self._aggregation,
224 distribute_strategy=self._distribute_strategy,
225 name=self._name,
226 unique_id=unique_id,
227 handle_name=handle_name,
228 graph_element=None,
229 initial_value=initial_value,
230 initializer_op=None,
231 is_initialized_op=None,
232 cached_value=None,
233 caching_device=None,
234 )
237def _lazy_init_variable_creator(next_creator, **kwargs):
238 if getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False):
239 return next_creator(**kwargs)
240 else:
241 return LazyInitVariable(**kwargs)
244@tf_contextlib.contextmanager
245def lazy_init_scope():
246 with variable_scope.variable_creator_scope(_lazy_init_variable_creator):
247 yield
250@tf_contextlib.contextmanager
251def disable_init_variable_creator():
252 try:
253 global _DISABLE_LAZY_VARIABLE_INIT
254 existing_value = getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False)
255 _DISABLE_LAZY_VARIABLE_INIT.disabled = True
256 yield
257 finally:
258 _DISABLE_LAZY_VARIABLE_INIT.disabled = existing_value