Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py: 22%

148 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Reaching definition analysis. 

16 

17This analysis attaches a set of a Definition objects to each symbol, one 

18for each distinct definition that may reach it. The Definition objects are 

19mutable and may be used by subsequent analyses to further annotate data like 

20static type and value information. 

21The analysis also attaches the set of the symbols defined at the entry of 

22control flow statements. 

23 

24Requires activity analysis. 

25""" 

26 

27import weakref 

28 

29import gast 

30 

31from tensorflow.python.autograph.pyct import anno 

32from tensorflow.python.autograph.pyct import cfg 

33from tensorflow.python.autograph.pyct import transformer 

34 

35 

36class Definition(object): 

37 """Definition objects describe a unique definition of a variable. 

38 

39 Subclasses of this may be used by passing an appropriate factory function to 

40 resolve. 

41 

42 Attributes: 

43 param_of: Optional[ast.AST] 

44 directives: Dict, optional definition annotations 

45 """ 

46 

47 def __init__(self): 

48 self.param_of = None 

49 self.directives = {} 

50 

51 def __repr__(self): 

52 return '%s[%d]' % (self.__class__.__name__, id(self)) 

53 

54 

55class _NodeState(object): 

56 """Abstraction for the state of the CFG walk for reaching definition analysis. 

57 

58 This is a value type. Only implements the strictly necessary operators. 

59 

60 Attributes: 

61 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and 

62 their possible definitions 

63 """ 

64 

65 def __init__(self, init_from=None): 

66 if init_from: 

67 if isinstance(init_from, _NodeState): 

68 self.value = { 

69 s: set(other_infos) for s, other_infos in init_from.value.items() 

70 } 

71 elif isinstance(init_from, dict): 

72 self.value = {s: set((init_from[s],)) for s in init_from} 

73 else: 

74 assert False, init_from 

75 else: 

76 self.value = {} 

77 

78 def __eq__(self, other): 

79 if frozenset(self.value.keys()) != frozenset(other.value.keys()): 

80 return False 

81 ret = all(self.value[s] == other.value[s] for s in self.value) 

82 return ret 

83 

84 def __ne__(self, other): 

85 return not self.__eq__(other) 

86 

87 def __or__(self, other): 

88 assert isinstance(other, _NodeState) 

89 result = _NodeState(self) 

90 for s, other_infos in other.value.items(): 

91 if s in result.value: 

92 result.value[s].update(other_infos) 

93 else: 

94 result.value[s] = set(other_infos) 

95 return result 

96 

97 def __sub__(self, other): 

98 assert isinstance(other, set) 

99 result = _NodeState(self) 

100 for s in other: 

101 result.value.pop(s, None) 

102 return result 

103 

104 def __repr__(self): 

105 return 'NodeState[%s]=%s' % (id(self), repr(self.value)) 

106 

107 

108class Analyzer(cfg.GraphVisitor): 

109 """CFG visitor that determines reaching definitions at statement level.""" 

110 

111 def __init__(self, graph, definition_factory): 

112 self._definition_factory = definition_factory 

113 super(Analyzer, self).__init__(graph) 

114 self.gen_map = {} 

115 

116 def init_state(self, _): 

117 return _NodeState() 

118 

119 def visit_node(self, node): 

120 prev_defs_out = self.out[node] 

121 

122 defs_in = _NodeState() 

123 for n in node.prev: 

124 defs_in |= self.out[n] 

125 

126 if anno.hasanno(node.ast_node, anno.Static.SCOPE): 

127 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) 

128 # The definition objects created by each node must be singletons because 

129 # their ids are used in equality checks. 

130 if node not in self.gen_map: 

131 node_symbols = {} 

132 # Every binding operation (assign, nonlocal, global, etc.) counts as a 

133 # definition, with the exception of del, which only deletes without 

134 # creating a new variable. 

135 newly_defined = ((node_scope.bound | node_scope.globals) - 

136 node_scope.deleted) 

137 for s in newly_defined: 

138 def_ = self._definition_factory() 

139 node_symbols[s] = def_ 

140 # Every param receives a definition. Params are not necessarily 

141 # considered as "modified". 

142 for s, p in node_scope.params.items(): 

143 def_ = self._definition_factory() 

144 def_.param_of = weakref.ref(p) 

145 node_symbols[s] = def_ 

146 self.gen_map[node] = _NodeState(node_symbols) 

147 

148 gen = self.gen_map[node] 

149 kill = node_scope.modified | node_scope.deleted 

150 defs_out = gen | (defs_in - kill) 

151 

152 gen = self.gen_map[node] 

153 defs_out = gen | (defs_in - kill) 

154 

155 else: 

156 assert self.can_ignore(node), (node.ast_node, node) 

157 defs_out = defs_in 

158 

159 self.in_[node] = defs_in 

160 self.out[node] = defs_out 

161 

162 return prev_defs_out != defs_out 

163 

164 

165class TreeAnnotator(transformer.Base): 

166 """AST visitor that annotates each symbol name with its reaching definitions. 

