Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py: 31%

181 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation for AtomicFunction.""" 

16 

17import dataclasses 

18from typing import Any 

19 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.core.function import trace_type 

22from tensorflow.core.function.polymorphism import function_type as function_type_lib 

23from tensorflow.python.client import pywrap_tf_session 

24from tensorflow.python.eager import context 

25from tensorflow.python.eager import record 

26from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib 

27from tensorflow.python.framework import auto_control_deps_utils as acd 

28from tensorflow.python.framework import error_interpolation 

29from tensorflow.python.framework import errors 

30from tensorflow.python.framework import ops 

31from tensorflow.python.framework import tensor_spec 

32from tensorflow.python.ops import handle_data_util 

33from tensorflow.python.util import compat 

34from tensorflow.python.util import function_utils 

35 

36 

37class _InterpolateFunctionError(object): 

38 """Context Manager that interpolates the exception from 'top_level_func'.""" 

39 

40 __slots__ = ["_func"] 

41 

42 def __init__(self, top_level_func): 

43 self._func = top_level_func 

44 

45 def __enter__(self): 

46 pass 

47 

48 def __exit__(self, typ, exc, tb): 

49 if not exc or not isinstance(exc, errors.OpError): 

50 return False 

51 message = compat.as_text(exc.message) 

52 _, func_tags, _ = error_interpolation.parse_message(message) 

53 g = None 

54 for func_tag in func_tags: 

55 # TODO(mdan): Tests should cover this. 

56 if func_tag.name == compat.as_str(self._func.name): 

57 g = self._func.graph 

58 elif g: 

59 next_func = g._get_function(func_tag.name) # pylint: disable=protected-access 

60 if next_func is not None and isinstance(next_func, AtomicFunction): 

61 g = next_func.graph 

62 if g: 

63 exc._message = error_interpolation.interpolate(message, g) # pylint: disable=protected-access 

64 return False 

65 

66 

67# TODO(b/232961485): Remove after quarantined `add_function_callback` removed. 

68function_callbacks = set() 

69 

70 

71# TODO(fmuham): Lower to FunctionRecord or remove otherwise. 

72@dataclasses.dataclass(frozen=True) 

73class GraphArtifacts: 

74 control_captures: Any 

75 graph: Any 

76 stateful_ops: Any 

77 

78# Maps the scope_id and name in runtime to the number of AtomicFunctions. 

79RUNTIME_FUNCTION_REFS = {} 

80 

81 

82class AtomicFunction: 

83 """A Python callable for functions in the TF Runtime. 

84 

85 Supports tf.function features such as structured value inputs and outputs, 

86 captures and control dependencies. 

87 

88 Lowest level abstraction in the Python tf.function implementation. 

89 """ 

90 __slots__ = [ 

91 "_name", 

92 "_bound_context", 

93 "_function_type", 

94 "_graph_artifacts", 

95 "_cached_definition", 

96 ] 

97 

98 def __init__(self, name, bound_context, function_type, graph_artifacts): 

99 self._name = compat.as_bytes(name) 

100 self._bound_context = bound_context 

101 self._function_type = function_type 

102 self._graph_artifacts = graph_artifacts 

103 self._cached_definition = None 

104 

105 ref_key = (self._bound_context.function_scope_id, self.name) 

106 if ref_key not in RUNTIME_FUNCTION_REFS: 

107 RUNTIME_FUNCTION_REFS[ref_key] = 1 

108 else: 

109 RUNTIME_FUNCTION_REFS[ref_key] += 1 

110 

111 @property 

112 def _c_func(self): 

113 return context.get_c_function(self.name) 

114 

115 @property 

116 def function_type(self): 

117 return self._function_type 

118 

119 # TODO(fmuham): Remove this property. 

120 @property 

121 def graph(self): 

122 return self._graph_artifacts.graph 

123 

124 # TODO(fmuham): Remove this property. 

125 @property 

126 def stateful_ops(self): 

127 return self._graph_artifacts.stateful_ops 

128 

129 @property 

130 def definition(self): 

131 """Current FunctionDef in the Runtime.""" 

132 return self._bound_context.get_function_def(self.name) 

133 

134 # TODO(fmuham): Move caching to dependent code and remove method. 

135 @property 

136 def cached_definition(self): 

137 """Cached FunctionDef (not guaranteed to be fresh).""" 

138 if self._cached_definition is None: 

139 self._cached_definition = self.definition 

140 

141 return self._cached_definition 

142 

143 @property 

144 def name(self): 

145 """Name represented in UTF-8 encoded bytes.""" 

146 return self._name 

147 

148 @property 

149 def graph_call_attrs(self): 

150 """Returns a dictionary of attributes needed to add a call in graph.""" 

151 attrs = { 

152 "is_stateful": len(self.stateful_ops) > 0, # pylint: disable=g-explicit-length-test 

153 "tout": [ 

154 o.dtype.as_datatype_enum for o in self.function_type.flat_outputs 

155 ], 

156 "xla_compile_attr": self.cached_definition.attr.get( 

157 attributes_lib.XLA_COMPILE, None 

158 ), 

159 } 

160 attrs.update(self._bound_context.function_call_options.as_attrs()) 

161 return attrs 

162 

163 def __call__(self, *args): 

164 """Calls this function with `args` as inputs. 

