Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py: 71%
397 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type-based dispatch for TensorFlow's Python APIs.
17"Python APIs" refers to Python functions that have been exported with
18`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also
19referred to as "ops".
21There are currently two dispatch systems for TensorFlow:
23 * The "fallback dispatch" system calls an API's standard implementation first,
24 and only tries to perform dispatch if that standard implementation raises a
25 TypeError (or ValueError) exception.
27 * The "type-based dispatch" system checks the types of the parameters passed
28 to an API, and performs dispatch if those types match any signatures that
29 have been registered for dispatch.
31The fallback dispatch system was the original dispatch system, but it was
32somewhat brittle and had limitations, such as an inability to support dispatch
33for some operations (like convert_to_tensor). We plan to remove the fallback
34dispatch system in favor of the type-based dispatch system, once all users have
35been switched over to use it.
37### Fallback Dispatch
39The fallback dispatch system is based on "operation dispatchers", which can be
40used to override the behavior for TensorFlow ops when they are called with
41otherwise unsupported argument types. In particular, when an operation is
42called with arguments that would cause it to raise a TypeError, it falls back on
43its registered operation dispatchers. If any registered dispatchers can handle
44the arguments, then its result is returned. Otherwise, the original TypeError is
45raised.
47### Type-based Dispatch
49The main interface for the type-based dispatch system is the `dispatch_for_api`
50decorator, which overrides the default implementation for a TensorFlow API.
51The decorated function (known as the "dispatch target") will override the
52default implementation for the API when the API is called with parameters that
53match a specified type signature.
55### Dispatch Support
57By default, dispatch support is added to the generated op wrappers for any
58visible ops by default. APIs/ops that are implemented in Python can opt in to
59dispatch support using the `add_dispatch_support` decorator.
60"""
62import collections
63import itertools
64import typing # pylint: disable=unused-import (used in doctests)
66from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher
67from tensorflow.python.framework import ops
68from tensorflow.python.util import tf_decorator
69from tensorflow.python.util import tf_export as tf_export_lib
70from tensorflow.python.util import tf_inspect
71from tensorflow.python.util import traceback_utils
72from tensorflow.python.util import type_annotations
73from tensorflow.python.util.tf_export import tf_export
76# Private function attributes used to store dispatchers on TensorFlow APIs.
77FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers"
78TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher"
80# OpDispatchers which should be used for all operations.
81_GLOBAL_DISPATCHERS = []
84################################################################################
85# Fallback Dispatch
86################################################################################
89@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
90class OpDispatcher(object):
91 """Abstract base class for TensorFlow operator dispatchers.
93 Each operation dispatcher acts as an override handler for a single
94 TensorFlow operation, and its results are used when the handler indicates
95 that it can handle the operation's arguments (by returning any value other
96 than `OpDispatcher.NOT_SUPPORTED`).
97 """
99 # Sentinel value that can be returned to indicate that an operation
100 # dispatcher does not support a given set of arguments.
101 NOT_SUPPORTED = object()
103 def handle(self, args, kwargs): # pylint: disable=unused-argument
104 """Handle this dispatcher's operation with the specified arguments.
106 If this operation dispatcher can handle the given arguments, then
107 return an appropriate value (or raise an appropriate exception).
109 Args:
110 args: The arguments to the operation.
111 kwargs: They keyword arguments to the operation.
113 Returns:
114 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
115 dispatcher can not handle the given arguments.
116 """
117 return self.NOT_SUPPORTED
119 def register(self, op):
120 """Register this dispatcher as a handler for `op`.
122 Args:
123 op: Python function: the TensorFlow operation that should be handled. Must
124 have a dispatch list (which is added automatically for generated ops,
125 and can be added to Python ops using the `add_dispatch_support`
126 decorator).
127 """
128 if not hasattr(op, FALLBACK_DISPATCH_ATTR):
129 raise AssertionError("Dispatching not enabled for %s" % op)
130 getattr(op, FALLBACK_DISPATCH_ATTR).append(self)
133@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
134class GlobalOpDispatcher(object):
135 """Abstract base class for TensorFlow global operator dispatchers."""
137 NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
139 def handle(self, op, args, kwargs):
140 """Handle the specified operation with the specified arguments."""
142 def register(self):
143 """Register this dispatcher as a handler for all ops."""
144 _GLOBAL_DISPATCHERS.append(self)
147def dispatch(op, args, kwargs):
148 """Returns the result from the first successful dispatcher for a given op.
150 Calls the `handle` method of each `OpDispatcher` that has been registered
151 to handle `op`, and returns the value from the first successful handler.
153 Args:
154 op: Python function: the operation to dispatch for.
155 args: The arguments to the operation.
156 kwargs: They keyword arguments to the operation.
158 Returns:
159 The result of the operation, or `NOT_SUPPORTED` if no registered
160 dispatcher can handle the given arguments.
161 """
162 for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR):
163 result = dispatcher.handle(args, kwargs)
164 if result is not OpDispatcher.NOT_SUPPORTED:
165 return result
166 for dispatcher in _GLOBAL_DISPATCHERS:
167 result = dispatcher.handle(op, args, kwargs)
168 if result is not OpDispatcher.NOT_SUPPORTED:
169 return result
170 return OpDispatcher.NOT_SUPPORTED
173class _TypeBasedDispatcher(OpDispatcher):
174 """Dispatcher that handles op if any arguments have a specified type.
176 Checks the types of the arguments and keyword arguments (including elements
177 of lists or tuples), and if any argument values have the indicated type(s),
178 then delegates to an override function.
179 """
181 def __init__(self, override_func, types):
182 self._types = types
183 self._override_func = override_func
185 def _handles(self, args, kwargs):
186 for arg in itertools.chain(args, kwargs.values()):
187 if (isinstance(arg, self._types) or
188 (isinstance(arg, (list, tuple)) and
189 any(isinstance(elt, self._types) for elt in arg))):
190 return True
191 return False
193 def handle(self, args, kwargs):
194 if self._handles(args, kwargs):
195 return self._override_func(*args, **kwargs)
196 else:
197 return self.NOT_SUPPORTED
200# pylint: disable=g-doc-return-or-yield
201def dispatch_for_types(op, *types):
202 """Decorator to declare that a Python function overrides an op for a type.
204 The decorated function is used to override `op` if any of the arguments or
205 keyword arguments (including elements of lists or tuples) have one of the
206 specified types.
208 Example:
210 ```python
211 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
212 def ragged_add(x, y, name=None): ...
213 ```
215 Args:
216 op: Python function: the operation that should be overridden.
217 *types: The argument types for which this function should be used.
218 """
220 def decorator(func):
221 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
222 raise AssertionError("The decorated function's signature must exactly "
223 "match the signature of the overridden op.")
224 _TypeBasedDispatcher(func, types).register(op)
225 return func
227 return decorator
230# pylint: enable=g-doc-return-or-yield
233def add_fallback_dispatch_list(target):
234 """Decorator that adds a dispatch_list attribute to an op."""
235 if hasattr(target, FALLBACK_DISPATCH_ATTR):
236 raise AssertionError("%s already has a dispatch list" % target)
237 setattr(target, FALLBACK_DISPATCH_ATTR, [])
238 return target
241# Alias for backwards-compatibility.
242add_dispatch_list = add_fallback_dispatch_list
245################################################################################
246# Type-based Dispatch
247################################################################################
250@tf_export("experimental.dispatch_for_api")
251def dispatch_for_api(api, *signatures):
252 """Decorator that overrides the default implementation for a TensorFlow API.
254 The decorated function (known as the "dispatch target") will override the
255 default implementation for the API when the API is called with parameters that
256 match a specified type signature. Signatures are specified using dictionaries
257 that map parameter names to type annotations. E.g., in the following example,
258 `masked_add` will be called for `tf.add` if both `x` and `y` are
259 `MaskedTensor`s:
261 >>> class MaskedTensor(tf.experimental.ExtensionType):
262 ... values: tf.Tensor
263 ... mask: tf.Tensor
265 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
266 ... def masked_add(x, y, name=None):
267 ... return MaskedTensor(x.values + y.values, x.mask & y.mask)
269 >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
270 >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
271 values=[11 12], mask=[ True False]
273 If multiple type signatures are specified, then the dispatch target will be
274 called if any of the signatures match. For example, the following code
275 registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
276 a `MaskedTensor`.
278 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
279 ... def masked_add(x, y):
280 ... x_values = x.values if isinstance(x, MaskedTensor) else x
281 ... x_mask = x.mask if isinstance(x, MaskedTensor) else True
282 ... y_values = y.values if isinstance(y, MaskedTensor) else y
283 ... y_mask = y.mask if isinstance(y, MaskedTensor) else True
284 ... return MaskedTensor(x_values + y_values, x_mask & y_mask)
286 The type annotations in type signatures may be type objects (e.g.,
287 `MaskedTensor`), `typing.List` values, or `typing.Union` values. For
288 example, the following will register `masked_concat` to be called if `values`
289 is a list of `MaskedTensor` values:
291 >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
292 ... def masked_concat(values, axis):
293 ... return MaskedTensor(tf.concat([v.values for v in values], axis),
294 ... tf.concat([v.mask for v in values], axis))
296 Each type signature must contain at least one subclass of `tf.CompositeTensor`
297 (which includes subclasses of `tf.ExtensionType`), and dispatch will only be
298 triggered if at least one type-annotated parameter contains a
299 `CompositeTensor` value. This rule avoids invoking dispatch in degenerate
300 cases, such as the following examples:
302 * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
303 dispatch to the decorated dispatch target when the user calls
304 `tf.concat([])`.
306 * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
307 Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
308 target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.
310 The dispatch target's signature must match the signature of the API that is
311 being overridden. In particular, parameters must have the same names, and
312 must occur in the same order. The dispatch target may optionally elide the
313 "name" parameter, in which case it will be wrapped with a call to
314 `tf.name_scope` when appropraite.
316 Args:
317 api: The TensorFlow API to override.
318 *signatures: Dictionaries mapping parameter names or indices to type
319 annotations, specifying when the dispatch target should be called. In
320 particular, the dispatch target will be called if any signature matches;
321 and a signature matches if all of the specified parameters have types that
322 match with the indicated type annotations. If no signatures are
323 specified, then a signature will be read from the dispatch target
324 function's type annotations.
326 Returns:
327 A decorator that overrides the default implementation for `api`.
329 #### Registered APIs
331 The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:
333 <<API_LIST>>
334 """
335 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
336 if dispatcher is None:
337 raise ValueError(f"{api} does not support dispatch.")
339 api_signature = tf_inspect.signature(api)
340 signature_checkers = [
341 _make_signature_checker(api_signature, signature)
342 for signature in signatures
343 ]
345 def decorator(dispatch_target):
346 """Decorator that registers the given dispatch target."""
347 if not callable(dispatch_target):
348 raise TypeError("Expected dispatch_target to be callable; "
349 f"got {dispatch_target!r}")
350 dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature)
351 _check_signature(api_signature, dispatch_target)
353 for signature_checker in signature_checkers:
354 dispatcher.Register(signature_checker, dispatch_target)
355 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures)
357 if not signature_checkers:
358 signature = _signature_from_annotations(dispatch_target)
359 checker = _make_signature_checker(api_signature, signature)
360 dispatcher.Register(checker, dispatch_target)
361 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature)
363 return dispatch_target
365 return decorator
368# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`,
369# which can be used for documentation generation and for improved error messages
370# when APIs are called with unsupported types.
371_TYPE_BASED_DISPATCH_SIGNATURES = {}
374def apis_with_type_based_dispatch():
375 """Returns a list of TensorFlow APIs that support type-based dispatch."""
376 return sorted(
377 _TYPE_BASED_DISPATCH_SIGNATURES,
378 key=lambda api: f"{api.__module__}.{api.__name__}")
381def type_based_dispatch_signatures_for(cls):
382 """Returns dispatch signatures that have been registered for a given class.
384 This function is intended for documentation-generation purposes.
386 Args:
387 cls: The class to search for. Type signatures are searched recursively, so
388 e.g., if `cls=RaggedTensor`, then information will be returned for all
389 dispatch targets that have `RaggedTensor` anywhere in their type
390 annotations (including nested in `typing.Union` or `typing.List`.)
392 Returns:
393 A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API
394 function; and `signatures` is a list of dispatch signatures for `api`
395 that include `cls`. (Each signature is a dict mapping argument names to
396 type annotations; see `dispatch_for_api` for more info.)
397 """
399 def contains_cls(x):
400 """Returns true if `x` contains `cls`."""
401 if isinstance(x, dict):
402 return any(contains_cls(v) for v in x.values())
403 elif x is cls:
404 return True
405 elif (type_annotations.is_generic_list(x) or
406 type_annotations.is_generic_union(x)):
407 type_args = type_annotations.get_generic_type_args(x)
408 return any(contains_cls(arg) for arg in type_args)
409 else:
410 return False
412 result = {}
413 for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
414 for _, signatures in api_signatures.items():
415 filtered = list(filter(contains_cls, signatures))
416 if filtered:
417 result.setdefault(api, []).extend(filtered)
418 return result
421# TODO(edloper): Consider using a mechanism like this to automatically add
422# the `name` argument to all TensorFlow APIs that are implemented in Python
423# (so each Python function doesn't need to do it manually).
424def _add_name_scope_wrapper(func, api_signature):
425 """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.
427 If `func` already expects a "name" arg, or if `api_signature` does not
428 expect a "name" arg, then returns `func` as-is.
430 Args:
431 func: The function to wrap. Signature must match `api_signature` (except
432 the "name" parameter may be missing.
433 api_signature: The signature of the original API (used to find the index for
434 the "name" parameter).
436 Returns:
437 The wrapped function (or the original function if no wrapping is needed).
438 """
439 if "name" not in api_signature.parameters:
440 return func # no wrapping needed (API has no name parameter).
442 func_signature = tf_inspect.signature(func)
443 func_argspec = tf_inspect.getargspec(func)
444 if "name" in func_signature.parameters or func_argspec.keywords is not None:
445 return func # No wrapping needed (already has name parameter).
447 name_index = list(api_signature.parameters).index("name")
449 def wrapped_func(*args, **kwargs):
450 if name_index < len(args):
451 name = args[name_index]
452 args = args[:name_index] + args[name_index + 1:]
453 else:
454 name = kwargs.pop("name", None)
455 if name is None:
456 return func(*args, **kwargs)
457 else:
458 with ops.name_scope(name):
459 return func(*args, **kwargs)
461 wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
462 wrapped_func.__signature__ = func_signature.replace(
463 parameters=(list(func_signature.parameters.values()) +
464 [api_signature.parameters["name"]]))
465 del wrapped_func._tf_decorator
466 return wrapped_func
469@tf_export("experimental.unregister_dispatch_for")
470def unregister_dispatch_for(dispatch_target):
471 """Unregisters a function that was registered with `@dispatch_for_*`.
473 This is primarily intended for testing purposes.
475 Example:
477 >>> # Define a type and register a dispatcher to override `tf.abs`:
478 >>> class MyTensor(tf.experimental.ExtensionType):
479 ... value: tf.Tensor
480 >>> @tf.experimental.dispatch_for_api(tf.abs)
481 ... def my_abs(x: MyTensor):
482 ... return MyTensor(tf.abs(x.value))
483 >>> tf.abs(MyTensor(5))
484 MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
486 >>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`.
487 >>> unregister_dispatch_for(my_abs)
488 >>> tf.abs(MyTensor(5))
489 Traceback (most recent call last):
490 ...
491 ValueError: Attempt to convert a value ... to a Tensor.
493 Args:
494 dispatch_target: The function to unregister.
496 Raises:
497 ValueError: If `dispatch_target` was not registered using `@dispatch_for`,
498 `@dispatch_for_unary_elementwise_apis`, or
499 `@dispatch_for_binary_elementwise_apis`.
500 """
501 found = False
503 # Check if dispatch_target registered by `@dispatch_for_api`
504 for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
505 if dispatch_target in signatures:
506 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR)
507 dispatcher.Unregister(dispatch_target)
508 del signatures[dispatch_target]
509 found = True
511 # Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis`
512 elementwise_keys_to_delete = [
513 key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items()
514 if handler is dispatch_target
515 ]
516 for key in set(elementwise_keys_to_delete):
517 for _, target in _ELEMENTWISE_API_TARGETS[key]:
518 unregister_dispatch_for(target)
519 del _ELEMENTWISE_API_HANDLERS[key]
520 del _ELEMENTWISE_API_TARGETS[key]
521 found = True
523 if not found:
524 raise ValueError(f"Function {dispatch_target} was not registered using "
525 "a `@dispatch_for_*` decorator.")
528def register_dispatchable_type(cls):
529 """Class decorator that registers a type for use with type-based dispatch.
531 Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType`
532 (which are automatically registered).
534 Note: this function is intended to support internal legacy use cases (such
535 as RaggedTensorValue), and will probably not be exposed as a public API.
537 Args:
538 cls: The class to register.
540 Returns:
541 `cls`.
542 """
543 _api_dispatcher.register_dispatchable_type(cls)
544 return cls
547def add_type_based_api_dispatcher(target):
548 """Adds a PythonAPIDispatcher to the given TensorFlow API function."""
549 if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
550 raise ValueError(f"{target} already has a type-based API dispatcher.")
552 _, unwrapped = tf_decorator.unwrap(target)
553 target_argspec = tf_inspect.getargspec(unwrapped)
554 if target_argspec.varargs or target_argspec.keywords:
555 # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
556 # and keywords. Examples of APIs that take varargs and kwargs: meshgrid,
557 # einsum, map_values, map_flat_values.
558 return target
560 setattr(
561 target, TYPE_BASED_DISPATCH_ATTR,
562 _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
563 target_argspec.args,
564 target_argspec.defaults))
565 _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
566 return target
569def _check_signature(api_signature, func):
570 """Checks that a dispatch target's signature is compatible with an API.
572 Args:
573 api_signature: The signature of the TensorFlow API.
574 func: The dispatch target.
576 Raises:
577 ValueError: if the signatures are incompatible. Two signatures are
578 considered compatible if they have the same number of parameters, and all
579 corresponding parameters have the same `name` and `kind`. (Parameters
580 are not required to have the same default value or the same annotation.)
581 """
582 # Special case: if func_signature is (*args, **kwargs), then assume it's ok.
583 func_argspec = tf_inspect.getargspec(func)
584 if (func_argspec.varargs is not None and func_argspec.keywords is not None
585 and not func_argspec.args):
586 return
588 func_signature = tf_inspect.signature(func)
589 ok = len(api_signature.parameters) == len(func_signature.parameters)
590 if ok:
591 for param_1, param_2 in zip(api_signature.parameters.values(),
592 func_signature.parameters.values()):
593 if (param_1.name != param_2.name) or (param_1.kind != param_2.kind):
594 ok = False
595 if not ok:
596 raise ValueError(f"Dispatch function's signature {func_signature} does "
597 f"not match API's signature {api_signature}.")
600def _make_signature_checker(api_signature, signature):
601 """Builds a PySignatureChecker for the given type signature.
603 Args:
604 api_signature: The `inspect.Signature` of the API whose signature is
605 being checked.
606 signature: Dictionary mapping parameter names to type annotations.
608 Returns:
609 A `PySignatureChecker`.
610 """
611 if not (isinstance(signature, dict) and
612 all(isinstance(k, (str, int)) for k in signature)):
613 raise TypeError("signatures must be dictionaries mapping parameter names "
614 "to type annotations.")
615 checkers = []
617 param_names = list(api_signature.parameters)
618 for param_name, param_type in signature.items():
619 # Convert positional parameters to named parameters.
620 if (isinstance(param_name, int) and
621 param_name < len(api_signature.parameters)):
622 param_name = list(api_signature.parameters.values())[param_name].name
624 # Check that the parameter exists, and has an appropriate kind.
625 param = api_signature.parameters.get(param_name, None)
626 if param is None:
627 raise ValueError("signature includes annotation for unknown "
628 f"parameter {param_name!r}.")
629 if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
630 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
631 raise ValueError("Dispatch currently only supports type annotations "
632 "for positional parameters; can't handle annotation "
633 f"for {param.kind!r} parameter {param_name}.")
635 checker = make_type_checker(param_type)
636 index = param_names.index(param_name)
637 checkers.append((index, checker))
639 return _api_dispatcher.PySignatureChecker(checkers)
642# Cache for InstanceTypeChecker objects (we only want to create one
643# InstanceTypeChecker for each type, since each one uses an internal cache
644# to avoid repeated calls back into Python's isinstance).
645_is_instance_checker_cache = {}
648def make_type_checker(annotation):
649 """Builds a PyTypeChecker for the given type annotation."""
650 if type_annotations.is_generic_union(annotation):
651 type_args = type_annotations.get_generic_type_args(annotation)
653 # If the union contains two or more simple types, then use a single
654 # InstanceChecker to check them.
655 simple_types = [t for t in type_args if isinstance(t, type)]
656 simple_types = tuple(sorted(simple_types, key=id))
657 if len(simple_types) > 1:
658 if simple_types not in _is_instance_checker_cache:
659 checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
660 _is_instance_checker_cache[simple_types] = checker
661 options = ([_is_instance_checker_cache[simple_types]] +
662 [make_type_checker(t) for t in type_args
663 if not isinstance(t, type)])
664 return _api_dispatcher.MakeUnionChecker(options)
666 options = [make_type_checker(t) for t in type_args]
667 return _api_dispatcher.MakeUnionChecker(options)
669 elif type_annotations.is_generic_list(annotation):
670 type_args = type_annotations.get_generic_type_args(annotation)
671 if len(type_args) != 1:
672 raise AssertionError("Expected List[...] to have a single type parameter")
673 elt_type = make_type_checker(type_args[0])
674 return _api_dispatcher.MakeListChecker(elt_type)
676 elif isinstance(annotation, type):
677 if annotation not in _is_instance_checker_cache:
678 checker = _api_dispatcher.MakeInstanceChecker(annotation)
679 _is_instance_checker_cache[annotation] = checker
680 return _is_instance_checker_cache[annotation]
682 elif annotation is None:
683 return make_type_checker(type(None))
685 else:
686 raise ValueError(f"Type annotation {annotation} is not currently supported"
687 " by dispatch. Supported annotations: type objects, "
688 " List[...], and Union[...]")
691def _signature_from_annotations(func):
692 """Builds a dict mapping from parameter names to type annotations."""
693 func_signature = tf_inspect.signature(func)
695 signature = dict([(name, param.annotation)
696 for (name, param) in func_signature.parameters.items()
697 if param.annotation != tf_inspect.Parameter.empty])
698 if not signature:
699 raise ValueError("The dispatch_for_api decorator must be called with at "
700 "least one signature, or applied to a function that "
701 "has type annotations on its parameters.")
702 return signature
705# Registries for elementwise APIs and API handlers.
706#
707# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered
708# as elementwise operations using the `register_*_elementwise_api`
709# decorators.
710#
711# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API
712# handlers that have been registered with the `dispatch_for_*_elementwise_apis`
713# decorators.
714#
715# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of
716# `(api, dispatch_target)` pairs. Used to impelement
717# `unregister_elementwise_api_handler`.
718_UNARY_ELEMENTWISE_APIS = []
719_BINARY_ELEMENTWISE_APIS = []
720_BINARY_ELEMENTWISE_ASSERT_APIS = []
721_ELEMENTWISE_API_HANDLERS = {}
722_ELEMENTWISE_API_TARGETS = {}
724_ASSERT_API_TAG = "ASSERT_API_TAG"
727@tf_export("experimental.dispatch_for_unary_elementwise_apis")
728def dispatch_for_unary_elementwise_apis(x_type):
729 """Decorator to override default implementation for unary elementwise APIs.
731 The decorated function (known as the "elementwise api handler") overrides
732 the default implementation for any unary elementwise API whenever the value
733 for the first argument (typically named `x`) matches the type annotation
734 `x_type`. The elementwise api handler is called with two arguments:
736 `elementwise_api_handler(api_func, x)`
738 Where `api_func` is a function that takes a single parameter and performs the
739 elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the
740 elementwise api.
742 The following example shows how this decorator can be used to update all
743 unary elementwise operations to handle a `MaskedTensor` type:
745 >>> class MaskedTensor(tf.experimental.ExtensionType):
746 ... values: tf.Tensor
747 ... mask: tf.Tensor
748 >>> @dispatch_for_unary_elementwise_apis(MaskedTensor)
749 ... def unary_elementwise_api_handler(api_func, x):
750 ... return MaskedTensor(api_func(x.values), x.mask)
751 >>> mt = MaskedTensor([1, -2, -3], [True, False, True])
752 >>> abs_mt = tf.abs(mt)
753 >>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
754 values=[1 2 3], mask=[ True False True]
756 For unary elementwise operations that take extra arguments beyond `x`, those
757 arguments are *not* passed to the elementwise api handler, but are
758 automatically added when `api_func` is called. E.g., in the following
759 example, the `dtype` parameter is not passed to
760 `unary_elementwise_api_handler`, but is added by `api_func`.
762 >>> ones_mt = tf.ones_like(mt, dtype=tf.float32)
763 >>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
764 values=[1.0 1.0 1.0], mask=[ True False True]
766 Args:
767 x_type: A type annotation indicating when the api handler should be called.
768 See `dispatch_for_api` for a list of supported annotation types.
770 Returns:
771 A decorator.
773 #### Registered APIs
775 The unary elementwise APIs are:
777 <<API_LIST>>
778 """
780 def decorator(handler):
781 if (x_type,) in _ELEMENTWISE_API_HANDLERS:
782 raise ValueError("A unary elementwise dispatch handler "
783 f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) "
784 f"has already been registered for {x_type}.")
785 _ELEMENTWISE_API_HANDLERS[(x_type,)] = handler
786 for api in _UNARY_ELEMENTWISE_APIS:
787 _add_dispatch_for_unary_elementwise_api(api, x_type, handler)
789 return handler
791 return decorator
794@tf_export("experimental.dispatch_for_binary_elementwise_apis")
795def dispatch_for_binary_elementwise_apis(x_type, y_type):
796 """Decorator to override default implementation for binary elementwise APIs.
798 The decorated function (known as the "elementwise api handler") overrides
799 the default implementation for any binary elementwise API whenever the value
800 for the first two arguments (typically named `x` and `y`) match the specified
801 type annotations. The elementwise api handler is called with two arguments:
803 `elementwise_api_handler(api_func, x, y)`
805 Where `x` and `y` are the first two arguments to the elementwise api, and
806 `api_func` is a TensorFlow function that takes two parameters and performs the
807 elementwise operation (e.g., `tf.add`).
809 The following example shows how this decorator can be used to update all
810 binary elementwise operations to handle a `MaskedTensor` type:
812 >>> class MaskedTensor(tf.experimental.ExtensionType):
813 ... values: tf.Tensor
814 ... mask: tf.Tensor
815 >>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
816 ... def binary_elementwise_api_handler(api_func, x, y):
817 ... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
818 >>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
819 >>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
820 >>> c = tf.add(a, b)
821 >>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
822 values=[ 3 6 9 12 5], mask=[ True True True False False]
824 Args:
825 x_type: A type annotation indicating when the api handler should be called.
826 y_type: A type annotation indicating when the api handler should be called.
828 Returns:
829 A decorator.
831 #### Registered APIs
833 The binary elementwise APIs are:
835 <<API_LIST>>
836 """
838 def decorator(handler):
839 if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS:
840 raise ValueError("A binary elementwise dispatch handler "
841 f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) "
842 f"has already been registered for ({x_type}, {y_type}).")
843 _ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler
844 for api in _BINARY_ELEMENTWISE_APIS:
845 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
847 return handler
849 return decorator
852@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis")
853def dispatch_for_binary_elementwise_assert_apis(x_type, y_type):
854 """Decorator to override default implementation for binary elementwise assert APIs.
856 The decorated function (known as the "elementwise assert handler")
857 overrides the default implementation for any binary elementwise assert API
858 whenever the value for the first two arguments (typically named `x` and `y`)
859 match the specified type annotations. The handler is called with two
860 arguments:
862 `elementwise_assert_handler(assert_func, x, y)`
864 Where `x` and `y` are the first two arguments to the binary elementwise assert
865 operation, and `assert_func` is a TensorFlow function that takes two
866 parameters and performs the elementwise assert operation (e.g.,
867 `tf.debugging.assert_equal`).
869 The following example shows how this decorator can be used to update all
870 binary elementwise assert operations to handle a `MaskedTensor` type:
872 >>> class MaskedTensor(tf.experimental.ExtensionType):
873 ... values: tf.Tensor
874 ... mask: tf.Tensor
875 >>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
876 ... def binary_elementwise_assert_api_handler(assert_func, x, y):
877 ... merged_mask = tf.logical_and(x.mask, y.mask)
878 ... selected_x_values = tf.boolean_mask(x.values, merged_mask)
879 ... selected_y_values = tf.boolean_mask(y.values, merged_mask)
880 ... assert_func(selected_x_values, selected_y_values)
881 >>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
882 >>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
883 >>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
885 >>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
886 >>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
887 >>> tf.debugging.assert_greater(a, b)
888 Traceback (most recent call last):
889 ...
890 InvalidArgumentError: Condition x > y did not hold.
892 Args:
893 x_type: A type annotation indicating when the api handler should be called.
894 y_type: A type annotation indicating when the api handler should be called.
896 Returns:
897 A decorator.
899 #### Registered APIs
901 The binary elementwise assert APIs are:
903 <<API_LIST>>
904 """
906 def decorator(handler):
907 api_handler_key = (x_type, y_type, _ASSERT_API_TAG)
908 if api_handler_key in _ELEMENTWISE_API_HANDLERS:
909 raise ValueError("A binary elementwise assert dispatch handler "
910 f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) "
911 f"has already been registered for ({x_type}, {y_type}).")
912 _ELEMENTWISE_API_HANDLERS[api_handler_key] = handler
913 for api in _BINARY_ELEMENTWISE_ASSERT_APIS:
914 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
916 return handler
918 return decorator
921def register_unary_elementwise_api(func):
922 """Decorator that registers a TensorFlow op as a unary elementwise API."""
923 _UNARY_ELEMENTWISE_APIS.append(func)
924 for args, handler in _ELEMENTWISE_API_HANDLERS.items():
925 if len(args) == 1:
926 _add_dispatch_for_unary_elementwise_api(func, args[0], handler)
927 return func
930def register_binary_elementwise_api(func):
931 """Decorator that registers a TensorFlow op as a binary elementwise API."""
932 _BINARY_ELEMENTWISE_APIS.append(func)
933 for args, handler in _ELEMENTWISE_API_HANDLERS.items():
934 if len(args) == 2:
935 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
936 return func
939def register_binary_elementwise_assert_api(func):
940 """Decorator that registers a TensorFlow op as a binary elementwise assert API.
942 Different from `dispatch_for_binary_elementwise_apis`, this decorator is used
943 for assert apis, such as assert_equal, assert_none_equal, etc, which return
944 None in eager mode and an op in graph mode.
946 Args:
947 func: The function that implements the binary elementwise assert API.
949 Returns:
950 `func`
951 """
952 _BINARY_ELEMENTWISE_ASSERT_APIS.append(func)
953 for args, handler in _ELEMENTWISE_API_HANDLERS.items():
954 if len(args) == 3 and args[2] is _ASSERT_API_TAG:
955 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
956 return func
959def unary_elementwise_apis():
960 """Returns a list of APIs that have been registered as unary elementwise."""
961 return tuple(_UNARY_ELEMENTWISE_APIS)
964def binary_elementwise_apis():
965 """Returns a list of APIs that have been registered as binary elementwise."""
966 return tuple(_BINARY_ELEMENTWISE_APIS)
969def _add_dispatch_for_unary_elementwise_api(api, x_type,
970 elementwise_api_handler):
971 """Registers a unary elementwise handler as a dispatcher for a given API."""
972 api_signature = tf_inspect.signature(api)
973 x_name = list(api_signature.parameters)[0]
974 name_index = _find_name_index(api_signature)
976 need_to_bind_api_args = (
977 len(api_signature.parameters) > 2 or
978 "name" not in api_signature.parameters)
980 @dispatch_for_api(api, {x_name: x_type})
981 def dispatch_target(*args, **kwargs):
982 args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
983 if args:
984 x, args = args[0], args[1:]
985 else:
986 x = kwargs.pop(x_name)
988 if need_to_bind_api_args:
989 tensor_api = lambda v: api(v, *args, **kwargs)
990 else:
991 tensor_api = api
993 if name is None:
994 return elementwise_api_handler(tensor_api, x)
995 else:
996 with ops.name_scope(name, None, [x]):
997 return elementwise_api_handler(tensor_api, x)
999 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
1000 dispatch_target.__qualname__ = dispatch_target.__name__
1001 # Keep track of what targets we've registered (so we can unregister them).
1002 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), [])
1003 target_list.append((api, dispatch_target))
1006def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type,
1007 elementwise_api_handler):
1008 """Registers a binary elementwise handler as a dispatcher for a given API."""
1009 api_signature = tf_inspect.signature(api)
1010 x_name, y_name = list(api_signature.parameters)[:2]
1011 name_index = _find_name_index(api_signature)
1013 need_to_bind_api_args = (len(api_signature.parameters) > 3 or
1014 "name" not in api_signature.parameters)
1016 @dispatch_for_api(api, {x_name: x_type, y_name: y_type})
1017 def dispatch_target(*args, **kwargs):
1018 args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
1019 if len(args) > 1:
1020 x, y, args = args[0], args[1], args[2:]
1021 elif args:
1022 x, args = args[0], args[1:]
1023 y = kwargs.pop(y_name, None)
1024 else:
1025 x = kwargs.pop(x_name, None)
1026 y = kwargs.pop(y_name, None)
1028 if need_to_bind_api_args:
1029 tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs)
1030 else:
1031 tensor_api = api
1033 if name is None:
1034 return elementwise_api_handler(tensor_api, x, y)
1035 else:
1036 with ops.name_scope(name, None, [x, y]):
1037 return elementwise_api_handler(tensor_api, x, y)
1039 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
1040 dispatch_target.__qualname__ = dispatch_target.__name__
1041 # Keep track of what targets we've registered (so we can unregister them).
1042 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), [])
1043 target_list.append((api, dispatch_target))
1046def _find_name_index(signature):
1047 """Returns the index of the `name` parameter, or -1 if it's not present."""
1048 try:
1049 return list(signature.parameters).index("name")
1050 except ValueError:
1051 return -1
1054def _extract_name_arg(args, kwargs, name_index):
1055 """Extracts the parameter `name` and returns `(args, kwargs, name_value)`."""
1056 if name_index < 0:
1057 name_value = None
1058 elif name_index < len(args):
1059 name_value = args[name_index]
1060 args = args[:name_index] + args[name_index + 1:]
1061 else:
1062 name_value = kwargs.pop("name", None)
1063 return args, kwargs, name_value
1066def update_docstrings_with_api_lists():
1067 """Updates the docstrings of dispatch decorators with API lists.
1069 Updates docstrings for `dispatch_for_api`,
1070 `dispatch_for_unary_elementwise_apis`, and
1071 `dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>'
1072 with a list of APIs that have been registered for that decorator.
1073 """
1074 _update_docstring_with_api_list(dispatch_for_unary_elementwise_apis,
1075 _UNARY_ELEMENTWISE_APIS)
1076 _update_docstring_with_api_list(dispatch_for_binary_elementwise_apis,
1077 _BINARY_ELEMENTWISE_APIS)
1078 _update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis,
1079 _BINARY_ELEMENTWISE_ASSERT_APIS)
1080 _update_docstring_with_api_list(dispatch_for_api,
1081 _TYPE_BASED_DISPATCH_SIGNATURES)
1084def _update_docstring_with_api_list(target, api_list):
1085 """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
1086 lines = []
1087 for func in api_list:
1088 name = tf_export_lib.get_canonical_name_for_symbol(
1089 func, add_prefix_to_v1_names=True)
1090 if name is not None:
1091 params = tf_inspect.signature(func).parameters.keys()
1092 lines.append(f" * `tf.{name}({', '.join(params)})`")
1093 lines.sort()
1094 target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines))
1097################################################################################
1098# Dispatch Support
1099################################################################################
1100@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
1101def add_dispatch_support(target=None, iterable_parameters=None):
1102 """Decorator that adds a dispatch handling wrapper to a TensorFlow Python API.
1104 This wrapper adds the decorated function as an API that can be overridden
1105 using the `@dispatch_for_api` decorator. In the following example, we first
1106 define a new API (`double`) that supports dispatch, then define a custom type
1107 (`MaskedTensor`) and finally use `dispatch_for_api` to override the default
1108 implementation of `double` when called with `MaskedTensor` values:
1110 >>> @add_dispatch_support
1111 ... def double(x):
1112 ... return x * 2
1113 >>> class MaskedTensor(tf.experimental.ExtensionType):
1114 ... values: tf.Tensor
1115 ... mask: tf.Tensor
1116 >>> @dispatch_for_api(double, {'x': MaskedTensor})
1117 ... def masked_double(x):
1118 ... return MaskedTensor(x.values * 2, y.mask)
1120 The optional `iterable_parameter` argument can be used to mark parameters that
1121 can take arbitrary iterable values (such as generator expressions). These
1122 need to be handled specially during dispatch, since just iterating over an
1123 iterable uses up its values. In the following example, we define a new API
1124 whose second argument can be an iterable value; and then override the default
1125 implementatio of that API when the iterable contains MaskedTensors:
1127 >>> @add_dispatch_support(iterable_parameters=['ys'])
1128 ... def add_tensor_to_list_of_tensors(x, ys):
1129 ... return [x + y for y in ys]
1130 >>> @dispatch_for_api(add_tensor_to_list_of_tensors,
1131 ... {'ys': typing.List[MaskedTensor]})
1132 ... def masked_add_tensor_to_list_of_tensors(x, ys):
1133 ... return [MaskedTensor(x+y.values, y.mask) for y in ys]
1135 (Note: the only TensorFlow API that currently supports iterables is `add_n`.)
1137 Args:
1138 target: The TensorFlow API that should support dispatch.
1139 iterable_parameters: Optional list of parameter names that may be called
1140 with iterables (such as the `inputs` parameter for `tf.add_n`).
1142 Returns:
1143 A decorator.
1144 """
1146 if not (iterable_parameters is None or
1147 (isinstance(iterable_parameters, (list, tuple)) and
1148 all(isinstance(p, str) for p in iterable_parameters))):
1149 raise TypeError("iterable_parameters should be a list or tuple of string.")
1151 def decorator(dispatch_target):
1153 # Get the name & index for each iterable parameter.
1154 if iterable_parameters is None:
1155 iterable_params = None
1156 else:
1157 arg_names = tf_inspect.getargspec(dispatch_target).args
1158 iterable_params = [
1159 (name, arg_names.index(name)) for name in iterable_parameters
1160 ]
1162 @traceback_utils.filter_traceback
1163 def op_dispatch_handler(*args, **kwargs):
1164 """Call `dispatch_target`, peforming dispatch when appropriate."""
1166 # Type-based dispatch system (dispatch v2):
1167 if api_dispatcher is not None:
1168 if iterable_params is not None:
1169 args, kwargs = replace_iterable_params(args, kwargs, iterable_params)
1170 result = api_dispatcher.Dispatch(args, kwargs)
1171 if result is not NotImplemented:
1172 return result
1174 # Fallback dispatch system (dispatch v1):
1175 try:
1176 return dispatch_target(*args, **kwargs)
1177 except (TypeError, ValueError):
1178 # Note: convert_to_eager_tensor currently raises a ValueError, not a
1179 # TypeError, when given unexpected types. So we need to catch both.
1180 result = dispatch(op_dispatch_handler, args, kwargs)
1181 if result is not OpDispatcher.NOT_SUPPORTED:
1182 return result
1183 else:
1184 raise
1186 add_fallback_dispatch_list(op_dispatch_handler)
1187 op_dispatch_handler = tf_decorator.make_decorator(dispatch_target,
1188 op_dispatch_handler)
1189 add_type_based_api_dispatcher(op_dispatch_handler)
1190 api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR,
1191 None)
1192 return op_dispatch_handler
1194 if target is None:
1195 return decorator
1196 else:
1197 return decorator(target)
1200def replace_iterable_params(args, kwargs, iterable_params):
1201 """Returns (args, kwargs) with any iterable parameters converted to lists.
1203 Args:
1204 args: Positional rguments to a function
1205 kwargs: Keyword arguments to a function.
1206 iterable_params: A list of (name, index) tuples for iterable parameters.
1208 Returns:
1209 A tuple (args, kwargs), where any positional or keyword parameters in
1210 `iterable_params` have their value converted to a `list`.
1211 """
1212 args = list(args)
1213 for name, index in iterable_params:
1214 if index < len(args):
1215 args[index] = list(args[index])
1216 elif name in kwargs:
1217 kwargs[name] = list(kwargs[name])
1218 return tuple(args), kwargs