Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/control_flow_util.py: 22%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions for control flow. 

16 

17This file is copied from tensorflow/python/ops/control_flow_util.py. 

18""" 

19 

20import tensorflow.compat.v2 as tf 

21 

22 

23def InXlaContext(graph): 

24 ctxt = graph._get_control_flow_context() 

25 return GetContainingXLAContext(ctxt) is not None 

26 

27 

28def GraphOrParentsInXlaContext(graph): 

29 while True: 

30 if InXlaContext(graph): 

31 return True 

32 try: 

33 graph = graph.outer_graph 

34 except AttributeError: 

35 return False 

36 

37 

38def IsInWhileLoop(op): 

39 ctxt = op._get_control_flow_context() 

40 return GetContainingWhileContext(ctxt) is not None 

41 

42 

43def GetContainingWhileContext(ctxt, stop_ctxt=None): 

44 """Returns the first ancestor WhileContext of `ctxt`. 

45 

46 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 

47 while loop. 

48 

49 Args: 

50 ctxt: ControlFlowContext 

51 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 

52 if it sees stop_ctxt. 

53 

54 Returns: 

55 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext 

56 containing `ctxt`, or None if `ctxt` is not in a while loop. If 

57 `stop_ctxt` is not `None`, this returns `ctxt` if it matches `stop_ctxt` 

58 in its traversal. 

59 """ 

60 while ctxt: 

61 if ctxt.IsWhileContext() or ctxt == stop_ctxt: 

62 return ctxt 

63 ctxt = ctxt.outer_context 

64 return None 

65 

66 

67def GetContainingXLAContext(ctxt): 

68 """Returns the first ancestor XLAContext of `ctxt`. 

69 

70 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 

71 while loop. 

72 

73 Args: 

74 ctxt: ControlFlowContext 

75 

76 Returns: 

77 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 

78 `ctxt`, or None if `ctxt` is not in a while loop. 

79 """ 

80 while ctxt: 

81 if ctxt.IsXLAContext(): 

82 return ctxt 

83 ctxt = ctxt.outer_context 

84 return None 

85 

86 

87def smart_cond(pred, true_fn=None, false_fn=None, name=None): 

88 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 

89 

90 If `pred` is a bool or has a constant value, we return either `true_fn()` 

91 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 

92 

93 Args: 

94 pred: A scalar determining whether to return the result of `true_fn` or 

95 `false_fn`. 

96 true_fn: The callable to be performed if pred is true. 

97 false_fn: The callable to be performed if pred is false. 

98 name: Optional name prefix when using `tf.cond`. 

99 

100 Returns: 

101 Tensors returned by the call to either `true_fn` or `false_fn`. 

102 

103 Raises: 

104 TypeError: If `true_fn` or `false_fn` is not callable. 

105 """ 

106 if isinstance(pred, tf.Variable): 

107 return tf.cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) 

108 return tf.__internal__.smart_cond.smart_cond( 

109 pred, true_fn=true_fn, false_fn=false_fn, name=name 

110 ) 

111 

112 

113def constant_value(pred): 

114 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 

115 

116 Args: 

117 pred: A scalar, either a Python bool or a TensorFlow boolean variable 

118 or tensor, or the Python integer 1 or 0. 

119 

120 Returns: 

121 True or False if `pred` has a constant boolean value, None otherwise. 

122 

123 Raises: 

124 TypeError: If `pred` is not a Variable, Tensor or bool, or Python 

125 integer 1 or 0. 

126 """ 

127 if isinstance(pred, tf.Tensor): 

128 return tf.get_static_value(pred) 

129 if pred in {0, 1}: # Accept 1/0 as valid boolean values 

130 return bool(pred) 

131 if isinstance(pred, bool): 

132 return pred 

133 if isinstance(pred, tf.Variable): 

134 return None 

135 raise TypeError( 

136 "`pred` must be a Tensor, or a Python bool, or 1 or 0. " 

137 f"Received: {type(pred)}" 

138 ) 

139