Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/parser.py: 17%

192 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Converting code to AST. 

16 

17Adapted from Tangent. 

18""" 

19 

20import ast 

21import inspect 

22import io 

23import linecache 

24import re 

25import sys 

26import textwrap 

27import tokenize 

28 

29import astunparse 

30import gast 

31 

32from tensorflow.python.autograph.pyct import errors 

33from tensorflow.python.autograph.pyct import inspect_utils 

34from tensorflow.python.util import tf_inspect 

35 

36 

37PY2_PREAMBLE = textwrap.dedent(""" 

38""") 

39PY3_PREAMBLE = '' 

40MAX_SIZE = 0 

41 

42if sys.version_info >= (3, 9): 

43 astunparse = ast 

44 

45if sys.version_info >= (3,): 

46 STANDARD_PREAMBLE = PY3_PREAMBLE 

47 MAX_SIZE = sys.maxsize 

48else: 

49 STANDARD_PREAMBLE = PY2_PREAMBLE 

50 MAX_SIZE = sys.maxint 

51 

52STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') 

53 

54 

55_LEADING_WHITESPACE = re.compile(r'\s*') 

56 

57 

58def _unfold_continuations(code_string): 

59 """Removes any backslash line continuations from the code.""" 

60 return code_string.replace('\\\n', '') 

61 

62 

63def dedent_block(code_string): 

64 """Dedents a code so that its first line starts at row zero.""" 

65 

66 code_string = _unfold_continuations(code_string) 

67 

68 token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline) 

69 

70 block_indentation = None 

71 tokens = [] 

72 try: 

73 for tok in token_gen: 

74 tokens.append(tok) 

75 except tokenize.TokenError: 

76 # Resolution of lambda functions may yield incomplete code, which can 

77 # in turn generate this error. We silently ignore this error because the 

78 # parser may still be able to deal with it. 

79 pass 

80 

81 for tok in tokens: 

82 tok_type, tok_string, _, _, _ = tok 

83 if tok_type == tokenize.INDENT: 

84 block_indentation = tok_string 

85 block_level = len(block_indentation) 

86 break 

87 elif tok_type not in ( 

88 tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): 

89 block_indentation = '' 

90 break 

91 

92 if not block_indentation: 

93 return code_string 

94 

95 block_level = len(block_indentation) 

96 first_indent_uses_tabs = '\t' in block_indentation 

97 for i, tok in enumerate(tokens): 

98 tok_type, tok_string, _, _, _ = tok 

99 if tok_type == tokenize.INDENT: 

100 if ((' ' in tok_string and first_indent_uses_tabs) 

101 or ('\t' in tok_string and not first_indent_uses_tabs)): 

102 # TODO(mdan): We could attempt to convert tabs to spaces by unix rule. 

103 # See: 

104 # https://docs.python.org/3/reference/lexical_analysis.html#indentation 

105 raise errors.UnsupportedLanguageElementError( 

106 'code mixing tabs and spaces for indentation is not allowed') 

107 if len(tok_string) >= block_level: 

108 tok_string = tok_string[block_level:] 

109 tokens[i] = (tok_type, tok_string) 

110 

111 new_code = tokenize.untokenize(tokens) 

112 

113 # Note: untokenize respects the line structure, but not the whitespace within 

114 # lines. For example, `def foo()` may be untokenized as `def foo ()` 

115 # So instead of using the output of dedent, we match the leading whitespace 

116 # on each line. 

117 dedented_code = [] 

118 for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): 

119 original_indent = re.match(_LEADING_WHITESPACE, line).group() 

120 new_indent = re.match(_LEADING_WHITESPACE, new_line).group() 

121 if len(original_indent) > len(new_indent): 

122 dedented_line = line[len(original_indent) - len(new_indent):] 

123 else: 

124 dedented_line = line 

125 dedented_code.append(dedented_line) 

126 new_code = '\n'.join(dedented_code) 

127 

128 return new_code 

129 

130 

131def parse_entity(entity, future_features): 

132 """Returns the AST and source code of given entity. 

133 

134 Args: 

135 entity: Any, Python function/method/class 

136 future_features: Iterable[Text], future features to use (e.g. 

137 'print_statement'). See 

138 https://docs.python.org/2/reference/simple_stmts.html#future 

139 

140 Returns: 

141 gast.AST, Text: the parsed AST node; the source code that was parsed to 

142 generate the AST (including any prefixes that this function may have added). 

