Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/control_flow_util.py: 22%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for control flow.
17This file is copied from tensorflow/python/ops/control_flow_util.py.
18"""
20import tensorflow.compat.v2 as tf
23def InXlaContext(graph):
24 ctxt = graph._get_control_flow_context()
25 return GetContainingXLAContext(ctxt) is not None
28def GraphOrParentsInXlaContext(graph):
29 while True:
30 if InXlaContext(graph):
31 return True
32 try:
33 graph = graph.outer_graph
34 except AttributeError:
35 return False
38def IsInWhileLoop(op):
39 ctxt = op._get_control_flow_context()
40 return GetContainingWhileContext(ctxt) is not None
43def GetContainingWhileContext(ctxt, stop_ctxt=None):
44 """Returns the first ancestor WhileContext of `ctxt`.
46 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
47 while loop.
49 Args:
50 ctxt: ControlFlowContext
51 stop_ctxt: ControlFlowContext, optional. If provided, the search will end
52 if it sees stop_ctxt.
54 Returns:
55 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext
56 containing `ctxt`, or None if `ctxt` is not in a while loop. If
57 `stop_ctxt` is not `None`, this returns `ctxt` if it matches `stop_ctxt`
58 in its traversal.
59 """
60 while ctxt:
61 if ctxt.IsWhileContext() or ctxt == stop_ctxt:
62 return ctxt
63 ctxt = ctxt.outer_context
64 return None
67def GetContainingXLAContext(ctxt):
68 """Returns the first ancestor XLAContext of `ctxt`.
70 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
71 while loop.
73 Args:
74 ctxt: ControlFlowContext
76 Returns:
77 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
78 `ctxt`, or None if `ctxt` is not in a while loop.
79 """
80 while ctxt:
81 if ctxt.IsXLAContext():
82 return ctxt
83 ctxt = ctxt.outer_context
84 return None
87def smart_cond(pred, true_fn=None, false_fn=None, name=None):
88 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
90 If `pred` is a bool or has a constant value, we return either `true_fn()`
91 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
93 Args:
94 pred: A scalar determining whether to return the result of `true_fn` or
95 `false_fn`.
96 true_fn: The callable to be performed if pred is true.
97 false_fn: The callable to be performed if pred is false.
98 name: Optional name prefix when using `tf.cond`.
100 Returns:
101 Tensors returned by the call to either `true_fn` or `false_fn`.
103 Raises:
104 TypeError: If `true_fn` or `false_fn` is not callable.
105 """
106 if isinstance(pred, tf.Variable):
107 return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
108 return tf.__internal__.smart_cond.smart_cond(
109 pred, true_fn=true_fn, false_fn=false_fn, name=name
110 )
113def constant_value(pred):
114 """Return the bool value for `pred`, or None if `pred` had a dynamic value.
116 Args:
117 pred: A scalar, either a Python bool or a TensorFlow boolean variable
118 or tensor, or the Python integer 1 or 0.
120 Returns:
121 True or False if `pred` has a constant boolean value, None otherwise.
123 Raises:
124 TypeError: If `pred` is not a Variable, Tensor or bool, or Python
125 integer 1 or 0.
126 """
127 if isinstance(pred, tf.Tensor):
128 return tf.get_static_value(pred)
129 if pred in {0, 1}: # Accept 1/0 as valid boolean values
130 return bool(pred)
131 if isinstance(pred, bool):
132 return pred
133 if isinstance(pred, tf.Variable):
134 return None
135 raise TypeError(
136 "`pred` must be a Tensor, or a Python bool, or 1 or 0. "
137 f"Received: {type(pred)}"
138 )