Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/logical_expressions.py: 30%
63 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
17import gast
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import parser
21from tensorflow.python.autograph.pyct import templates
23# TODO(mdan): Properly extract boolean ops according to lazy eval rules.
24# Note that this isn't completely safe either, because tensors may have control
25# dependencies.
26# Note that for loops that should be done after the loop was converted to
27# tf.while_loop so that the expanded conditionals are properly scoped.
29# Used to signal that an operand is safe for non-lazy evaluation.
30SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
33LOGICAL_OPERATORS = {
34 gast.And: 'ag__.and_',
35 gast.Not: 'ag__.not_',
36 gast.Or: 'ag__.or_',
37}
39EQUALITY_OPERATORS = {
40 gast.Eq: 'ag__.eq',
41 gast.NotEq: 'ag__.not_eq',
42}
45class LogicalExpressionTransformer(converter.Base):
46 """Converts logical expressions to corresponding TF calls."""
48 def _overload_of(self, operator):
49 op_type = type(operator)
50 if op_type in LOGICAL_OPERATORS:
51 return LOGICAL_OPERATORS[op_type]
52 if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
53 if op_type in EQUALITY_OPERATORS:
54 return EQUALITY_OPERATORS[op_type]
55 return None
57 def _as_lambda(self, expr):
58 return templates.replace_as_expression('lambda: expr', expr=expr)
60 def _as_binary_function(self, func_name, arg1, arg2):
61 return templates.replace_as_expression(
62 'func_name(arg1, arg2)',
63 func_name=parser.parse_expression(func_name),
64 arg1=arg1,
65 arg2=arg2)
67 def _as_binary_operation(self, op, arg1, arg2):
68 template = templates.replace_as_expression(
69 'arg1 is arg2', # Note: `is` will be replaced with `op` below.
70 arg1=arg1,
71 arg2=arg2)
72 template.ops[0] = op
73 return template
75 def _as_unary_function(self, func_name, arg):
76 return templates.replace_as_expression(
77 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
79 def _process_binop(self, op, left, right):
80 overload = self._overload_of(op)
81 if overload is None:
82 return self._as_binary_operation(op, left, right)
83 return self._as_binary_function(overload, left, right)
85 def visit_Compare(self, node):
86 node = self.generic_visit(node)
88 ops_and_comps = list(zip(node.ops, node.comparators))
89 left = node.left
91 # Repeated comparisons are converted to conjunctions:
92 # a < b < c -> a < b and b < c
93 op_tree = None
94 while ops_and_comps:
95 op, right = ops_and_comps.pop(0)
96 binary_comparison = self._process_binop(op, left, right)
97 if op_tree is not None:
98 op_tree = self._as_binary_function('ag__.and_',
99 self._as_lambda(op_tree),
100 self._as_lambda(binary_comparison))
101 else:
102 op_tree = binary_comparison
103 left = right
105 assert op_tree is not None
106 return op_tree
108 def visit_UnaryOp(self, node):
109 node = self.generic_visit(node)
111 overload = self._overload_of(node.op)
112 if overload is None:
113 return node
115 return self._as_unary_function(overload, node.operand)
117 def visit_BoolOp(self, node):
118 node = self.generic_visit(node)
119 node_values = node.values
120 right = node.values.pop()
121 while node_values:
122 left = node_values.pop()
123 right = self._as_binary_function(
124 self._overload_of(node.op), self._as_lambda(left),
125 self._as_lambda(right))
126 return right
129def transform(node, ctx):
130 transformer = LogicalExpressionTransformer(ctx)
131 return transformer.visit(node)