165 

166 `ConcreteFunction` execution respects device annotations only if the 

167 function won't be compiled with xla. 

168 

169 Args: 

170 *args: arguments to call this function with. 

171 

172 Returns: 

173 The outputs of the function call. 

174 

175 Raises: 

176 ValueError: if the number of arguments is incorrect. 

177 FunctionAlreadyGarbageCollectedError: if the function is no longer 

178 available to be called because it has been garbage collected. 

179 """ 

180 if len(args) != len(self.cached_definition.signature.input_arg): 

181 raise ValueError( 

182 "Signature specifies" 

183 f" {len(list(self.cached_definition.signature.input_arg))} arguments," 

184 f" got: {len(args)}." 

185 ) 

186 

187 with _InterpolateFunctionError(self): 

188 with ops.control_dependencies(self._graph_artifacts.control_captures): 

189 # The caller must use record_operation to record this operation in the 

190 # eager case, so we enforce the same requirement for the non-eager 

191 # case by explicitly pausing recording. We don't have a gradient 

192 # registered for PartitionedCall, so recording this operation confuses 

193 # forwardprop code (GradientTape manages to ignore it). 

194 with record.stop_recording(): 

195 if self._bound_context.executing_eagerly(): 

196 outputs = self._bound_context.call_function( 

197 self.name, 

198 list(args), 

199 len(self.function_type.flat_outputs), 

200 ) 

201 else: 

202 outputs = make_call_op_in_graph(self, list(args)) 

203 

204 for i, output_type in enumerate(self.function_type.flat_outputs): 

205 handle_data = output_type.dtype._handle_data 

206 if handle_data: 

207 handle_data_util.set_handle_data(outputs[i], handle_data) 

208 

209 # TODO(fmuham): Use FunctionType cast here for all cases. 

210 if not self._bound_context.executing_eagerly(): 

211 for i, output_type in enumerate(self.function_type.flat_outputs): 

212 outputs[i].set_shape(output_type.shape) 

213 

214 return outputs 

215 

216 def __del__(self): 

217 key = (self._bound_context.function_scope_id, self.name) 

218 RUNTIME_FUNCTION_REFS[key] -= 1 

219 if RUNTIME_FUNCTION_REFS[key] < 0: 

220 raise RuntimeError( 

221 f"AtomicFunction Refcounting for {self.name} is invalid." 

222 ) 

223 

224 if RUNTIME_FUNCTION_REFS[key] == 0: 

225 try: 

226 self._bound_context.remove_function(self.name) 

227 RUNTIME_FUNCTION_REFS.pop(key) 

228 except TypeError: 

229 # Suppress some exceptions, mainly for the case when we're running on 

230 # module deletion. Things that can go wrong include the context module 

231 # already being unloaded, self._handle._handle_data no longer being 

232 # valid, and so on. Printing warnings in these cases is silly 

233 # (exceptions raised from __del__ are printed as warnings to stderr). 

234 pass # 'NoneType' object is not callable when the handle has been 

235 # partially unloaded. 

236 except AttributeError: 

237 pass # 'NoneType' object has no attribute 'eager_mode' when context has 

238 # been unloaded. Will catch other module unloads as well. 

239 

240 

241def _set_read_only_resource_inputs_attr(op, func_graph): 

242 """Sets the list of resource inputs which are read-only. 

