Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py: 31%
42 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for control flow.
17This file is copied from tensorflow/python/ops/control_flow_util.py.
18"""
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import smart_cond as smart_module
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import cond
24from tensorflow.python.ops import variables
27def InXlaContext(graph):
28 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
29 return GetContainingXLAContext(ctxt) is not None
32def GraphOrParentsInXlaContext(graph):
33 while True:
34 if InXlaContext(graph): return True
35 try:
36 graph = graph.outer_graph
37 except AttributeError:
38 return False
41def IsInWhileLoop(op):
42 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
43 return GetContainingWhileContext(ctxt) is not None
46def GetContainingWhileContext(ctxt, stop_ctxt=None):
47 """Returns the first ancestor WhileContext of `ctxt`.
49 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
50 while loop.
52 Args:
53 ctxt: ControlFlowContext
54 stop_ctxt: ControlFlowContext, optional. If provided, the search will end
55 if it sees stop_ctxt.
57 Returns:
58 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
59 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
60 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
61 """
62 while ctxt:
63 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
64 ctxt = ctxt.outer_context
65 return None
68def GetContainingXLAContext(ctxt):
69 """Returns the first ancestor XLAContext of `ctxt`.
71 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
72 while loop.
74 Args:
75 ctxt: ControlFlowContext
77 Returns:
78 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
79 `ctxt`, or None if `ctxt` is not in a while loop.
80 """
81 while ctxt:
82 if ctxt.IsXLAContext(): return ctxt
83 ctxt = ctxt.outer_context
84 return None
87def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable=invalid-name
88 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
90 If `pred` is a bool or has a constant value, we return either `true_fn()`
91 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
93 Args:
94 pred: A scalar determining whether to return the result of `true_fn` or
95 `false_fn`.
96 true_fn: The callable to be performed if pred is true.
97 false_fn: The callable to be performed if pred is false.
98 name: Optional name prefix when using `tf.cond`.
100 Returns:
101 Tensors returned by the call to either `true_fn` or `false_fn`.
103 Raises:
104 TypeError: If `true_fn` or `false_fn` is not callable.
105 """
106 if isinstance(pred, variables.Variable):
107 return cond.cond(
108 pred, true_fn=true_fn, false_fn=false_fn, name=name)
109 return smart_module.smart_cond(
110 pred, true_fn=true_fn, false_fn=false_fn, name=name)
113def constant_value(pred): # pylint: disable=invalid-name
114 """Return the bool value for `pred`, or None if `pred` had a dynamic value.
116 Args:
117 pred: A scalar, either a Python bool or a TensorFlow boolean variable
118 or tensor, or the Python integer 1 or 0.
120 Returns:
121 True or False if `pred` has a constant boolean value, None otherwise.
123 Raises:
124 TypeError: If `pred` is not a Variable, Tensor or bool, or Python
125 integer 1 or 0.
126 """
127 if isinstance(pred, ops.Tensor):
128 return tensor_util.constant_value(pred)
129 if pred in {0, 1}: # Accept 1/0 as valid boolean values
130 return bool(pred)
131 if isinstance(pred, bool):
132 return pred
133 if isinstance(pred, variables.Variable):
134 return None
135 raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. "
136 "Found instead: %s" % type(pred))