Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/static_analysis/reaching_fndefs.py: 30%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""An analysis that determines the reach of a function definition. 

16 

17A function definition is said to reach a statement if that function may exist 

18(and therefore may be called) when that statement executes. 

19""" 

20 

21import gast 

22 

23from tensorflow.python.autograph.pyct import anno 

24from tensorflow.python.autograph.pyct import cfg 

25from tensorflow.python.autograph.pyct import transformer 

26 

27 

28class Definition(object): 

29 """Definition objects describe a unique definition of a function.""" 

30 

31 def __init__(self, def_node): 

32 self.def_node = def_node 

33 

34 

35class _NodeState(object): 

36 """Abstraction for the state of the CFG walk for reaching definition analysis. 

37 

38 This is a value type. Only implements the strictly necessary operators. 

39 

40 Attributes: 

41 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 

42 their possible definitions 

43 """ 

44 

45 def __init__(self, init_from=None): 

46 if init_from: 

47 self.value = set(init_from) 

48 else: 

49 self.value = set() 

50 

51 def __eq__(self, other): 

52 return self.value == other.value 

53 

54 def __ne__(self, other): 

55 return self.value != other.value 

56 

57 def __or__(self, other): 

58 assert isinstance(other, _NodeState) 

59 result = _NodeState(self.value) 

60 result.value.update(other.value) 

61 return result 

62 

63 def __add__(self, value): 

64 result = _NodeState(self.value) 

65 result.value.add(value) 

66 return result 

67 

68 def __repr__(self): 

69 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 

70 

71 

72class Analyzer(cfg.GraphVisitor): 

73 """CFG visitor that determines reaching definitions at statement level.""" 

74 

75 def __init__(self, graph, external_defs): 

76 super(Analyzer, self).__init__(graph) 

77 # This allows communicating that nodes have extra reaching definitions, 

78 # e.g. those that a function closes over. 

79 self.external_defs = external_defs 

80 

81 def init_state(self, _): 

82 return _NodeState() 

83 

84 def visit_node(self, node): 

85 prev_defs_out = self.out[node] 

86 

87 if node is self.graph.entry: 

88 defs_in = _NodeState(self.external_defs) 

89 else: 

90 defs_in = prev_defs_out 

91 

92 for n in node.prev: 

93 defs_in |= self.out[n] 

94 

95 defs_out = defs_in 

96 if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)): 

97 defs_out += node.ast_node 

98 

99 self.in_[node] = defs_in 

100 self.out[node] = defs_out 

101 

102 return prev_defs_out != defs_out 

103 

104 

105class TreeAnnotator(transformer.Base): 

106 """AST visitor that annotates each symbol name with its reaching definitions. 

107 

108 Simultaneously, the visitor runs the dataflow analysis on each function node, 

109 accounting for the effect of closures. For example: 

110 

111 def foo(): 

112 def f(): 

113 pass 

114 def g(): 

115 # `def f` reaches here 

116 """ 

117 

118 def __init__(self, source_info, graphs): 

119 super(TreeAnnotator, self).__init__(source_info) 

120 self.graphs = graphs 

121 self.allow_skips = False 

122 self.current_analyzer = None 

123 

124 def _proces_function(self, node): 

125 parent_analyzer = self.current_analyzer 

126 subgraph = self.graphs[node] 

127 

128 if (self.current_analyzer is not None 

129 and node in self.current_analyzer.graph.index): 

130 cfg_node = self.current_analyzer.graph.index[node] 

131 defined_in = self.current_analyzer.in_[cfg_node].value 

132 else: 

133 defined_in = () 

134 

135 analyzer = Analyzer(subgraph, defined_in) 

136 analyzer.visit_forward() 

137 

138 self.current_analyzer = analyzer 

139 node = self.generic_visit(node) 

140 self.current_analyzer = parent_analyzer 

141 return node 

142 

143 def visit_FunctionDef(self, node): 

144 return self._proces_function(node) 

145 

146 def visit_Lambda(self, node): 

147 return self._proces_function(node) 

148 

149 def visit(self, node): 

150 # This can happen before entering the top level function 

151 if (self.current_analyzer is not None 

152 and node in self.current_analyzer.graph.index): 

153 cfg_node = self.current_analyzer.graph.index[node] 

154 anno.setanno(node, anno.Static.DEFINED_FNS_IN, 

155 self.current_analyzer.in_[cfg_node].value) 

156 

157 extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None) 

158 if extra_node is not None: 

159 cfg_node = self.current_analyzer.graph.index[extra_node] 

160 anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN, 

161 self.current_analyzer.in_[cfg_node].value) 

162 

163 return super(TreeAnnotator, self).visit(node) 

164 

165 

166def resolve(node, source_info, graphs): 

167 """Resolves reaching definitions for each symbol. 

168 

169 Args: 

170 node: ast.AST 

171 source_info: transformer.SourceInfo 

172 graphs: Dict[ast.FunctionDef, cfg.Graph] 

173 Returns: 

174 ast.AST 

175 """ 

176 visitor = TreeAnnotator(source_info, graphs) 

177 node = visitor.visit(node) 

178 return node