243 

244 This is used by AutomaticControlDependencies. 

245 

246 Args: 

247 op: PartitionedCall Operation. 

248 func_graph: FuncGraph. 

249 """ 

250 read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph) 

251 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 

252 read_only_indices) 

253 

254 

255def partitioned_call_op( 

256 name, 

257 args, 

258 is_stateful, 

259 tout, 

260 config=None, 

261 executor_type=None, 

262 xla_compile_attr=None, 

263): 

264 """Generates a function call op respecting device annotations. 

265 

266 Args: 

267 name: Name of the function to call. 

268 args: The arguments of the function, including captured inputs. 

269 is_stateful: If the function is stateful. 

270 tout: a list containing the output dtypes enums 

271 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`, 

272 all optimizations are disabled. Currently only handled for eager defined 

273 functions. 

274 executor_type: (Optional) A string for the name of the executor to be used 

275 in the function call. If not set, or set to an empty string, the default 

276 tensorflow executor will be used. 

277 xla_compile_attr: (Optional) value of the XLA compilation attribute. 

278 

279 Returns: 

280 Returns the operation. 

281 """ 

282 if config is None: 

283 config = function_utils.get_disabled_rewriter_config() 

284 

285 if executor_type is None: 

286 executor_type = "" 

287 

288 # The generated binding returns an empty list for functions that don't 

289 # return any Tensors, hence the need to use `create_op` directly. 

290 args = [ops.convert_to_tensor(x) for x in args] 

291 tin_attr = attr_value_pb2.AttrValue( 

292 list=attr_value_pb2.AttrValue.ListValue( 

293 type=[x.dtype.as_datatype_enum for x in args])) 

294 tout_attr = attr_value_pb2.AttrValue( 

295 list=attr_value_pb2.AttrValue.ListValue(type=tout)) 

296 func_attr = attr_value_pb2.AttrValue( 

297 func=attr_value_pb2.NameAttrList(name=name)) 

298 executor_type_attr = attr_value_pb2.AttrValue( 

299 s=compat.as_bytes(executor_type)) 

300 

301 # When running in graph mode, the graph and function graphs are optimized 

302 # (i.e. run through grappler) per the session options, so we can disable any 

303 # eager-specific rewriting. 

304 config_proto = attr_value_pb2.AttrValue(s=config) 

305 

306 op_name = "StatefulPartitionedCall" if is_stateful else "PartitionedCall" 

307 

308 # Propagate the attribute indicating the need to compile from function to the 

309 # call itself. 

310 op_attrs = { 

311 "Tin": tin_attr, 

312 "Tout": tout_attr, 

313 "f": func_attr, 

314 "config_proto": config_proto, 

315 "executor_type": executor_type_attr, 

316 } 

317 if xla_compile_attr is not None: 

318 op_attrs[attributes_lib.XLA_COMPILE] = xla_compile_attr 

319 

320 op = ops.get_default_graph().create_op( 

321 op_name, args, tout, name=op_name, attrs=op_attrs 

322 ) 

323 return op 

324 

325 

326def make_call_op_in_graph(atomic, tensor_inputs): 

327 """Adds an AtomicFunction to graph.""" 

328 graph = ops.get_default_graph() 

329 graph._add_function_recursive(atomic) # pylint: disable=protected-access 

330 

331 function_call_attrs = atomic.graph_call_attrs 

332 op = partitioned_call_op( 

333 name=atomic.name, 

334 args=tensor_inputs, 

335 is_stateful=function_call_attrs["is_stateful"], 

336 tout=function_call_attrs["tout"], 

337 config=function_call_attrs["config_proto"], 

338 executor_type=function_call_attrs["executor_type"], 

339 xla_compile_attr=function_call_attrs["xla_compile_attr"], 

340 ) 

341 _set_read_only_resource_inputs_attr(op, atomic.graph) 

342 if hasattr(atomic.graph, "collective_manager_ids_used"): 

343 ops.set_int_list_attr( 

344 op, 

345 acd.COLLECTIVE_MANAGER_IDS, 

346 atomic.graph.collective_manager_ids_used, 

347 ) 

348 return op.outputs if op.outputs else op 

349 

350# List of AtomicFunction -> AtomicFunction transformation functions. 

351FUNCTION_TRANSFORMS = [] 

352 

353 

354def from_func_graph(name, graph, inputs, outputs, attrs): 

355 """Initializes an AtomicFunction from FuncGraph with transforms.""" 

356 

357 atomic = from_func_graph_no_transforms(name, graph, inputs, outputs, attrs) 

358 for transform in FUNCTION_TRANSFORMS: 

359 atomic = transform(atomic) 

360 if not isinstance(atomic, AtomicFunction): 

361 raise TypeError( 

362 f"Transformation {transform} did not return an AtomicFunction." 

363 ) 

364 

365 return atomic 

366 

367 

368def from_func_graph_no_transforms( 

369 name, graph, inputs, outputs, attrs, overwrite=False 

370): 

371 """Initializes an AtomicFunction from FuncGraph. 

372 

373 Args: 

374 name: str, the name for the created function. 

375 graph: Graph, the graph containing the operations in the function 

376 inputs: the tensors in the graph to be used as inputs to the function 

377 outputs: the tensors in the graph which will be outputs from the function 

378 attrs: dict mapping names of attributes to their AttrValue values 

379 overwrite: overwrites function definition in the current context if needed 

380 

381 Returns: 

382 An AtomicFunction instance. 

383 """ 

384 input_ops = set(arg.op for arg in inputs) 

385 operations = [op for op in graph.get_operations() if op not in input_ops] 

386 

387 graph_output_names = graph._output_names # pylint: disable=protected-access 

388 if graph_output_names is not None and all( 

389 ops.tensor_id(t) in graph_output_names for t in outputs 

390 ): 

391 output_names = [ 

392 compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs 

393 ] 

394 if len(set(output_names)) != len(output_names): 

395 # There are duplicate names for some reason, probably an invalid 

396 # signature. Revert to auto-naming. 

397 output_names = [] 

398 else: 

399 output_names = [] 

400 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

401 fn = pywrap_tf_session.TF_GraphToFunction_wrapper( 

402 c_graph, 

403 compat.as_str(name), 

404 False, 

405 [o._c_op for o in operations], # pylint: disable=protected-access 

406 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 

407 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 

408 output_names, 

409 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access 

410 [], # control_output_names 

411 None, 

412 compat.as_str(""), 

413 ) 

414 

415 for attr_name, attr_value in attrs.items(): 

416 serialized = attr_value.SerializeToString() 

417 pywrap_tf_session.TF_FunctionSetAttrValueProto( 

418 fn, compat.as_str(attr_name), serialized 

419 ) 

420 

421 name = compat.as_bytes(name) 

422 bound_context = context.context() 

423 

424 if overwrite and bound_context.has_function(name): 

425 bound_context.remove_function(name) 

426 

427 bound_context.add_c_function(fn) 

428 pywrap_tf_session.TF_DeleteFunction(fn) 

429 

430 graph_artifacts = GraphArtifacts( 

431 control_captures=graph.function_captures.control, 

432 graph=graph, 

433 stateful_ops=tuple(op for op in operations if op._is_stateful), # pylint: disable=protected-access 

434 ) 

435 

436 if graph.structured_input_signature is not None: 

437 input_signature = graph.structured_input_signature 

438 else: 

439 input_signature = ( 

440 tuple(tensor_spec.TensorSpec.from_tensor(i) for i in inputs), 

441 {}, 

442 ) 

443 

444 # TODO(fmuham): Include output structure info from structured_outputs 

445 output_signature = tuple( 

446 trace_type.from_value(o) for o in outputs 

447 ) 

448 

449 function_type = function_type_lib.from_structured_signature( 

450 input_signature, 

451 output_signature, 

452 graph.function_captures.capture_types, 

453 ) 

454 

455 return AtomicFunction(name, bound_context, function_type, graph_artifacts)