143 """ 

144 if inspect_utils.islambda(entity): 

145 return _parse_lambda(entity) 

146 

147 try: 

148 original_source = inspect_utils.getimmediatesource(entity) 

149 except OSError as e: 

150 raise errors.InaccessibleSourceCodeError( 

151 f'Unable to locate the source code of {entity}. Note that functions' 

152 ' defined in certain environments, like the interactive Python shell,' 

153 ' do not expose their source code. If that is the case, you should' 

154 ' define them in a .py source file. If you are certain the code is' 

155 ' graph-compatible, wrap the call using' 

156 f' @tf.autograph.experimental.do_not_convert. Original error: {e}') 

157 

158 source = dedent_block(original_source) 

159 

160 future_statements = tuple( 

161 'from __future__ import {}'.format(name) for name in future_features) 

162 source = '\n'.join(future_statements + (source,)) 

163 

164 return parse(source, preamble_len=len(future_features)), source 

165 

166 

167def _without_context(node, lines, minl, maxl): 

168 """Returns a clean node and source code without indenting and context.""" 

169 for n in gast.walk(node): 

170 lineno = getattr(n, 'lineno', None) 

171 if lineno is not None: 

172 n.lineno = lineno - minl 

173 end_lineno = getattr(n, 'end_lineno', None) 

174 if end_lineno is not None: 

175 n.end_lineno = end_lineno - minl 

176 

177 code_lines = lines[minl - 1:maxl] 

178 

179 # Attempt to clean up surrounding context code. 

180 

181 end_col_offset = getattr(node, 'end_col_offset', None) 

182 if end_col_offset is not None: 

183 # This is only available in 3.8. 

184 code_lines[-1] = code_lines[-1][:end_col_offset] 

185 

186 col_offset = getattr(node, 'col_offset', None) 

187 if col_offset is None: 

188 # Older Python: try to find the "lambda" token. This is brittle. 

189 match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) 

190 if match is not None: 

191 col_offset = match.start(0) 

192 

193 if col_offset is not None: 

194 code_lines[0] = code_lines[0][col_offset:] 

195 

196 code_block = '\n'.join([c.rstrip() for c in code_lines]) 

197 

198 return node, code_block 

199 

200 

201def _arg_name(node): 

202 if node is None: 

203 return None 

204 if isinstance(node, gast.Name): 

205 return node.id 

206 assert isinstance(node, str) 

207 return node 

208 

209 

210def _node_matches_argspec(node, func): 

211 """Returns True is node fits the argspec of func.""" 

212 # TODO(mdan): Use just inspect once support for Python 2 is dropped. 

213 arg_spec = tf_inspect.getfullargspec(func) 

214 

215 node_args = tuple(_arg_name(arg) for arg in node.args.args) 

216 if node_args != tuple(arg_spec.args): 

217 return False 

218 

219 if arg_spec.varargs != _arg_name(node.args.vararg): 

220 return False 

221 

222 if arg_spec.varkw != _arg_name(node.args.kwarg): 

223 return False 

224 

225 node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) 

226 if node_kwonlyargs != tuple(arg_spec.kwonlyargs): 

227 return False 

228 

229 return True 

230 

231 

232def _parse_lambda(lam): 

233 """Returns the AST and source code of given lambda function. 

234 

235 Args: 

236 lam: types.LambdaType, Python function/method/class 

237 

238 Returns: 

239 gast.AST, Text: the parsed AST node; the source code that was parsed to 

240 generate the AST (including any prefixes that this function may have added). 

