Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/resource.py: 59%
128 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definitions for resource-type trackable object classes."""
17import contextlib
18import copy
19import weakref
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import ops
24from tensorflow.python.trackable import base
25from tensorflow.python.util import tf_contextlib
26from tensorflow.python.util.tf_export import tf_export
28# global _RESOURCE_TRACKER_STACK
29_RESOURCE_TRACKER_STACK = []
32class ResourceTracker:
33 """An object that tracks a list of resources."""
35 __slots__ = ["_resources"]
37 def __init__(self):
38 self._resources = []
40 @property
41 def resources(self):
42 return self._resources
44 def add_resource(self, resource):
45 self._resources.append(resource)
48@tf_contextlib.contextmanager
49def resource_tracker_scope(resource_tracker):
50 """A context to manage resource trackers.
52 Use this in order to collect up all resources created within a block of code.
53 Example usage:
55 ```python
56 resource_tracker = ResourceTracker()
57 with resource_tracker_scope(resource_tracker):
58 resource = TrackableResource()
60 assert resource_tracker.resources == [resource]
62 Args:
63 resource_tracker: The passed in ResourceTracker object
65 Yields:
66 A scope in which the resource_tracker is active.
67 """
68 global _RESOURCE_TRACKER_STACK
69 old = list(_RESOURCE_TRACKER_STACK)
70 _RESOURCE_TRACKER_STACK.append(resource_tracker)
71 try:
72 yield
73 finally:
74 _RESOURCE_TRACKER_STACK = old
77def _make_getter(captured_getter, captured_previous):
78 """To avoid capturing loop variables."""
80 def getter(*args, **kwargs):
81 return captured_getter(captured_previous, *args, **kwargs)
83 return getter
86class _ResourceMetaclass(type):
87 """Metaclass for CapturableResource."""
89 def __call__(cls, *args, **kwargs):
91 def default_resource_creator(next_creator, *a, **kw):
92 assert next_creator is None
93 obj = cls.__new__(cls, *a, **kw)
94 obj.__init__(*a, **kw)
95 return obj
97 previous_getter = lambda *a, **kw: default_resource_creator(None, *a, **kw)
98 resource_creator_stack = ops.get_default_graph()._resource_creator_stack
99 for getter in resource_creator_stack[cls._resource_type()]:
100 previous_getter = _make_getter(getter, previous_getter)
102 return previous_getter(*args, **kwargs)
105class CapturableResource(base.Trackable, metaclass=_ResourceMetaclass):
106 """Holds a Tensor which a tf.function can capture.
108 `CapturableResource`s are discovered by traversing the graph of object
109 attributes, e.g. during `tf.saved_model.save`. They are excluded from the
110 scope-based tracking of `TrackableResource`; generally things that require
111 initialization should inherit from `TrackableResource` instead of
112 `CapturableResource` directly.
113 """
115 def __init__(self, device=""):
116 """Initialize the `CapturableResource`.
118 Args:
119 device: A string indicating a required placement for this resource,
120 e.g. "CPU" if this resource must be created on a CPU device. A blank
121 device allows the user to place resource creation, so generally this
122 should be blank unless the resource only makes sense on one device.
123 """
124 self._resource_handle_value = None
125 self._resource_device = device
126 self._self_destruction_context = (
127 context.eager_mode if context.executing_eagerly()
128 else ops.get_default_graph().as_default)
130 @classmethod
131 def _resource_type(cls):
132 return cls.__name__
134 @property
135 def _destruction_context(self):
136 return getattr(self, "_self_destruction_context",
137 # no-op context
138 contextlib.suppress)
140 @_destruction_context.setter
141 def _destruction_context(self, destruction_context):
142 self._self_destruction_context = destruction_context
144 def _create_resource(self):
145 """A function that creates a resource handle."""
146 raise NotImplementedError("TrackableResource._create_resource not "
147 "implemented.")
149 @property
150 def _resource_handle(self):
151 return self._resource_handle_value
153 @_resource_handle.setter
154 def _resource_handle(self, value):
155 if isinstance(value, (ops.Tensor, ops.EagerTensor)):
156 value._parent_trackable = weakref.ref(self) # pylint: disable=protected-access
157 self._resource_handle_value = value
159 def _initialize(self):
160 """A function that initializes the resource. Optional."""
161 pass
163 def _destroy_resource(self):
164 """A function that destroys the resource. Optional."""
165 pass
167 @property
168 def resource_handle(self):
169 """Returns the resource handle associated with this Resource."""
170 if self._resource_handle is None:
171 with ops.device(self._resource_device):
172 self._resource_handle = self._create_resource()
173 return self._resource_handle
175 def _export_to_saved_model_graph(
176 self, object_map, tensor_map, **unused_kwargs):
177 """For implementing `Trackable`."""
178 new_obj = copy.copy(self)
179 # pylint: disable=protected-access
180 with ops.device(self._resource_device):
181 new_resource = new_obj._create_resource()
182 new_obj._resource_handle = new_resource
183 # pylint: enable=protected-access
184 object_map[self] = new_obj
185 tensor_map[self.resource_handle] = new_resource
186 return [self.resource_handle]
188 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
189 children = super()._trackable_children(save_type, **kwargs)
190 if save_type == "savedmodel":
191 @def_function.function(input_signature=[], autograph=False)
192 def _creator():
193 resource = self._create_resource()
194 return resource
196 @def_function.function(input_signature=[], autograph=False)
197 def _initializer():
198 self._initialize()
199 return 1 # Dummy return
201 @def_function.function(input_signature=[], autograph=False)
202 def _destroyer():
203 self._destroy_resource()
204 return 1 # Dummy return
206 children.update({
207 "_create_resource": _creator,
208 "_initialize": _initializer,
209 "_destroy_resource": _destroyer,
210 })
211 return children
213 def __del__(self):
214 try:
215 # Outer race condition: on program exit, the destruction context may be
216 # deleted before this __del__ is called. At this point we can safely
217 # exit without calling _destroy_resource() and let Python handle things.
218 with self._destruction_context():
219 # Inner race condition: possible between this and `ScopedTFFunction`
220 # whereby if an entire garbage collection chain containing both
221 # objects is moved to unreachable during the same garbage collection
222 # cycle, the __del__ for `ScopedTFFunction` can be collected before
223 # this method is called. In that case, we can't do much but
224 # continue.
225 self._destroy_resource()
226 except Exception: # pylint: disable=broad-except
227 # Silence all error logs that occur when attempting to destroy this
228 # resource.
229 pass
232@tf_export("saved_model.experimental.TrackableResource")
233class TrackableResource(CapturableResource):
234 """Holds a Tensor which a tf.function can capture.
236 A TrackableResource is most useful for stateful Tensors that require
237 initialization, such as `tf.lookup.StaticHashTable`. `TrackableResource`s
238 are discovered by traversing the graph of object attributes, e.g. during
239 `tf.saved_model.save`.
241 A TrackableResource has three methods to override:
243 * `_create_resource` should create the resource tensor handle.
244 * `_initialize` should initialize the resource held at `self.resource_handle`.
245 * `_destroy_resource` is called upon a `TrackableResource`'s destruction
246 and should decrement the resource's ref count. For most resources, this
247 should be done with a call to `tf.raw_ops.DestroyResourceOp`.
249 Example usage:
251 >>> class DemoResource(tf.saved_model.experimental.TrackableResource):
252 ... def __init__(self):
253 ... super().__init__()
254 ... self._initialize()
255 ... def _create_resource(self):
256 ... return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
257 ... def _initialize(self):
258 ... tf.raw_ops.AssignVariableOp(
259 ... resource=self.resource_handle, value=tf.ones([2]))
260 ... def _destroy_resource(self):
261 ... tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
262 >>> class DemoModule(tf.Module):
263 ... def __init__(self):
264 ... self.resource = DemoResource()
265 ... def increment(self, tensor):
266 ... return tensor + tf.raw_ops.ReadVariableOp(
267 ... resource=self.resource.resource_handle, dtype=tf.float32)
268 >>> demo = DemoModule()
269 >>> demo.increment([5, 1])
270 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>
271 """
273 def __init__(self, device=""):
274 """Initialize the `TrackableResource`.
276 Args:
277 device: A string indicating a required placement for this resource,
278 e.g. "CPU" if this resource must be created on a CPU device. A blank
279 device allows the user to place resource creation, so generally this
280 should be blank unless the resource only makes sense on one device.
281 """
282 global _RESOURCE_TRACKER_STACK
283 for resource_tracker in _RESOURCE_TRACKER_STACK:
284 resource_tracker.add_resource(self)
285 super().__init__(device=device)
288# TODO(b/124205571,b/124092991): Solve destruction of resources.
289class RestoredResource(TrackableResource):
290 """Restored SavedResource."""
292 def __init__(self, device=""):
293 super().__init__(device=device)
295 @classmethod
296 def _deserialize_from_proto(cls, object_proto, dependencies, **unused_kwargs):
297 obj = cls(device=object_proto.resource.device)
298 resource_creator = dependencies.get("_create_resource")
299 if resource_creator is not None:
300 obj._create_resource = resource_creator # pylint: disable=protected-access
301 return obj
303 def _add_trackable_child(self, name, value):
304 setattr(self, name, value)
305 if (isinstance(value, base.Trackable) and
306 not isinstance(value, def_function.Function)):
307 self._track_trackable(value, name)