Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py: 32%
228 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Script Language Operators."""
17# pylint: disable=g-bad-name
18import threading
20# Used by py_util.cc to get tracebacks.
21import traceback # pylint: disable=unused-import
22import weakref
24import numpy as np
26from tensorflow.python.autograph.impl import api as autograph
27from tensorflow.python.eager import backprop
28from tensorflow.python.eager import backprop_util
29from tensorflow.python.eager import context
30from tensorflow.python.eager import record
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import func_graph
35from tensorflow.python.framework import function
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import type_spec
39from tensorflow.python.lib.core import _pywrap_py_func
40from tensorflow.python.ops import autograph_ops # pylint: disable=unused-import
41from tensorflow.python.ops import gen_script_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.util import compat
44from tensorflow.python.util import deprecation
45from tensorflow.python.util import dispatch
46from tensorflow.python.util import nest
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util import variable_utils
49from tensorflow.python.util.tf_export import tf_export
52# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
53# used for differentiation.
54tape_cache = {}
57def _maybe_copy_to_context_device(tensor, device_name):
58 """Copy an EagerTensor to the current device if it's not on `device_name`."""
59 in_device = tensor.backing_device
60 if device_name == in_device:
61 return tensor
62 else:
63 # Note that EagerTensor._copy bypasses the placer and copies to the context
64 # device, which means e.g. int32 Tensors which would normally be forced onto
65 # the CPU can instead be placed on the GPU. This is necessary so that the
66 # PyFunc kernel always returns Tensors on the device it's executing on.
67 return tensor._copy() # pylint: disable=protected-access
70class EagerFunc:
71 """A wrapper for a function owned by an EagerPyFunc."""
73 def __init__(self, func, Tout, is_grad_func):
74 """Constructs an EagerFunc.
76 Args:
77 func: The function to wrap.
78 Tout: A list of datatypes for the output; an empty list if the output is
79 None.
80 is_grad_func: Whether this EagerFunc is the gradient of another
81 EagerPyFunc.
82 """
83 self._func = func
84 self._out_dtypes = Tout
85 self._is_grad_func = is_grad_func
86 self._support_graph_mode_gradient = False
88 def set_support_graph_mode_gradient(self):
89 """Indicates the object shall support gradient ops.
91 This function is internally used by _EagerPyFuncGrad to support
92 graph mode gradient of EagerFunc via tf.gradient().
93 """
94 self._support_graph_mode_gradient = True
96 def _convert(self, value, dtype):
97 """Converts `value` to a tensor of type `dtype`, with error checking.
99 Args:
100 value: The tensor to convert.
101 dtype: The desired dtype.
103 Returns:
104 A tensor of type `dtype`, or a zeros tensor if value is None and
105 this function is in fact a gradient function.
107 Raises:
108 RuntimeError: if `value` is a variable.
109 """
111 if isinstance(value, resource_variable_ops.ResourceVariable):
112 raise RuntimeError(
113 "Attempting to return a variable from an eagerly executed py_func. "
114 "Only numeric data structures like Tensors or NumPy arrays should "
115 "be returned; to return the value of a variable, make sure to obtain "
116 "the Tensor backing it by calling `.read_value()` on the variable in "
117 f"question: {value}")
118 if value is None and self._is_grad_func:
119 # Gradient functions may legitimately return a list that contains
120 # both Tensors and Python Nones. Unfortunately this breaks the
121 # OpKernel, so for now we replace None objects with zeros, which is
122 # mathematically correct but will prevent short-circuiting gradient
123 # computations.
124 #
125 # TODO(akshayka): Make it possible to return a list of both Tensors and
126 # Nones from an EagerPyFunc.
127 return constant_op.constant(0.0, dtype=dtype)
128 return ops.convert_to_tensor(value, dtype=dtype)
130 def __call__(self, device, token, args):
131 """Calls `self._func` in eager mode, recording the tape if needed."""
132 use_tape_cache = (
133 self._support_graph_mode_gradient or record.could_possibly_record())
135 if use_tape_cache:
136 with backprop.GradientTape() as tape:
137 for tensor in args:
138 for t in nest.flatten(tensor):
139 if backprop_util.IsTrainable(t):
140 tape.watch(t)
141 outputs = self._call(device, args)
142 tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
143 else:
144 outputs = self._call(device, args)
146 return outputs
148 def _call(self, device, args):
149 """Passes `args` to `self._func`, which is executed eagerly."""
150 with context.eager_mode():
151 ret = self._func(*args)
152 # copy the returned tensors to the PyFunc op's device if necessary.
153 device_name = device
154 if device_name is None:
155 # "None" here means "CPU", from the nullptr convention with C++ device
156 # pointers.
157 device_name = "/job:localhost/replica:0/task:0/device:CPU:0"
158 with ops.device(device):
159 if isinstance(ret, (tuple, list)):
160 outputs = [
161 _maybe_copy_to_context_device(self._convert(x, dtype=dtype),
162 device_name)
163 for (x, dtype) in zip(ret, self._out_dtypes)
164 ]
165 elif ret is None:
166 outputs = None
167 else:
168 outputs = _maybe_copy_to_context_device(
169 self._convert(ret, dtype=self._out_dtypes[0]), device_name)
170 return outputs
173class FuncRegistry:
174 """A helper class to keep track of registered py functions.
176 FuncRegistry keeps a map from unique tokens (string) to python
177 functions, which takes numpy arrays and outputs numpy arrays.
178 """
180 def __init__(self):
181 self._lock = threading.Lock()
182 self._unique_id = 0 # GUARDED_BY(self._lock)
183 # Only store weakrefs to the functions. The strong reference is stored in
184 # the graph.
185 self._funcs = weakref.WeakValueDictionary()
187 @property
188 def _ctx(self):
189 # N.B. This is needed to support calling py_func with GPU tensors,
190 # which must be transferred to CPU if used in any of the NumPy APIs.
191 context.ensure_initialized()
192 return context.context()._handle # pylint: disable=protected-access
194 def insert(self, func):
195 """Registers `func` and returns a unique token for this entry."""
196 token = self._next_unique_token()
197 # Store a weakref to the function
198 self._funcs[token] = func
199 return token
201 def remove(self, token):
202 """Removes the registered function corresponding to `token`."""
203 self._funcs.pop(token, None)
205 def get(self, token, default=None):
206 """Gets the registered function corresponding to `token`."""
207 return self._funcs.get(token, default)
209 @staticmethod
210 def _convert(value, dtype=None):
211 """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
213 Numpy pads with zeros when using string and unicode dtypes if different
214 components of a tensor have different lengths. This is bad: ignoring the
215 padding is wrong for text data, and removing the padding is wrong for binary
216 data. To avoid this bug, we redo the conversion using an object dtype.
217 Additionally, we convert unicode strings to (byte-)strings for
218 compatibility.
220 Args:
221 value: Value to convert to a numpy array.
222 dtype: (Optional.) Desired NumPy type for the returned value.
224 Returns:
225 A numpy array.
226 """
227 result = np.asarray(value, dtype=dtype, order="C")
228 if result.dtype.char == "S" and result is not value:
229 return np.asarray(value, order="C", dtype=object)
230 elif result.dtype.char == "U" and result is not value:
231 value = np.vectorize(lambda x: x.encode("utf8"))(value)
232 return np.asarray(value, order="C", dtype=object)
233 elif result.dtype.char == "U":
234 return result.astype(np.bytes_)
235 else:
236 return result
238 def __call__(self, token, device, args):
239 """Calls the registered function for `token` with args.
241 Args:
242 token: A key into this `FuncRegistry` identifying which function to call.
243 device: Name of the device on which outputs of `token`'s corresponding
244 operation should be placed. Used iff the function registered for `token`
245 is an EagerPyFunc.
246 args: The arguments to pass to the function registered for `token`.
248 Returns:
249 The output of the function registered for `token`.
251 Raises:
252 ValueError: if no function is registered for `token`.
253 """
254 func = self.get(token, None)
255 if func is None:
256 raise ValueError(f"Could not find callback with key={token} in the "
257 "registry.")
258 if isinstance(func, EagerFunc):
259 # NB: Different invocations of the same py_func will share the same
260 # token, and the entries they stash in the tape_cache will collide.
261 # In practice, when executing a graph, this should only happen if
262 # the py_func is in a while_loop whose iterations are run in parallel
263 # or if the graph is being driven by concurrent session.run() calls.
264 #
265 # TODO(akshayka): Key the tape cache in a thread-safe way.
266 return func(device, token, args)
267 else:
268 ret = func(*args)
269 # Strings seem to lead to a memory leak here if they're not wrapped in a
270 # list.
271 if isinstance(ret, bytes):
272 ret = [ret]
273 # Ensures that we return either a single numpy array or a list of numpy
274 # arrays.
275 if isinstance(ret, (tuple, list)):
276 return [self._convert(x) for x in ret]
277 else:
278 return self._convert(ret)
280 def size(self):
281 """Returns how many functions are currently registered."""
282 return len(self._funcs)
284 def _next_unique_token(self):
285 """Returns a unique token."""
286 with self._lock:
287 uid = self._unique_id
288 self._unique_id += 1
289 return "pyfunc_%d" % uid
292# Global registry for py functions.
293_py_funcs = FuncRegistry()
295_pywrap_py_func.initialize_py_trampoline(_py_funcs)
298def _internal_py_func(func,
299 inp,
300 Tout,
301 stateful=None,
302 use_eager_py_func=False,
303 is_grad_func=False,
304 name=None):
305 """See documentation for py_func and eager_py_func."""
306 if not callable(func):
307 raise ValueError(
308 f"Expected func to be callable. Received func={func} of type "
309 f"{type(func)}.")
311 original_func = func
312 func = autograph.do_not_convert(func)
313 inp = variable_utils.convert_variables_to_tensors(list(inp))
315 # Normalize Tout.
316 is_list_or_tuple = isinstance(Tout, (list, tuple))
317 Tout = Tout if is_list_or_tuple else [Tout]
318 Tout = [_as_dtype_or_type_spec(t) for t in Tout]
320 # Check if we need to handle CompositeTensor inputs or outputs.
321 handle_composite_tensors = (
322 use_eager_py_func and
323 (any(isinstance(v, composite_tensor.CompositeTensor) for v in inp) or
324 any(isinstance(t, type_spec.TypeSpec) for t in Tout)))
325 if handle_composite_tensors:
326 func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout)
328 if use_eager_py_func:
329 func = EagerFunc(func, Tout, is_grad_func)
331 # Tying the registered function's lifetime with the current default graph is
332 # not reliable. For example, Estimator-based binaries may switch graphs in
333 # between model training end evaluation, via saved_model. Those binaries work
334 # because the original function is global, and break once the registered
335 # function is an anonymous lambda, like the one produced by do_not_convert.
336 # To avoid breaking those cases, we attach the wrapper to the original
337 # function so that their lifetime is connected.
338 # TODO(b/144286616): Remove this.
339 if tf_inspect.isfunction(original_func):
340 # Note: this check is needed because original_func may be a descriptor
341 # (https://docs.python.org/3/howto/descriptor.html)
342 # and we can't attach attributes to those.
343 original_func.ag_dnc_wrapper__ = func
345 token = _py_funcs.insert(func)
346 # We tie the registered function's lifetime with the current default graph,
347 # i.e., when the current graph is destroyed, we remove its py funcs.
348 graph = ops.get_default_graph()
350 while True:
351 current_graph = graph
352 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access
353 graph = graph._outer_graph # pylint: disable=protected-access
354 elif isinstance(graph, func_graph.FuncGraph):
355 graph = graph.outer_graph
356 if graph is current_graph:
357 break
359 # TODO(zhifengc): Consider adding a Graph method to collect
360 # `cleanup` objects in one of its member.
361 if not hasattr(graph, "_py_funcs_used_in_graph"):
362 graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access
364 # Store a reference to the function in the graph to ensure it stays alive
365 # as long as the graph lives. When the graph is destroyed, the function
366 # is left to the garbage collector for destruction as well.
367 graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access
369 if use_eager_py_func:
370 result = gen_script_ops.eager_py_func(
371 input=inp,
372 token=token,
373 is_async=context.is_async(),
374 Tout=Tout,
375 name=name)
376 else:
377 if stateful:
378 result = gen_script_ops.py_func(
379 input=inp, token=token, Tout=Tout, name=name)
380 else:
381 result = gen_script_ops.py_func_stateless(
382 input=inp, token=token, Tout=Tout, name=name)
384 if handle_composite_tensors and Tout:
385 result = nest.pack_sequence_as(
386 out_structure, result, expand_composites=True)
388 return result if is_list_or_tuple else result[0]
391# TODO(akshayka): Implement higher-order derivatives.
392@ops.RegisterGradient("EagerPyFunc")
393def _EagerPyFuncGrad(op, *dy):
394 """Computes the gradient of an EagerPyFunc."""
396 token = op.get_attr("token")
398 def eagerly_executed_grad(*dy):
399 tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
400 return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
402 with ops.control_dependencies(op.outputs):
403 gradient_op = _internal_py_func(
404 func=eagerly_executed_grad,
405 inp=dy,
406 Tout=[tensor.dtype for tensor in op.inputs],
407 use_eager_py_func=True,
408 is_grad_func=True)
410 if not context.executing_eagerly():
411 # In graph mode, we find the func object from its token and
412 # notify the eager func object it needs to support the gradients.
413 func = _py_funcs.get(token.decode())
414 assert isinstance(func, EagerFunc), (
415 f"EagerPyFuncGrad called on a non-EagerFunc object: {func}.")
416 func.set_support_graph_mode_gradient()
417 return gradient_op
420@tf_export("py_function")
421@dispatch.add_dispatch_support
422def eager_py_func(func, inp, Tout, name=None):
423 """Wraps a python function into a TensorFlow op that executes it eagerly.
425 This function allows expressing computations in a TensorFlow graph as
426 Python functions. In particular, it wraps a Python function `func`
427 in a once-differentiable TensorFlow operation that executes it with eager
428 execution enabled. As a consequence, `tf.py_function` makes it
429 possible to express control flow using Python constructs (`if`, `while`,
430 `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`,
431 `tf.while_loop`). For example, you might use `tf.py_function` to
432 implement the log huber function:
434 ```python
435 def log_huber(x, m):
436 if tf.abs(x) <= m:
437 return x**2
438 else:
439 return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
441 x = tf.constant(1.0)
442 m = tf.constant(2.0)
444 with tf.GradientTape() as t:
445 t.watch([x, m])
446 y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
448 dy_dx = t.gradient(y, x)
449 assert dy_dx.numpy() == 2.0
450 ```
452 You can also use `tf.py_function` to debug your models at runtime
453 using Python tools, i.e., you can isolate portions of your code that
454 you want to debug, wrap them in Python functions and insert `pdb` tracepoints
455 or print statements as desired, and wrap those functions in
456 `tf.py_function`.
458 For more information on eager execution, see the
459 [Eager guide](https://tensorflow.org/guide/eager).
461 `tf.py_function` is similar in spirit to `tf.compat.v1.py_func`, but unlike
462 the latter, the former lets you use TensorFlow operations in the wrapped
463 Python function. In particular, while `tf.compat.v1.py_func` only runs on CPUs
464 and wraps functions that take NumPy arrays as inputs and return NumPy arrays
465 as outputs, `tf.py_function` can be placed on GPUs and wraps functions
466 that take Tensors as inputs, execute TensorFlow operations in their bodies,
467 and return Tensors as outputs.
469 Note: We recommend to avoid using `tf.py_function` outside of prototyping
470 and experimentation due to the following known limitations:
472 * Calling `tf.py_function` will acquire the Python Global Interpreter Lock
473 (GIL) that allows only one thread to run at any point in time. This will
474 preclude efficient parallelization and distribution of the execution of the
475 program.
477 * The body of the function (i.e. `func`) will not be serialized in a
478 `GraphDef`. Therefore, you should not use this function if you need to
479 serialize your model and restore it in a different environment.
481 * The operation must run in the same address space as the Python program
482 that calls `tf.py_function()`. If you are using distributed
483 TensorFlow, you must run a `tf.distribute.Server` in the same process as the
484 program that calls `tf.py_function()` and you must pin the created
485 operation to a device in that server (e.g. using `with tf.device():`).
487 * Currently `tf.py_function` is not compatible with XLA. Calling
488 `tf.py_function` inside `tf.function(jit_compile=True)` will raise an
489 error.
491 Args:
492 func: A Python function that accepts `inp` as arguments, and returns a
493 value (or list of values) whose type is described by `Tout`.
495 inp: Input arguments for `func`. A list whose elements are `Tensor`s or
496 `CompositeTensors` (such as `tf.RaggedTensor`); or a single `Tensor` or
497 `CompositeTensor`.
499 Tout: The type(s) of the value(s) returned by `func`. One of the
500 following.
502 * If `func` returns a `Tensor` (or a value that can be converted to a
503 Tensor): the `tf.DType` for that value.
504 * If `func` returns a `CompositeTensor`: The `tf.TypeSpec` for that value.
505 * If `func` returns `None`: the empty list (`[]`).
506 * If `func` returns a list of `Tensor` and `CompositeTensor` values:
507 a corresponding list of `tf.DType`s and `tf.TypeSpec`s for each value.
509 name: A name for the operation (optional).
511 Returns:
512 The value(s) computed by `func`: a `Tensor`, `CompositeTensor`, or list of
513 `Tensor` and `CompositeTensor`; or an empty list if `func` returns `None`.
514 """
515 if ops.executing_eagerly_outside_functions():
516 with ops.device(context.context().host_address_space()):
517 return _internal_py_func(
518 func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name)
520 return _internal_py_func(
521 func=func, inp=inp, Tout=Tout, use_eager_py_func=True, name=name)
524def py_func_common(func, inp, Tout, stateful=True, name=None):
525 """Wraps a python function and uses it as a TensorFlow op.
527 Given a python function `func`, which takes numpy arrays as its
528 arguments and returns numpy arrays as its outputs, wrap this function as an
529 operation in a TensorFlow graph. The following snippet constructs a simple
530 TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
531 in the graph:
533 ```python
534 def my_func(x):
535 # x will be a numpy array with the contents of the placeholder below
536 return np.sinh(x)
537 input = tf.compat.v1.placeholder(tf.float32)
538 y = tf.compat.v1.py_func(my_func, [input], tf.float32)
539 ```
541 **N.B.** The `tf.compat.v1.py_func()` operation has the following known
542 limitations:
544 * The body of the function (i.e. `func`) will not be serialized in a
545 `GraphDef`. Therefore, you should not use this function if you need to
546 serialize your model and restore it in a different environment.
548 * The operation must run in the same address space as the Python program
549 that calls `tf.compat.v1.py_func()`. If you are using distributed
550 TensorFlow, you
551 must run a `tf.distribute.Server` in the same process as the program that
552 calls
553 `tf.compat.v1.py_func()` and you must pin the created operation to a device
554 in that
555 server (e.g. using `with tf.device():`).
557 Note: It produces tensors of unknown shape and rank as shape inference
558 does not work on arbitrary Python code.
559 If you need the shape, you need to set it based on statically
560 available information.
562 E.g.
563 ```python
564 import tensorflow as tf
565 import numpy as np
567 def make_synthetic_data(i):
568 return np.cast[np.uint8](i) * np.ones([20,256,256,3],
569 dtype=np.float32) / 10.
571 def preprocess_fn(i):
572 ones = tf.py_function(make_synthetic_data,[i],tf.float32)
573 ones.set_shape(tf.TensorShape([None, None, None, None]))
574 ones = tf.image.resize(ones, [224,224])
575 return ones
577 ds = tf.data.Dataset.range(10)
578 ds = ds.map(preprocess_fn)
579 ```
581 Args:
582 func: A Python function, which accepts `ndarray` objects as arguments and
583 returns a list of `ndarray` objects (or a single `ndarray`). This function
584 must accept as many arguments as there are tensors in `inp`, and these
585 argument types will match the corresponding `tf.Tensor` objects in `inp`.
586 The returns `ndarray`s must match the number and types defined `Tout`.
587 Important Note: Input and output numpy `ndarray`s of `func` are not
588 guaranteed to be copies. In some cases their underlying memory will be
589 shared with the corresponding TensorFlow tensors. In-place modification
590 or storing `func` input or return values in python datastructures
591 without explicit (np.)copy can have non-deterministic consequences.
592 inp: A list of `Tensor` objects.
593 Tout: A list or tuple of tensorflow data types or a single tensorflow data
594 type if there is only one, indicating what `func` returns.
595 stateful: (Boolean.) If True, the function should be considered stateful. If
596 a function is stateless, when given the same input it will return the same
597 output and have no observable side effects. Optimizations such as common
598 subexpression elimination are only performed on stateless operations.
599 name: A name for the operation (optional).
601 Returns:
602 A list of `Tensor` or a single `Tensor` which `func` computes.
604 @compatibility(TF2)
606 This name was deprecated and removed in TF2, but `tf.numpy_function` is a
607 near-exact replacement, just drop the `stateful` argument (all
608 `tf.numpy_function` calls are considered stateful). It is compatible with
609 eager execution and `tf.function`.
611 `tf.py_function` is a close but not an exact replacement, passing TensorFlow
612 tensors to the wrapped function instead of NumPy arrays, which provides
613 gradients and can take advantage of accelerators.
615 Before:
617 >>> def fn_using_numpy(x):
618 ... x[0] = 0.
619 ... return x
620 >>> tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])],
621 ... Tout=tf.float32, stateful=False)
622 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
624 After:
626 >>> tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])],
627 ... Tout=tf.float32)
628 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
630 @end_compatibility
632 """
633 if context.executing_eagerly():
634 result = func(*[np.array(x) for x in inp])
635 result = nest.flatten(result)
637 result = [x if x is None else ops.convert_to_tensor(x) for x in result]
638 if len(result) == 1:
639 # Mimic the automatic unwrapping in graph-mode py_func
640 result, = result
641 return result
643 if ops.executing_eagerly_outside_functions():
644 with ops.device(context.context().host_address_space()):
645 return _internal_py_func(
646 func=func,
647 inp=inp,
648 Tout=Tout,
649 stateful=stateful,
650 use_eager_py_func=False,
651 name=name)
653 return _internal_py_func(
654 func=func,
655 inp=inp,
656 Tout=Tout,
657 stateful=stateful,
658 use_eager_py_func=False,
659 name=name)
662@deprecation.deprecated(
663 date=None,
664 instructions="""tf.py_func is deprecated in TF V2. Instead, there are two
665 options available in V2.
666 - tf.py_function takes a python function which manipulates tf eager
667 tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
668 an ndarray (just call tensor.numpy()) but having access to eager tensors
669 means `tf.py_function`s can use accelerators such as GPUs as well as
670 being differentiable using a gradient tape.
671 - tf.numpy_function maintains the semantics of the deprecated tf.py_func
672 (it is not differentiable, and manipulates numpy arrays). It drops the
673 stateful argument making all functions stateful.
674 """)
675@tf_export(v1=["py_func"])
676@dispatch.add_dispatch_support
677def py_func(func, inp, Tout, stateful=True, name=None):
678 return py_func_common(func, inp, Tout, stateful, name=name)
681py_func.__doc__ = "%s" % py_func_common.__doc__
684@tf_export("numpy_function")
685@dispatch.add_dispatch_support
686def numpy_function(func, inp, Tout, stateful=True, name=None):
687 """Wraps a python function and uses it as a TensorFlow op.
689 Given a python function `func` wrap this function as an operation in a
690 TensorFlow function. `func` must take numpy arrays as its arguments and
691 return numpy arrays as its outputs.
693 The following example creates a TensorFlow graph with `np.sinh()` as an
694 operation in the graph:
696 >>> def my_numpy_func(x):
697 ... # x will be a numpy array with the contents of the input to the
698 ... # tf.function
699 ... return np.sinh(x)
700 >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
701 ... def tf_function(input):
702 ... y = tf.numpy_function(my_numpy_func, [input], tf.float32)
703 ... return y * y
704 >>> tf_function(tf.constant(1.))
705 <tf.Tensor: shape=(), dtype=float32, numpy=1.3810978>
707 Comparison to `tf.py_function`:
708 `tf.py_function` and `tf.numpy_function` are very similar, except that
709 `tf.numpy_function` takes numpy arrays, and not `tf.Tensor`s. If you want the
710 function to contain `tf.Tensors`, and have any TensorFlow operations executed
711 in the function be differentiable, please use `tf.py_function`.
713 Note: We recommend to avoid using `tf.numpy_function` outside of
714 prototyping and experimentation due to the following known limitations:
716 * Calling `tf.numpy_function` will acquire the Python Global Interpreter Lock
717 (GIL) that allows only one thread to run at any point in time. This will
718 preclude efficient parallelization and distribution of the execution of the
719 program. Therefore, you are discouraged to use `tf.numpy_function` outside
720 of prototyping and experimentation.
722 * The body of the function (i.e. `func`) will not be serialized in a
723 `tf.SavedModel`. Therefore, you should not use this function if you need to
724 serialize your model and restore it in a different environment.
726 * The operation must run in the same address space as the Python program
727 that calls `tf.numpy_function()`. If you are using distributed
728 TensorFlow, you must run a `tf.distribute.Server` in the same process as the
729 program that calls `tf.numpy_function` you must pin the created
730 operation to a device in that server (e.g. using `with tf.device():`).
732 * Currently `tf.numpy_function` is not compatible with XLA. Calling
733 `tf.numpy_function` inside `tf.function(jit_compile=True)` will raise an
734 error.
736 * Since the function takes numpy arrays, you cannot take gradients
737 through a numpy_function. If you require something that is differentiable,
738 please consider using tf.py_function.
740 Args:
741 func: A Python function, which accepts `numpy.ndarray` objects as arguments
742 and returns a list of `numpy.ndarray` objects (or a single
743 `numpy.ndarray`). This function must accept as many arguments as there are
744 tensors in `inp`, and these argument types will match the corresponding
745 `tf.Tensor` objects in `inp`. The returns `numpy.ndarray`s must match the
746 number and types defined `Tout`.
747 Important Note: Input and output `numpy.ndarray`s of `func` are not
748 guaranteed to be copies. In some cases their underlying memory will be
749 shared with the corresponding TensorFlow tensors. In-place modification
750 or storing `func` input or return values in python datastructures
751 without explicit (np.)copy can have non-deterministic consequences.
752 inp: A list of `tf.Tensor` objects.
753 Tout: A list or tuple of tensorflow data types or a single tensorflow data
754 type if there is only one, indicating what `func` returns.
755 stateful: (Boolean.) Setting this argument to False tells the runtime to
756 treat the function as stateless, which enables certain optimizations.
757 A function is stateless when given the same input it will return the
758 same output and have no side effects; its only purpose is to have a
759 return value.
760 The behavior for a stateful function with the `stateful` argument False
761 is undefined. In particular, caution should be taken when
762 mutating the input arguments as this is a stateful operation.
763 name: (Optional) A name for the operation.
765 Returns:
766 Single or list of `tf.Tensor` which `func` computes.
767 """
768 return py_func_common(func, inp, Tout, stateful=stateful, name=name)
771def _as_dtype_or_type_spec(t):
772 return t if isinstance(t, type_spec.TypeSpec) else dtypes.as_dtype(t)
775def _wrap_for_composites(func, inp, Tout):
776 """Wraps user inputs to support composite tensors for `py_function`.
778 1. Flattens `inp` to a list of Tensors (by flattening any composite tensors).
779 2. Creates a wrapper fuction for `func` that expects flat inputs and:
780 - Packs the inputs into the input structure expected by `func`.
781 - Calls `func` with the packed inputs.
782 - Checks that `func`'s output matches `Tout`.
783 - Flattens func`'s output to a list of Tensors (flattening any composite
784 tensors).
786 Args:
787 func: The function to wrap (`func` argument to `py_function`).
788 inp: The input arguments for func (`inp` argument to `py_function`).
789 Tout: The expected output types for func (`Tout` argument to `py_function).
791 Returns:
792 A tuple `(func, inp, Tout, out_structure)`, where `func` is the wrapped
793 function, `inp` is the flattened inputs, `Tout` is the list of expected
794 dtypes for the flattened outputs, and `out_structure` is the expected
795 output structure (which can be used to pack the output tensors).
796 """
797 in_structure = [
798 v if isinstance(v, composite_tensor.CompositeTensor) else 1 for v in inp
799 ]
800 inp = nest.flatten_up_to(in_structure, inp, expand_composites=True)
801 out_structure = Tout
802 Tout = [
803 v.dtype if isinstance(v, tensor_spec.TensorSpec) else v
804 for v in nest.flatten(Tout, expand_composites=True)
805 ]
807 def wrapped_func(*flat_inp):
808 structured_inp = nest.pack_sequence_as(
809 in_structure, flat_inp, expand_composites=True)
810 out = func(*structured_inp)
811 if not out_structure:
812 return [] # Ignore return value if none is requested/expected.
813 if not isinstance(out, (list, tuple)):
814 out = [out] # func may return a single value instead of a list.
815 flat_out = []
816 for elt, expected_type in zip(out, out_structure):
817 if (isinstance(expected_type, type_spec.TypeSpec) and
818 not isinstance(expected_type, tensor_spec.TensorSpec)):
819 if not expected_type.is_compatible_with(elt):
820 # pylint: disable=protected-access
821 raise ValueError(
822 f"py_function: func={func} returned {out!r}, "
823 f"which did not match Tout={out_structure!r}.\nIn particular, "
824 f"{elt!r} is not compatible with {expected_type!r}.")
825 flat_out.extend(nest.flatten(elt, expand_composites=True))
826 else:
827 # Pro-actively check if the return value is a composite tensor when
828 # we expect a Tensor. We would catch this later (when we call
829 # convert_to_tensor), but checking it here lets us give a better
830 # error message.
831 if isinstance(elt, composite_tensor.CompositeTensor):
832 raise ValueError(
833 f"py_function: func={func} returned {out!r}, "
834 f"which did not match Tout={out_structure!r}.\nIn particular, "
835 f"{elt!r} is not a Tensor.")
836 flat_out.append(elt)
837 return flat_out
839 return wrapped_func, inp, Tout, out_structure
842ops.NotDifferentiable("PyFunc")
843ops.NotDifferentiable("PyFuncStateless")