Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_case.py: 30%
79 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Case functions for Control Flow Operations."""
17import collections
18import functools
19from tensorflow.python.eager import context
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops_stack
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import control_flow_assert
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.util import dispatch
29from tensorflow.python.util.tf_export import tf_export
32@tf_export("case", v1=[])
33@dispatch.add_dispatch_support
34def case_v2(pred_fn_pairs,
35 default=None,
36 exclusive=False,
37 strict=False,
38 name="case"):
39 """Create a case operation.
41 See also `tf.switch_case`.
43 The `pred_fn_pairs` parameter is a list of pairs of size N.
44 Each pair contains a boolean scalar tensor and a python callable that
45 creates the tensors to be returned if the boolean evaluates to True.
46 `default` is a callable generating a list of tensors. All the callables
47 in `pred_fn_pairs` as well as `default` (if provided) should return the same
48 number and types of tensors.
50 If `exclusive==True`, all predicates are evaluated, and an exception is
51 thrown if more than one of the predicates evaluates to `True`.
52 If `exclusive==False`, execution stops at the first predicate which
53 evaluates to True, and the tensors generated by the corresponding function
54 are returned immediately. If none of the predicates evaluate to True, this
55 operation returns the tensors generated by `default`.
57 `tf.case` supports nested structures as implemented in
58 `tf.nest`. All of the callables must return the same (possibly nested) value
59 structure of lists, tuples, and/or named tuples. Singleton lists and tuples
60 form the only exceptions to this: when returned by a callable, they are
61 implicitly unpacked to single values. This behavior is disabled by passing
62 `strict=True`.
64 @compatibility(v2)
65 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
66 tf.Variable are no longer hashable in v2, so cannot be used as a key for a
67 dictionary. Please use a list or a tuple instead.
68 @end_compatibility
71 **Example 1:**
73 Pseudocode:
75 ```
76 if (x < y) return 17;
77 else return 23;
78 ```
80 Expressions:
82 ```python
83 f1 = lambda: tf.constant(17)
84 f2 = lambda: tf.constant(23)
85 r = tf.case([(tf.less(x, y), f1)], default=f2)
86 ```
88 **Example 2:**
90 Pseudocode:
92 ```
93 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
94 if (x < y) return 17;
95 else if (x > z) return 23;
96 else return -1;
97 ```
99 Expressions:
101 ```python
102 def f1(): return tf.constant(17)
103 def f2(): return tf.constant(23)
104 def f3(): return tf.constant(-1)
105 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
106 default=f3, exclusive=True)
107 ```
109 Args:
110 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
111 returns a list of tensors.
112 default: Optional callable that returns a list of tensors.
113 exclusive: True iff at most one predicate is allowed to evaluate to `True`.
114 strict: A boolean that enables/disables 'strict' mode; see above.
115 name: A name for this operation (optional).
117 Returns:
118 The tensors returned by the first pair whose predicate evaluated to True, or
119 those returned by `default` if none does.
121 Raises:
122 TypeError: If `pred_fn_pairs` is not a list/tuple.
123 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
124 TypeError: If `fns[i]` is not callable for any i, or `default` is not
125 callable.
126 """
127 return _case_helper(
128 cond.cond,
129 pred_fn_pairs,
130 default,
131 exclusive,
132 name,
133 allow_python_preds=False,
134 strict=strict)
137@tf_export(v1=["case"])
138@dispatch.add_dispatch_support
139def case(pred_fn_pairs,
140 default=None,
141 exclusive=False,
142 strict=False,
143 name="case"):
144 """Create a case operation.
146 See also `tf.switch_case`.
148 The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
149 Each pair contains a boolean scalar tensor and a python callable that
150 creates the tensors to be returned if the boolean evaluates to True.
151 `default` is a callable generating a list of tensors. All the callables
152 in `pred_fn_pairs` as well as `default` (if provided) should return the same
153 number and types of tensors.
155 If `exclusive==True`, all predicates are evaluated, and an exception is
156 thrown if more than one of the predicates evaluates to `True`.
157 If `exclusive==False`, execution stops at the first predicate which
158 evaluates to True, and the tensors generated by the corresponding function
159 are returned immediately. If none of the predicates evaluate to True, this
160 operation returns the tensors generated by `default`.
162 `tf.case` supports nested structures as implemented in
163 `tf.nest`. All of the callables must return the same (possibly nested) value
164 structure of lists, tuples, and/or named tuples. Singleton lists and tuples
165 form the only exceptions to this: when returned by a callable, they are
166 implicitly unpacked to single values. This behavior is disabled by passing
167 `strict=True`.
169 If an unordered dictionary is used for `pred_fn_pairs`, the order of the
170 conditional tests is not guaranteed. However, the order is guaranteed to be
171 deterministic, so that variables created in conditional branches are created
172 in fixed order across runs.
174 @compatibility(eager)
175 Unordered dictionaries are not supported in eager mode when `exclusive=False`.
176 Use a list of tuples instead.
177 @end_compatibility
180 **Example 1:**
182 Pseudocode:
184 ```
185 if (x < y) return 17;
186 else return 23;
187 ```
189 Expressions:
191 ```python
192 f1 = lambda: tf.constant(17)
193 f2 = lambda: tf.constant(23)
194 r = tf.case([(tf.less(x, y), f1)], default=f2)
195 ```
197 **Example 2:**
199 Pseudocode:
201 ```
202 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
203 if (x < y) return 17;
204 else if (x > z) return 23;
205 else return -1;
206 ```
208 Expressions:
210 ```python
211 def f1(): return tf.constant(17)
212 def f2(): return tf.constant(23)
213 def f3(): return tf.constant(-1)
214 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
215 default=f3, exclusive=True)
216 ```
218 Args:
219 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
220 callable which returns a list of tensors.
221 default: Optional callable that returns a list of tensors.
222 exclusive: True iff at most one predicate is allowed to evaluate to `True`.
223 strict: A boolean that enables/disables 'strict' mode; see above.
224 name: A name for this operation (optional).
226 Returns:
227 The tensors returned by the first pair whose predicate evaluated to True, or
228 those returned by `default` if none does.
230 Raises:
231 TypeError: If `pred_fn_pairs` is not a list/dictionary.
232 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
233 TypeError: If `fns[i]` is not callable for any i, or `default` is not
234 callable.
235 """
236 return _case_helper(
237 cond.cond,
238 pred_fn_pairs,
239 default,
240 exclusive,
241 name,
242 allow_python_preds=False,
243 strict=strict)
246def _assert_at_most_n_true(predicates, n, msg):
247 """Returns an Assert op that checks that at most n predicates are True.
249 Args:
250 predicates: list of bool scalar tensors.
251 n: maximum number of true predicates allowed.
252 msg: Error message.
253 """
254 preds_c = array_ops_stack.stack(predicates, name="preds_c")
255 num_true_conditions = math_ops.reduce_sum(
256 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
257 condition = math_ops.less_equal(num_true_conditions,
258 constant_op.constant(n, name="n_true_conds"))
259 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
260 error_msg = [
261 "%s: more than %d conditions (%s) evaluated as True:" %
262 (msg, n, preds_names), preds_c
263 ]
264 return control_flow_assert.Assert(
265 condition, data=error_msg, summarize=len(predicates))
268def _case_create_default_action(predicates, actions):
269 """Creates default action for a list of actions and their predicates.
271 It uses the input actions to select an arbitrary as default and makes sure
272 that corresponding predicates have valid values.
274 Args:
275 predicates: a list of bool scalar tensors
276 actions: a list of callable objects which return tensors.
278 Returns:
279 a callable
280 """
281 k = len(predicates) - 1 # could pick any
282 predicate, action = predicates[k], actions[k]
283 other_predicates, other_actions = predicates[:k], actions[:k]
285 def default_action():
286 others_msg = ("Implementation error: "
287 "selected default action #%d was called, but some of other "
288 "predicates are True: " % k)
289 default_msg = ("Input error: "
290 "None of conditions evaluated as True:",
291 array_ops_stack.stack(predicates, name="preds_c"))
292 with ops.control_dependencies([
293 _assert_at_most_n_true( # pylint: disable=protected-access
294 other_predicates, n=0, msg=others_msg),
295 control_flow_assert.Assert(predicate, data=default_msg)
296 ]):
297 return action()
299 return default_action, other_predicates, other_actions
302def _case_helper(cond_fn,
303 pred_fn_pairs,
304 default,
305 exclusive,
306 name,
307 allow_python_preds=False,
308 **cond_kwargs):
309 """Implementation of case that allows for different cond functions.
311 Args:
312 cond_fn: method that has signature and semantics of `cond` above.
313 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
314 callable which returns a list of tensors.
315 default: Optional callable that returns a list of tensors.
316 exclusive: True iff at most one predicate is allowed to evaluate to `True`.
317 name: A name for this operation (optional).
318 allow_python_preds: if true, pred_fn_pairs may contain Python bools in
319 addition to boolean Tensors
320 **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
322 Returns:
323 The tensors returned by the first pair whose predicate evaluated to True, or
324 those returned by `default` if none does.
326 Raises:
327 TypeError: If `pred_fn_pairs` is not a list/dictionary.
328 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
329 TypeError: If `fns[i]` is not callable for any i, or `default` is not
330 callable.
331 """
332 predicates, actions = _case_verify_and_canonicalize_args(
333 pred_fn_pairs, exclusive, name, allow_python_preds)
334 with ops.name_scope(name, "case", [predicates]):
335 if default is None:
336 default, predicates, actions = _case_create_default_action(
337 predicates, actions)
338 fn = default
339 # To eval conditions in direct order we create nested conditions in reverse:
340 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
341 for predicate, action in reversed(list(zip(predicates, actions))):
342 fn = functools.partial(
343 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
344 if exclusive:
345 with ops.control_dependencies([
346 _assert_at_most_n_true( # pylint: disable=protected-access
347 predicates, n=1, msg="Input error: exclusive=True")
348 ]):
349 return fn()
350 else:
351 return fn()
354def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
355 allow_python_preds):
356 """Verifies input arguments for the case function.
358 Args:
359 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
360 callable which returns a list of tensors.
361 exclusive: True iff at most one predicate is allowed to evaluate to `True`.
362 name: A name for the case operation.
363 allow_python_preds: if true, pred_fn_pairs may contain Python bools in
364 addition to boolean Tensors
366 Raises:
367 TypeError: If `pred_fn_pairs` is not a list/dictionary.
368 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
369 TypeError: If `fns[i]` is not callable for any i, or `default` is not
370 callable.
372 Returns:
373 a tuple <list of scalar bool tensors, list of callables>.
374 """
375 if not isinstance(pred_fn_pairs, (list, tuple, dict)):
376 raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. "
377 f"Received: {type(pred_fn_pairs)}")
379 if isinstance(pred_fn_pairs, collections.OrderedDict):
380 pred_fn_pairs = pred_fn_pairs.items()
381 elif isinstance(pred_fn_pairs, dict):
382 if context.executing_eagerly():
383 # No name to sort on in eager mode. Use dictionary traversal order,
384 # which is nondeterministic in versions of Python < 3.6
385 if not exclusive:
386 raise ValueError("Unordered dictionaries are not supported for the "
387 "'pred_fn_pairs' argument when `exclusive=False` and "
388 "eager mode is enabled.")
389 pred_fn_pairs = list(pred_fn_pairs.items())
390 else:
391 pred_fn_pairs = sorted(
392 pred_fn_pairs.items(), key=lambda item: item[0].name)
393 if not exclusive:
394 logging.warn(
395 "%s: An unordered dictionary of predicate/fn pairs was "
396 "provided, but exclusive=False. The order of conditional "
397 "tests is deterministic but not guaranteed.", name)
398 for pred_fn_pair in pred_fn_pairs:
399 if not isinstance(pred_fn_pair, tuple) or len(pred_fn_pair) != 2:
400 raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. "
401 f"Received {pred_fn_pair}.")
402 pred, fn = pred_fn_pair
404 if isinstance(pred, ops.Tensor):
405 if pred.dtype != dtypes.bool:
406 raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
407 elif not allow_python_preds:
408 raise TypeError("pred must be a Tensor, got: %s" % pred)
409 elif not isinstance(pred, bool):
410 raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
412 if not callable(fn):
413 raise TypeError("fn for pred %s must be callable." % pred.name)
415 predicates, actions = zip(*pred_fn_pairs)
416 return predicates, actions