Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/ast_util.py: 17%
180 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST manipulation utilities."""
17import ast
19import gast
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import parser
23from tensorflow.python.autograph.pyct import qual_names
26class CleanCopier(object):
27 """NodeTransformer-like visitor that copies an AST."""
29 def __init__(self, preserve_annos):
30 super(CleanCopier, self).__init__()
31 self.preserve_annos = preserve_annos
33 def copy(self, node):
34 """Returns a deep copy of node (excluding some fields, see copy_clean)."""
36 if isinstance(node, list):
37 return [self.copy(n) for n in node]
38 elif isinstance(node, tuple):
39 return tuple(self.copy(n) for n in node)
40 elif not isinstance(node, (gast.AST, ast.AST)):
41 # Assuming everything that's not an AST, list or tuple is a value type
42 # and may simply be assigned.
43 return node
45 assert isinstance(node, (gast.AST, ast.AST))
47 new_fields = {}
48 for f in node._fields:
49 if not f.startswith('__') and hasattr(node, f):
50 new_fields[f] = self.copy(getattr(node, f))
51 new_node = type(node)(**new_fields)
53 if self.preserve_annos:
54 for k in self.preserve_annos:
55 anno.copyanno(node, new_node, k)
56 return new_node
59def copy_clean(node, preserve_annos=None):
60 """Creates a deep copy of an AST.
62 The copy will not include fields that are prefixed by '__', with the
63 exception of user-specified annotations.
65 Args:
66 node: ast.AST
67 preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
68 copy
69 Returns:
70 ast.AST
71 """
72 return CleanCopier(preserve_annos).copy(node)
75class SymbolRenamer(gast.NodeTransformer):
76 """Transformer that can rename symbols to a simple names."""
78 def __init__(self, name_map):
79 self.name_map = name_map
81 def _process_name_node(self, node):
82 qn = anno.getanno(node, anno.Basic.QN)
83 if qn in self.name_map:
84 new_node = gast.Name(
85 str(self.name_map[qn]),
86 ctx=node.ctx,
87 annotation=None,
88 type_comment=None)
89 # All annotations get carried over.
90 for k in anno.keys(node):
91 anno.copyanno(node, new_node, k)
92 return new_node
93 return self.generic_visit(node)
95 def _process_list_of_strings(self, names):
96 for i in range(len(names)):
97 qn = qual_names.QN(names[i])
98 if qn in self.name_map:
99 names[i] = str(self.name_map[qn])
100 return names
102 def visit_Nonlocal(self, node):
103 node.names = self._process_list_of_strings(node.names)
104 return node
106 def visit_Global(self, node):
107 node.names = self._process_list_of_strings(node.names)
108 return node
110 def visit_Name(self, node):
111 return self._process_name_node(node)
113 def visit_Attribute(self, node):
114 if anno.hasanno(node, anno.Basic.QN):
115 return self._process_name_node(node)
116 # Renaming attributes is not supported.
117 return self.generic_visit(node)
119 def visit_FunctionDef(self, node):
120 qn = qual_names.QN(node.name)
121 if qn in self.name_map:
122 node.name = str(self.name_map[qn])
123 return self.generic_visit(node)
126def rename_symbols(node, name_map):
127 """Renames symbols in an AST. Requires qual_names annotations."""
128 renamer = SymbolRenamer(name_map)
129 if isinstance(node, list):
130 return [renamer.visit(n) for n in node]
131 elif isinstance(node, tuple):
132 return tuple(renamer.visit(n) for n in node)
133 return renamer.visit(node)
136def keywords_to_dict(keywords):
137 """Converts a list of ast.keyword objects to a dict."""
138 keys = []
139 values = []
140 for kw in keywords:
141 keys.append(gast.Constant(kw.arg, kind=None))
142 values.append(kw.value)
143 return gast.Dict(keys=keys, values=values)
146class PatternMatcher(gast.NodeVisitor):
147 """Matches a node against a pattern represented by a node."""
149 def __init__(self, pattern):
150 self.pattern = pattern
151 self.pattern_stack = []
152 self.matches = True
154 def compare_and_visit(self, node, pattern):
155 self.pattern_stack.append(self.pattern)
156 self.pattern = pattern
157 self.generic_visit(node)
158 self.pattern = self.pattern_stack.pop()
160 def no_match(self):
161 self.matches = False
162 return False
164 def is_wildcard(self, p):
165 if isinstance(p, (list, tuple)) and len(p) == 1:
166 p, = p
167 if isinstance(p, gast.Name) and p.id == '_':
168 return True
169 if p == '_':
170 return True
171 return False
173 def generic_visit(self, node):
174 if not self.matches:
175 return
177 pattern = self.pattern
178 for f in node._fields:
179 if f.startswith('__'):
180 continue
182 if not hasattr(node, f):
183 if hasattr(pattern, f) and getattr(pattern, f):
184 return self.no_match()
185 else:
186 continue
187 if not hasattr(pattern, f):
188 return self.no_match()
190 v = getattr(node, f)
191 p = getattr(pattern, f)
193 if self.is_wildcard(p):
194 continue
195 if isinstance(v, (list, tuple)):
196 if not isinstance(p, (list, tuple)) or len(v) != len(p):
197 return self.no_match()
198 for v_item, p_item in zip(v, p):
199 self.compare_and_visit(v_item, p_item)
200 elif isinstance(v, (gast.AST, ast.AST)):
201 if not isinstance(v, type(p)) and not isinstance(p, type(v)):
202 return self.no_match()
203 self.compare_and_visit(v, p)
204 else:
205 # Assume everything else is a value type.
206 if v != p:
207 return self.no_match()
210def matches(node, pattern):
211 """Basic pattern matcher for AST.
213 The pattern may contain wildcards represented by the symbol '_'. A node
214 matches a pattern if for every node in the tree, either there is a node of
215 the same type in pattern, or a Name node with id='_'.
217 Args:
218 node: ast.AST
219 pattern: ast.AST
220 Returns:
221 bool
222 """
223 if isinstance(pattern, str):
224 pattern = parser.parse_str(pattern)
226 matcher = PatternMatcher(pattern)
227 matcher.visit(node)
228 return matcher.matches
231# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
232def apply_to_single_assignments(targets, values, apply_fn):
233 """Applies a function to each individual assignment.
235 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
236 It tries to break down the unpacking if possible. In effect, it has the same
237 effect as passing the assigned values in SSA form to apply_fn.
239 Examples:
241 The following will result in apply_fn(a, c), apply_fn(b, d):
243 a, b = c, d
245 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
247 a, b = c
249 The following will result in apply_fn(a, (b, c)):
251 a = b, c
253 It uses the visitor pattern to allow subclasses to process single
254 assignments individually.
256 Args:
257 targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
258 used with the targets field of an ast.Assign node
259 values: ast.AST
260 apply_fn: Callable[[ast.AST, ast.AST], None], called with the
261 respective nodes of each single assignment
262 """
263 if not isinstance(targets, (list, tuple)):
264 targets = (targets,)
265 for target in targets:
266 if isinstance(target, (gast.Tuple, gast.List)):
267 for i in range(len(target.elts)):
268 target_el = target.elts[i]
269 if isinstance(values, (gast.Tuple, gast.List)):
270 value_el = values.elts[i]
271 else:
272 idx = parser.parse_expression(str(i))
273 value_el = gast.Subscript(values, idx, ctx=gast.Load())
274 apply_to_single_assignments(target_el, value_el, apply_fn)
275 else:
276 apply_fn(target, values)
279def parallel_walk(node, other):
280 """Walks two ASTs in parallel.
282 The two trees must have identical structure.
284 Args:
285 node: Union[ast.AST, Iterable[ast.AST]]
286 other: Union[ast.AST, Iterable[ast.AST]]
287 Yields:
288 Tuple[ast.AST, ast.AST]
289 Raises:
290 ValueError: if the two trees don't have identical structure.
291 """
292 if isinstance(node, (list, tuple)):
293 node_stack = list(node)
294 else:
295 node_stack = [node]
297 if isinstance(other, (list, tuple)):
298 other_stack = list(other)
299 else:
300 other_stack = [other]
302 while node_stack and other_stack:
303 assert len(node_stack) == len(other_stack)
304 n = node_stack.pop()
305 o = other_stack.pop()
307 if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or
308 (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or
309 n.__class__.__name__ != o.__class__.__name__):
310 raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format(
311 n, n.__class__.__name__, o, o.__class__.__name__))
313 yield n, o
315 if isinstance(n, str):
316 assert isinstance(o, str), 'The check above should have ensured this'
317 continue
318 if n is None:
319 assert o is None, 'The check above should have ensured this'
320 continue
322 for f in n._fields:
323 n_child = getattr(n, f, None)
324 o_child = getattr(o, f, None)
325 if f.startswith('__') or n_child is None or o_child is None:
326 continue
328 if isinstance(n_child, (list, tuple)):
329 if (not isinstance(o_child, (list, tuple)) or
330 len(n_child) != len(o_child)):
331 raise ValueError(
332 'inconsistent values for field {}: {} and {}'.format(
333 f, n_child, o_child))
334 node_stack.extend(n_child)
335 other_stack.extend(o_child)
337 elif isinstance(n_child, (gast.AST, ast.AST)):
338 node_stack.append(n_child)
339 other_stack.append(o_child)
341 elif n_child != o_child:
342 raise ValueError(
343 'inconsistent values for field {}: {} and {}'.format(
344 f, n_child, o_child))