Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/capture/capture_container.py: 32%

177 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""FuncGraph and related functionality.""" 

16 

17import collections as py_collections 

18import functools 

19from typing import Any, Callable, Hashable, Mapping, Optional 

20 

21from tensorflow.core.function import trace_type 

22from tensorflow.python import pywrap_tfe 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.types import core 

25from tensorflow.python.util import object_identity 

26 

27 

28_EAGER_CONST_THRESHOLD = 128 

29 

30 

31class MutationAwareDict(py_collections.OrderedDict): 

32 """A dict with a mutation flag.""" 

33 

34 def __init__(self, *args, **kwargs): 

35 super().__init__(*args, **kwargs) 

36 self._mutated = True 

37 

38 def pop(self, key, default=None): 

39 self._mutated = True 

40 return super().pop(key, default) 

41 

42 def __setitem__(self, key, value): 

43 self._mutated = True 

44 return super().__setitem__(key, value) 

45 

46 def __delitem__(self, key): 

47 self._mutated = True 

48 return super().__delitem__(key) 

49 

50 def clear(self): 

51 self._mutated = True 

52 return super().clear() 

53 

54 @property 

55 def mutated(self): 

56 return self._mutated 

57 

58 @mutated.setter 

59 def mutated(self, value): 

60 self._mutated = value 

61 

62 

63class FunctionCaptures(object): 

64 """A container for all capture usages within FuncGraph.""" 

65 

66 def __init__(self): 

67 self._by_ref_internal = py_collections.OrderedDict() 

68 self._by_ref_external = py_collections.OrderedDict() 

69 self._by_ref_tracetype = py_collections.OrderedDict() 

70 self._by_val_internal = MutationAwareDict() 

71 self._by_val_external = MutationAwareDict() 

72 self._by_val_tracetype = py_collections.OrderedDict() 

73 

74 # Set of external ops on which the graph has a control dependency 

75 self.control = object_identity.ObjectIdentitySet() 

76 

77 def clear(self): 

78 self._by_ref_internal.clear() 

79 self._by_ref_external.clear() 

80 self._by_ref_tracetype.clear() 

81 self._by_val_internal.clear() 

82 self._by_val_external.clear() 

83 

84 def capture_by_value( 

85 self, 

86 graph: Any, 

87 tensor: core.Tensor, 

88 name: Optional[str] = None 

89 ) -> core.Tensor: 

90 """Captures `tensor` if it's external to this graph. 

91 

92 If `tensor` is from a different graph, returns a placeholder for it. 

93 `tensor` and the placeholder will appear in self.captures, and the 

94 placeholder will appear in self.inputs. Multiple calls to this method with 

95 the same `tensor` argument will return the same placeholder. If `tensor` is 

96 from this graph, returns `tensor`. 

97 

98 Args: 

99 graph: The FuncGraph that captures this tensor. 

100 tensor: Tensor. May be from this FuncGraph or a different graph. 

101 name: Optional name if a placeholder is created. 

102 

103 Returns: 

104 Tensor from this FuncGraph. 

105 

106 Raises: 

107 InaccessibleTensorError: if any tensors are accessed in a manner that 

108 bypasses the mechanisms required for the data dependencies to be correctly 

109 wired. 

