1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Reaching definition analysis.
16
17This analysis attaches a set of a Definition objects to each symbol, one
18for each distinct definition that may reach it. The Definition objects are
19mutable and may be used by subsequent analyses to further annotate data like
20static type and value information.
21The analysis also attaches the set of the symbols defined at the entry of
22control flow statements.
23
24Requires activity analysis.
25"""
26
27import weakref
28
29import gast
30
31from tensorflow.python.autograph.pyct import anno
32from tensorflow.python.autograph.pyct import cfg
33from tensorflow.python.autograph.pyct import transformer
34
35
36class Definition(object):
37 """Definition objects describe a unique definition of a variable.
38
39 Subclasses of this may be used by passing an appropriate factory function to
40 resolve.
41
42 Attributes:
43 param_of: Optional[ast.AST]
44 directives: Dict, optional definition annotations
45 """
46
47 def __init__(self):
48 self.param_of = None
49 self.directives = {}
50
51 def __repr__(self):
52 return '%s[%d]' % (self.__class__.__name__, id(self))
53
54
55class _NodeState(object):
56 """Abstraction for the state of the CFG walk for reaching definition analysis.
57
58 This is a value type. Only implements the strictly necessary operators.
59
60 Attributes:
61 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
62 their possible definitions
63 """
64
65 def __init__(self, init_from=None):
66 if init_from:
67 if isinstance(init_from, _NodeState):
68 self.value = {
69 s: set(other_infos) for s, other_infos in init_from.value.items()
70 }
71 elif isinstance(init_from, dict):
72 self.value = {s: set((init_from[s],)) for s in init_from}
73 else:
74 assert False, init_from
75 else:
76 self.value = {}
77
78 def __eq__(self, other):
79 if frozenset(self.value.keys()) != frozenset(other.value.keys()):
80 return False
81 ret = all(self.value[s] == other.value[s] for s in self.value)
82 return ret
83
84 def __ne__(self, other):
85 return not self.__eq__(other)
86
87 def __or__(self, other):
88 assert isinstance(other, _NodeState)
89 result = _NodeState(self)
90 for s, other_infos in other.value.items():
91 if s in result.value:
92 result.value[s].update(other_infos)
93 else:
94 result.value[s] = set(other_infos)
95 return result
96
97 def __sub__(self, other):
98 assert isinstance(other, set)
99 result = _NodeState(self)
100 for s in other:
101 result.value.pop(s, None)
102 return result
103
104 def __repr__(self):
105 return 'NodeState[%s]=%s' % (id(self), repr(self.value))
106
107
108class Analyzer(cfg.GraphVisitor):
109 """CFG visitor that determines reaching definitions at statement level."""
110
111 def __init__(self, graph, definition_factory):
112 self._definition_factory = definition_factory
113 super(Analyzer, self).__init__(graph)
114 self.gen_map = {}
115
116 def init_state(self, _):
117 return _NodeState()
118
119 def visit_node(self, node):
120 prev_defs_out = self.out[node]
121
122 defs_in = _NodeState()
123 for n in node.prev:
124 defs_in |= self.out[n]
125
126 if anno.hasanno(node.ast_node, anno.Static.SCOPE):
127 node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
128 # The definition objects created by each node must be singletons because
129 # their ids are used in equality checks.
130 if node not in self.gen_map:
131 node_symbols = {}
132 # Every binding operation (assign, nonlocal, global, etc.) counts as a
133 # definition, with the exception of del, which only deletes without
134 # creating a new variable.
135 newly_defined = ((node_scope.bound | node_scope.globals) -
136 node_scope.deleted)
137 for s in newly_defined:
138 def_ = self._definition_factory()
139 node_symbols[s] = def_
140 # Every param receives a definition. Params are not necessarily
141 # considered as "modified".
142 for s, p in node_scope.params.items():
143 def_ = self._definition_factory()
144 def_.param_of = weakref.ref(p)
145 node_symbols[s] = def_
146 self.gen_map[node] = _NodeState(node_symbols)
147
148 gen = self.gen_map[node]
149 kill = node_scope.modified | node_scope.deleted
150 defs_out = gen | (defs_in - kill)
151
152 gen = self.gen_map[node]
153 defs_out = gen | (defs_in - kill)
154
155 else:
156 assert self.can_ignore(node), (node.ast_node, node)
157 defs_out = defs_in
158
159 self.in_[node] = defs_in
160 self.out[node] = defs_out
161
162 return prev_defs_out != defs_out
163
164
165class TreeAnnotator(transformer.Base):
166 """AST visitor that annotates each symbol name with its reaching definitions.
167
168 Simultaneously, the visitor runs the dataflow analysis on each function node,
169 accounting for the effect of closures. For example:
170
171 def foo():
172 bar = 1
173 def baz():
174 # bar = 1 reaches here
175 """
176
177 def __init__(self, source_info, graphs, definition_factory):
178 super(TreeAnnotator, self).__init__(source_info)
179 self.allow_skips = False
180 self.definition_factory = definition_factory
181 self.graphs = graphs
182 self.current_analyzer = None
183 self.current_cfg_node = None
184
185 def visit_FunctionDef(self, node):
186 parent_analyzer = self.current_analyzer
187 subgraph = self.graphs[node]
188
189 analyzer = Analyzer(subgraph, self.definition_factory)
190 analyzer.visit_forward()
191
192 # Recursively process any remaining subfunctions.
193 self.current_analyzer = analyzer
194 node.args = self.visit(node.args)
195 node.body = self.visit_block(node.body)
196 self.current_analyzer = parent_analyzer
197
198 return node
199
200 def visit_Name(self, node):
201 if self.current_analyzer is None:
202 # Names may appear outside function defs - for example in class
203 # definitions.
204 return node
205
206 analyzer = self.current_analyzer
207 cfg_node = self.current_cfg_node
208
209 assert cfg_node is not None, ('name node, %s, outside of any statement?'
210 % node.id)
211
212 qn = anno.getanno(node, anno.Basic.QN)
213 if isinstance(node.ctx, gast.Load):
214 anno.setanno(node, anno.Static.DEFINITIONS,
215 tuple(analyzer.in_[cfg_node].value.get(qn, ())))
216 else:
217 anno.setanno(node, anno.Static.DEFINITIONS,
218 tuple(analyzer.out[cfg_node].value.get(qn, ())))
219
220 return node
221
222 def _aggregate_predecessors_defined_in(self, node):
223 preds = self.current_analyzer.graph.stmt_prev[node]
224 node_defined_in = set()
225 for p in preds:
226 node_defined_in |= set(self.current_analyzer.out[p].value.keys())
227 anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
228
229 def visit_If(self, node):
230 self._aggregate_predecessors_defined_in(node)
231 return self.generic_visit(node)
232
233 def visit_For(self, node):
234 self._aggregate_predecessors_defined_in(node)
235
236 # Manually accounting for the shortcoming described in
237 # cfg.AstToCfg.visit_For.
238 parent = self.current_cfg_node
239 self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
240 node.target = self.visit(node.target)
241 self.current_cfg_node = parent
242
243 node.iter = self.visit(node.iter)
244 node.body = self.visit_block(node.body)
245 node.orelse = self.visit_block(node.orelse)
246
247 return node
248
249 def visit_While(self, node):
250 self._aggregate_predecessors_defined_in(node)
251 return self.generic_visit(node)
252
253 def visit_Try(self, node):
254 self._aggregate_predecessors_defined_in(node)
255 return self.generic_visit(node)
256
257 def visit_ExceptHandler(self, node):
258 self._aggregate_predecessors_defined_in(node)
259 # TODO(mdan): Also track the exception type / name symbols.
260 node.body = self.visit_block(node.body)
261 return node
262
263 def visit(self, node):
264 parent = self.current_cfg_node
265
266 if (self.current_analyzer is not None and
267 node in self.current_analyzer.graph.index):
268 self.current_cfg_node = self.current_analyzer.graph.index[node]
269 node = super(TreeAnnotator, self).visit(node)
270
271 self.current_cfg_node = parent
272 return node
273
274
275def resolve(node, source_info, graphs, definition_factory=Definition):
276 """Resolves reaching definitions for each symbol.
277
278 Args:
279 node: ast.AST
280 source_info: transformer.SourceInfo
281 graphs: Dict[ast.FunctionDef, cfg.Graph]
282 definition_factory: Callable[[], Definition]
283 Returns:
284 ast.AST
285 """
286 visitor = TreeAnnotator(source_info, graphs, definition_factory)
287 node = visitor.visit(node)
288 return node