Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/session_ops.py: 32%

125 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Tensor Handle Operations.""" 

17 

18# pylint: disable=g-bad-name 

19import numpy as np 

20 

21from tensorflow.core.framework import resource_handle_pb2 

22from tensorflow.python.client import pywrap_tf_session 

23from tensorflow.python.framework import device as pydev 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import gen_data_flow_ops 

28from tensorflow.python.util import compat 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32def encode_resource_handle(resource_handle): 

33 """Encode a ResourceHandle proto as custom numpy struct type.""" 

34 return np.asarray(bytearray(resource_handle.SerializeToString()), 

35 dtype=dtypes.np_resource) 

36 

37 

38class TensorHandle: 

39 """Represents a handle for a live tensor in a session.""" 

40 

41 def __init__(self, handle, dtype, session): 

42 """Constructs a new tensor handle. 

43 

44 A tensor handle for a persistent tensor is a python string 

45 that has the form of "tensor_name;unique_id;device_name". 

46 

47 Args: 

48 handle: A tensor handle. 

49 dtype: The data type of the tensor represented by `handle`. 

50 session: The session in which the tensor is produced. 

51 """ 

52 self._handle = compat.as_str_any(handle) 

53 self._resource_handle = None 

54 self._dtype = dtype 

55 self._session = session 

56 self._auto_gc_enabled = True 

57 

58 def __del__(self): 

59 if self._auto_gc_enabled: 

60 self._session._register_dead_handle(self.handle) 

61 

62 def __str__(self): 

63 return self._handle 

64 

65 def _get_resource_handle(self): 

66 """The ResourceHandle representation of this handle.""" 

67 if not self._resource_handle: 

68 self._resource_handle = resource_handle_pb2.ResourceHandleProto() 

69 self._resource_handle.device = self._handle.split(";")[-1] 

70 self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY) 

71 self._resource_handle.name = self._handle 

72 return self._resource_handle 

73 

74 def to_numpy_array(self): 

75 """Convert a TensorHandle object to a feedable numpy value. 

76 

77 Returns: 

78 A numpy array of a custom struct type that can be used as a feed value 

79 to run(). 

80 """ 

81 return encode_resource_handle(self._get_resource_handle()) 

82 

83 @property 

84 def handle(self): 

85 """The string representation of this handle.""" 

86 return self._handle 

87 

88 def eval(self): 

89 """Return the value of the tensor represented by this handle.""" 

90 if not self._auto_gc_enabled: 

91 raise TypeError("Persistent tensor %s may have already been deleted." 

92 % self.handle) 

93 holder, reader = _get_handle_reader(self._session.graph, self._handle, 

94 self._dtype) 

95 return self._session.run(reader, feed_dict={holder: self._handle}) 

96 

97 def delete(self): 

98 """Force the deletion of this persistent tensor.""" 

99 if not self._auto_gc_enabled: 

100 raise TypeError("Persistent tensor %s may have already been deleted." 

101 % self.handle) 

102 self._auto_gc_enabled = False 

103 holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) 

104 self._session.run(deleter, feed_dict={holder: self.handle}) 

105 

106 def get_raw_handle(self): 

107 """Return the raw handle of the tensor. 

108 

109 Note that the method disables the automatic garbage collection of this 

110 persistent tensor. The caller is now responsible for managing the life 

111 time of the tensor. 

112 """ 

113 self._auto_gc_enabled = False 

114 return self._handle 

115 

116 @staticmethod 

117 def _get_device_name(handle): 

118 """The device name encoded in the handle.""" 

119 handle_str = compat.as_str_any(handle) 

120 return pydev.canonical_name(handle_str.split(";")[-1]) 

121 

122 @staticmethod 

123 def _get_reader_key(handle): 

124 """The graph key for reader.""" 

125 handle_parts = str(handle).split(";") 

126 return handle_parts[0] + ";" + handle_parts[-1] 

127 

128 @staticmethod 

129 def _get_mover_key(feeder, handle): 

130 """The graph key for mover.""" 

131 return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) 

132 

133 

134@tf_export(v1=["get_session_handle"]) 

135def get_session_handle(data, name=None): 

136 """Return the handle of `data`. 

137 

138 This is EXPERIMENTAL and subject to change. 

139 

140 Keep `data` "in-place" in the runtime and create a handle that can be 

141 used to retrieve `data` in a subsequent run(). 

142 

143 Combined with `get_session_tensor`, we can keep a tensor produced in 

144 one run call in place, and use it as the input in a future run call. 

145 

146 Args: 

147 data: A tensor to be stored in the session. 

148 name: Optional name prefix for the return tensor. 

149 

150 Returns: 

151 A scalar string tensor representing a unique handle for `data`. 

152 

153 Raises: 

154 TypeError: if `data` is not a Tensor. 

155 

156 Example: 

157 

158 ```python 

159 c = tf.multiply(a, b) 

160 h = tf.compat.v1.get_session_handle(c) 

161 h = sess.run(h) 

162 

163 p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32) 

164 b = tf.multiply(a, 10) 

165 c = sess.run(b, feed_dict={p: h.handle}) 

166 ``` 