241 """ 

242 # TODO(mdan): Use a fast path if the definition is not multi-line. 

243 # We could detect that the lambda is in a multi-line expression by looking 

244 # at the surrounding code - an surrounding set of parentheses indicates a 

245 # potential multi-line definition. 

246 

247 mod = inspect.getmodule(lam) 

248 f = inspect.getsourcefile(lam) 

249 def_line = lam.__code__.co_firstlineno 

250 

251 # This method is more robust that just calling inspect.getsource(mod), as it 

252 # works in interactive shells, where getsource would fail. This is the 

253 # same procedure followed by inspect for non-modules: 

254 # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772 

255 lines = linecache.getlines(f, mod.__dict__) 

256 source = ''.join(lines) 

257 

258 # Narrow down to the last node starting before our definition node. 

259 all_nodes = parse(source, preamble_len=0, single_node=False) 

260 search_nodes = [] 

261 for node in all_nodes: 

262 # Also include nodes without a line number, for safety. This is defensive - 

263 # we don't know whether such nodes might exist, and if they do, whether 

264 # they are not safe to skip. 

265 # TODO(mdan): Replace this check with an assertion or skip such nodes. 

266 if getattr(node, 'lineno', def_line) <= def_line: 

267 search_nodes.append(node) 

268 else: 

269 # Found a node starting past our lambda - can stop the search. 

270 break 

271 

272 # Extract all lambda nodes from the shortlist. 

273 lambda_nodes = [] 

274 for node in search_nodes: 

275 lambda_nodes.extend( 

276 n for n in gast.walk(node) if isinstance(n, gast.Lambda)) 

277 

278 # Filter down to lambda nodes which span our actual lambda. 

279 candidates = [] 

280 for ln in lambda_nodes: 

281 minl, maxl = MAX_SIZE, 0 

282 for n in gast.walk(ln): 

283 minl = min(minl, getattr(n, 'lineno', minl)) 

284 lineno = getattr(n, 'lineno', maxl) 

285 end_lineno = getattr(n, 'end_lineno', None) 

286 if end_lineno is not None: 

287 # end_lineno is more precise, but lineno should almost always work too. 

288 lineno = end_lineno 

289 maxl = max(maxl, lineno) 

290 if minl <= def_line <= maxl: 

291 candidates.append((ln, minl, maxl)) 

292 

293 # Happy path: exactly one node found. 

294 if len(candidates) == 1: 

295 (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking 

296 return _without_context(node, lines, minl, maxl) 

297 

298 elif not candidates: 

299 lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes]) 

300 raise errors.UnsupportedLanguageElementError( 

301 f'could not parse the source code of {lam}:' 

302 f' no matching AST found among candidates:\n{lambda_codes}') 

303 

304 # Attempt to narrow down selection by signature is multiple nodes are found. 

305 matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] 

306 if len(matches) == 1: 

307 (node, minl, maxl), = matches 

308 return _without_context(node, lines, minl, maxl) 

309 

310 # Give up if could not narrow down to a single node. 

311 matches = '\n'.join( 

312 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) 

313 for i, (node, _, _) in enumerate(matches)) 

314 raise errors.UnsupportedLanguageElementError( 

315 f'could not parse the source code of {lam}: found multiple definitions' 

316 ' with identical signatures at the location. This error' 

317 ' may be avoided by defining each lambda on a single line and with' 

318 f' unique argument names. The matching definitions were:\n{matches}') 

319 

320 

321# TODO(mdan): This should take futures as input instead. 

322def parse(src, preamble_len=0, single_node=True): 

323 """Returns the AST of given piece of code. 

324 

325 Args: 

326 src: Text 

327 preamble_len: Int, indicates leading nodes in the parsed AST which should be 

328 dropped. 

329 single_node: Bool, whether `src` is assumed to be represented by exactly one 

330 AST node. 

331 

332 Returns: 

333 ast.AST 

334 """ 

335 module_node = gast.parse(src) 

336 nodes = module_node.body 

337 if preamble_len: 

338 nodes = nodes[preamble_len:] 

339 if single_node: 

340 if len(nodes) != 1: 

341 raise ValueError('expected exactly one node, got {}'.format(nodes)) 

342 return nodes[0] 

343 return nodes 

344 

345 

346def parse_expression(src): 

347 """Returns the AST of given identifier. 

348 

349 Args: 

350 src: A piece of code that represents a single Python expression 

351 Returns: 

352 A gast.AST object. 

353 Raises: 

354 ValueError: if src does not consist of a single Expression. 

355 """ 

356 src = STANDARD_PREAMBLE + src.strip() 

357 node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) 

358 if __debug__: 

359 if not isinstance(node, gast.Expr): 

360 raise ValueError( 

361 'expected exactly one node of type Expr, got {}'.format(node)) 

362 return node.value 

363 

364 

365def unparse(node, indentation=None, include_encoding_marker=True): 

366 """Returns the source code of given AST. 

367 

368 Args: 

369 node: The code to compile, as an AST object. 

370 indentation: Unused, deprecated. The returning code will always be indented 

371 at 4 spaces. 

372 include_encoding_marker: Bool, whether to include a comment on the first 

373 line to explicitly specify UTF-8 encoding. 

374 

375 Returns: 

376 code: The source code generated from the AST object 

377 source_mapping: A mapping between the user and AutoGraph generated code. 

378 """ 

379 del indentation # astunparse doesn't allow configuring it. 

380 if not isinstance(node, (list, tuple)): 

381 node = (node,) 

382 

383 codes = [] 

384 if include_encoding_marker: 

385 codes.append('# coding=utf-8') 

386 for n in node: 

387 if isinstance(n, gast.AST): 

388 ast_n = gast.gast_to_ast(n) 

389 else: 

390 ast_n = n 

391 

392 if astunparse is ast: 

393 ast.fix_missing_locations(ast_n) # Only ast needs to call this. 

394 codes.append(astunparse.unparse(ast_n).strip()) 

395 

396 return '\n'.join(codes)