1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation for AtomicFunction."""
16
17import dataclasses
18from typing import Any
19
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.core.function import trace_type
22from tensorflow.core.function.polymorphism import function_type as function_type_lib
23from tensorflow.python.client import pywrap_tf_session
24from tensorflow.python.eager import context
25from tensorflow.python.eager import record
26from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
27from tensorflow.python.framework import auto_control_deps_utils as acd
28from tensorflow.python.framework import error_interpolation
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.ops import handle_data_util
33from tensorflow.python.util import compat
34from tensorflow.python.util import function_utils
35
36
37class _InterpolateFunctionError(object):
38 """Context Manager that interpolates the exception from 'top_level_func'."""
39
40 __slots__ = ["_func"]
41
42 def __init__(self, top_level_func):
43 self._func = top_level_func
44
45 def __enter__(self):
46 pass
47
48 def __exit__(self, typ, exc, tb):
49 if not exc or not isinstance(exc, errors.OpError):
50 return False
51 message = compat.as_text(exc.message)
52 _, func_tags, _ = error_interpolation.parse_message(message)
53 g = None
54 for func_tag in func_tags:
55 # TODO(mdan): Tests should cover this.
56 if func_tag.name == compat.as_str(self._func.name):
57 g = self._func.graph
58 elif g:
59 next_func = g._get_function(func_tag.name) # pylint: disable=protected-access
60 if next_func is not None and isinstance(next_func, AtomicFunction):
61 g = next_func.graph
62 if g:
63 exc._message = error_interpolation.interpolate(message, g) # pylint: disable=protected-access
64 return False
65
66
67# TODO(b/232961485): Remove after quarantined `add_function_callback` removed.
68function_callbacks = set()
69
70
71# TODO(fmuham): Lower to FunctionRecord or remove otherwise.
72@dataclasses.dataclass(frozen=True)
73class GraphArtifacts:
74 control_captures: Any
75 graph: Any
76 stateful_ops: Any
77
78# Maps the scope_id and name in runtime to the number of AtomicFunctions.
79RUNTIME_FUNCTION_REFS = {}
80
81
82class AtomicFunction:
83 """A Python callable for functions in the TF Runtime.
84
85 Supports tf.function features such as structured value inputs and outputs,
86 captures and control dependencies.
87
88 Lowest level abstraction in the Python tf.function implementation.
89 """
90 __slots__ = [
91 "_name",
92 "_bound_context",
93 "_function_type",
94 "_graph_artifacts",
95 "_cached_definition",
96 ]
97
98 def __init__(self, name, bound_context, function_type, graph_artifacts):
99 self._name = compat.as_bytes(name)
100 self._bound_context = bound_context
101 self._function_type = function_type
102 self._graph_artifacts = graph_artifacts
103 self._cached_definition = None
104
105 ref_key = (self._bound_context.function_scope_id, self.name)
106 if ref_key not in RUNTIME_FUNCTION_REFS:
107 RUNTIME_FUNCTION_REFS[ref_key] = 1
108 else:
109 RUNTIME_FUNCTION_REFS[ref_key] += 1
110
111 @property
112 def _c_func(self):
113 return context.get_c_function(self.name)
114
115 @property
116 def function_type(self):
117 return self._function_type
118
119 # TODO(fmuham): Remove this property.
120 @property
121 def graph(self):
122 return self._graph_artifacts.graph
123
124 # TODO(fmuham): Remove this property.
125 @property
126 def stateful_ops(self):
127 return self._graph_artifacts.stateful_ops
128
129 @property
130 def definition(self):
131 """Current FunctionDef in the Runtime."""
132 return self._bound_context.get_function_def(self.name)
133
134 # TODO(fmuham): Move caching to dependent code and remove method.
135 @property
136 def cached_definition(self):
137 """Cached FunctionDef (not guaranteed to be fresh)."""
138 if self._cached_definition is None:
139 self._cached_definition = self.definition
140
141 return self._cached_definition
142
143 @property
144 def name(self):
145 """Name represented in UTF-8 encoded bytes."""
146 return self._name
147
148 @property
149 def graph_call_attrs(self):
150 """Returns a dictionary of attributes needed to add a call in graph."""
151 attrs = {
152 "is_stateful": len(self.stateful_ops) > 0, # pylint: disable=g-explicit-length-test
153 "tout": [
154 o.dtype.as_datatype_enum for o in self.function_type.flat_outputs
155 ],
156 "xla_compile_attr": self.cached_definition.attr.get(
157 attributes_lib.XLA_COMPILE, None
158 ),
159 }
160 attrs.update(self._bound_context.function_call_options.as_attrs())
161 return attrs
162
163 def __call__(self, *args):
164 """Calls this function with `args` as inputs.
165
166 `ConcreteFunction` execution respects device annotations only if the
167 function won't be compiled with xla.
168
169 Args:
170 *args: arguments to call this function with.
171
172 Returns:
173 The outputs of the function call.
174
175 Raises:
176 ValueError: if the number of arguments is incorrect.
177 FunctionAlreadyGarbageCollectedError: if the function is no longer
178 available to be called because it has been garbage collected.
179 """
180 if len(args) != len(self.cached_definition.signature.input_arg):
181 raise ValueError(
182 "Signature specifies"
183 f" {len(list(self.cached_definition.signature.input_arg))} arguments,"
184 f" got: {len(args)}."
185 )
186
187 with _InterpolateFunctionError(self):
188 with ops.control_dependencies(self._graph_artifacts.control_captures):
189 # The caller must use record_operation to record this operation in the
190 # eager case, so we enforce the same requirement for the non-eager
191 # case by explicitly pausing recording. We don't have a gradient
192 # registered for PartitionedCall, so recording this operation confuses
193 # forwardprop code (GradientTape manages to ignore it).
194 with record.stop_recording():
195 if self._bound_context.executing_eagerly():
196 outputs = self._bound_context.call_function(
197 self.name,
198 list(args),
199 len(self.function_type.flat_outputs),
200 )
201 else:
202 outputs = make_call_op_in_graph(self, list(args))
203
204 for i, output_type in enumerate(self.function_type.flat_outputs):
205 handle_data = output_type.dtype._handle_data
206 if handle_data:
207 handle_data_util.set_handle_data(outputs[i], handle_data)
208
209 # TODO(fmuham): Use FunctionType cast here for all cases.
210 if not self._bound_context.executing_eagerly():
211 for i, output_type in enumerate(self.function_type.flat_outputs):
212 outputs[i].set_shape(output_type.shape)
213
214 return outputs
215
216 def __del__(self):
217 key = (self._bound_context.function_scope_id, self.name)
218 RUNTIME_FUNCTION_REFS[key] -= 1
219 if RUNTIME_FUNCTION_REFS[key] < 0:
220 raise RuntimeError(
221 f"AtomicFunction Refcounting for {self.name} is invalid."
222 )
223
224 if RUNTIME_FUNCTION_REFS[key] == 0:
225 try:
226 self._bound_context.remove_function(self.name)
227 RUNTIME_FUNCTION_REFS.pop(key)
228 except TypeError:
229 # Suppress some exceptions, mainly for the case when we're running on
230 # module deletion. Things that can go wrong include the context module
231 # already being unloaded, self._handle._handle_data no longer being
232 # valid, and so on. Printing warnings in these cases is silly
233 # (exceptions raised from __del__ are printed as warnings to stderr).
234 pass # 'NoneType' object is not callable when the handle has been
235 # partially unloaded.
236 except AttributeError:
237 pass # 'NoneType' object has no attribute 'eager_mode' when context has
238 # been unloaded. Will catch other module unloads as well.
239
240
241def _set_read_only_resource_inputs_attr(op, func_graph):
242 """Sets the list of resource inputs which are read-only.
243
244 This is used by AutomaticControlDependencies.
245
246 Args:
247 op: PartitionedCall Operation.
248 func_graph: FuncGraph.
249 """
250 read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph)
251 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
252 read_only_indices)
253
254
255def partitioned_call_op(
256 name,
257 args,
258 is_stateful,
259 tout,
260 config=None,
261 executor_type=None,
262 xla_compile_attr=None,
263):
264 """Generates a function call op respecting device annotations.
265
266 Args:
267 name: Name of the function to call.
268 args: The arguments of the function, including captured inputs.
269 is_stateful: If the function is stateful.
270 tout: a list containing the output dtypes enums
271 config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`,
272 all optimizations are disabled. Currently only handled for eager defined
273 functions.
274 executor_type: (Optional) A string for the name of the executor to be used
275 in the function call. If not set, or set to an empty string, the default
276 tensorflow executor will be used.
277 xla_compile_attr: (Optional) value of the XLA compilation attribute.
278
279 Returns:
280 Returns the operation.
281 """
282 if config is None:
283 config = function_utils.get_disabled_rewriter_config()
284
285 if executor_type is None:
286 executor_type = ""
287
288 # The generated binding returns an empty list for functions that don't
289 # return any Tensors, hence the need to use `create_op` directly.
290 args = [ops.convert_to_tensor(x) for x in args]
291 tin_attr = attr_value_pb2.AttrValue(
292 list=attr_value_pb2.AttrValue.ListValue(
293 type=[x.dtype.as_datatype_enum for x in args]))
294 tout_attr = attr_value_pb2.AttrValue(
295 list=attr_value_pb2.AttrValue.ListValue(type=tout))
296 func_attr = attr_value_pb2.AttrValue(
297 func=attr_value_pb2.NameAttrList(name=name))
298 executor_type_attr = attr_value_pb2.AttrValue(
299 s=compat.as_bytes(executor_type))
300
301 # When running in graph mode, the graph and function graphs are optimized
302 # (i.e. run through grappler) per the session options, so we can disable any
303 # eager-specific rewriting.
304 config_proto = attr_value_pb2.AttrValue(s=config)
305
306 op_name = "StatefulPartitionedCall" if is_stateful else "PartitionedCall"
307
308 # Propagate the attribute indicating the need to compile from function to the
309 # call itself.
310 op_attrs = {
311 "Tin": tin_attr,
312 "Tout": tout_attr,
313 "f": func_attr,
314 "config_proto": config_proto,
315 "executor_type": executor_type_attr,
316 }
317 if xla_compile_attr is not None:
318 op_attrs[attributes_lib.XLA_COMPILE] = xla_compile_attr
319
320 op = ops.get_default_graph().create_op(
321 op_name, args, tout, name=op_name, attrs=op_attrs
322 )
323 return op
324
325
326def make_call_op_in_graph(atomic, tensor_inputs):
327 """Adds an AtomicFunction to graph."""
328 graph = ops.get_default_graph()
329 graph._add_function_recursive(atomic) # pylint: disable=protected-access
330
331 function_call_attrs = atomic.graph_call_attrs
332 op = partitioned_call_op(
333 name=atomic.name,
334 args=tensor_inputs,
335 is_stateful=function_call_attrs["is_stateful"],
336 tout=function_call_attrs["tout"],
337 config=function_call_attrs["config_proto"],
338 executor_type=function_call_attrs["executor_type"],
339 xla_compile_attr=function_call_attrs["xla_compile_attr"],
340 )
341 _set_read_only_resource_inputs_attr(op, atomic.graph)
342 if hasattr(atomic.graph, "collective_manager_ids_used"):
343 ops.set_int_list_attr(
344 op,
345 acd.COLLECTIVE_MANAGER_IDS,
346 atomic.graph.collective_manager_ids_used,
347 )
348 return op.outputs if op.outputs else op
349
350# List of AtomicFunction -> AtomicFunction transformation functions.
351FUNCTION_TRANSFORMS = []
352
353
354def from_func_graph(name, graph, inputs, outputs, attrs):
355 """Initializes an AtomicFunction from FuncGraph with transforms."""
356
357 atomic = from_func_graph_no_transforms(name, graph, inputs, outputs, attrs)
358 for transform in FUNCTION_TRANSFORMS:
359 atomic = transform(atomic)
360 if not isinstance(atomic, AtomicFunction):
361 raise TypeError(
362 f"Transformation {transform} did not return an AtomicFunction."
363 )
364
365 return atomic
366
367
368def from_func_graph_no_transforms(
369 name, graph, inputs, outputs, attrs, overwrite=False
370):
371 """Initializes an AtomicFunction from FuncGraph.
372
373 Args:
374 name: str, the name for the created function.
375 graph: Graph, the graph containing the operations in the function
376 inputs: the tensors in the graph to be used as inputs to the function
377 outputs: the tensors in the graph which will be outputs from the function
378 attrs: dict mapping names of attributes to their AttrValue values
379 overwrite: overwrites function definition in the current context if needed
380
381 Returns:
382 An AtomicFunction instance.
383 """
384 input_ops = set(arg.op for arg in inputs)
385 operations = [op for op in graph.get_operations() if op not in input_ops]
386
387 graph_output_names = graph._output_names # pylint: disable=protected-access
388 if graph_output_names is not None and all(
389 ops.tensor_id(t) in graph_output_names for t in outputs
390 ):
391 output_names = [
392 compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
393 ]
394 if len(set(output_names)) != len(output_names):
395 # There are duplicate names for some reason, probably an invalid
396 # signature. Revert to auto-naming.
397 output_names = []
398 else:
399 output_names = []
400 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access
401 fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
402 c_graph,
403 compat.as_str(name),
404 False,
405 [o._c_op for o in operations], # pylint: disable=protected-access
406 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
407 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
408 output_names,
409 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access
410 [], # control_output_names
411 None,
412 compat.as_str(""),
413 )
414
415 for attr_name, attr_value in attrs.items():
416 serialized = attr_value.SerializeToString()
417 pywrap_tf_session.TF_FunctionSetAttrValueProto(
418 fn, compat.as_str(attr_name), serialized
419 )
420
421 name = compat.as_bytes(name)
422 bound_context = context.context()
423
424 if overwrite and bound_context.has_function(name):
425 bound_context.remove_function(name)
426
427 bound_context.add_c_function(fn)
428 pywrap_tf_session.TF_DeleteFunction(fn)
429
430 graph_artifacts = GraphArtifacts(
431 control_captures=graph.function_captures.control,
432 graph=graph,
433 stateful_ops=tuple(op for op in operations if op._is_stateful), # pylint: disable=protected-access
434 )
435
436 if graph.structured_input_signature is not None:
437 input_signature = graph.structured_input_signature
438 else:
439 input_signature = (
440 tuple(tensor_spec.TensorSpec.from_tensor(i) for i in inputs),
441 {},
442 )
443
444 # TODO(fmuham): Include output structure info from structured_outputs
445 output_signature = tuple(
446 trace_type.from_value(o) for o in outputs
447 )
448
449 function_type = function_type_lib.from_structured_signature(
450 input_signature,
451 output_signature,
452 graph.function_captures.capture_types,
453 )
454
455 return AtomicFunction(name, bound_context, function_type, graph_artifacts)