Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/resource.py: 59%

128 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Definitions for resource-type trackable object classes.""" 

16 

17import contextlib 

18import copy 

19import weakref 

20 

21from tensorflow.python.eager import context 

22from tensorflow.python.eager import def_function 

23from tensorflow.python.framework import ops 

24from tensorflow.python.trackable import base 

25from tensorflow.python.util import tf_contextlib 

26from tensorflow.python.util.tf_export import tf_export 

27 

28# global _RESOURCE_TRACKER_STACK 

29_RESOURCE_TRACKER_STACK = [] 

30 

31 

32class ResourceTracker: 

33 """An object that tracks a list of resources.""" 

34 

35 __slots__ = ["_resources"] 

36 

37 def __init__(self): 

38 self._resources = [] 

39 

40 @property 

41 def resources(self): 

42 return self._resources 

43 

44 def add_resource(self, resource): 

45 self._resources.append(resource) 

46 

47 

48@tf_contextlib.contextmanager 

49def resource_tracker_scope(resource_tracker): 

50 """A context to manage resource trackers. 

51 

52 Use this in order to collect up all resources created within a block of code. 

53 Example usage: 

54 

55 ```python 

56 resource_tracker = ResourceTracker() 

57 with resource_tracker_scope(resource_tracker): 

58 resource = TrackableResource() 

59 

60 assert resource_tracker.resources == [resource] 

61 

62 Args: 

63 resource_tracker: The passed in ResourceTracker object 

64 

65 Yields: 

66 A scope in which the resource_tracker is active. 

67 """ 

68 global _RESOURCE_TRACKER_STACK 

69 old = list(_RESOURCE_TRACKER_STACK) 

70 _RESOURCE_TRACKER_STACK.append(resource_tracker) 

71 try: 

72 yield 

73 finally: 

74 _RESOURCE_TRACKER_STACK = old 

75 

76 

77def _make_getter(captured_getter, captured_previous): 

78 """To avoid capturing loop variables.""" 

79 

80 def getter(*args, **kwargs): 

81 return captured_getter(captured_previous, *args, **kwargs) 

82 

83 return getter 

84 

85 

86class _ResourceMetaclass(type): 

87 """Metaclass for CapturableResource.""" 

88 

89 def __call__(cls, *args, **kwargs): 

90 

91 def default_resource_creator(next_creator, *a, **kw): 

92 assert next_creator is None 

93 obj = cls.__new__(cls, *a, **kw) 

94 obj.__init__(*a, **kw) 

95 return obj 

96 

97 previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw) 

98 resource_creator_stack = ops.get_default_graph()._resource_creator_stack 

99 for getter in resource_creator_stack[cls._resource_type()]: 

100 previous_getter = _make_getter(getter, previous_getter) 

101 

102 return previous_getter(*args, **kwargs) 

103 

104 

105class CapturableResource(base.Trackable, metaclass=_ResourceMetaclass): 

106 """Holds a Tensor which a tf.function can capture. 

107 

108 `CapturableResource`s are discovered by traversing the graph of object 

109 attributes, e.g. during `tf.saved_model.save`. They are excluded from the 

110 scope-based tracking of `TrackableResource`; generally things that require 

111 initialization should inherit from `TrackableResource` instead of 

112 `CapturableResource` directly. 

113 """ 

114 

115 def __init__(self, device=""): 

116 """Initialize the `CapturableResource`. 

117 

118 Args: 

119 device: A string indicating a required placement for this resource, 

120 e.g. "CPU" if this resource must be created on a CPU device. A blank 

121 device allows the user to place resource creation, so generally this 

122 should be blank unless the resource only makes sense on one device. 

123 """ 

124 self._resource_handle_value = None 

125 self._resource_device = device 

126 self._self_destruction_context = ( 

127 context.eager_mode if context.executing_eagerly() 

128 else ops.get_default_graph().as_default) 

129 

130 @classmethod 

131 def _resource_type(cls): 

132 return cls.__name__ 

133 

134 @property 

135 def _destruction_context(self): 

136 return getattr(self, "_self_destruction_context", 

137 # no-op context 

138 contextlib.suppress) 

139 

140 @_destruction_context.setter 

141 def _destruction_context(self, destruction_context): 

142 self._self_destruction_context = destruction_context 

143 

144 def _create_resource(self): 

145 """A function that creates a resource handle.""" 

146 raise NotImplementedError("TrackableResource._create_resource not " 

147 "implemented.") 

148 

149 @property 

150 def _resource_handle(self): 

151 return self._resource_handle_value 

152 

153 @_resource_handle.setter 

154 def _resource_handle(self, value): 

155 if isinstance(value, (ops.Tensor, ops.EagerTensor)): 

156 value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access 

157 self._resource_handle_value = value 

158 

159 def _initialize(self): 

160 """A function that initializes the resource. Optional.""" 

161 pass 

162 

163 def _destroy_resource(self): 

164 """A function that destroys the resource. Optional.""" 

165 pass 

166 

167 @property 

168 def resource_handle(self): 

169 """Returns the resource handle associated with this Resource.""" 

170 if self._resource_handle is None: 

171 with ops.device(self._resource_device): 

172 self._resource_handle = self._create_resource() 

173 return self._resource_handle 

174 

175 def _export_to_saved_model_graph( 

176 self, object_map, tensor_map, **unused_kwargs): 

177 """For implementing `Trackable`.""" 

178 new_obj = copy.copy(self) 

179 # pylint: disable=protected-access 

180 with ops.device(self._resource_device): 

181 new_resource = new_obj._create_resource() 

182 new_obj._resource_handle = new_resource 

183 # pylint: enable=protected-access 

184 object_map[self] = new_obj 

185 tensor_map[self.resource_handle] = new_resource 

186 return [self.resource_handle] 

187 

188 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

189 children = super()._trackable_children(save_type, **kwargs) 

190 if save_type == "savedmodel": 

191 @def_function.function(input_signature=[], autograph=False) 

192 def _creator(): 

193 resource = self._create_resource() 

194 return resource 

195 

196 @def_function.function(input_signature=[], autograph=False) 

197 def _initializer(): 

198 self._initialize() 

199 return 1 # Dummy return 

200 

201 @def_function.function(input_signature=[], autograph=False) 

202 def _destroyer(): 

203 self._destroy_resource() 

204 return 1 # Dummy return 

205 

206 children.update({ 

207 "_create_resource": _creator, 

208 "_initialize": _initializer, 

209 "_destroy_resource": _destroyer, 

210 }) 

211 return children 

212 

213 def __del__(self): 

214 try: 

215 # Outer race condition: on program exit, the destruction context may be 

216 # deleted before this __del__ is called. At this point we can safely 

217 # exit without calling _destroy_resource() and let Python handle things. 

218 with self._destruction_context(): 

219 # Inner race condition: possible between this and `ScopedTFFunction` 

220 # whereby if an entire garbage collection chain containing both 

221 # objects is moved to unreachable during the same garbage collection 

222 # cycle, the __del__ for `ScopedTFFunction` can be collected before 

223 # this method is called. In that case, we can't do much but 

224 # continue. 

225 self._destroy_resource() 

226 except Exception: # pylint: disable=broad-except 

227 # Silence all error logs that occur when attempting to destroy this 

228 # resource. 

229 pass 

230 

231 

232@tf_export("saved_model.experimental.TrackableResource") 

233class TrackableResource(CapturableResource): 

234 """Holds a Tensor which a tf.function can capture. 

235 

236 A TrackableResource is most useful for stateful Tensors that require 

237 initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s 

238 are discovered by traversing the graph of object attributes, e.g. during 

239 `tf.saved_model.save`. 

240 

241 A TrackableResource has three methods to override: 

242 

243 * `_create_resource` should create the resource tensor handle. 

244 * `_initialize` should initialize the resource held at `self.resource_handle`. 

245 * `_destroy_resource` is called upon a `TrackableResource`'s destruction 

246 and should decrement the resource's ref count. For most resources, this 

247 should be done with a call to `tf.raw_ops.DestroyResourceOp`. 

248 

249 Example usage: 

250 

251 >>> class DemoResource(tf.saved_model.experimental.TrackableResource): 

252 ... def __init__(self): 

253 ... super().__init__() 

254 ... self._initialize() 

255 ... def _create_resource(self): 

256 ... return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2]) 

257 ... def _initialize(self): 

258 ... tf.raw_ops.AssignVariableOp( 

259 ... resource=self.resource_handle, value=tf.ones([2])) 

260 ... def _destroy_resource(self): 

261 ... tf.raw_ops.DestroyResourceOp(resource=self.resource_handle) 

262 >>> class DemoModule(tf.Module): 

263 ... def __init__(self): 

264 ... self.resource = DemoResource() 

265 ... def increment(self, tensor): 

266 ... return tensor + tf.raw_ops.ReadVariableOp( 

267 ... resource=self.resource.resource_handle, dtype=tf.float32) 

268 >>> demo = DemoModule() 

269 >>> demo.increment([5, 1]) 

270 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)> 

271 """ 

272 

273 def __init__(self, device=""): 

274 """Initialize the `TrackableResource`. 

275 

276 Args: 

277 device: A string indicating a required placement for this resource, 

278 e.g. "CPU" if this resource must be created on a CPU device. A blank 

279 device allows the user to place resource creation, so generally this 

280 should be blank unless the resource only makes sense on one device. 

281 """ 

282 global _RESOURCE_TRACKER_STACK 

283 for resource_tracker in _RESOURCE_TRACKER_STACK: 

284 resource_tracker.add_resource(self) 

285 super().__init__(device=device) 

286 

287 

288# TODO(b/124205571,b/124092991): Solve destruction of resources. 

289class RestoredResource(TrackableResource): 

290 """Restored SavedResource.""" 

291 

292 def __init__(self, device=""): 

293 super().__init__(device=device) 

294 

295 @classmethod 

296 def _deserialize_from_proto(cls, object_proto, dependencies, **unused_kwargs): 

297 obj = cls(device=object_proto.resource.device) 

298 resource_creator = dependencies.get("_create_resource") 

299 if resource_creator is not None: 

300 obj._create_resource = resource_creator # pylint: disable=protected-access 

301 return obj 

302 

303 def _add_trackable_child(self, name, value): 

304 setattr(self, name, value) 

305 if (isinstance(value, base.Trackable) and 

306 not isinstance(value, def_function.Function)): 

307 self._track_trackable(value, name)