Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/ast_util.py: 17%

180 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""AST manipulation utilities.""" 

16 

17import ast 

18 

19import gast 

20 

21from tensorflow.python.autograph.pyct import anno 

22from tensorflow.python.autograph.pyct import parser 

23from tensorflow.python.autograph.pyct import qual_names 

24 

25 

26class CleanCopier(object): 

27 """NodeTransformer-like visitor that copies an AST.""" 

28 

29 def __init__(self, preserve_annos): 

30 super(CleanCopier, self).__init__() 

31 self.preserve_annos = preserve_annos 

32 

33 def copy(self, node): 

34 """Returns a deep copy of node (excluding some fields, see copy_clean).""" 

35 

36 if isinstance(node, list): 

37 return [self.copy(n) for n in node] 

38 elif isinstance(node, tuple): 

39 return tuple(self.copy(n) for n in node) 

40 elif not isinstance(node, (gast.AST, ast.AST)): 

41 # Assuming everything that's not an AST, list or tuple is a value type 

42 # and may simply be assigned. 

43 return node 

44 

45 assert isinstance(node, (gast.AST, ast.AST)) 

46 

47 new_fields = {} 

48 for f in node._fields: 

49 if not f.startswith('__') and hasattr(node, f): 

50 new_fields[f] = self.copy(getattr(node, f)) 

51 new_node = type(node)(**new_fields) 

52 

53 if self.preserve_annos: 

54 for k in self.preserve_annos: 

55 anno.copyanno(node, new_node, k) 

56 return new_node 

57 

58 

59def copy_clean(node, preserve_annos=None): 

60 """Creates a deep copy of an AST. 

61 

62 The copy will not include fields that are prefixed by '__', with the 

63 exception of user-specified annotations. 

64 

65 Args: 

66 node: ast.AST 

67 preserve_annos: Optional[Set[Hashable]], annotation keys to include in the 

68 copy 

69 Returns: 

70 ast.AST 

71 """ 

72 return CleanCopier(preserve_annos).copy(node) 

73 

74 

75class SymbolRenamer(gast.NodeTransformer): 

76 """Transformer that can rename symbols to a simple names.""" 

77 

78 def __init__(self, name_map): 

79 self.name_map = name_map 

80 

81 def _process_name_node(self, node): 

82 qn = anno.getanno(node, anno.Basic.QN) 

83 if qn in self.name_map: 

84 new_node = gast.Name( 

85 str(self.name_map[qn]), 

86 ctx=node.ctx, 

87 annotation=None, 

88 type_comment=None) 

89 # All annotations get carried over. 

90 for k in anno.keys(node): 

91 anno.copyanno(node, new_node, k) 

92 return new_node 

93 return self.generic_visit(node) 

94 

95 def _process_list_of_strings(self, names): 

96 for i in range(len(names)): 

97 qn = qual_names.QN(names[i]) 

98 if qn in self.name_map: 

99 names[i] = str(self.name_map[qn]) 

100 return names 

101 

102 def visit_Nonlocal(self, node): 

103 node.names = self._process_list_of_strings(node.names) 

104 return node 

105 

106 def visit_Global(self, node): 

107 node.names = self._process_list_of_strings(node.names) 

108 return node 

109 

110 def visit_Name(self, node): 

111 return self._process_name_node(node) 

112 

113 def visit_Attribute(self, node): 

114 if anno.hasanno(node, anno.Basic.QN): 

115 return self._process_name_node(node) 

116 # Renaming attributes is not supported. 

117 return self.generic_visit(node) 

118 

119 def visit_FunctionDef(self, node): 

120 qn = qual_names.QN(node.name) 

121 if qn in self.name_map: 

122 node.name = str(self.name_map[qn]) 

123 return self.generic_visit(node) 

124 

125 

126def rename_symbols(node, name_map): 

127 """Renames symbols in an AST. Requires qual_names annotations.""" 

128 renamer = SymbolRenamer(name_map) 

129 if isinstance(node, list): 

130 return [renamer.visit(n) for n in node] 

131 elif isinstance(node, tuple): 

132 return tuple(renamer.visit(n) for n in node) 

133 return renamer.visit(node) 

134 

135 

136def keywords_to_dict(keywords): 

137 """Converts a list of ast.keyword objects to a dict.""" 

138 keys = [] 

139 values = [] 

140 for kw in keywords: 

141 keys.append(gast.Constant(kw.arg, kind=None)) 

142 values.append(kw.value) 

143 return gast.Dict(keys=keys, values=values) 

144 

145 

146class PatternMatcher(gast.NodeVisitor): 

147 """Matches a node against a pattern represented by a node.""" 

148 

149 def __init__(self, pattern): 

150 self.pattern = pattern 

151 self.pattern_stack = [] 

152 self.matches = True 

153 

154 def compare_and_visit(self, node, pattern): 

155 self.pattern_stack.append(self.pattern) 

156 self.pattern = pattern 

157 self.generic_visit(node) 

158 self.pattern = self.pattern_stack.pop() 

159 

160 def no_match(self): 

161 self.matches = False 

162 return False 

163 

164 def is_wildcard(self, p): 

165 if isinstance(p, (list, tuple)) and len(p) == 1: 

166 p, = p 

167 if isinstance(p, gast.Name) and p.id == '_': 

168 return True 

169 if p == '_': 

170 return True 

171 return False 

172 

173 def generic_visit(self, node): 

174 if not self.matches: 

175 return 

176 

177 pattern = self.pattern 

178 for f in node._fields: 

179 if f.startswith('__'): 

180 continue 

181 

182 if not hasattr(node, f): 

183 if hasattr(pattern, f) and getattr(pattern, f): 

184 return self.no_match() 

185 else: 

186 continue 

187 if not hasattr(pattern, f): 

188 return self.no_match() 

189 

190 v = getattr(node, f) 

191 p = getattr(pattern, f) 

192 

193 if self.is_wildcard(p): 

194 continue 

195 if isinstance(v, (list, tuple)): 

196 if not isinstance(p, (list, tuple)) or len(v) != len(p): 

197 return self.no_match() 

198 for v_item, p_item in zip(v, p): 

199 self.compare_and_visit(v_item, p_item) 

200 elif isinstance(v, (gast.AST, ast.AST)): 

201 if not isinstance(v, type(p)) and not isinstance(p, type(v)): 

202 return self.no_match() 

203 self.compare_and_visit(v, p) 

204 else: 

205 # Assume everything else is a value type. 

206 if v != p: 

207 return self.no_match() 

208 

209 

210def matches(node, pattern): 

211 """Basic pattern matcher for AST. 