167 

168 """ 

169 if not isinstance(data, ops.Tensor): 

170 raise TypeError("`data` must be of type Tensor.") 

171 

172 # Colocate this operation with data. 

173 with ops.colocate_with(data): 

174 return gen_data_flow_ops.get_session_handle(data, name=name) 

175 

176 

177@tf_export(v1=["get_session_tensor"]) 

178def get_session_tensor(handle, dtype, name=None): 

179 """Get the tensor of type `dtype` by feeding a tensor handle. 

180 

181 This is EXPERIMENTAL and subject to change. 

182 

183 Get the value of the tensor from a tensor handle. The tensor 

184 is produced in a previous run() and stored in the state of the 

185 session. 

186 

187 Args: 

188 handle: The string representation of a persistent tensor handle. 

189 dtype: The type of the output tensor. 

190 name: Optional name prefix for the return tensor. 

191 

192 Returns: 

193 A pair of tensors. The first is a placeholder for feeding a 

194 tensor handle and the second is the tensor in the session state 

195 keyed by the tensor handle. 

196 

197 Example: 

198 

199 ```python 

200 c = tf.multiply(a, b) 

201 h = tf.compat.v1.get_session_handle(c) 

202 h = sess.run(h) 

203 

204 p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32) 

205 b = tf.multiply(a, 10) 

206 c = sess.run(b, feed_dict={p: h.handle}) 

207 ``` 

208 

209 """ 

210 handle_device = TensorHandle._get_device_name(handle) 

211 with ops.device(handle_device): 

212 holder = array_ops.placeholder(dtypes.string) 

213 _register_handle_feeder(holder.graph, holder, dtype) 

214 tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name) 

215 return (holder, tensor) 

216 

217 

218@tf_export(v1=["delete_session_tensor"]) 

219def delete_session_tensor(handle, name=None): 

220 """Delete the tensor for the given tensor handle. 

221 

222 This is EXPERIMENTAL and subject to change. 

223 

224 Delete the tensor of a given tensor handle. The tensor is produced 

225 in a previous run() and stored in the state of the session. 

226 

227 Args: 

228 handle: The string representation of a persistent tensor handle. 

229 name: Optional name prefix for the return tensor. 

230 

231 Returns: 

232 A pair of graph elements. The first is a placeholder for feeding a 

233 tensor handle and the second is a deletion operation. 

234 """ 

235 handle_device = TensorHandle._get_device_name(handle) 

236 with ops.device(handle_device): 

237 holder = array_ops.placeholder(dtypes.string) 

238 deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name) 

239 return (holder, deleter) 

240 

241 

242def _register_handle_feeder(graph, feeder, dtype): 

243 graph._handle_feeders[feeder.op.name] = dtype 

244 

245 

246def _get_handle_feeder(graph, feeder): 

247 return graph._handle_feeders.get(feeder.op.name) 

248 

249 

250def _get_handle_reader(graph, handle, dtype): 

251 """Return a read subgraph for this handle.""" 

252 graph_key = TensorHandle._get_reader_key(handle) 

253 result = graph._handle_readers.get(graph_key) 

254 if result is None: 

255 # Create reader if we haven't done it. 

256 handle_device = TensorHandle._get_device_name(handle) 

257 with graph.as_default(), graph.device(handle_device): 

258 holder = array_ops.placeholder(dtypes.string) 

259 _register_handle_feeder(holder.graph, holder, dtype) 

260 reader = gen_data_flow_ops.get_session_tensor(holder, dtype) 

261 result = (holder, reader) 

262 graph._handle_readers[graph_key] = result 

263 return result 

264 

265 

266def _get_handle_mover(graph, feeder, handle): 

267 """Return a move subgraph for this pair of feeder and handle.""" 

268 dtype = _get_handle_feeder(graph, feeder) 

269 if dtype is None: 

270 return None 

271 handle_device = TensorHandle._get_device_name(handle) 

272 if feeder.op.device == handle_device: 

273 return None 

274 # Now we know we have to move the tensor. 

275 graph_key = TensorHandle._get_mover_key(feeder, handle) 

276 result = graph._handle_movers.get(graph_key) 

277 if result is None: 

278 # Create mover if we haven't done it. 

279 holder, reader = _get_handle_reader(graph, handle, dtype) 

280 with graph.as_default(), graph.device(feeder.op.device): 

281 mover = gen_data_flow_ops.get_session_handle(reader) 

282 result = (holder, mover) 

283 graph._handle_movers[graph_key] = result 

284 return result 

285 

286 

287def _get_handle_deleter(graph, deleter_key, handle): 

288 """Return a deletion subgraph for this handle.""" 

289 result = graph._handle_deleters.get(deleter_key) 

290 if result is None: 

291 # Create deleter if we haven't done it. 

292 handle_device = TensorHandle._get_device_name(handle) 

293 with graph.as_default(), graph.device(handle_device): 

294 holder = array_ops.placeholder(dtypes.string) 

295 deleter = gen_data_flow_ops.delete_session_tensor(holder) 

296 result = (holder, deleter) 

297 graph._handle_deleters[deleter_key] = result 

298 return result