Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/templates.py: 22%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""AST conversion templates. 

16 

17Adapted from Tangent. 

18""" 

19 

20import ast 

21import textwrap 

22 

23import gast 

24 

25from tensorflow.python.autograph.pyct import anno 

26from tensorflow.python.autograph.pyct import ast_util 

27from tensorflow.python.autograph.pyct import parser 

28from tensorflow.python.autograph.pyct import qual_names 

29 

30 

31class ContextAdjuster(gast.NodeTransformer): 

32 """Adjusts the ctx field of nodes to ensure consistency. 

33 

34 This transformer can change the ctx fields of a variable, tuple and other 

35 AST elements that allow one, based on whether the element is being read or 

36 written. 

37 """ 

38 

39 def __init__(self, override_value): 

40 self._ctx_override = override_value 

41 

42 def visit(self, node): 

43 original_override = self._ctx_override 

44 node = super(ContextAdjuster, self).visit(node) 

45 if hasattr(node, 'ctx'): 

46 assert node.ctx is not None, 'node {} has ctx unset'.format(node) 

47 self._ctx_override = original_override 

48 return node 

49 

50 def _apply_override(self, node): 

51 if self._ctx_override is not None: 

52 node.ctx = self._ctx_override() 

53 

54 def visit_Attribute(self, node): 

55 self._apply_override(node) 

56 self._ctx_override = gast.Load 

57 node = self.generic_visit(node) 

58 return node 

59 

60 def visit_Tuple(self, node): 

61 self._apply_override(node) 

62 return self.generic_visit(node) 

63 

64 def visit_List(self, node): 

65 self._apply_override(node) 

66 return self.generic_visit(node) 

67 

68 def visit_Name(self, node): 

69 self._apply_override(node) 

70 return self.generic_visit(node) 

71 

72 def visit_Call(self, node): 

73 self._apply_override(node) 

74 # We may be able to override these to Load(), but for now it's simpler 

75 # to just assert that they're set. 

76 self._ctx_override = None 

77 return self.generic_visit(node) 

78 

79 def visit_Dict(self, node): 

80 # We may be able to override these to Load(), but for now it's simpler 

81 # to just assert that they're set. 

82 self._ctx_override = None 

83 return self.generic_visit(node) 

84 

85 def visit_Subscript(self, node): 

86 self._apply_override(node) 

87 self._ctx_override = gast.Load 

88 node.value = self.visit(node.value) 

89 return self.generic_visit(node) 

90 

91 def visit_comprehension(self, node): 

92 # We may be able to override some of these, but for now it's simpler 

93 # to just assert that they're set. 

94 self._ctx_override = None 

95 return self.generic_visit(node) 

96 

97 def visit_Lambda(self, node): 

98 # We may be able to override some of these, but for now it's simpler 

99 # to just assert that they're set. 

100 self._ctx_override = None 

101 return self.generic_visit(node) 

102 

103 

104class ReplaceTransformer(gast.NodeTransformer): 

105 """Replace AST nodes.""" 

106 

107 def __init__(self, replacements): 

108 """Create a new ReplaceTransformer. 

109 

110 Args: 

111 replacements: A mapping from placeholder names to (lists of) AST nodes 

112 that these placeholders will be replaced by. 

113 """ 

114 self.replacements = replacements 

115 self.in_replacements = False 

116 self.preserved_annos = { 

117 anno.Basic.DIRECTIVES, 

118 anno.Basic.EXTRA_LOOP_TEST, 

119 anno.Basic.ORIGIN, 

120 anno.Basic.SKIP_PROCESSING, 

121 anno.Static.ORIG_DEFINITIONS, 

122 'function_context_name', 

123 } 

124 

125 def _prepare_replacement(self, replaced, key): 

126 """Prepares a replacement AST that's safe to swap in for a node. 

127 

128 Args: 

129 replaced: ast.AST, the node being replaced 

130 key: Hashable, the key of the replacement AST 

131 Returns: 

132 ast.AST, the replacement AST 

