Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_switch_case.py: 28%
58 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Switch case for Control Flow Operations."""
17from tensorflow.python.eager import context
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import control_flow_util as util
21from tensorflow.python.ops import gen_functional_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.util.lazy_loader import LazyLoader
24from tensorflow.python.util.tf_export import tf_export
26# TODO(b/269483538): needed for references while refactors are in progress
27# This is to avoid a circular dependency:
28# cond_v2 -> gradients_util -> control_flow_ops
29cond_v2 = LazyLoader("cond_v2", globals(),
30 "tensorflow.python.ops.cond_v2")
33def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
34 branch_index):
35 """Verifies input arguments for the case function.
37 Args:
38 branch_fns: Dict or list of pairs of an `int` and a callable which returns a
39 list of tensors.
40 default: Optional callable that returns a list of tensors.
41 branch_index: Optional int `Tensor`, which selects for the corresponding
42 pred_fn_pair.
44 Raises:
45 TypeError: If `branch_fns` is not a list/dictionary.
46 TypeError: If `branch_fns` is a list but does not contain 2-tuples or
47 callables.
48 TypeError: If `fns[i]` is not callable for any i, or `default` is not
49 callable.
51 Returns:
52 branch_fns: validated list of callables for each branch (default last).
53 """
54 if not isinstance(branch_index, ops.Tensor):
55 raise TypeError("'branch_index' must be a Tensor, got {}".format(
56 type(branch_index)))
57 if not branch_index.dtype.is_integer:
58 raise TypeError("'branch_index' must be an integer Tensor, got {}".format(
59 branch_index.dtype))
61 if not branch_fns:
62 raise ValueError("Must provide at least one item in 'branch_fns'")
63 if not isinstance(branch_fns, (list, tuple, dict)):
64 raise TypeError("'branch_fns' must be a list, tuple, or dict")
66 if isinstance(branch_fns, dict):
67 branch_fns = branch_fns.items()
69 if all(callable(fn) for fn in branch_fns):
70 branch_fns = list(enumerate(branch_fns))
72 for key_fn_pair in branch_fns:
73 if not isinstance(key_fn_pair, tuple) or len(key_fn_pair) != 2:
74 raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. "
75 f"Received {key_fn_pair}.")
76 key, branch_fn = key_fn_pair
78 if not isinstance(key, int):
79 raise TypeError("key must be a Python `int`, got {}".format(type(key)))
81 if not callable(branch_fn):
82 raise TypeError("fn for key {} must be callable.".format(key))
84 keys = [p[0] for p in branch_fns]
85 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
86 raise ValueError(
87 "branch indices (keys) must form contiguous range of [0 to {}) but "
88 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
89 actions = [p[1] for p in sorted(branch_fns)]
90 if default is not None:
91 actions.append(default)
92 return actions
95def _indexed_case_helper(branch_fns,
96 default,
97 branch_index,
98 name,
99 lower_using_switch_merge=None):
100 """Implementation of case that emits the n-way indexed Case op.
102 Args:
103 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a callable
104 which returns a list of tensors.
105 default: Optional callable that returns a list of tensors.
106 branch_index: Optional int `Tensor`, which selects for the corresponding
107 pred_fn_pair.
108 name: A name for this operation (optional).
109 lower_using_switch_merge: Lower this op using switch merge ops (optional).
111 Returns:
112 The tensors returned by the pair whose key matched branch_index, or
113 those returned by `default` if none does.
115 Raises:
116 TypeError: If `branch_fns` is not a list/dictionary.
117 TypeError: If `branch_fns` is a list but does not contain 2-tuples or
118 callables.
119 TypeError: If `fns[i]` is not callable for any i, or `default` is not
120 callable.
121 """
122 branch_fns = _indexed_case_verify_and_canonicalize_args(
123 branch_fns, default, branch_index)
124 with ops.name_scope(name, "case", [branch_index]):
125 if context.executing_eagerly() and not hasattr(branch_index, "graph"):
126 branch_index = array_ops.where(
127 math_ops.less(branch_index, 0)
128 | math_ops.greater_equal(branch_index, len(branch_fns)),
129 len(branch_fns) - 1, branch_index)
130 return branch_fns[int(branch_index)]()
131 return cond_v2.indexed_case(
132 branch_index,
133 branch_fns,
134 lower_using_switch_merge=lower_using_switch_merge)
137@tf_export("__internal__.execute_fn_for_device", v1=[])
138def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
139 """Executes one of the provided callables based on the device placement.
141 This API is used when the implementations for high level function depend on
142 the underlying device placement. It takes a dictionary of device type to
143 callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
144 the device where to run this op matches the key in 'device_branch_fns',
145 the corresponding callable is executed, falling back to 'default_fn' if none
146 matches.
148 **Example:**
149 ```python
150 def f1(): return tf.constant(1)
151 def f2(): return tf.constant(2)
152 r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
153 ```
154 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
155 any other device types.
158 Args:
159 device_branch_fns: a dictionary of device types to the callables. Each
160 callable must return a matching structure of tensors.
161 default_fn: fallback callable when the underlying device does not match any
162 key in the 'device_branch_fns'.
163 name: A name for this operation (optional).
165 Returns:
166 The tensors returned by the callable identified by device type during
167 execution, or those returned by 'default_fn' if no key matches.
168 """
169 # Always execute the default fn for XLA to avoid complicated graph by case op.
170 # see more discussions in b/167276293.
171 is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
172 if is_in_xla:
173 return default_fn()
174 device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
175 branch_fns = list(device_branch_fns_upper.values())
176 devices = list(device_branch_fns_upper.keys())
177 device_index = gen_functional_ops.device_index(device_names=devices)
178 return _indexed_case_helper(
179 branch_fns,
180 default_fn,
181 device_index,
182 name,
183 lower_using_switch_merge=False)
186@tf_export("switch_case")
187def switch_case(branch_index, branch_fns, default=None, name="switch_case"):
188 """Create a switch/case operation, i.e.
190 an integer-indexed conditional.
192 See also `tf.case`.
194 This op can be substantially more efficient than `tf.case` when exactly one
195 branch will be selected. `tf.switch_case` is more like a C++ switch/case
196 statement than `tf.case`, which is more like an if/elif/elif/else chain.
198 The `branch_fns` parameter is either a dict from `int` to callables, or list
199 of (`int`, callable) pairs, or simply a list of callables (in which case the
200 index is implicitly the key). The `branch_index` `Tensor` is used to select an
201 element in `branch_fns` with matching `int` key, falling back to `default`
202 if none match, or `max(keys)` if no `default` is provided. The keys must form
203 a contiguous set from `0` to `len(branch_fns) - 1`.
205 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All
206 callables must return the same (possibly nested) value structure of lists,
207 tuples, and/or named tuples.
209 **Example:**
211 Pseudocode:
213 ```c++
214 switch (branch_index) { // c-style switch
215 case 0: return 17;
216 case 1: return 31;
217 default: return -1;
218 }
219 ```
220 or
221 ```python
222 branches = {0: lambda: 17, 1: lambda: 31}
223 branches.get(branch_index, lambda: -1)()
224 ```
226 Expressions:
228 ```python
229 def f1(): return tf.constant(17)
230 def f2(): return tf.constant(31)
231 def f3(): return tf.constant(-1)
232 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
233 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
234 ```
236 Args:
237 branch_index: An int Tensor specifying which of `branch_fns` should be
238 executed.
239 branch_fns: A `dict` mapping `int`s to callables, or a `list` of (`int`,
240 callable) pairs, or simply a list of callables (in which case the index
241 serves as the key). Each callable must return a matching structure of
242 tensors.
243 default: Optional callable that returns a structure of tensors.
244 name: A name for this operation (optional).
246 Returns:
247 The tensors returned by the callable identified by `branch_index`, or those
248 returned by `default` if no key matches and `default` was provided, or those
249 returned by the max-keyed `branch_fn` if no `default` is provided.
251 Raises:
252 TypeError: If `branch_fns` is not a list/dictionary.
253 TypeError: If `branch_fns` is a list but does not contain 2-tuples or
254 callables.
255 TypeError: If `fns[i]` is not callable for any i, or `default` is not
256 callable.
257 """
258 return _indexed_case_helper(branch_fns, default, branch_index, name)