Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/lazy_variable.py: 28%

78 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Lazily initialized variables, useful for creating a symbolic Keras model.""" 

16 

17import threading 

18 

19# isort: off 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.python.eager import context 

22from tensorflow.python.framework import ops 

23from tensorflow.python.ops import gen_resource_variable_ops 

24from tensorflow.python.ops import resource_variable_ops 

25from tensorflow.python.ops import variable_scope 

26from tensorflow.python.trackable import base as trackable 

27from tensorflow.python.util import compat 

28from tensorflow.python.util import tf_contextlib 

29 

30_DISABLE_LAZY_VARIABLE_INIT = threading.local() 

31 

32 

33def _infer_shape_dtype_and_create_handle(initial_value, shape, dtype, name): 

34 """Infer shape and dtype from initial_value and create a variable handle.""" 

35 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 

36 handle_name = ops.name_from_scope_name(name) 

37 unique_id = "%s_%d" % (handle_name, ops.uid()) 

38 

39 # Use attr_scope and device(None) to simulate the behavior of 

40 # colocate_with when the variable we want to colocate with doesn't 

41 # yet exist. 

42 device_context_manager = ops.NullContextmanager 

43 attr = attr_value_pb2.AttrValue( 

44 list=attr_value_pb2.AttrValue.ListValue( 

45 s=[compat.as_bytes(f"loc:@{handle_name}")] 

46 ) 

47 ) 

48 with ops.get_default_graph()._attr_scope({"_class": attr}): 

49 with ops.name_scope("Initializer"), device_context_manager(None): 

50 if not callable(initial_value): 

51 if isinstance( 

52 initial_value, trackable.CheckpointInitialValue 

53 ): 

54 raise NotImplementedError( 

55 "CheckpointInitialValue is not supported to be the " 

56 "initial value of a lazy variable." 

57 ) 

58 initial_value = ops.convert_to_tensor( 

59 initial_value, name="initial_value", dtype=dtype 

60 ) 

61 assert not callable(initial_value) 

62 

63 assert initial_value.shape.is_compatible_with(shape) 

64 dtype = dtype or initial_value.dtype.base_dtype 

65 shape = shape or initial_value.shape 

66 

67 assert dtype 

68 assert shape 

69 handle = ( 

70 resource_variable_ops._variable_handle_from_shape_and_dtype( 

71 shape=shape, 

72 dtype=dtype, 

73 shared_name=None, # Never shared 

74 name=name, 

75 graph_mode=False, 

76 initial_value=None, 

77 ) 

78 ) 

79 # initial_value=initial_value if not callable(initial_value) else 

80 # None) 

81 return initial_value, shape, dtype, handle, handle_name, unique_id 

82 

83 

84class LazyInitVariable(resource_variable_ops.BaseResourceVariable): 

85 """Lazily initialized variables. 

86 

87 The major use case for this class is to serve as a memory efficient 

88 alternative for tf.Variable. The resource handle of this class is point to 

89 nothing, which mean it will raise error when its value is fetched in a eager 

90 context. Having said that, it will perform like a normal tf.Variable when 

91 using with graph tensor, like KerasTensor produced from tf.keras.Input. 

