Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/parser.py: 17%
192 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converting code to AST.
17Adapted from Tangent.
18"""
20import ast
21import inspect
22import io
23import linecache
24import re
25import sys
26import textwrap
27import tokenize
29import astunparse
30import gast
32from tensorflow.python.autograph.pyct import errors
33from tensorflow.python.autograph.pyct import inspect_utils
34from tensorflow.python.util import tf_inspect
37PY2_PREAMBLE = textwrap.dedent("""
38""")
39PY3_PREAMBLE = ''
40MAX_SIZE = 0
42if sys.version_info >= (3, 9):
43 astunparse = ast
45if sys.version_info >= (3,):
46 STANDARD_PREAMBLE = PY3_PREAMBLE
47 MAX_SIZE = sys.maxsize
48else:
49 STANDARD_PREAMBLE = PY2_PREAMBLE
50 MAX_SIZE = sys.maxint
52STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__')
55_LEADING_WHITESPACE = re.compile(r'\s*')
58def _unfold_continuations(code_string):
59 """Removes any backslash line continuations from the code."""
60 return code_string.replace('\\\n', '')
63def dedent_block(code_string):
64 """Dedents a code so that its first line starts at row zero."""
66 code_string = _unfold_continuations(code_string)
68 token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline)
70 block_indentation = None
71 tokens = []
72 try:
73 for tok in token_gen:
74 tokens.append(tok)
75 except tokenize.TokenError:
76 # Resolution of lambda functions may yield incomplete code, which can
77 # in turn generate this error. We silently ignore this error because the
78 # parser may still be able to deal with it.
79 pass
81 for tok in tokens:
82 tok_type, tok_string, _, _, _ = tok
83 if tok_type == tokenize.INDENT:
84 block_indentation = tok_string
85 block_level = len(block_indentation)
86 break
87 elif tok_type not in (
88 tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
89 block_indentation = ''
90 break
92 if not block_indentation:
93 return code_string
95 block_level = len(block_indentation)
96 first_indent_uses_tabs = '\t' in block_indentation
97 for i, tok in enumerate(tokens):
98 tok_type, tok_string, _, _, _ = tok
99 if tok_type == tokenize.INDENT:
100 if ((' ' in tok_string and first_indent_uses_tabs)
101 or ('\t' in tok_string and not first_indent_uses_tabs)):
102 # TODO(mdan): We could attempt to convert tabs to spaces by unix rule.
103 # See:
104 # https://docs.python.org/3/reference/lexical_analysis.html#indentation
105 raise errors.UnsupportedLanguageElementError(
106 'code mixing tabs and spaces for indentation is not allowed')
107 if len(tok_string) >= block_level:
108 tok_string = tok_string[block_level:]
109 tokens[i] = (tok_type, tok_string)
111 new_code = tokenize.untokenize(tokens)
113 # Note: untokenize respects the line structure, but not the whitespace within
114 # lines. For example, `def foo()` may be untokenized as `def foo ()`
115 # So instead of using the output of dedent, we match the leading whitespace
116 # on each line.
117 dedented_code = []
118 for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
119 original_indent = re.match(_LEADING_WHITESPACE, line).group()
120 new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
121 if len(original_indent) > len(new_indent):
122 dedented_line = line[len(original_indent) - len(new_indent):]
123 else:
124 dedented_line = line
125 dedented_code.append(dedented_line)
126 new_code = '\n'.join(dedented_code)
128 return new_code
131def parse_entity(entity, future_features):
132 """Returns the AST and source code of given entity.
134 Args:
135 entity: Any, Python function/method/class
136 future_features: Iterable[Text], future features to use (e.g.
137 'print_statement'). See
138 https://docs.python.org/2/reference/simple_stmts.html#future
140 Returns:
141 gast.AST, Text: the parsed AST node; the source code that was parsed to
142 generate the AST (including any prefixes that this function may have added).
143 """
144 if inspect_utils.islambda(entity):
145 return _parse_lambda(entity)
147 try:
148 original_source = inspect_utils.getimmediatesource(entity)
149 except OSError as e:
150 raise errors.InaccessibleSourceCodeError(
151 f'Unable to locate the source code of {entity}. Note that functions'
152 ' defined in certain environments, like the interactive Python shell,'
153 ' do not expose their source code. If that is the case, you should'
154 ' define them in a .py source file. If you are certain the code is'
155 ' graph-compatible, wrap the call using'
156 f' @tf.autograph.experimental.do_not_convert. Original error: {e}')
158 source = dedent_block(original_source)
160 future_statements = tuple(
161 'from __future__ import {}'.format(name) for name in future_features)
162 source = '\n'.join(future_statements + (source,))
164 return parse(source, preamble_len=len(future_features)), source
167def _without_context(node, lines, minl, maxl):
168 """Returns a clean node and source code without indenting and context."""
169 for n in gast.walk(node):
170 lineno = getattr(n, 'lineno', None)
171 if lineno is not None:
172 n.lineno = lineno - minl
173 end_lineno = getattr(n, 'end_lineno', None)
174 if end_lineno is not None:
175 n.end_lineno = end_lineno - minl
177 code_lines = lines[minl - 1:maxl]
179 # Attempt to clean up surrounding context code.
181 end_col_offset = getattr(node, 'end_col_offset', None)
182 if end_col_offset is not None:
183 # This is only available in 3.8.
184 code_lines[-1] = code_lines[-1][:end_col_offset]
186 col_offset = getattr(node, 'col_offset', None)
187 if col_offset is None:
188 # Older Python: try to find the "lambda" token. This is brittle.
189 match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
190 if match is not None:
191 col_offset = match.start(0)
193 if col_offset is not None:
194 code_lines[0] = code_lines[0][col_offset:]
196 code_block = '\n'.join([c.rstrip() for c in code_lines])
198 return node, code_block
201def _arg_name(node):
202 if node is None:
203 return None
204 if isinstance(node, gast.Name):
205 return node.id
206 assert isinstance(node, str)
207 return node
210def _node_matches_argspec(node, func):
211 """Returns True is node fits the argspec of func."""
212 # TODO(mdan): Use just inspect once support for Python 2 is dropped.
213 arg_spec = tf_inspect.getfullargspec(func)
215 node_args = tuple(_arg_name(arg) for arg in node.args.args)
216 if node_args != tuple(arg_spec.args):
217 return False
219 if arg_spec.varargs != _arg_name(node.args.vararg):
220 return False
222 if arg_spec.varkw != _arg_name(node.args.kwarg):
223 return False
225 node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
226 if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
227 return False
229 return True
232def _parse_lambda(lam):
233 """Returns the AST and source code of given lambda function.
235 Args:
236 lam: types.LambdaType, Python function/method/class
238 Returns:
239 gast.AST, Text: the parsed AST node; the source code that was parsed to
240 generate the AST (including any prefixes that this function may have added).
241 """
242 # TODO(mdan): Use a fast path if the definition is not multi-line.
243 # We could detect that the lambda is in a multi-line expression by looking
244 # at the surrounding code - an surrounding set of parentheses indicates a
245 # potential multi-line definition.
247 mod = inspect.getmodule(lam)
248 f = inspect.getsourcefile(lam)
249 def_line = lam.__code__.co_firstlineno
251 # This method is more robust that just calling inspect.getsource(mod), as it
252 # works in interactive shells, where getsource would fail. This is the
253 # same procedure followed by inspect for non-modules:
254 # https://github.com/python/cpython/blob/3.8/Lib/inspect.py#L772
255 lines = linecache.getlines(f, mod.__dict__)
256 source = ''.join(lines)
258 # Narrow down to the last node starting before our definition node.
259 all_nodes = parse(source, preamble_len=0, single_node=False)
260 search_nodes = []
261 for node in all_nodes:
262 # Also include nodes without a line number, for safety. This is defensive -
263 # we don't know whether such nodes might exist, and if they do, whether
264 # they are not safe to skip.
265 # TODO(mdan): Replace this check with an assertion or skip such nodes.
266 if getattr(node, 'lineno', def_line) <= def_line:
267 search_nodes.append(node)
268 else:
269 # Found a node starting past our lambda - can stop the search.
270 break
272 # Extract all lambda nodes from the shortlist.
273 lambda_nodes = []
274 for node in search_nodes:
275 lambda_nodes.extend(
276 n for n in gast.walk(node) if isinstance(n, gast.Lambda))
278 # Filter down to lambda nodes which span our actual lambda.
279 candidates = []
280 for ln in lambda_nodes:
281 minl, maxl = MAX_SIZE, 0
282 for n in gast.walk(ln):
283 minl = min(minl, getattr(n, 'lineno', minl))
284 lineno = getattr(n, 'lineno', maxl)
285 end_lineno = getattr(n, 'end_lineno', None)
286 if end_lineno is not None:
287 # end_lineno is more precise, but lineno should almost always work too.
288 lineno = end_lineno
289 maxl = max(maxl, lineno)
290 if minl <= def_line <= maxl:
291 candidates.append((ln, minl, maxl))
293 # Happy path: exactly one node found.
294 if len(candidates) == 1:
295 (node, minl, maxl), = candidates # pylint:disable=unbalanced-tuple-unpacking
296 return _without_context(node, lines, minl, maxl)
298 elif not candidates:
299 lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes])
300 raise errors.UnsupportedLanguageElementError(
301 f'could not parse the source code of {lam}:'
302 f' no matching AST found among candidates:\n{lambda_codes}')
304 # Attempt to narrow down selection by signature is multiple nodes are found.
305 matches = [v for v in candidates if _node_matches_argspec(v[0], lam)]
306 if len(matches) == 1:
307 (node, minl, maxl), = matches
308 return _without_context(node, lines, minl, maxl)
310 # Give up if could not narrow down to a single node.
311 matches = '\n'.join(
312 'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False))
313 for i, (node, _, _) in enumerate(matches))
314 raise errors.UnsupportedLanguageElementError(
315 f'could not parse the source code of {lam}: found multiple definitions'
316 ' with identical signatures at the location. This error'
317 ' may be avoided by defining each lambda on a single line and with'
318 f' unique argument names. The matching definitions were:\n{matches}')
321# TODO(mdan): This should take futures as input instead.
322def parse(src, preamble_len=0, single_node=True):
323 """Returns the AST of given piece of code.
325 Args:
326 src: Text
327 preamble_len: Int, indicates leading nodes in the parsed AST which should be
328 dropped.
329 single_node: Bool, whether `src` is assumed to be represented by exactly one
330 AST node.
332 Returns:
333 ast.AST
334 """
335 module_node = gast.parse(src)
336 nodes = module_node.body
337 if preamble_len:
338 nodes = nodes[preamble_len:]
339 if single_node:
340 if len(nodes) != 1:
341 raise ValueError('expected exactly one node, got {}'.format(nodes))
342 return nodes[0]
343 return nodes
346def parse_expression(src):
347 """Returns the AST of given identifier.
349 Args:
350 src: A piece of code that represents a single Python expression
351 Returns:
352 A gast.AST object.
353 Raises:
354 ValueError: if src does not consist of a single Expression.
355 """
356 src = STANDARD_PREAMBLE + src.strip()
357 node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True)
358 if __debug__:
359 if not isinstance(node, gast.Expr):
360 raise ValueError(
361 'expected exactly one node of type Expr, got {}'.format(node))
362 return node.value
365def unparse(node, indentation=None, include_encoding_marker=True):
366 """Returns the source code of given AST.
368 Args:
369 node: The code to compile, as an AST object.
370 indentation: Unused, deprecated. The returning code will always be indented
371 at 4 spaces.
372 include_encoding_marker: Bool, whether to include a comment on the first
373 line to explicitly specify UTF-8 encoding.
375 Returns:
376 code: The source code generated from the AST object
377 source_mapping: A mapping between the user and AutoGraph generated code.
378 """
379 del indentation # astunparse doesn't allow configuring it.
380 if not isinstance(node, (list, tuple)):
381 node = (node,)
383 codes = []
384 if include_encoding_marker:
385 codes.append('# coding=utf-8')
386 for n in node:
387 if isinstance(n, gast.AST):
388 ast_n = gast.gast_to_ast(n)
389 else:
390 ast_n = n
392 if astunparse is ast:
393 ast.fix_missing_locations(ast_n) # Only ast needs to call this.
394 codes.append(astunparse.unparse(ast_n).strip())
396 return '\n'.join(codes)