Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/break_statements.py: 21%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Lowers break statements to conditionals.""" 

16 

17from tensorflow.python.autograph.core import converter 

18from tensorflow.python.autograph.pyct import anno 

19from tensorflow.python.autograph.pyct import qual_names 

20from tensorflow.python.autograph.pyct import templates 

21from tensorflow.python.autograph.pyct.static_analysis import activity 

22from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 

23 

24 

25class _Break(object): 

26 

27 def __init__(self): 

28 self.used = False 

29 self.control_var_name = None 

30 

31 def __repr__(self): 

32 return 'used: %s, var: %s' % (self.used, self.control_var_name) 

33 

34 

35class BreakTransformer(converter.Base): 

36 """Canonicalizes break statements into additional conditionals.""" 

37 

38 def visit_Break(self, node): 

39 self.state[_Break].used = True 

40 var_name = self.state[_Break].control_var_name 

41 # TODO(mdan): This will fail when expanded inside a top-level else block. 

42 template = """ 

43 var_name = True 

44 continue 

45 """ 

46 return templates.replace(template, var_name=var_name) 

47 

48 def _guard_if_present(self, block, var_name): 

49 """Prevents the block from executing if var_name is set.""" 

50 if not block: 

51 return block 

52 

53 template = """ 

54 if not var_name: 

55 block 

56 """ 

57 node = templates.replace( 

58 template, 

59 var_name=var_name, 

60 block=block) 

61 return node 

62 

63 def _process_body(self, nodes, break_var): 

64 self.state[_Break].enter() 

65 self.state[_Break].control_var_name = break_var 

66 nodes = self.visit_block(nodes) 

67 break_used = self.state[_Break].used 

68 self.state[_Break].exit() 

69 return nodes, break_used 

70 

71 def visit_While(self, node): 

72 original_node = node 

73 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 

74 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 

75 

76 node.test = self.visit(node.test) 

77 node.body, break_used = self._process_body(node.body, break_var) 

78 # A break in the else clause applies to the containing scope. 

79 node.orelse = self.visit_block(node.orelse) 

80 

81 if not break_used: 

82 template = """ 

83 while test: 

84 body 

85 orelse 

86 """ 

87 node = templates.replace( 

88 template, test=node.test, body=node.body, orelse=node.orelse) 

89 

90 new_while_node = node[0] 

91 anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) 

92 

93 return node 

94 

95 # Python's else clause only triggers if the loop exited cleanly (e.g. 

96 # break did not trigger). 

97 guarded_orelse = self._guard_if_present(node.orelse, break_var) 

98 

99 template = """ 

100 var_name = False 

101 while not var_name and test: 

102 body 

103 orelse 

104 """ 

105 node = templates.replace( 

106 template, 

107 var_name=break_var, 

108 test=node.test, 

109 body=node.body, 

110 orelse=guarded_orelse) 

111 

112 new_while_node = node[1] 

113 anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) 

114 

115 return node 

116 

117 def visit_For(self, node): 

118 original_node = node 

119 scope = anno.getanno(node, NodeAnno.BODY_SCOPE) 

120 break_var = self.ctx.namer.new_symbol('break_', scope.referenced) 

121 

122 node.target = self.visit(node.target) 

123 node.iter = self.visit(node.iter) 

124 node.body, break_used = self._process_body(node.body, break_var) 

125 # A break in the else clause applies to the containing scope. 

126 node.orelse = self.visit_block(node.orelse) 

127 

128 if not break_used: 

129 template = """ 

130 for target in iter_: 

131 body 

132 orelse 

133 """ 

134 node = templates.replace( 

135 template, 

136 iter_=node.iter, 

137 target=node.target, 

138 body=node.body, 

139 orelse=node.orelse) 

140 

141 new_for_node = node[0] 

142 anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) 

143 anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) 

144 

145 return node 

146 

147 # Python's else clause only triggers if the loop exited cleanly (e.g. 

148 # break did not trigger). 

149 guarded_orelse = self._guard_if_present(node.orelse, break_var) 

150 extra_test = templates.replace_as_expression( 

151 'not var_name', var_name=break_var) 

152 

153 # The extra test is hidden in the AST, which will confuse the static 

154 # analysis. To mitigate that, we insert a no-op statement that ensures 

155 # the control variable is marked as used. 

156 # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) 

157 template = """ 

158 var_name = False 

159 for target in iter_: 

160 (var_name,) 

161 body 

162 orelse 

163 """ 

164 node = templates.replace( 

165 template, 

166 var_name=break_var, 

167 iter_=node.iter, 

168 target=node.target, 

169 body=node.body, 

170 orelse=guarded_orelse) 

171 

172 new_for_node = node[1] 

173 anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) 

174 anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) 

175 

176 return node 

177 

178 

179def transform(node, ctx): 

180 node = qual_names.resolve(node) 

181 node = activity.resolve(node, ctx, None) 

182 

183 transformer = BreakTransformer(ctx) 

184 node = transformer.visit(node) 

185 return node