Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/logical_expressions.py: 30%

63 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`.""" 

16 

17import gast 

18 

19from tensorflow.python.autograph.core import converter 

20from tensorflow.python.autograph.pyct import parser 

21from tensorflow.python.autograph.pyct import templates 

22 

23# TODO(mdan): Properly extract boolean ops according to lazy eval rules. 

24# Note that this isn't completely safe either, because tensors may have control 

25# dependencies. 

26# Note that for loops that should be done after the loop was converted to 

27# tf.while_loop so that the expanded conditionals are properly scoped. 

28 

29# Used to signal that an operand is safe for non-lazy evaluation. 

30SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' 

31 

32 

33LOGICAL_OPERATORS = { 

34 gast.And: 'ag__.and_', 

35 gast.Not: 'ag__.not_', 

36 gast.Or: 'ag__.or_', 

37} 

38 

39EQUALITY_OPERATORS = { 

40 gast.Eq: 'ag__.eq', 

41 gast.NotEq: 'ag__.not_eq', 

42} 

43 

44 

45class LogicalExpressionTransformer(converter.Base): 

46 """Converts logical expressions to corresponding TF calls.""" 

47 

48 def _overload_of(self, operator): 

49 op_type = type(operator) 

50 if op_type in LOGICAL_OPERATORS: 

51 return LOGICAL_OPERATORS[op_type] 

52 if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS): 

53 if op_type in EQUALITY_OPERATORS: 

54 return EQUALITY_OPERATORS[op_type] 

55 return None 

56 

57 def _as_lambda(self, expr): 

58 return templates.replace_as_expression('lambda: expr', expr=expr) 

59 

60 def _as_binary_function(self, func_name, arg1, arg2): 

61 return templates.replace_as_expression( 

62 'func_name(arg1, arg2)', 

63 func_name=parser.parse_expression(func_name), 

64 arg1=arg1, 

65 arg2=arg2) 

66 

67 def _as_binary_operation(self, op, arg1, arg2): 

68 template = templates.replace_as_expression( 

69 'arg1 is arg2', # Note: `is` will be replaced with `op` below. 

70 arg1=arg1, 

71 arg2=arg2) 

72 template.ops[0] = op 

73 return template 

74 

75 def _as_unary_function(self, func_name, arg): 

76 return templates.replace_as_expression( 

77 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg) 

78 

79 def _process_binop(self, op, left, right): 

80 overload = self._overload_of(op) 

81 if overload is None: 

82 return self._as_binary_operation(op, left, right) 

83 return self._as_binary_function(overload, left, right) 

84 

85 def visit_Compare(self, node): 

86 node = self.generic_visit(node) 

87 

88 ops_and_comps = list(zip(node.ops, node.comparators)) 

89 left = node.left 

90 

91 # Repeated comparisons are converted to conjunctions: 

92 # a < b < c -> a < b and b < c 

93 op_tree = None 

94 while ops_and_comps: 

95 op, right = ops_and_comps.pop(0) 

96 binary_comparison = self._process_binop(op, left, right) 

97 if op_tree is not None: 

98 op_tree = self._as_binary_function('ag__.and_', 

99 self._as_lambda(op_tree), 

100 self._as_lambda(binary_comparison)) 

101 else: 

102 op_tree = binary_comparison 

103 left = right 

104 

105 assert op_tree is not None 

106 return op_tree 

107 

108 def visit_UnaryOp(self, node): 

109 node = self.generic_visit(node) 

110 

111 overload = self._overload_of(node.op) 

112 if overload is None: 

113 return node 

114 

115 return self._as_unary_function(overload, node.operand) 

116 

117 def visit_BoolOp(self, node): 

118 node = self.generic_visit(node) 

119 node_values = node.values 

120 right = node.values.pop() 

121 while node_values: 

122 left = node_values.pop() 

123 right = self._as_binary_function( 

124 self._overload_of(node.op), self._as_lambda(left), 

125 self._as_lambda(right)) 

126 return right 

127 

128 

129def transform(node, ctx): 

130 transformer = LogicalExpressionTransformer(ctx) 

131 return transformer.visit(node)