Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py: 32%
31 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""smart_cond and related utilities."""
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import tensor_util
19from tensorflow.python.ops import cond
20from tensorflow.python.ops import control_flow_case
21from tensorflow.python.util.tf_export import tf_export
24@tf_export("__internal__.smart_cond.smart_cond", v1=[])
25def smart_cond(pred, true_fn=None, false_fn=None, name=None):
26 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
28 If `pred` is a bool or has a constant value, we return either `true_fn()`
29 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
31 Args:
32 pred: A scalar determining whether to return the result of `true_fn` or
33 `false_fn`.
34 true_fn: The callable to be performed if pred is true.
35 false_fn: The callable to be performed if pred is false.
36 name: Optional name prefix when using `tf.cond`.
38 Returns:
39 Tensors returned by the call to either `true_fn` or `false_fn`.
41 Raises:
42 TypeError: If `true_fn` or `false_fn` is not callable.
43 """
44 if not callable(true_fn):
45 raise TypeError(f"Argument `true_fn` must be callable. Received {true_fn}")
46 if not callable(false_fn):
47 raise TypeError(
48 f"Argument `false_fn` must be callable. Received {false_fn}")
50 pred_value = smart_constant_value(pred)
51 if pred_value is not None:
52 if pred_value:
53 return true_fn()
54 else:
55 return false_fn()
56 else:
57 return cond.cond(pred, true_fn=true_fn, false_fn=false_fn,
58 name=name)
61def smart_constant_value(pred):
62 """Return the bool value for `pred`, or None if `pred` had a dynamic value.
64 Args:
65 pred: A scalar, either a Python bool or tensor.
67 Returns:
68 True or False if `pred` has a constant boolean value, None otherwise.
70 Raises:
71 TypeError: If `pred` is not a Tensor or bool.
72 """
73 if isinstance(pred, ops.Tensor):
74 pred_value = tensor_util.constant_value(pred)
75 # TODO(skyewm): consider folding this into tensor_util.constant_value.
76 # pylint: disable=protected-access
77 if pred_value is None:
78 pred_value = tensor_util.try_evaluate_constant(pred)
79 # pylint: enable=protected-access
80 elif pred in {0, 1}: # Accept 1/0 as valid boolean values
81 pred_value = bool(pred)
82 elif isinstance(pred, bool):
83 pred_value = pred
84 else:
85 raise TypeError("Argument `pred` must be a Tensor, or a Python bool, or 1 "
86 f"or 0. Received: pred={pred} of type "
87 f"{type(pred).__name__}")
89 return pred_value
92def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"):
93 """Like tf.case, except attempts to statically evaluate predicates.
95 If any predicate in `pred_fn_pairs` is a bool or has a constant value, the
96 associated callable will be called or omitted depending on its value.
97 Otherwise this functions like tf.case.
99 Args:
100 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
101 callable which returns a list of tensors.
102 default: Optional callable that returns a list of tensors.
103 exclusive: True iff at most one predicate is allowed to evaluate to `True`.
104 name: A name for this operation (optional).
106 Returns:
107 The tensors returned by the first pair whose predicate evaluated to True, or
108 those returned by `default` if none does.
110 Raises:
111 TypeError: If `pred_fn_pairs` is not a list/dictionary.
112 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
113 TypeError: If `fns[i]` is not callable for any i, or `default` is not
114 callable.
115 """
116 return control_flow_case._case_helper( # pylint: disable=protected-access
117 smart_cond,
118 pred_fn_pairs,
119 default,
120 exclusive,
121 name,
122 allow_python_preds=True)