Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/c_api_util.py: 44%
100 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities for using the TensorFlow C API."""
18import contextlib
19from tensorflow.core.framework import api_def_pb2
20from tensorflow.core.framework import op_def_pb2
21from tensorflow.python.client import pywrap_tf_session as c_api
22from tensorflow.python.util import compat
23from tensorflow.python.util import tf_contextlib
26class AlreadyGarbageCollectedError(Exception):
28 def __init__(self, name, obj_type):
29 super(AlreadyGarbageCollectedError,
30 self).__init__(f"{name} of type {obj_type} has already been garbage "
31 f"collected and cannot be called.")
34# FIXME(b/235488206): Convert all Scoped objects to the context manager
35# to protect against deletion during use when the object is attached to
36# an attribute.
37class UniquePtr(object):
38 """Wrapper around single-ownership C-API objects that handles deletion."""
40 __slots__ = ["_obj", "deleter", "name", "type_name"]
42 def __init__(self, name, obj, deleter):
43 # '_' prefix marks _obj private, but unclear if it is required also to
44 # maintain a special CPython destruction order.
45 self._obj = obj
46 self.name = name
47 # Note: when we're destructing the global context (i.e when the process is
48 # terminating) we may have already deleted other modules. By capturing the
49 # DeleteGraph function here, we retain the ability to cleanly destroy the
50 # graph at shutdown, which satisfies leak checkers.
51 self.deleter = deleter
52 self.type_name = str(type(obj))
54 @contextlib.contextmanager
55 def get(self):
56 """Yields the managed C-API Object, guaranteeing aliveness.
58 This is a context manager. Inside the context the C-API object is
59 guaranteed to be alive.
61 Raises:
62 AlreadyGarbageCollectedError: if the object is already deleted.
63 """
64 # Thread-safety: self.__del__ never runs during the call of this function
65 # because there is a reference to self from the argument list.
66 if self._obj is None:
67 raise AlreadyGarbageCollectedError(self.name, self.type_name)
68 yield self._obj
70 def __del__(self):
71 obj = self._obj
72 if obj is not None:
73 self._obj = None
74 self.deleter(obj)
77class ScopedTFStatus(object):
78 """Wrapper around TF_Status that handles deletion."""
80 __slots__ = ["status"]
82 def __init__(self):
83 self.status = c_api.TF_NewStatus()
85 def __del__(self):
86 # Note: when we're destructing the global context (i.e when the process is
87 # terminating) we can have already deleted other modules.
88 if c_api is not None and c_api.TF_DeleteStatus is not None:
89 c_api.TF_DeleteStatus(self.status)
92class ScopedTFImportGraphDefOptions(object):
93 """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
95 __slots__ = ["options"]
97 def __init__(self):
98 self.options = c_api.TF_NewImportGraphDefOptions()
100 def __del__(self):
101 # Note: when we're destructing the global context (i.e when the process is
102 # terminating) we can have already deleted other modules.
103 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
104 c_api.TF_DeleteImportGraphDefOptions(self.options)
107class ScopedTFImportGraphDefResults(object):
108 """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
110 __slots__ = ["results"]
112 def __init__(self, results):
113 self.results = results
115 def __del__(self):
116 # Note: when we're destructing the global context (i.e when the process is
117 # terminating) we can have already deleted other modules.
118 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
119 c_api.TF_DeleteImportGraphDefResults(self.results)
122class ScopedTFFunction(UniquePtr):
123 """Wrapper around TF_Function that handles deletion."""
125 def __init__(self, func, name):
126 super(ScopedTFFunction, self).__init__(
127 name=name, obj=func, deleter=c_api.TF_DeleteFunction)
130class ScopedTFBuffer(object):
131 """An internal class to help manage the TF_Buffer lifetime."""
133 __slots__ = ["buffer"]
135 def __init__(self, buf_string):
136 self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
138 def __del__(self):
139 c_api.TF_DeleteBuffer(self.buffer)
142class ApiDefMap(object):
143 """Wrapper around Tf_ApiDefMap that handles querying and deletion.
145 The OpDef protos are also stored in this class so that they could
146 be queried by op name.
147 """
149 __slots__ = ["_api_def_map", "_op_per_name"]
151 def __init__(self):
152 op_def_proto = op_def_pb2.OpList()
153 buf = c_api.TF_GetAllOpList()
154 try:
155 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
156 self._api_def_map = c_api.TF_NewApiDefMap(buf)
157 finally:
158 c_api.TF_DeleteBuffer(buf)
160 self._op_per_name = {}
161 for op in op_def_proto.op:
162 self._op_per_name[op.name] = op
164 def __del__(self):
165 # Note: when we're destructing the global context (i.e when the process is
166 # terminating) we can have already deleted other modules.
167 if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
168 c_api.TF_DeleteApiDefMap(self._api_def_map)
170 def put_api_def(self, text):
171 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
173 def get_api_def(self, op_name):
174 api_def_proto = api_def_pb2.ApiDef()
175 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
176 try:
177 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
178 finally:
179 c_api.TF_DeleteBuffer(buf)
180 return api_def_proto
182 def get_op_def(self, op_name):
183 if op_name in self._op_per_name:
184 return self._op_per_name[op_name]
185 raise ValueError(f"No op_def found for op name {op_name}.")
187 def op_names(self):
188 return self._op_per_name.keys()
191@tf_contextlib.contextmanager
192def tf_buffer(data=None):
193 """Context manager that creates and deletes TF_Buffer.
195 Example usage:
196 with tf_buffer() as buf:
197 # get serialized graph def into buf
198 ...
199 proto_data = c_api.TF_GetBuffer(buf)
200 graph_def.ParseFromString(compat.as_bytes(proto_data))
201 # buf has been deleted
203 with tf_buffer(some_string) as buf:
204 c_api.TF_SomeFunction(buf)
205 # buf has been deleted
207 Args:
208 data: An optional `bytes`, `str`, or `unicode` object. If not None, the
209 yielded buffer will contain this data.
211 Yields:
212 Created TF_Buffer
213 """
214 if data:
215 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
216 else:
217 buf = c_api.TF_NewBuffer()
218 try:
219 yield buf
220 finally:
221 c_api.TF_DeleteBuffer(buf)
224def tf_output(c_op, index):
225 """Returns a wrapped TF_Output with specified operation and index.
227 Args:
228 c_op: wrapped TF_Operation
229 index: integer
231 Returns:
232 Wrapped TF_Output
233 """
234 ret = c_api.TF_Output()
235 ret.oper = c_op
236 ret.index = index
237 return ret