110 """ 

111 if isinstance(tensor, core.Value): 

112 if name is None: 

113 # A unique (within the program execution) integer. 

114 name = str(pywrap_tfe.TFE_Py_UID()) 

115 

116 # Small EagerTensors are captured with Const ops 

117 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and 

118 functools.reduce(lambda a, b: a*b, tensor.shape, 1) <= 

119 _EAGER_CONST_THRESHOLD): 

120 graph_const = self.by_val_internal.get(id(tensor)) 

121 if graph_const is None: 

122 graph_const = tensor._capture_as_const(name) # pylint: disable=protected-access 

123 if graph_const is None: 

124 # Some eager tensors, e.g. parallel tensors, are not convertible to 

125 # a single constant. We'll use a placeholder for this case. 

126 graph_const = self._create_placeholder_helper(graph, tensor, name) 

127 self.add_or_replace( 

128 key=id(tensor), 

129 external=tensor, 

130 internal=graph_const, 

131 is_by_ref=False) 

132 graph.inputs.append(graph_const) 

133 graph_const._record_tape(tensor) # pylint: disable=protected-access 

134 return graph_const 

135 

136 # Large EagerTensors and resources are captured with Placeholder ops 

137 return self._create_placeholder_helper(graph, tensor, name) 

138 

139 if tensor.graph is not graph: 

140 graph._validate_in_scope(tensor) # pylint: disable=protected-access 

141 if name is None: 

142 assert tensor.op is not None, ( 

143 tensor.__class__, 

144 dir(tensor), 

145 tensor.__class__.__name__, 

146 ) 

147 name = tensor.op.name 

148 # cond/while graphs override _capture_helper() so cannot call 

149 # self.create_placeholder_helper() here directly. 

150 return graph._capture_helper(tensor, name) # pylint: disable=protected-access 

151 return tensor 

152 

153 def add_or_replace( 

154 self, 

155 key: Hashable, 

156 external: Any, 

157 internal: core.Tensor, 

158 tracetype: Any = None, 

159 is_by_ref: bool = False) -> None: 

160 """Replace a already exsiting capture, otherwise add it.""" 

161 if is_by_ref: 

162 self._by_ref_external[key] = external 

163 self._by_ref_internal[key] = internal 

164 self._by_ref_tracetype[key] = tracetype 

165 else: 

166 self._by_val_internal[key] = internal 

167 self._by_val_external[key] = external 

168 if tracetype is not None: 

169 self._by_val_tracetype[key] = tracetype 

170 else: 

171 self._by_val_tracetype[key] = trace_type.from_value(external) 

172 

173 def pop(self, 

174 key: Hashable, 

175 is_by_ref: bool = False) -> Any: 

176 if is_by_ref: 

177 return (self._by_ref_external.pop(key, None), 

178 self._by_ref_internal.pop(key, None), 

179 self._by_ref_tracetype.pop(key, None)) 

180 else: 

181 return (self._by_val_external.pop(key, None), 

182 self._by_val_internal.pop(key, None), 

183 self._by_val_tracetype.pop(key, None)) 

184 

185 def reset_captures(self, tensors, placeholders): 

186 """Set the captures with the provided list of captures & placeholder.""" 

187 self._by_val_external = MutationAwareDict() 

188 self._by_val_internal = MutationAwareDict() 

189 self._by_val_tracetype = MutationAwareDict() 

190 for external, internal in zip(tensors, placeholders): 

191 key = id(external) 

192 self._by_val_external[key] = external 

193 self._by_val_internal[key] = internal 

194 self._by_val_tracetype[key] = trace_type.from_value(external) 

195 

196 # TODO(panzf): make the method public after supporting lam() returns 

197 # non-tensor values. Currently, this method is only used by 

198 # FuncGraph._experimental_capture_side_input_by_ref(), which contains the 

199 # logics for converting non-tensor values to tensor. 

200 def _capture_by_ref(self, 

201 graph: Any, 

202 lam: Callable[[], Any], 

203 key: Hashable = None) -> Any: 

204 """Used during tracing process to create/retrive by-ref captures. 

205 

206 Args: 

207 graph: The FuncGraph that captures this tensor. 

208 lam: A callable that takes no arguments and returns tensor captures. 

209 key: A hashable identifier. 

210 

211 Returns: 

212 Tensor from this FuncGraph. 