92 """ 

93 

94 def __init__( 

95 self, 

96 initial_value=None, 

97 trainable=None, 

98 collections=None, 

99 validate_shape=True, 

100 caching_device=None, 

101 name=None, 

102 dtype=None, 

103 variable_def=None, 

104 import_scope=None, 

105 constraint=None, 

106 distribute_strategy=None, 

107 synchronization=None, 

108 aggregation=None, 

109 shape=None, 

110 **kwargs, 

111 ): 

112 assert context.executing_eagerly() # To simplify the logic 

113 assert variable_def is None # Not supported yet. 

114 assert caching_device is None # Not supported yet 

115 

116 if initial_value is None: 

117 raise ValueError( 

118 "The `initial_value` arg to `tf.Variable` must " 

119 "be specified except when you are not providing a " 

120 "`variable_def`. You provided neither." 

121 ) 

122 

123 if ( 

124 isinstance(initial_value, ops.Tensor) 

125 and hasattr(initial_value, "graph") 

126 and initial_value.graph.building_function 

127 ): 

128 raise ValueError( 

129 f"Argument `initial_value` ({initial_value}) could not " 

130 "be lifted out of a `tf.function`. " 

131 f"(Tried to create variable with name='{name}'). " 

132 "To avoid this error, when constructing `tf.Variable`s " 

133 "inside of `tf.function` you can create the " 

134 "`initial_value` tensor in a " 

135 "`tf.init_scope` or pass a callable `initial_value` " 

136 "(e.g., `tf.Variable(lambda : " 

137 "tf.truncated_normal([10, 40]))`). " 

138 "Please file a feature request if this " 

139 "restriction inconveniences you." 

140 ) 

141 

142 if constraint is not None and not callable(constraint): 

143 raise ValueError( 

144 "Argument `constraint` must be None or a callable. " 

145 f"a callable. Got a {type(constraint)}: {constraint}" 

146 ) 

147 

148 self._name = name 

149 ( 

150 initial_value, 

151 shape, 

152 dtype, 

153 handle, 

154 handle_name, 

155 unique_id, 

156 ) = _infer_shape_dtype_and_create_handle( 

157 initial_value, shape, dtype, name 

158 ) 

159 

160 super().__init__( 

161 distribute_strategy=distribute_strategy, 

162 initial_value=initial_value, 

163 shape=shape, 

164 dtype=dtype, 

165 name=name, 

166 unique_id=unique_id, 

167 handle_name=handle_name, 

168 constraint=constraint, 

169 handle=handle, 

170 graph_element=None, 

171 trainable=trainable, 

172 synchronization=synchronization, 

173 aggregation=aggregation, 

174 in_graph_mode=False, 

175 ) 

176 

177 # TODO(scottzhu): This method and create_and_initialize might be removed if 

178 # we decide to just use the tf.Variable to replace this class. 

179 def initialize(self): 

180 with ops.name_scope(self._name, "Variable", skip_on_eager=False): 

181 with ops.colocate_with(self._handle), ops.name_scope("Initializer"): 

182 if callable(self._initial_value): 

183 initial_value = self._initial_value() 

184 else: 

185 initial_value = self._initial_value 

186 

187 if not initial_value.shape.is_compatible_with(self._shape): 

188 raise ValueError( 

189 "In this `tf.Variable` creation, the initial value's " 

190 f"shape ({initial_value.shape}) is not compatible with " 

191 "the explicitly supplied `shape` " 

192 f"argument ({self._shape})." 

193 ) 

194 assert self._dtype is initial_value.dtype.base_dtype 

195 gen_resource_variable_ops.assign_variable_op( 

196 self._handle, initial_value 

197 ) 

198 

199 def create_and_initialize(self): 

200 if callable(self._initial_value): 

201 initial_value = self._initial_value() 

202 

203 with ops.device(initial_value.device): 

204 ( 

205 initial_value, 

206 shape, 

207 dtype, 

208 handle, 

209 handle_name, 

210 unique_id, 

211 ) = _infer_shape_dtype_and_create_handle( 

212 initial_value, self._shape, self._dtype, self._name 

213 ) 

214 self.initialize() 

215 

216 super().__init__( 

217 trainable=self._trainable, 

218 shape=shape, 

219 dtype=dtype, 

220 handle=handle, 

221 synchronization=self._synchronization, 

222 constraint=self._constraint, 

223 aggregation=self._aggregation, 

224 distribute_strategy=self._distribute_strategy, 

225 name=self._name, 

226 unique_id=unique_id, 

227 handle_name=handle_name, 

228 graph_element=None, 

229 initial_value=initial_value, 

230 initializer_op=None, 

231 is_initialized_op=None, 

232 cached_value=None, 

233 caching_device=None, 

234 ) 

235 

236 

237def _lazy_init_variable_creator(next_creator, **kwargs): 

238 if getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False): 

239 return next_creator(**kwargs) 

240 else: 

241 return LazyInitVariable(**kwargs) 

242 

243 

244@tf_contextlib.contextmanager 

245def lazy_init_scope(): 

246 with variable_scope.variable_creator_scope(_lazy_init_variable_creator): 

247 yield 

248 

249 

250@tf_contextlib.contextmanager 

251def disable_init_variable_creator(): 

252 try: 

253 global _DISABLE_LAZY_VARIABLE_INIT 

254 existing_value = getattr(_DISABLE_LAZY_VARIABLE_INIT, "disabled", False) 

255 _DISABLE_LAZY_VARIABLE_INIT.disabled = True 

256 yield 

257 finally: 

258 _DISABLE_LAZY_VARIABLE_INIT.disabled = existing_value 

259