1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An analysis that determines the reach of a function definition.
16
17A function definition is said to reach a statement if that function may exist
18(and therefore may be called) when that statement executes.
19"""
20
21import gast
22
23from tensorflow.python.autograph.pyct import anno
24from tensorflow.python.autograph.pyct import cfg
25from tensorflow.python.autograph.pyct import transformer
26
27
28class Definition(object):
29 """Definition objects describe a unique definition of a function."""
30
31 def __init__(self, def_node):
32 self.def_node = def_node
33
34
35class _NodeState(object):
36 """Abstraction for the state of the CFG walk for reaching definition analysis.
37
38 This is a value type. Only implements the strictly necessary operators.
39
40 Attributes:
41 value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
42 their possible definitions
43 """
44
45 def __init__(self, init_from=None):
46 if init_from:
47 self.value = set(init_from)
48 else:
49 self.value = set()
50
51 def __eq__(self, other):
52 return self.value == other.value
53
54 def __ne__(self, other):
55 return self.value != other.value
56
57 def __or__(self, other):
58 assert isinstance(other, _NodeState)
59 result = _NodeState(self.value)
60 result.value.update(other.value)
61 return result
62
63 def __add__(self, value):
64 result = _NodeState(self.value)
65 result.value.add(value)
66 return result
67
68 def __repr__(self):
69 return 'NodeState[%s]=%s' % (id(self), repr(self.value))
70
71
72class Analyzer(cfg.GraphVisitor):
73 """CFG visitor that determines reaching definitions at statement level."""
74
75 def __init__(self, graph, external_defs):
76 super(Analyzer, self).__init__(graph)
77 # This allows communicating that nodes have extra reaching definitions,
78 # e.g. those that a function closes over.
79 self.external_defs = external_defs
80
81 def init_state(self, _):
82 return _NodeState()
83
84 def visit_node(self, node):
85 prev_defs_out = self.out[node]
86
87 if node is self.graph.entry:
88 defs_in = _NodeState(self.external_defs)
89 else:
90 defs_in = prev_defs_out
91
92 for n in node.prev:
93 defs_in |= self.out[n]
94
95 defs_out = defs_in
96 if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
97 defs_out += node.ast_node
98
99 self.in_[node] = defs_in
100 self.out[node] = defs_out
101
102 return prev_defs_out != defs_out
103
104
105class TreeAnnotator(transformer.Base):
106 """AST visitor that annotates each symbol name with its reaching definitions.
107
108 Simultaneously, the visitor runs the dataflow analysis on each function node,
109 accounting for the effect of closures. For example:
110
111 def foo():
112 def f():
113 pass
114 def g():
115 # `def f` reaches here
116 """
117
118 def __init__(self, source_info, graphs):
119 super(TreeAnnotator, self).__init__(source_info)
120 self.graphs = graphs
121 self.allow_skips = False
122 self.current_analyzer = None
123
124 def _proces_function(self, node):
125 parent_analyzer = self.current_analyzer
126 subgraph = self.graphs[node]
127
128 if (self.current_analyzer is not None
129 and node in self.current_analyzer.graph.index):
130 cfg_node = self.current_analyzer.graph.index[node]
131 defined_in = self.current_analyzer.in_[cfg_node].value
132 else:
133 defined_in = ()
134
135 analyzer = Analyzer(subgraph, defined_in)
136 analyzer.visit_forward()
137
138 self.current_analyzer = analyzer
139 node = self.generic_visit(node)
140 self.current_analyzer = parent_analyzer
141 return node
142
143 def visit_FunctionDef(self, node):
144 return self._proces_function(node)
145
146 def visit_Lambda(self, node):
147 return self._proces_function(node)
148
149 def visit(self, node):
150 # This can happen before entering the top level function
151 if (self.current_analyzer is not None
152 and node in self.current_analyzer.graph.index):
153 cfg_node = self.current_analyzer.graph.index[node]
154 anno.setanno(node, anno.Static.DEFINED_FNS_IN,
155 self.current_analyzer.in_[cfg_node].value)
156
157 extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
158 if extra_node is not None:
159 cfg_node = self.current_analyzer.graph.index[extra_node]
160 anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
161 self.current_analyzer.in_[cfg_node].value)
162
163 return super(TreeAnnotator, self).visit(node)
164
165
166def resolve(node, source_info, graphs):
167 """Resolves reaching definitions for each symbol.
168
169 Args:
170 node: ast.AST
171 source_info: transformer.SourceInfo
172 graphs: Dict[ast.FunctionDef, cfg.Graph]
173 Returns:
174 ast.AST
175 """
176 visitor = TreeAnnotator(source_info, graphs)
177 node = visitor.visit(node)
178 return node