Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/function.py: 16%
523 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
17NOTE: At this time, functions are experimental and subject to change!. Proceed
18with caution.
19"""
21import collections
22import hashlib
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import function_pb2
26from tensorflow.python.client import pywrap_tf_session as c_api
27from tensorflow.python.eager import context
28from tensorflow.python.framework import c_api_util
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import graph_to_function_def
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variable_scope as vs
35from tensorflow.python.util import compat
36from tensorflow.python.util import function_utils
37from tensorflow.python.util import tf_contextlib
38from tensorflow.python.util import tf_inspect
41# TODO(b/136040013): Drop support for Defun.
42class Defun(object):
43 """Obsolete. Slated for deletion. Please use tf.function instead.
45 Known feature gaps while migrating to tf.function (could be outdated):
46 - tf.function doesn’t support Send/Recv capability since it doesn’t share
47 rendezvous with the main graph but always creates a new one.
48 - tf.function doesn’t support custom gradient function directly, instead you
49 need to define the function inside a tf.custom_gradient wrapper together
50 with the gradient function.
51 - Unlike Defun, Keras layers used inside a tf.function need to be created only
52 once to avoid variable recreation.
53 - Defun respects the device assignments and applies them to the function body
54 but tf.function needs it to be done manually.
55 - Defun might prune out unused ops automatically but tf.function doesn't.
57 Limitations of Defun:
58 - Original source locations are not preserved so errors do not include
59 full/valid stack traces.
60 - Only supports linear sequence of arguments and return values, putting the
61 burden on the caller to pack/unpack everything across a Defun boundary into
62 tuples (as opposed to passing list and dict-like structures directly).
63 - Does not support overloading or late-bound specializations.
64 - Has its own way for defining gradient overrides which does not follow
65 current conventions.
66 - Cannot support imperative control flow or automatic control dependencies.
67 - Does not reflect statefulness in the graph and has a calling convention that
68 differs from how more modern tools interact.
69 - Is only compatible with graph building mode.
71 Decorator used to define TensorFlow functions.
73 Use this decorator to make a Python function usable directly as a TensorFlow
74 function.
76 The decorated function must add ops to the default graph and return zero or
77 more `Tensor` objects. Call the decorator with named arguments, one for each
78 argument of the function to decorate, with the expected type of the argument
79 as value.
81 For example if the function to decorate accepts two `tf.float32` arguments
82 named `x` and `y`, call the decorator with:
84 @Defun(tf.float32, tf.float32)
85 def foo(x, y):
86 ...
88 When you call the decorated function, it adds the `call` ops to the
89 default graph. In addition, it adds the definition of the function into the
90 default graph. Because the addition of the function into the graph
91 is deferred, the decorator can be used anywhere in the program.
93 Any variables created inside of the function are hoisted into the outer graph.
94 Note that the variables are created in the variable scope that was active
95 during the first call to the function. Subsequent function calls will refer to
96 the same set of variables.
98 Definitions of functions in a graph are frozen as soon as the graph is used to
99 create a session. However, new functions and new calls to existing functions
100 may be added to the graph, with the new functions themselves becoming
101 immediately frozen.
103 Example, but also see the [How To on functions](link_needed).
105 ```python
106 # Defining the function.
107 @tf.Defun(tf.float32, tf.float32)
108 def MyFunc(x, y):
109 return x + y, x - y
111 # Building the graph.
112 a = tf.constant([1.0])
113 b = tf.constant([2.0])
114 c, d = MyFunc(a, b, name='mycall')
115 ```
116 """
118 def __init__(self, *input_types, **kwargs):
119 """Create a `Defun` decorator.
121 Args:
122 *input_types: A list of `tf.DType`
123 **kwargs: Optional keyword arguments, including
124 func_name - (optional). A python string, the name to use to
125 declare this `Function` in the graph.
127 grad_func - (optional). A function implementing the gradient
128 of the function-to-register. This is must be a
129 `_DefinedFunction` object. The gradient
130 function must satisfy the criterion defined in
131 function.proto:GradientDef.
133 python_grad_func - (optional). A function implementing the
134 gradient of the function python-side. This function must
135 take the current op and the gradients w.r.t. its outputs,
136 and return the gradients w.r.t. the inputs. That is it must
137 implement the interface expected by `tf.RegisterGradient`).
138 This will be called by tf.gradients to add the gradient ops
139 to the graph. At most one of grad_func and python_grad_func
140 can be specified.
142 out_names = (optional). A list of strings, one per output
143 tensor.
145 shape_func - (optional). A function taking the op and returning a list
146 of static shapes to set for the function's outputs.
147 """
148 self._input_types = input_types
149 self._func_name = kwargs.pop("func_name", None)
150 self._grad_func = kwargs.pop("grad_func", None)
151 self._python_grad_func = kwargs.pop("python_grad_func", None)
152 self._out_names = kwargs.pop("out_names", None)
153 self._extra_kwargs = kwargs
155 def __call__(self, func):
156 # Various sanity checks on the callable func.
157 if not callable(func):
158 raise ValueError(f"Function {func} must be a callable.")
160 # Func should not use kwargs and defaults.
161 argspec = tf_inspect.getargspec(func)
162 if argspec.keywords or argspec.defaults:
163 raise ValueError(
164 "Functions with argument defaults or keywords arguments are not "
165 f"supported. {func} has defaults {argspec.defaults} and keywords "
166 f"{argspec.keywords}.")
168 # Computes how many arguments 'func' has.
169 min_args = len(argspec.args)
170 max_args = min_args
171 if argspec.varargs:
172 max_args = 1000000
173 argnames = argspec.args
174 if tf_inspect.ismethod(func):
175 # 1st argument is the "class" type.
176 min_args -= 1
177 argnames = argnames[1:]
179 if self._input_types:
180 # If Defun is given a list of types for the inputs, the number
181 # of input types should be compatible with 'func'.
182 num = len(self._input_types)
183 if num < min_args or num > max_args:
184 raise ValueError(
185 "The number of tf.function input types is not compatible with the "
186 f"allowed arguments of {func}. The tf.function have {num} input "
187 f"types, while the python function allows minimum {min_args} and "
188 f"maximum {max_args} arguments.")
189 return _DefinedFunction(
190 func,
191 argnames,
192 self._input_types,
193 self._func_name,
194 self._grad_func,
195 self._python_grad_func,
196 out_names=self._out_names,
197 **self._extra_kwargs)
199 # 'func' expects no arguments and input types is an empty list.
200 if min_args == 0 and max_args == 0:
201 return _DefinedFunction(
202 func, [], [],
203 self._func_name,
204 self._grad_func,
205 self._python_grad_func,
206 out_names=self._out_names,
207 **self._extra_kwargs)
209 # Input types are unknown. It's an overloaded function and hence
210 # its definition needs to be deferred until it's called.
211 return _OverloadedFunction(
212 func,
213 argnames,
214 self._func_name,
215 self._grad_func,
216 self._python_grad_func,
217 out_names=self._out_names,
218 **self._extra_kwargs)
221class _DefinedFunctionDeleter(object):
222 """Unregister function from eager context."""
224 __slots__ = ["name"]
226 def __init__(self, name):
227 self.name = name
229 def __del__(self):
230 try:
231 context.remove_function(self.name)
232 except TypeError:
233 # Suppress some exceptions, mainly for the case when we're running on
234 # module deletion. Things that can go wrong include the context module
235 # already being unloaded, self._handle._handle_data no longer being
236 # valid, and so on. Printing warnings in these cases is silly
237 # (exceptions raised from __del__ are printed as warnings to stderr).
238 pass # 'NoneType' object is not callable when the handle has been
239 # partially unloaded.
240 except AttributeError:
241 pass # 'NoneType' object has no attribute 'eager_mode' when context has
242 # been unloaded. Will catch other module unloads as well.
245class _DefinedFunction(object):
246 """_DefinedFunction encapsulates a function definition and its properties.
248 Attributes:
249 name: The function name.
250 definition: The definition of this function. A FunctionDef proto.
251 cached_definition: Same as definition. Needed to match AtomicFunction API.
252 grad_func_name: If not None, the name of this function's gradient function.
253 python_grad_func: A python callable implementing the gradient of
254 the function python-side.
255 """
257 def __init__(self,
258 func,
259 argnames,
260 input_types,
261 func_name=None,
262 grad_func=None,
263 python_grad_func=None,
264 out_names=None,
265 shape_func=None,
266 capture_by_value=False,
267 allowlisted_stateful_ops=None,
268 capture_resource_var_by_value=True,
269 **kwargs):
270 """Creates _DefinedFunction.
272 Args:
273 func: A python callable which constructs a tf function body.
274 argnames: A list of strings for function argument names.
275 input_types: The function's argument types. Can be a tuple, list of
276 tf data types.
277 func_name: The function name. Defaults to None, in which derives from
278 'func'.
279 grad_func: This function's gradient function, if not None. Defaults
280 to None.
281 python_grad_func: A python callable implementing the gradient of
282 the function python-side.
283 out_names: An optional list of strings for the function return value
284 names.
285 shape_func: An optional function mapping an op to a list of static
286 output shapes.
287 capture_by_value: Boolean (defaults to False). If True, captured values
288 will be copied into the function body.
289 allowlisted_stateful_ops: A set of ops that if stateful we ignore and
290 copy into the function body, when `capture_by_value` is True.
291 capture_resource_var_by_value: Boolean (defaults to True). If False,
292 captured resource variable returns the handle instead of value.
293 **kwargs: The keyword arguments. **kwargs is passed to every call
294 site of this function.
296 Raises:
297 ValueError: The function definition is invalid.
299 """
300 self._func = func
301 self._input_types = input_types
302 self._func_name = func_name
303 self._grad_func = grad_func
304 self._python_grad_func = python_grad_func
305 self._out_names = out_names
306 self._shape_func = shape_func
307 self._capture_by_value = capture_by_value
308 self._allowlisted_stateful_ops = allowlisted_stateful_ops
309 if self._allowlisted_stateful_ops is None:
310 self._allowlisted_stateful_ops = set()
311 self._capture_resource_var_by_value = capture_resource_var_by_value
312 self._extra_kwargs = kwargs
313 # Constructed only when C API is disabled, lazily
314 self._definition = None
315 # Constructed only when C API is enabled, lazily
316 self._c_func = None
317 self._function_deleter = None
318 self._sub_functions = {} # Constructed with _definition or _c_func
319 # pylint: disable=protected-access
320 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
321 # pylint: enable=protected-access
323 # Get the innermost device if possible.
324 self._caller_device = device_funcs[-1] if device_funcs else None
326 # Cached OpDef for this function. When C API is enabled, this is
327 # the only part of FunctionDef that we cache in Python. When C API
328 # is disabled the whole _definition is available and this is simply
329 # another reference to _definition.signature
330 self._op_def = None
332 assert isinstance(input_types, (list, tuple))
333 self._arg_types = input_types
334 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
335 for i in range(len(input_types))]
337 @property
338 def name(self):
339 """Function name."""
340 self._create_definition_if_needed()
341 return self._func_name
343 @property
344 def cached_definition(self):
345 return self.definition
347 @property
348 def definition(self):
349 """Function definition proto."""
350 self._create_definition_if_needed()
351 if self._c_func:
352 with c_api_util.tf_buffer() as buf:
353 with self._c_func.get() as func:
354 c_api.TF_FunctionToFunctionDef(func, buf)
355 fdef = function_pb2.FunctionDef()
356 proto_data = c_api.TF_GetBuffer(buf)
357 fdef.ParseFromString(compat.as_bytes(proto_data))
358 with ops.init_scope():
359 if context.executing_eagerly():
360 context.add_c_function(func)
361 self._function_deleter = _DefinedFunctionDeleter(
362 fdef.signature.name)
363 return fdef
364 return self._definition
366 @property
367 def _signature(self):
368 self._create_definition_if_needed()
369 return self._op_def
371 def set_grad_func(self, grad_func):
372 """Specifies the gradient function of this function."""
373 assert not self._grad_func
374 assert isinstance(grad_func, _DefinedFunction)
375 self._grad_func = grad_func
377 @property
378 def grad_func_name(self):
379 """Returns the name of the gradient function."""
380 return self._grad_func.name if self._grad_func else None
382 @property
383 def python_grad_func(self):
384 """Python gradient function callable."""
385 return self._python_grad_func
387 @property
388 def declared_input_types(self):
389 """Returns the list of data types of explicit declared inputs."""
390 return self._input_types
392 @property
393 def captured_inputs(self):
394 """Returns the list of implicitly captured inputs."""
395 self._create_definition_if_needed()
396 return self._extra_inputs
398 @property
399 def stateful_ops(self):
400 """Returns the list of stateful ops in function definition.
402 Returns:
403 A list of (op.name, op.type) pairs.
404 """
405 self._create_definition_if_needed()
406 return self._stateful_ops
408 def _create_definition_if_needed(self):
409 """Creates the function definition if it's not created yet."""
410 with context.graph_mode():
411 self._create_definition_if_needed_impl()
413 def _create_definition_if_needed_impl(self):
414 """This is not what you want, see _create_definition_if_needed."""
415 if self._definition is not None or self._c_func is not None:
416 return
418 # Copy variable collections (by reference) from the parent graph such that
419 # name based variable sharing (e.g. via tf.make_template) works between the
420 # func graph and parent graph.
421 variable_keys = []
422 variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access
423 variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access
425 parent_graph = ops.get_default_graph()
426 collections_ref = {
427 key: parent_graph.get_collection_ref(key) for key in variable_keys}
429 temp_graph = func_graph_from_py_func(
430 self._func,
431 self._arg_names,
432 self._arg_types,
433 self._func_name,
434 self._capture_by_value,
435 self._caller_device,
436 collections_ref=collections_ref,
437 allowlisted_stateful_ops=self._allowlisted_stateful_ops,
438 capture_resource_var_by_value=self._capture_resource_var_by_value)
440 self._extra_inputs = temp_graph.extra_inputs
441 # pylint: disable=protected-access
442 self._sub_functions = temp_graph._functions
443 # pylint: enable=protected-access
445 # Extra kwargs are treated as attrs on the function def.
446 if self._func_name:
447 base_func_name = self._func_name
448 else:
449 base_func_name = function_utils.get_func_name(self._func)
450 if self._grad_func:
451 base_func_name += ("_%s" % self._grad_func.name)
452 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
454 # FIXME(feyu): C API is always enabled now. The if-true branch never runs.
455 if not temp_graph._c_graph: # pylint: disable=protected-access
456 # Build the FunctionDef
457 self._definition = graph_to_function_def.graph_to_function_def(
458 temp_graph,
459 temp_graph.get_operations(),
460 temp_graph.inputs,
461 temp_graph.outputs,
462 out_names=self._out_names)
464 for k in kwargs_attr:
465 self._definition.attr[k].CopyFrom(kwargs_attr[k])
467 # Hash the definition and its dependencies.
468 self._hash_str = self._create_hash_str(
469 self._definition.signature.input_arg,
470 self._definition.signature.output_arg, self._definition.node_def)
472 # Finally, we decide the function name to use. If not specified,
473 # make up something which is almost certainly unique (but deterministic).
474 if not self._func_name:
475 self._func_name = "_".join([base_func_name, self._hash_str])
476 self._definition.signature.name = self._func_name
477 if self._func.__doc__:
478 self._definition.signature.description = self._func.__doc__
480 self._op_def = self._definition.signature
481 else: # C API is enabled
482 output_names = ([compat.as_bytes(x) for x in self._out_names]
483 if self._out_names else [])
484 description = self._func.__doc__ or None
485 # pylint: disable=protected-access
486 with temp_graph._c_graph.get() as c_graph:
487 c_func = c_api.TF_GraphToFunction_wrapper(
488 c_graph,
489 base_func_name,
490 self._func_name is None, # append_hash_to_fn_name
491 None, # opers
492 [t._as_tf_output() for t in temp_graph.inputs],
493 [t._as_tf_output() for t in temp_graph.outputs],
494 output_names,
495 [], # control_outputs
496 [], # control_output_names
497 None, # opts
498 description)
499 self._c_func = c_api_util.ScopedTFFunction(c_func, base_func_name)
500 # pylint: enable=protected-access
501 self._set_c_attrs(kwargs_attr)
503 # Set cached fields: _op_def and _func_name (if not already set)
504 self._op_def = self.definition.signature
505 if self._func_name:
506 assert self._func_name == self._op_def.name
507 else:
508 self._func_name = compat.as_str(self._op_def.name)
510 self._stateful_ops = [(op.name, op.type)
511 for op in temp_graph.get_operations()
512 if op._is_stateful] # pylint: disable=protected-access
514 def _set_c_attrs(self, attrs):
515 """Sets `attrs` as attributes of self._c_func.
517 Requires that self._c_func is not None.
519 Args:
520 attrs: a dictionary from attribute name to attribute proto value
521 """
522 for name, attr_value in attrs.items():
523 serialized = attr_value.SerializeToString()
524 # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
525 # It might be worth creating a convenient way to re-use the same status.
526 with self._c_func.get() as func:
527 c_api.TF_FunctionSetAttrValueProto(func, compat.as_str(name),
528 serialized)
530 def _create_hash_str(self, input_arg, output_arg, node_def):
531 """Creates an 8-character string unique to this input.
533 Args:
534 input_arg: the input_arg field of an OpDef
535 (e.g. self._definition.signature.input_arg)
536 output_arg: the output_arg field of an OpDef
537 (e.g. self._definition.signature.output_arg)
538 node_def: the node_def field of a FunctionDef
539 (e.g. self._definition.node_def)
541 Returns:
542 The unique string for this input
543 """
544 hasher = hashlib.sha1()
546 def update_num(n):
547 hasher.update(compat.as_bytes("%x" % n))
549 def update_str(s):
550 update_num(len(s))
551 hasher.update(compat.as_bytes(s))
553 def update_strs(slist):
554 update_num(len(slist))
555 for s in slist:
556 update_str(s)
558 for adef in input_arg:
559 update_str(adef.SerializeToString())
561 for adef in output_arg:
562 update_str(adef.SerializeToString())
564 for n in sorted(node_def, key=lambda n: n.name):
565 update_str(n.name)
566 update_str(n.op)
567 update_strs(n.input)
568 update_num(len(n.attr))
569 # NOTE: protobuf map serialization does not guarantee ordering.
570 for k in sorted(n.attr):
571 update_str(k)
572 update_str(n.attr[k].SerializeToString())
574 return hasher.hexdigest()[:8]
576 def add_to_graph(self, g):
577 """Adds this function into the graph g."""
578 self._create_definition_if_needed()
580 # Adds this function into 'g'.
581 # pylint: disable=protected-access
582 if context.executing_eagerly():
583 context.context().add_function_def(self.definition)
584 else:
585 g._add_function(self)
586 # pylint: enable=protected-access
588 # Ensures related sub-routines are defined in 'g', too.
589 for f in self._sub_functions.values():
590 g._add_function_recursive(f) # pylint: disable=protected-access
592 # Adds its gradient function, too.
593 if self._grad_func:
594 self._grad_func.add_to_graph(g)
596 def __call__(self, *args, **kwargs):
597 self.add_to_graph(ops.get_default_graph())
598 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
599 ret, op = _call(self._signature, *args, **kwargs)
601 # Set a hidden attr in 'op' so that gradients_impl can refer back
602 # to this _DefinedFunction instance to access python_grad_func.
603 assert isinstance(op, ops.Operation)
604 setattr(op, "__defun", self)
606 if self._shape_func is not None:
607 shapes = self._shape_func(op)
608 if len(shapes) != len(op.outputs):
609 raise ValueError(f"shape_func {self._shape_func} produced "
610 f"{len(shapes):d} shapes, which does not match "
611 f"{len(op.outputs)} outputs.")
612 for (t, shape) in zip(op.outputs, shapes):
613 t.set_shape(shape)
614 return ret
617class _OverloadedFunction(object):
618 """_OverloadedFunction encapsulates an overloaded function.
620 _OverloadedFunction maintains a mapping from input types to
621 instantiated _DefinedFunction in self._overload.
623 """
625 def __init__(self,
626 func,
627 argnames,
628 func_name=None,
629 grad_func=None,
630 python_grad_func=None,
631 out_names=None,
632 **kwargs):
633 """Creates _DefinedFunction.
635 Args:
636 func: A python callable which constructs a tf function body.
637 argnames: A list of strings for function argument names.
638 func_name: The function name. Defaults to None, in which derives from
639 'func'.
640 grad_func: This function's gradient function, if not None. Defaults
641 to None.
642 python_grad_func: A python callable implementing the gradient of
643 the function python-side.
644 out_names: A list of strings for the function return value names.
645 **kwargs: The keyword arguments. **kwargs is passed to every call
646 site of this function.
648 Raises:
649 ValueError: The function definition is invalid.
651 """
652 self._func = func
653 self._argnames = argnames
654 self._func_name = func_name
655 assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
656 self._grad_func = grad_func
657 self._python_grad_func = python_grad_func
658 self._out_names = out_names
659 self._extra_kwargs = kwargs
660 self._overload = {}
662 def instantiate(self, input_types):
663 """Instantiate this function given input argument types.
665 Args:
666 input_types: A list of data types for the inputs.
668 Returns:
669 _DefinedFunction for the given input types.
671 """
672 # Stringify the type list.
673 key = _type_list_to_str(input_types)
674 defined = self._overload.get(key)
675 if not defined:
676 # If not defined yet, define the function given the input types.
677 name = self._func_name
678 if name is not None:
679 name = "_".join([name, key])
680 defined = _DefinedFunction(
681 self._func,
682 self._argnames,
683 input_types,
684 name,
685 None,
686 self._python_grad_func,
687 out_names=self._out_names,
688 **self._extra_kwargs)
689 _ = defined.name # Fully instantiate the function definition.
690 if self._grad_func:
691 # If _grad_func is given, it is another
692 # _OverloadedFunction. We need to instantiate it with the
693 # right input types.
694 output_types = [
695 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access
696 ]
697 # pylint: disable=protected-access
698 defined._grad_func = self._grad_func.instantiate(input_types +
699 output_types)
700 # pylint: enable=protected-access
701 self._overload[key] = defined
702 return defined
704 def __call__(self, *args, **kwargs):
705 input_types = []
706 args = list(args)
707 for (i, x) in enumerate(args):
708 x = ops.convert_to_tensor(x)
709 if not isinstance(x, ops.Tensor):
710 raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.")
711 input_types.append(x.dtype)
712 args[i] = x
713 return self.instantiate(input_types)(*args, **kwargs)
716class _FuncGraph(ops.Graph):
717 """A helper for constructing a function.
719 _FuncGraph overrides ops.Graph's create_op() so that we can keep
720 track of all inputs into every op created inside the function. If
721 any input is from other graphs, we keep track of it in self.capture
722 and substitute the input with a place holder.
724 Each captured input's corresponding place holder is converted into a
725 function argument and the caller passes in the captured tensor.
726 """
728 def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
729 capture_resource_var_by_value, *args, **kwargs):
730 super(_FuncGraph, self).__init__(*args, **kwargs)
731 self._capture_by_value = capture_by_value
732 self._allowlisted_stateful_ops = allowlisted_stateful_ops
733 self._capture_resource_var_by_value = capture_resource_var_by_value
734 self._building_function = True
735 self._outer_graph = ops.get_default_graph()
736 self._vscope = vs.get_variable_scope()
737 self._old_custom_getter = self._vscope.custom_getter
739 # The name of the function.
740 self.name = name
741 # Placeholder tensors representing the inputs to this function. The tensors
742 # are in this _FuncGraph.
743 self.inputs = []
744 # Tensors that will be returned this function. The tensors are in this
745 # _FuncGraph.
746 self.outputs = []
747 # Maps external tensor -> internal tensor (e.g. input placeholder).
748 self._captured = {}
749 # The external tensors that have been captured as inputs and must be passed
750 # to this function (empty if capturing by value, otherwise these are the
751 # keys of _captured).
752 self.extra_inputs = []
753 # Input placeholders that been added for captured values (empty if capturing
754 # by value).
755 self.extra_args = []
756 # Captured variables.
757 # TODO(skyewm): is this needed?
758 self.extra_vars = []
760 # pylint: disable=g-doc-return-or-yield
762 @property
763 def outer_graph(self):
764 """The graph active when this _FuncGraph was created."""
765 return self._outer_graph
767 @tf_contextlib.contextmanager
768 def container(self, container_name):
769 """Returns a context manager that specifies the resource container to use.
771 Overridden from `tf.Graph` to update both the init_scope container
772 and the present inner container. This is necessary to make sure setting
773 containers applies correctly both to created variables and to stateful
774 ops.
776 Args:
777 container_name: container name string.
779 Returns:
780 A context manager for defining resource containers for stateful ops,
781 yields the container name.
782 """
783 original_container = self._container
784 # pylint: disable=protected-access
785 with ops.init_scope():
786 original_init_container = ops.get_default_graph()._container
787 try:
788 self._container = container_name
789 with ops.init_scope():
790 ops.get_default_graph()._container = container_name
791 yield self._container
792 finally:
793 self._container = original_container
794 with ops.init_scope():
795 ops.get_default_graph()._container = original_init_container
796 # pylint: enable=protected-access
798 # pylint: enable=g-doc-return-or-yield
800 def getvar(
801 self,
802 getter,
803 name,
804 shape=None,
805 dtype=None,
806 initializer=None,
807 reuse=None,
808 trainable=True,
809 collections=None, # pylint: disable=redefined-outer-name
810 use_resource=None,
811 **kwargs):
812 """A custom variable getter."""
813 # Here, we switch the default graph to the outer graph and ask the
814 # variable scope in which the function is defined to give us the
815 # variable. The variable is stashed in extra_vars and returned to
816 # the caller.
817 #
818 # We capture these variables so that the variable definition is
819 # hoisted upward to the outer most graph.
820 with self._outer_graph.as_default():
821 # pylint: disable=protected-access
822 var = self._vscope.get_variable(
823 vs._get_default_variable_store(),
824 name,
825 shape=shape,
826 dtype=dtype,
827 initializer=initializer,
828 reuse=reuse,
829 trainable=trainable,
830 collections=collections,
831 use_resource=use_resource)
832 self.extra_vars.append(var)
833 if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
834 self._capture_resource_var_by_value):
835 # For resource-based variables read the variable outside the function
836 # and pass in the value. This ensures that the function is pure and
837 # differentiable. TODO(apassos) this may have performance problems if
838 # the function will only do embedding lookups on the variable.
839 return var.value()
840 return var
842 def _create_op_internal(
843 self,
844 op_type,
845 inputs,
846 dtypes=None, # pylint: disable=redefined-outer-name
847 input_types=None,
848 name=None,
849 attrs=None,
850 op_def=None,
851 compute_device=True):
852 for i, x in enumerate(inputs):
853 if isinstance(x, ops.EagerTensor) or x.graph is not self:
854 inputs[i] = self.capture(x)
855 return super(_FuncGraph, self)._create_op_internal(
856 op_type,
857 inputs,
858 dtypes=dtypes,
859 input_types=input_types,
860 name=name,
861 attrs=attrs,
862 op_def=op_def,
863 compute_device=compute_device)
865 def capture(self, tensor, name=None):
866 """Adds the given tensor to this graph and returns the captured tensor."""
867 if tensor.ref() in self._captured:
868 # Captured already.
869 return self._captured[tensor.ref()]
870 elif self._capture_by_value:
871 return self._add_tensor_and_parents(tensor)
872 else:
873 return self._capture_tensor_as_extra_input(tensor, name)
875 @property
876 def captures(self):
877 """Pairs of tensors and captured tensor."""
878 return [(k.deref(), v) for k, v in self._captured.items()]
880 def _capture_tensor_as_extra_input(self, tensor, name=None):
881 # Substitute with a placeholder.
882 self.extra_inputs.append(tensor)
883 # Hoist the new input placeholder out of any control flow context
884 # we're currently in.
885 with ops.control_dependencies(None):
886 ph = array_ops.placeholder(
887 tensor.dtype, shape=tensor.get_shape(), name=name)
888 # pylint: disable=protected-access
889 if isinstance(tensor, ops.EagerTensor):
890 handle_data = tensor._handle_data
891 if handle_data:
892 handle_data = handle_data.SerializeToString()
893 else:
894 with tensor.graph._c_graph.get() as c_graph:
895 handle_data = c_api.GetHandleShapeAndType(c_graph,
896 tensor._as_tf_output())
898 if handle_data:
899 with ph.graph._c_graph.get() as c_graph:
900 c_api.SetHandleShapeAndType(c_graph, ph._as_tf_output(),
901 compat.as_bytes(handle_data))
902 # pylint: enable=protected-access
903 self.inputs.append(ph)
904 self._captured[tensor.ref()] = ph
905 self.extra_args.append(ph)
906 if _is_guaranteed_const(tensor):
907 with ops.control_dependencies(None):
908 return array_ops.guarantee_const(ph)
909 else:
910 return ph
912 def _add_tensor_and_parents(self, tensor):
913 op = self._add_op_and_parents(tensor.op)
914 return op.outputs[tensor.value_index]
916 def _add_op_and_parents(self, op):
917 # pylint: disable=protected-access
918 op_def = graph_to_function_def._get_op_def(op)
919 if op._is_stateful and op not in self._allowlisted_stateful_ops:
920 raise ValueError(f"Cannot capture a stateful node (name:{op.name}, "
921 f"type:{op.type}) by value.")
922 elif op.type in ("Placeholder", "PlaceholderV2"):
923 raise ValueError(f"Cannot capture a placeholder (name:{op.name}, "
924 f"type:{op.type}) by value.")
925 # pylint: enable=protected-access
927 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
929 captured_op = self._create_op_internal(
930 op.type,
931 captured_inputs, [o.dtype for o in op.outputs],
932 name=op.name,
933 attrs=op.node_def.attr,
934 op_def=op_def)
936 for t, captured_t in zip(op.outputs, captured_op.outputs):
937 self._captured[t.ref()] = captured_t
939 return captured_op
942def func_graph_from_py_func(func,
943 arg_names,
944 arg_types,
945 name=None,
946 capture_by_value=False,
947 device=None,
948 colocation_stack=None,
949 container=None,
950 collections_ref=None,
951 arg_shapes=None,
952 allowlisted_stateful_ops=None,
953 capture_resource_var_by_value=True):
954 """Returns a _FuncGraph generated from `func`.
956 Args:
957 func: A Python callable which constructs a TF function body. The arguments
958 must correspond to `arg_types`. Returns a value or list/tuple of values.
959 No returned value can be None.
960 arg_names: A sequence of strings for the function argument names.
961 arg_types: A sequence of the function's argument types.
962 name: The function name. If None, the name is derived from `func`.
963 capture_by_value: boolean. If True, captured values will be copied into the
964 function body.
965 device: device name or function.
966 colocation_stack: A colocation stack (list) the _FuncGraph should use.
967 container: A container name the _FuncGraph should start with.
968 collections_ref: A reference to a collections dict the _FuncGraph should
969 use internally.
970 arg_shapes: A sequence of the function's argument shapes.
971 allowlisted_stateful_ops: A set of ops that if stateful we ignore and
972 re-create.
973 capture_resource_var_by_value: Boolean (defaults to True). If False,
974 captured resource variable returns the handle instead of value.
976 Returns:
977 A _FuncGraph.
979 Raises:
980 ValueError: if func returns None.
981 """
982 if not name:
983 name = function_utils.get_func_name(func)
984 func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
985 capture_resource_var_by_value)
987 with func_graph.as_default(), ops.device(device):
988 # pylint: disable=protected-access
989 if collections_ref is not None:
990 func_graph._collections = collections_ref
991 if container is not None:
992 func_graph._container = container
993 if colocation_stack is not None:
994 func_graph._colocation_stack = colocation_stack
995 # pylint: enable=protected-access
997 if arg_shapes is None:
998 arg_shapes = [None] * len(arg_types)
1000 # Create placeholders for the function arguments.
1001 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
1002 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
1003 func_graph.inputs.append(argholder)
1004 # Call func and gather the output tensors.
1005 with vs.variable_scope("", custom_getter=func_graph.getvar):
1006 outputs = func(*func_graph.inputs)
1008 # There is no way of distinguishing between a function not returning
1009 # anything and a function returning None in Python.
1010 # We need to allow the former and ideally want to forbid the latter as
1011 # it is most likely user error.
1012 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
1013 # allow users to explicitly mark the function as not returning anything.
1014 # For now, we allow a single None return and interpret it as a function
1015 # with no output.
1016 if outputs is None:
1017 outputs = []
1018 else:
1019 # If func only returned one value, make it a tuple.
1020 if not isinstance(outputs, (list, tuple)):
1021 outputs = (outputs,)
1022 if any(_ is None for _ in outputs):
1023 raise ValueError(f"Function {name} can not return None.")
1024 # Ensures each output is a Tensor in the function graph.
1025 outputs = [ops.convert_to_tensor(t) for t in outputs]
1026 outputs = [func_graph.capture(t) if t.graph is not func_graph else t
1027 for t in outputs]
1028 func_graph.outputs = outputs
1029 return func_graph
1032def _is_guaranteed_const(tensor):
1033 """Determines whether `tensor` is guaranteed to be a constant.
1035 A tensor is guaranteed to be a constant if either it was produced by
1036 a `GuaranteeConst` op or if all of its children are guaranteed to be
1037 constants.
1039 Args:
1040 tensor: The tensor for which to determine const-ness.
1042 Returns:
1043 True if `tensor` is guaranteed to be a constant, False otherwise.
1044 """
1046 if isinstance(tensor, ops.EagerTensor):
1047 return False
1049 class Work(object):
1051 def __init__(self, op, leaving):
1052 self.op = op
1053 self.leaving = leaving
1055 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
1056 constants = set([])
1057 def all_inputs_const(op):
1058 # If all inputs of an op are guaranteed constants, then we can infer that
1059 # the op produces a constant as well.
1060 return op.inputs and all(inp.op in constants for inp in op.inputs)
1062 visited = set([])
1063 stack = [Work(tensor.op, leaving=False)]
1064 while stack:
1065 work = stack.pop()
1066 if work.leaving:
1067 if all_inputs_const(work.op):
1068 constants.add(work.op)
1069 continue
1070 visited.add(work.op)
1071 if is_guaranteed_const(work.op):
1072 constants.add(work.op)
1073 continue
1075 # This op will be revisited after all its inputs are checked for const-ness.
1076 stack.append(Work(work.op, leaving=True))
1077 for inp in work.op.inputs:
1078 if inp.op not in visited:
1079 stack.append(Work(inp.op, leaving=False))
1080 return tensor.op in constants
1083def _call(sig, *inputs, **kwargs):
1084 """Adds a node calling a function.
1086 This adds a `call` op to the default graph that calls the function
1087 of signature `sig`, passing the tensors in `inputs` as arguments.
1088 It returns the outputs of the call, which are one or more tensors.
1090 `sig` is OpDefArg.a `_DefinedFunction` object.
1092 You can pass an optional keyword parameter `name=string` to name the
1093 added operation.
1095 You can pass an optional keyword parameter `noinline=True|False` to
1096 instruct the runtime not to inline the function body into the call
1097 site.
1099 Args:
1100 sig: OpDefArg. The signature of the function.
1101 *inputs: arguments to the function.
1102 **kwargs: Optional keyword arguments. Can only contain 'name' or
1103 'noinline'.
1105 Returns:
1106 A 2-element tuple. First element: a Tensor if the function returns a single
1107 value; a list of Tensors if the function returns multiple value; the
1108 Operation if the function returns no values. Second element: the Operation.
1110 Raises:
1111 ValueError: if the arguments are invalid.
1112 """
1113 if len(inputs) != len(sig.input_arg):
1114 raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got "
1115 f"{len(inputs):d}.")
1116 name = kwargs.pop("name", None)
1117 g = ops.get_default_graph()
1118 func_name = sig.name
1119 if name is None:
1120 name = func_name
1121 attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
1122 output_types = [dtypes.DType(x.type) for x in sig.output_arg]
1123 op = g._create_op_internal( # pylint: disable=protected-access
1124 func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig)
1125 if op.outputs:
1126 if len(op.outputs) == 1:
1127 ret = op.outputs[0]
1128 else:
1129 ret = tuple(op.outputs)
1130 else:
1131 ret = op
1132 return ret, op
1135def _from_definition(fdef, grad_func=None):
1136 """Creates a _DefinedFunction initialized from a FunctionDef proto.
1138 Args:
1139 fdef: a FunctionDef
1140 grad_func: a _DefinedFunction or None
1142 Returns:
1143 A _DefinedFunction representing fdef
1144 """
1145 # TODO(iga): This method does major surgery on _DefinedFunction.
1146 # Make it a named constructor using @classmethod of _DefinedFunction.
1148 # The Python callable is only needed to create a FunctionDef. Since we have
1149 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
1150 # have access to such a callable here).
1151 func = None
1152 argnames = [arg.name for arg in fdef.signature.input_arg]
1153 input_types = tuple(
1154 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
1155 func_name = fdef.signature.name
1156 # Note: FunctionDefs do not include python gradient functions, so if the
1157 # original _DefinedFunction included one it will not be reflected here.
1158 python_grad_func = None
1159 out_names = [arg.name for arg in fdef.signature.output_arg]
1160 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
1161 python_grad_func, out_names)
1162 # pylint: disable=protected-access
1163 serialized = fdef.SerializeToString()
1164 c_func = c_api.TF_FunctionImportFunctionDef(serialized)
1165 result._c_func = c_api_util.ScopedTFFunction(c_func, func_name)
1166 result._extra_inputs = []
1167 result._op_def = fdef.signature
1168 # pylint: enable=protected-access
1170 return result
1173def from_library(lib):
1174 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
1176 This method handles assigning the correct gradient functions to each
1177 function.
1179 Args:
1180 lib: a FunctionDefLibrary
1182 Returns:
1183 A list of _DefinedFunctions
1185 Raises:
1186 ValueError: `lib` is invalid
1187 """
1188 if not lib.function and not lib.gradient:
1189 return []
1191 # function name -> FunctionDef proto
1192 funcs = {fdef.signature.name: fdef for fdef in lib.function}
1194 # Validate that all references function names have function defs
1195 for g in lib.gradient:
1196 if g.function_name not in funcs:
1197 raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' "
1198 f"FunctionDef\n{lib}")
1199 if g.gradient_func not in funcs:
1200 raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' "
1201 f"FunctionDef\n{lib}")
1203 # function name -> gradient function name
1204 func_to_grad = collections.defaultdict(lambda: None)
1205 # gradient function name -> names of functions having that grad function
1206 grad_to_funcs = collections.defaultdict(list)
1208 for gdef in lib.gradient:
1209 func_to_grad[gdef.function_name] = gdef.gradient_func
1210 grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
1212 # Start with functions without gradients
1213 ready = [
1214 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
1215 ]
1216 if not ready:
1217 raise ValueError(
1218 f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}")
1219 # function name -> _DefinedFunction
1220 initialized = {}
1222 while ready:
1223 fdef = ready.pop()
1224 name = fdef.signature.name
1226 grad = initialized.get(func_to_grad[name])
1227 if func_to_grad[name]:
1228 assert grad
1229 defined_func = _from_definition(fdef, grad_func=grad)
1230 initialized[name] = defined_func
1232 ready.extend(funcs[f] for f in grad_to_funcs[name])
1234 return initialized.values()
1237def _get_experimental_kwarg_as_attr(attr_name, value):
1238 """Creates an AttrValue for a python object."""
1239 if isinstance(value, bool):
1240 return attr_value_pb2.AttrValue(b=value)
1241 elif isinstance(value, int):
1242 return attr_value_pb2.AttrValue(i=value)
1243 elif isinstance(value, float):
1244 return attr_value_pb2.AttrValue(f=value)
1245 elif isinstance(value, str):
1246 return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1247 else:
1248 raise ValueError(f"Attribute {attr_name} must be bool, int, float, or "
1249 f"str. Got {type(value)}.")
1252def _get_kwarg_as_str_attr(attr_name, value):
1253 """Creates an AttrValue for a python object."""
1254 if isinstance(value, str):
1255 return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1256 else:
1257 raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.")
1260def _parse_kwargs_as_attrs(func_name, **kwargs):
1261 """Parses **kwargs into a node's attributes."""
1262 attrs = {}
1264 noinline = kwargs.pop("noinline", None)
1265 if noinline is not None:
1266 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
1268 # For compatibility with previous behavior, Defun does not perform shape
1269 # inference through its function call operations.
1270 attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True)
1272 compiled = kwargs.pop("compiled", None)
1273 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
1274 if compiled is not None:
1275 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
1276 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
1277 b=bool(separate_compiled_gradients))
1278 # Forward _XlaScope from enclosing context (if set), otherwise create new.
1279 # pylint: disable=protected-access
1280 if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
1281 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
1282 else:
1283 attrs["_XlaScope"] = attr_value_pb2.AttrValue(
1284 s=("function_%s" % func_name).encode())
1285 # pylint: enable=protected-access
1287 kwargs_keys = list(kwargs.keys())
1288 for key in kwargs_keys:
1289 if key.startswith("experimental_"):
1290 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
1291 del kwargs[key]
1292 # Support for https://github.com/tensorflow/community/pull/113/files.
1293 elif key == "_implements" or key == "_reference":
1294 attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key])
1295 del kwargs[key]
1296 if kwargs:
1297 raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.")
1298 return attrs
1301def get_extra_vars():
1302 """Returns the captured variables by the function.
1304 Returns:
1305 If the default graph is being used to define a function, the
1306 returned list of variables are those created inside the function
1307 body so far. Otherwise, returns an empty list.
1308 """
1309 g = ops.get_default_graph()
1310 if isinstance(g, _FuncGraph):
1311 return g.extra_vars
1312 else:
1313 return []
1316def get_extra_inputs():
1317 """Returns the captured input tensors by the function.
1319 Returns:
1320 If the default graph is being used to define a function, the
1321 returned list of tensors are those accessed inside the function body
1322 but defined outside the function body so far. Otherwise, returns an
1323 empty list.
1324 """
1325 g = ops.get_default_graph()
1326 if isinstance(g, _FuncGraph):
1327 return g.extra_inputs
1328 else:
1329 return []
1332def get_extra_args():
1333 """Returns the corresponding function arguments for the captured inputs.
1335 Returns:
1336 If the default graph is being used to define a function, the
1337 returned list of place holders are those used inside the function
1338 body corresponding those returned by get_extra_inputs(). Otherwise,
1339 returns an empty list.
1340 """
1341 g = ops.get_default_graph()
1342 if isinstance(g, _FuncGraph):
1343 return g.extra_args
1344 else:
1345 return []
1348def _type_list_to_str(types):
1349 if any(_ not in _DTYPE_TO_STR for _ in types):
1350 unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR]
1351 raise ValueError(f"Unsupported dtypes {unsupported_types} in "
1352 "`types`. Supported dtypes are "
1353 f"{_DTYPE_TO_STR.keys()}.")
1354 return "".join(_DTYPE_TO_STR[_] for _ in types)
1357# NOTE: The list needs to be extended when more data types are added.
1358_DTYPE_TO_STR = {
1359 dtypes.float16: "f16",
1360 dtypes.float32: "f32",
1361 dtypes.float64: "f64",
1362 dtypes.int32: "i32",
1363 dtypes.uint8: "i8",
1364 dtypes.uint16: "u16",
1365 dtypes.uint32: "u32",
1366 dtypes.uint64: "u64",
1367 dtypes.int16: "i16",
1368 dtypes.int8: "i8",
1369 dtypes.string: "s",
1370 dtypes.complex64: "c64",
1371 dtypes.complex128: "c128",
1372 dtypes.int64: "i64",
1373 dtypes.bool: "b",
1374 dtypes.qint8: "qi8",
1375 dtypes.quint8: "qu8",
1376 dtypes.qint16: "qi16",
1377 dtypes.quint16: "qu16",
1378 dtypes.qint32: "qi32",
1379 dtypes.bfloat16: "b16",
1380 dtypes.float8_e5m2: "f8e5m2",
1381 dtypes.float8_e4m3fn: "f8e4m3fn"
1382}