213 """ 

214 # Check if the capture exists in self._by_ref 

215 if key is not None and key in self._by_ref_internal: 

216 return self._by_ref_internal[key] 

217 if key is None: 

218 key = len(self._by_ref_internal) 

219 while key in self._by_ref_internal: 

220 key += 1 

221 

222 value_nested = lam() 

223 capture_trace_type = trace_type.from_value(value_nested) 

224 ctx = trace_type.InternalPlaceholderContext(graph) 

225 internal = capture_trace_type.placeholder_value(ctx) 

226 

227 def lam_fn(): 

228 # pytype: disable=attribute-error 

229 value = lam() 

230 return capture_trace_type._to_tensors(value) # pylint: disable=protected-access 

231 # pytype: enable=attribute-error 

232 

233 self._by_ref_external[key] = lam_fn 

234 self._by_ref_internal[key] = internal 

235 self._by_ref_tracetype[key] = capture_trace_type 

236 return self._by_ref_internal[key] 

237 

238 def merge_by_ref_with(self, other: "FunctionCaptures") -> None: 

239 """Add by-ref captures from `other` to `self` if not exist.""" 

240 assert isinstance(other, FunctionCaptures) 

241 for key in other.by_ref_external: 

242 if key not in self._by_ref_external: 

243 self._by_ref_external[key] = other.by_ref_external[key] 

244 self._by_ref_tracetype[key] = other.by_ref_tracetype[key] 

245 

246 def get_by_ref_snapshot(self) -> Mapping[Hashable, Any]: 

247 """Get a snapshot of current values of by-ref captures.""" 

248 snapshot = {} 

249 for key in self._by_ref_external: 

250 func = self._by_ref_external[key] 

251 try: 

252 value = func() 

253 except (AttributeError, RuntimeError): 

254 # b/269680071 In case of by-ref captures are unavailable at dispatch 

255 # time, use the predefined trace_type instead. 

256 value = self._by_ref_tracetype[key] 

257 snapshot[key] = value 

258 return snapshot 

259 

260 def _create_placeholder_helper( 

261 self, 

262 graph: Any, 

263 tensor: core.Tensor, 

264 name: str): 

265 """A helper function to create capture placeholder.""" 

266 placeholder = self._by_val_internal.get(id(tensor)) 

267 if placeholder is None: 

268 tracing_ctx = trace_type.InternalTracingContext() 

269 spec = trace_type.from_value(tensor, tracing_ctx) 

270 spec._name = name # pylint: disable=protected-access 

271 if isinstance(tensor, core.Value) and tensor.is_packed: 

272 composite_device_name = tensor.device 

273 else: 

274 composite_device_name = None 

275 placeholder_ctx = trace_type.InternalPlaceholderContext( 

276 graph, 

277 with_none_control_dependencies=True, 

278 composite_device_name=composite_device_name) 

279 placeholder = spec.placeholder_value(placeholder_ctx) 

280 self.add_or_replace( 

281 key=id(tensor), 

282 external=tensor, 

283 internal=placeholder, 

284 is_by_ref=False) 

285 graph.inputs.append(placeholder) 

286 placeholder._record_tape(tensor) # pylint: disable=protected-access 

287 return placeholder 

288 

289 def _recompute_tuple_cache(self): 

290 assert len(self._by_val_internal) == len(self._by_val_external) 

291 self._tuple_cache = [] 

292 for key in self._by_val_internal: 

293 assert key in self._by_val_external 

294 internal = self._by_val_internal[key] 

295 external = self._by_val_external[key] 

296 self._tuple_cache.append((external, internal)) 

297 

298 @property 

299 def capture_types(self): 

300 return {**self._by_val_tracetype, **self._by_ref_tracetype} 

301 

302 @property 

303 def by_val_capture_tuples(self): 

304 if self._by_val_internal.mutated or self._by_val_external.mutated: 

305 self. _recompute_tuple_cache() 

306 self._by_val_internal.mutated = False 

307 self._by_val_external.mutated = False 

308 return self._tuple_cache 

309 

310 @property 

311 def by_ref_internal(self): 

312 return self._by_ref_internal 

313 

314 @property 

315 def by_ref_external(self): 

316 return self._by_ref_external 

317 

318 @property 

319 def by_ref_tracetype(self): 

320 return self._by_ref_tracetype 

321 

322 @property 

323 def by_val_internal(self): 

324 return self._by_val_internal 

325 

326 @property 

327 def by_val_external(self): 

328 return self._by_val_external 

329 

330 @property 

331 def by_val_tracetype(self): 

332 return self._by_val_tracetype