133 """ 

134 repl = self.replacements[key] 

135 

136 new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) 

137 if isinstance(new_nodes, gast.AST): 

138 new_nodes = [new_nodes] 

139 

140 return new_nodes 

141 

142 def visit_Expr(self, node): 

143 # When replacing a placeholder with an entire statement, the replacement 

144 # must stand on its own and not be wrapped in an Expr. 

145 new_value = self.visit(node.value) 

146 if new_value is node.value: 

147 return node 

148 return new_value 

149 

150 def visit_keyword(self, node): 

151 if node.arg not in self.replacements: 

152 return self.generic_visit(node) 

153 

154 repl = self._prepare_replacement(node, node.arg) 

155 if isinstance(repl, gast.keyword): 

156 return repl 

157 elif (repl and isinstance(repl, (list, tuple)) and 

158 all(isinstance(r, gast.keyword) for r in repl)): 

159 return repl 

160 # TODO(mdan): We may allow replacing with a string as well. 

161 # For example, if one wanted to replace foo with bar in foo=baz, then 

162 # we could allow changing just node arg, so that we end up with bar=baz. 

163 raise ValueError( 

164 'a keyword argument may only be replaced by another keyword or a ' 

165 'non-empty list of keywords. Found: {} for keyword {}'.format( 

166 repl, node.arg)) 

167 

168 def visit_FunctionDef(self, node): 

169 node = self.generic_visit(node) 

170 if node.name not in self.replacements: 

171 return node 

172 

173 repl = self.replacements[node.name] 

174 if not isinstance(repl, (gast.Name, ast.Name)): 

175 raise ValueError( 

176 'a function name can only be replaced by a Name node. Found: %s' % 

177 repl) 

178 node.name = repl.id 

179 return node 

180 

181 def visit_Attribute(self, node): 

182 node = self.generic_visit(node) 

183 if node.attr not in self.replacements: 

184 return node 

185 

186 repl = self.replacements[node.attr] 

187 if not isinstance(repl, gast.Name): 

188 raise ValueError( 

189 'An attribute can only be replaced by a Name node. Found: %s' % repl) 

190 node.attr = repl.id 

191 return node 

192 

193 def visit_Name(self, node): 

194 if node.id not in self.replacements: 

195 return node 

196 

197 new_nodes = self._prepare_replacement(node, node.id) 

198 

199 if not new_nodes: 

200 return new_nodes 

201 

202 # Preserve the target context. 

203 adjuster = ContextAdjuster(type(node.ctx)) 

204 for n in new_nodes: 

205 if hasattr(n, 'ctx'): 

206 adjuster.visit(n) 

207 

208 if len(new_nodes) == 1: 

209 new_nodes, = new_nodes 

210 

211 return new_nodes 

212 

213 

214def _convert_to_ast(n): 

215 """Converts from a known data type to AST.""" 

216 # Note: When generating AST nodes from strings/QNs in isolation, ctx is 

217 # unknown. ctx must be filled in according to the template being used. 

218 # See ReplaceTransformer.visit_Name. 

219 if isinstance(n, str): 

220 return gast.Name(id=n, ctx=None, annotation=None, type_comment=None) 

221 if isinstance(n, qual_names.QN): 

222 return n.ast() 

223 if isinstance(n, list): 

224 return [_convert_to_ast(e) for e in n] 

225 if isinstance(n, tuple): 

226 return tuple(_convert_to_ast(e) for e in n) 

227 return n 

228 

229 

230def replace(template, **replacements): 

231 """Replaces placeholders in a Python template. 

232 

233 AST Name and Tuple nodes always receive the context that inferred from 

234 the template. However, when replacing more complex nodes (that can potentially 

235 contain Name children), then the caller is responsible for setting the 

236 appropriate context. 

237 

238 Args: 

239 template: A string representing Python code. Any symbol name can be used 

240 that appears in the template code can be used as placeholder. 

241 **replacements: A mapping from placeholder names to (lists of) AST nodes 

242 that these placeholders will be replaced by. String values are also 

243 supported as a shorthand for AST Name nodes with the respective ID. 

244 

245 Returns: 

246 An AST node or list of AST nodes with the replacements made. If the 

247 template was a function, a list will be returned. If the template was a 

248 node, the same node will be returned. If the template was a string, an 

249 AST node will be returned (a `Module` node in the case of a multi-line 

250 string, an `Expr` node otherwise). 

251 

252 Raises: 

253 ValueError: if the arguments are incorrect. 

254 """ 

255 if not isinstance(template, str): 

256 raise ValueError('Expected string template, got %s' % type(template)) 

257 for k in replacements: 

258 replacements[k] = _convert_to_ast(replacements[k]) 

259 template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template) 

260 nodes = parser.parse( 

261 template_str, 

262 preamble_len=parser.STANDARD_PREAMBLE_LEN, 

263 single_node=False) 

264 results = [] 

265 for node in nodes: 

266 node = ReplaceTransformer(replacements).visit(node) 

267 if isinstance(node, (list, tuple)): 

268 results.extend(node) 

269 else: 

270 results.append(node) 

271 results = [qual_names.resolve(r) for r in results] 

272 return results 

273 

274 

275def replace_as_expression(template, **replacements): 

276 """Variant of replace that generates expressions, instead of code blocks.""" 

277 replacement = replace(template, **replacements) 

278 if len(replacement) != 1: 

279 raise ValueError( 

280 'single expression expected; for more general templates use replace') 

281 node, = replacement 

282 

283 if isinstance(node, gast.Expr): 

284 return node.value 

285 elif isinstance(node, gast.Name): 

286 return node 

287 

288 raise ValueError( 

289 'the template is expected to generate an expression or a name node;' 

290 ' instead found %s' % node)