167 

168 Simultaneously, the visitor runs the dataflow analysis on each function node, 

169 accounting for the effect of closures. For example: 

170 

171 def foo(): 

172 bar = 1 

173 def baz(): 

174 # bar = 1 reaches here 

175 """ 

176 

177 def __init__(self, source_info, graphs, definition_factory): 

178 super(TreeAnnotator, self).__init__(source_info) 

179 self.allow_skips = False 

180 self.definition_factory = definition_factory 

181 self.graphs = graphs 

182 self.current_analyzer = None 

183 self.current_cfg_node = None 

184 

185 def visit_FunctionDef(self, node): 

186 parent_analyzer = self.current_analyzer 

187 subgraph = self.graphs[node] 

188 

189 analyzer = Analyzer(subgraph, self.definition_factory) 

190 analyzer.visit_forward() 

191 

192 # Recursively process any remaining subfunctions. 

193 self.current_analyzer = analyzer 

194 node.args = self.visit(node.args) 

195 node.body = self.visit_block(node.body) 

196 self.current_analyzer = parent_analyzer 

197 

198 return node 

199 

200 def visit_Name(self, node): 

201 if self.current_analyzer is None: 

202 # Names may appear outside function defs - for example in class 

203 # definitions. 

204 return node 

205 

206 analyzer = self.current_analyzer 

207 cfg_node = self.current_cfg_node 

208 

209 assert cfg_node is not None, ('name node, %s, outside of any statement?' 

210 % node.id) 

211 

212 qn = anno.getanno(node, anno.Basic.QN) 

213 if isinstance(node.ctx, gast.Load): 

214 anno.setanno(node, anno.Static.DEFINITIONS, 

215 tuple(analyzer.in_[cfg_node].value.get(qn, ()))) 

216 else: 

217 anno.setanno(node, anno.Static.DEFINITIONS, 

218 tuple(analyzer.out[cfg_node].value.get(qn, ()))) 

219 

220 return node 

221 

222 def _aggregate_predecessors_defined_in(self, node): 

223 preds = self.current_analyzer.graph.stmt_prev[node] 

224 node_defined_in = set() 

225 for p in preds: 

226 node_defined_in |= set(self.current_analyzer.out[p].value.keys()) 

227 anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) 

228 

229 def visit_If(self, node): 

230 self._aggregate_predecessors_defined_in(node) 

231 return self.generic_visit(node) 

232 

233 def visit_For(self, node): 

234 self._aggregate_predecessors_defined_in(node) 

235 

236 # Manually accounting for the shortcoming described in 

237 # cfg.AstToCfg.visit_For. 

238 parent = self.current_cfg_node 

239 self.current_cfg_node = self.current_analyzer.graph.index[node.iter] 

240 node.target = self.visit(node.target) 

241 self.current_cfg_node = parent 

242 

243 node.iter = self.visit(node.iter) 

244 node.body = self.visit_block(node.body) 

245 node.orelse = self.visit_block(node.orelse) 

246 

247 return node 

248 

249 def visit_While(self, node): 

250 self._aggregate_predecessors_defined_in(node) 

251 return self.generic_visit(node) 

252 

253 def visit_Try(self, node): 

254 self._aggregate_predecessors_defined_in(node) 

255 return self.generic_visit(node) 

256 

257 def visit_ExceptHandler(self, node): 

258 self._aggregate_predecessors_defined_in(node) 

259 # TODO(mdan): Also track the exception type / name symbols. 

260 node.body = self.visit_block(node.body) 

261 return node 

262 

263 def visit(self, node): 

264 parent = self.current_cfg_node 

265 

266 if (self.current_analyzer is not None and 

267 node in self.current_analyzer.graph.index): 

268 self.current_cfg_node = self.current_analyzer.graph.index[node] 

269 node = super(TreeAnnotator, self).visit(node) 

270 

271 self.current_cfg_node = parent 

272 return node 

273 

274 

275def resolve(node, source_info, graphs, definition_factory=Definition): 

276 """Resolves reaching definitions for each symbol. 

277 

278 Args: 

279 node: ast.AST 

280 source_info: transformer.SourceInfo 

281 graphs: Dict[ast.FunctionDef, cfg.Graph] 

282 definition_factory: Callable[[], Definition] 

283 Returns: 

284 ast.AST 

285 """ 

286 visitor = TreeAnnotator(source_info, graphs, definition_factory) 

287 node = visitor.visit(node) 

288 return node