Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/c_api_util.py: 44%

100 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities for using the TensorFlow C API.""" 

17 

18import contextlib 

19from tensorflow.core.framework import api_def_pb2 

20from tensorflow.core.framework import op_def_pb2 

21from tensorflow.python.client import pywrap_tf_session as c_api 

22from tensorflow.python.util import compat 

23from tensorflow.python.util import tf_contextlib 

24 

25 

26class AlreadyGarbageCollectedError(Exception): 

27 

28 def __init__(self, name, obj_type): 

29 super(AlreadyGarbageCollectedError, 

30 self).__init__(f"{name} of type {obj_type} has already been garbage " 

31 f"collected and cannot be called.") 

32 

33 

34# FIXME(b/235488206): Convert all Scoped objects to the context manager 

35# to protect against deletion during use when the object is attached to 

36# an attribute. 

37class UniquePtr(object): 

38 """Wrapper around single-ownership C-API objects that handles deletion.""" 

39 

40 __slots__ = ["_obj", "deleter", "name", "type_name"] 

41 

42 def __init__(self, name, obj, deleter): 

43 # '_' prefix marks _obj private, but unclear if it is required also to 

44 # maintain a special CPython destruction order. 

45 self._obj = obj 

46 self.name = name 

47 # Note: when we're destructing the global context (i.e when the process is 

48 # terminating) we may have already deleted other modules. By capturing the 

49 # DeleteGraph function here, we retain the ability to cleanly destroy the 

50 # graph at shutdown, which satisfies leak checkers. 

51 self.deleter = deleter 

52 self.type_name = str(type(obj)) 

53 

54 @contextlib.contextmanager 

55 def get(self): 

56 """Yields the managed C-API Object, guaranteeing aliveness. 

57 

58 This is a context manager. Inside the context the C-API object is 

59 guaranteed to be alive. 

60 

61 Raises: 

62 AlreadyGarbageCollectedError: if the object is already deleted. 

63 """ 

64 # Thread-safety: self.__del__ never runs during the call of this function 

65 # because there is a reference to self from the argument list. 

66 if self._obj is None: 

67 raise AlreadyGarbageCollectedError(self.name, self.type_name) 

68 yield self._obj 

69 

70 def __del__(self): 

71 obj = self._obj 

72 if obj is not None: 

73 self._obj = None 

74 self.deleter(obj) 

75 

76 

77class ScopedTFStatus(object): 

78 """Wrapper around TF_Status that handles deletion.""" 

79 

80 __slots__ = ["status"] 

81 

82 def __init__(self): 

83 self.status = c_api.TF_NewStatus() 

84 

85 def __del__(self): 

86 # Note: when we're destructing the global context (i.e when the process is 

87 # terminating) we can have already deleted other modules. 

88 if c_api is not None and c_api.TF_DeleteStatus is not None: 

89 c_api.TF_DeleteStatus(self.status) 

90 

91 

92class ScopedTFImportGraphDefOptions(object): 

93 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 

94 

95 __slots__ = ["options"] 

96 

97 def __init__(self): 

98 self.options = c_api.TF_NewImportGraphDefOptions() 

99 

100 def __del__(self): 

101 # Note: when we're destructing the global context (i.e when the process is 

102 # terminating) we can have already deleted other modules. 

103 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 

104 c_api.TF_DeleteImportGraphDefOptions(self.options) 

105 

106 

107class ScopedTFImportGraphDefResults(object): 

108 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 

109 

110 __slots__ = ["results"] 

111 

112 def __init__(self, results): 

113 self.results = results 

114 

115 def __del__(self): 

116 # Note: when we're destructing the global context (i.e when the process is 

117 # terminating) we can have already deleted other modules. 

118 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 

119 c_api.TF_DeleteImportGraphDefResults(self.results) 

120 

121 

122class ScopedTFFunction(UniquePtr): 

123 """Wrapper around TF_Function that handles deletion.""" 

124 

125 def __init__(self, func, name): 

126 super(ScopedTFFunction, self).__init__( 

127 name=name, obj=func, deleter=c_api.TF_DeleteFunction) 

128 

129 

130class ScopedTFBuffer(object): 

131 """An internal class to help manage the TF_Buffer lifetime.""" 

132 

133 __slots__ = ["buffer"] 

134 

135 def __init__(self, buf_string): 

136 self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string)) 

137 

138 def __del__(self): 

139 c_api.TF_DeleteBuffer(self.buffer) 

140 

141 

142class ApiDefMap(object): 

143 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 

144 

145 The OpDef protos are also stored in this class so that they could 

146 be queried by op name. 

147 """ 

148 

149 __slots__ = ["_api_def_map", "_op_per_name"] 

150 

151 def __init__(self): 

152 op_def_proto = op_def_pb2.OpList() 

153 buf = c_api.TF_GetAllOpList() 

154 try: 

155 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 

156 self._api_def_map = c_api.TF_NewApiDefMap(buf) 

157 finally: 

158 c_api.TF_DeleteBuffer(buf) 

159 

160 self._op_per_name = {} 

161 for op in op_def_proto.op: 

162 self._op_per_name[op.name] = op 

163 

164 def __del__(self): 

165 # Note: when we're destructing the global context (i.e when the process is 

166 # terminating) we can have already deleted other modules. 

167 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 

168 c_api.TF_DeleteApiDefMap(self._api_def_map) 

169 

170 def put_api_def(self, text): 

171 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 

172 

173 def get_api_def(self, op_name): 

174 api_def_proto = api_def_pb2.ApiDef() 

175 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 

176 try: 

177 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 

178 finally: 

179 c_api.TF_DeleteBuffer(buf) 

180 return api_def_proto 

181 

182 def get_op_def(self, op_name): 

183 if op_name in self._op_per_name: 

184 return self._op_per_name[op_name] 

185 raise ValueError(f"No op_def found for op name {op_name}.") 

186 

187 def op_names(self): 

188 return self._op_per_name.keys() 

189 

190 

191@tf_contextlib.contextmanager 

192def tf_buffer(data=None): 

193 """Context manager that creates and deletes TF_Buffer. 

194 

195 Example usage: 

196 with tf_buffer() as buf: 

197 # get serialized graph def into buf 

198 ... 

199 proto_data = c_api.TF_GetBuffer(buf) 

200 graph_def.ParseFromString(compat.as_bytes(proto_data)) 

201 # buf has been deleted 

202 

203 with tf_buffer(some_string) as buf: 

204 c_api.TF_SomeFunction(buf) 

205 # buf has been deleted 

206 

207 Args: 

208 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 

209 yielded buffer will contain this data. 

210 

211 Yields: 

212 Created TF_Buffer 

213 """ 

214 if data: 

215 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 

216 else: 

217 buf = c_api.TF_NewBuffer() 

218 try: 

219 yield buf 

220 finally: 

221 c_api.TF_DeleteBuffer(buf) 

222 

223 

224def tf_output(c_op, index): 

225 """Returns a wrapped TF_Output with specified operation and index. 

226 

227 Args: 

228 c_op: wrapped TF_Operation 

229 index: integer 

230 

231 Returns: 

232 Wrapped TF_Output 

233 """ 

234 ret = c_api.TF_Output() 

235 ret.oper = c_op 

236 ret.index = index 

237 return ret