Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py: 32%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""smart_cond and related utilities.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.framework import tensor_util 

19from tensorflow.python.ops import cond 

20from tensorflow.python.ops import control_flow_case 

21from tensorflow.python.util.tf_export import tf_export 

22 

23 

24@tf_export("__internal__.smart_cond.smart_cond", v1=[]) 

25def smart_cond(pred, true_fn=None, false_fn=None, name=None): 

26 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 

27 

28 If `pred` is a bool or has a constant value, we return either `true_fn()` 

29 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 

30 

31 Args: 

32 pred: A scalar determining whether to return the result of `true_fn` or 

33 `false_fn`. 

34 true_fn: The callable to be performed if pred is true. 

35 false_fn: The callable to be performed if pred is false. 

36 name: Optional name prefix when using `tf.cond`. 

37 

38 Returns: 

39 Tensors returned by the call to either `true_fn` or `false_fn`. 

40 

41 Raises: 

42 TypeError: If `true_fn` or `false_fn` is not callable. 

43 """ 

44 if not callable(true_fn): 

45 raise TypeError(f"Argument `true_fn` must be callable. Received {true_fn}") 

46 if not callable(false_fn): 

47 raise TypeError( 

48 f"Argument `false_fn` must be callable. Received {false_fn}") 

49 

50 pred_value = smart_constant_value(pred) 

51 if pred_value is not None: 

52 if pred_value: 

53 return true_fn() 

54 else: 

55 return false_fn() 

56 else: 

57 return cond.cond(pred, true_fn=true_fn, false_fn=false_fn, 

58 name=name) 

59 

60 

61def smart_constant_value(pred): 

62 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 

63 

64 Args: 

65 pred: A scalar, either a Python bool or tensor. 

66 

67 Returns: 

68 True or False if `pred` has a constant boolean value, None otherwise. 

69 

70 Raises: 

71 TypeError: If `pred` is not a Tensor or bool. 

72 """ 

73 if isinstance(pred, ops.Tensor): 

74 pred_value = tensor_util.constant_value(pred) 

75 # TODO(skyewm): consider folding this into tensor_util.constant_value. 

76 # pylint: disable=protected-access 

77 if pred_value is None: 

78 pred_value = tensor_util.try_evaluate_constant(pred) 

79 # pylint: enable=protected-access 

80 elif pred in {0, 1}: # Accept 1/0 as valid boolean values 

81 pred_value = bool(pred) 

82 elif isinstance(pred, bool): 

83 pred_value = pred 

84 else: 

85 raise TypeError("Argument `pred` must be a Tensor, or a Python bool, or 1 " 

86 f"or 0. Received: pred={pred} of type " 

87 f"{type(pred).__name__}") 

88 

89 return pred_value 

90 

91 

92def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"): 

93 """Like tf.case, except attempts to statically evaluate predicates. 

94 

95 If any predicate in `pred_fn_pairs` is a bool or has a constant value, the 

96 associated callable will be called or omitted depending on its value. 

97 Otherwise this functions like tf.case. 

98 

99 Args: 

100 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 

101 callable which returns a list of tensors. 

102 default: Optional callable that returns a list of tensors. 

103 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 

104 name: A name for this operation (optional). 

105 

106 Returns: 

107 The tensors returned by the first pair whose predicate evaluated to True, or 

108 those returned by `default` if none does. 

109 

110 Raises: 

111 TypeError: If `pred_fn_pairs` is not a list/dictionary. 

112 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 

113 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

114 callable. 

115 """ 

116 return control_flow_case._case_helper( # pylint: disable=protected-access 

117 smart_cond, 

118 pred_fn_pairs, 

119 default, 

120 exclusive, 

121 name, 

122 allow_python_preds=True)