212 

213 The pattern may contain wildcards represented by the symbol '_'. A node 

214 matches a pattern if for every node in the tree, either there is a node of 

215 the same type in pattern, or a Name node with id='_'. 

216 

217 Args: 

218 node: ast.AST 

219 pattern: ast.AST 

220 Returns: 

221 bool 

222 """ 

223 if isinstance(pattern, str): 

224 pattern = parser.parse_str(pattern) 

225 

226 matcher = PatternMatcher(pattern) 

227 matcher.visit(node) 

228 return matcher.matches 

229 

230 

231# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. 

232def apply_to_single_assignments(targets, values, apply_fn): 

233 """Applies a function to each individual assignment. 

234 

235 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. 

236 It tries to break down the unpacking if possible. In effect, it has the same 

237 effect as passing the assigned values in SSA form to apply_fn. 

238 

239 Examples: 

240 

241 The following will result in apply_fn(a, c), apply_fn(b, d): 

242 

243 a, b = c, d 

244 

245 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): 

246 

247 a, b = c 

248 

249 The following will result in apply_fn(a, (b, c)): 

250 

251 a = b, c 

252 

253 It uses the visitor pattern to allow subclasses to process single 

254 assignments individually. 

255 

256 Args: 

257 targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be 

258 used with the targets field of an ast.Assign node 

259 values: ast.AST 

260 apply_fn: Callable[[ast.AST, ast.AST], None], called with the 

261 respective nodes of each single assignment 

262 """ 

263 if not isinstance(targets, (list, tuple)): 

264 targets = (targets,) 

265 for target in targets: 

266 if isinstance(target, (gast.Tuple, gast.List)): 

267 for i in range(len(target.elts)): 

268 target_el = target.elts[i] 

269 if isinstance(values, (gast.Tuple, gast.List)): 

270 value_el = values.elts[i] 

271 else: 

272 idx = parser.parse_expression(str(i)) 

273 value_el = gast.Subscript(values, idx, ctx=gast.Load()) 

274 apply_to_single_assignments(target_el, value_el, apply_fn) 

275 else: 

276 apply_fn(target, values) 

277 

278 

279def parallel_walk(node, other): 

280 """Walks two ASTs in parallel. 

281 

282 The two trees must have identical structure. 

283 

284 Args: 

285 node: Union[ast.AST, Iterable[ast.AST]] 

286 other: Union[ast.AST, Iterable[ast.AST]] 

287 Yields: 

288 Tuple[ast.AST, ast.AST] 

289 Raises: 

290 ValueError: if the two trees don't have identical structure. 

291 """ 

292 if isinstance(node, (list, tuple)): 

293 node_stack = list(node) 

294 else: 

295 node_stack = [node] 

296 

297 if isinstance(other, (list, tuple)): 

298 other_stack = list(other) 

299 else: 

300 other_stack = [other] 

301 

302 while node_stack and other_stack: 

303 assert len(node_stack) == len(other_stack) 

304 n = node_stack.pop() 

305 o = other_stack.pop() 

306 

307 if ((not isinstance(n, (ast.AST, gast.AST, str)) and n is not None) or 

308 (not isinstance(o, (ast.AST, gast.AST, str)) and n is not None) or 

309 n.__class__.__name__ != o.__class__.__name__): 

310 raise ValueError('inconsistent nodes: {} ({}) and {} ({})'.format( 

311 n, n.__class__.__name__, o, o.__class__.__name__)) 

312 

313 yield n, o 

314 

315 if isinstance(n, str): 

316 assert isinstance(o, str), 'The check above should have ensured this' 

317 continue 

318 if n is None: 

319 assert o is None, 'The check above should have ensured this' 

320 continue 

321 

322 for f in n._fields: 

323 n_child = getattr(n, f, None) 

324 o_child = getattr(o, f, None) 

325 if f.startswith('__') or n_child is None or o_child is None: 

326 continue 

327 

328 if isinstance(n_child, (list, tuple)): 

329 if (not isinstance(o_child, (list, tuple)) or 

330 len(n_child) != len(o_child)): 

331 raise ValueError( 

332 'inconsistent values for field {}: {} and {}'.format( 

333 f, n_child, o_child)) 

334 node_stack.extend(n_child) 

335 other_stack.extend(o_child) 

336 

337 elif isinstance(n_child, (gast.AST, ast.AST)): 

338 node_stack.append(n_child) 

339 other_stack.append(o_child) 

340 

341 elif n_child != o_child: 

342 raise ValueError( 

343 'inconsistent values for field {}: {} and {}'.format( 

344 f, n_child, o_child))