Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py: 25%
135 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Generic source code transformation infrastructure."""
17import inspect
18import threading
19import types
21import gast
23from tensorflow.python.autograph.pyct import cache
24from tensorflow.python.autograph.pyct import inspect_utils
25from tensorflow.python.autograph.pyct import loader
26from tensorflow.python.autograph.pyct import naming
27from tensorflow.python.autograph.pyct import origin_info
28from tensorflow.python.autograph.pyct import parser
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.autograph.pyct import transformer
31from tensorflow.python.autograph.utils import ag_logging as logging
34def _wrap_into_factory(nodes, entity_name, inner_factory_name,
35 outer_factory_name, closure_vars, factory_args,
36 future_features):
37 """Wraps an AST into the body of a factory with consistent lexical context.
39 The AST is expected to define some symbol with a name given by `entity_name`.
41 This mechanism ensures that the resulting transformed entity has lexical
42 scoping identical to that of the source entity, while allowing extra
43 parametrization.
45 Two nested factories achieve the following:
47 1. The inner factory dynamically creates the entity represented by `nodes`.
48 2. The inner factory is parametrized by a custom set of arguments.
49 3. The inner factory has a closure identical to that of the transformed
50 entity.
51 4. The inner factory has local variables named like `args`, which `nodes` may
52 use as additional parameters.
53 5. The inner factory returns the variables given by `entity_name`.
54 6. The outer factory is niladic.
55 7. The outer factory has no closure.
56 8. The outer factory creates the necessary lexical scope for the inner
57 factory, so that the loaded code has the given configuration for
58 closure/globals.
59 9. The outer factory returns the inner factory.
61 Roughly speaking, the following code is generated:
63 from __future__ import future_feature_1
64 from __future__ import future_feature_2
65 ...
67 def outer_factory():
68 closure_var_1 = None
69 closure_var_2 = None
70 ...
72 def inner_factory(arg_1, arg_2, ...):
73 <<nodes>>
74 return entity
76 return inner_factory
78 The lexical scoping is created using dummy symbol declarations which create
79 local variables in the body of the outer factory, so that the Python parser
80 correctly marks them as free non-global variables upon load (that is, it
81 creates cell slots for each symbol. These symbols are initialized with None,
82 but their values are not expected to be used; instead, the caller is expected
83 to replace them with the cells of the source entity. For more details, see:
84 https://docs.python.org/3/reference/executionmodel.html#binding-of-names
86 Args:
87 nodes: Tuple[ast.AST], the source code to wrap.
88 entity_name: Union[Text, ast.AST], the name of the principal entity that
89 `nodes` define.
90 inner_factory_name: Text, the name of the inner factory.
91 outer_factory_name: Text, the name of the outer factory.
92 closure_vars: Iterable[Text], names of the closure variables for the inner
93 factory.
94 factory_args: Iterable[Text], names of additional arguments for the
95 inner factory. Useful to configure variables that the converted code can
96 use. Typically, these are modules.
97 future_features: Iterable[Text], names of future statements to associate the
98 code with.
100 Returns:
101 ast.AST
102 """
103 dummy_closure_defs = []
104 for var_name in closure_vars:
105 template = """
106 var_name = None
107 """
108 dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
110 if future_features:
111 future_imports = gast.ImportFrom(
112 module='__future__',
113 names=[gast.alias(name=name, asname=None) for name in future_features],
114 level=0)
115 else:
116 future_imports = []
118 factory_args = [
119 gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
120 for name in factory_args
121 ]
123 template = """
124 future_imports
125 def outer_factory_name():
126 dummy_closure_defs
127 def inner_factory_name(factory_args):
128 entity_defs
129 return entity_name
130 return inner_factory_name
131 """
132 return templates.replace(
133 template,
134 dummy_closure_defs=dummy_closure_defs,
135 entity_defs=nodes,
136 entity_name=entity_name,
137 factory_args=factory_args,
138 future_imports=future_imports,
139 inner_factory_name=inner_factory_name,
140 outer_factory_name=outer_factory_name)
143class _PythonFnFactory(object):
144 """Helper object that wraps a Python function factory."""
146 def __init__(self, name, freevars, extra_locals):
147 """Creates a new factory for a Python function.
149 Args:
150 name: The function name.
151 freevars: The list of non-global free variables for the function.
152 extra_locals: Dict[Text, Any], names and values for custom variables that
153 are accessible to the generated code as local variables.
154 """
155 self._name = name
156 self._freevars = freevars
157 self._extra_locals = extra_locals
159 self._unbound_factory = None
160 self.module = None
161 self.source_map = None
163 def create(self,
164 nodes,
165 namer,
166 inner_factory_name='inner_factory',
167 outer_factory_name='outer_factory',
168 future_features=()):
169 """Initializes a function."""
170 if self._unbound_factory is not None:
171 raise ValueError('double initialization; create a new object instead')
173 inner_factory_name = namer.new_symbol(inner_factory_name, ())
174 outer_factory_name = namer.new_symbol(outer_factory_name, ())
175 nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
176 outer_factory_name, self._freevars,
177 self._extra_locals.keys(), future_features)
179 module, _, source_map = loader.load_ast(
180 nodes, include_source_map=True)
181 outer_factory = getattr(module, outer_factory_name)
182 self._unbound_factory = outer_factory()
183 self.module = module
184 self.source_map = source_map
186 def instantiate(self,
187 globals_,
188 closure,
189 defaults=None,
190 kwdefaults=None):
191 """Creates a new function instance."""
192 if self._unbound_factory is None:
193 raise ValueError('call create first')
195 factory_code = self._unbound_factory.__code__
196 factory_freevars = factory_code.co_freevars
197 closure_map = dict(zip(self._freevars, closure))
198 factory_closure = tuple(
199 closure_map[name] for name in factory_code.co_freevars)
200 if len(factory_closure) != len(closure):
201 raise ValueError(
202 'closure mismatch, requested {}, but source function had {}'.format(
203 self._freevars, factory_freevars))
205 bound_factory = types.FunctionType(
206 code=factory_code,
207 globals=globals_,
208 name=self._name,
209 argdefs=(),
210 closure=factory_closure)
212 # The lint override is a false positive.
213 new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable
215 if defaults:
216 new_fn.__defaults__ = defaults
217 if kwdefaults:
218 new_fn.__kwdefaults__ = kwdefaults
220 return new_fn
223class GenericTranspiler(object):
224 """A generic transpiler for Python functions.
226 Its interface is the `transform` API, which can process Python function
227 objects. Internally, it handles parsing.
229 Users typically subclass this, customizing the `transform_ast` method. The
230 output of transformed_ast is returned directly by `transform`. Existing
231 methods like `transform_function` may also be overloaded.
233 Example:
235 class MyTransformer(GenericTranspiler):
237 def transform_ast(self, node, ctx):
238 result = <<transform node>>
239 return result
241 transformer = MyTransfomer()
243 result = transformer.transform(f, ...)
244 # result is the output
245 """
247 def get_transformed_name(self, node):
248 """Returns a name for the output function. Subclasses may override this."""
249 if isinstance(node, gast.Lambda):
250 return 'lam'
251 elif isinstance(node, gast.FunctionDef):
252 return node.name
253 raise ValueError('Unknown node type {}'.format(node))
255 def transform_ast(self, node, ctx):
256 """Performs an actual transformation of a function's AST.
258 Subclasses must implement this method, and do not usually call it.
260 Args:
261 node: One or more ast.AST nodes representing the AST to be transformed.
262 ctx: transformer.Context.
263 """
264 raise NotImplementedError('subclasses must override this')
266 def transform(self, obj, user_context):
267 """Transforms a Python object.
269 Users typically call this method.
271 Args:
272 obj: A Python object, function, type, etc.
273 user_context: An opaque object (may be None) that is forwarded to
274 transform_ast, through the ctx.user attribute.
275 Returns:
276 The result of calling transform_function.
278 Raises:
279 NotImplementedError: if the type of obj is not handled.
280 """
281 if inspect.isfunction(obj) or inspect.ismethod(obj):
282 return self.transform_function(obj, user_context)
284 raise NotImplementedError('Non-function: {}'.format(type(obj)))
286 def _erase_arg_defaults(self, node):
287 """Erase arg default expressions, which would otherwise be unbound."""
288 args = node.args
289 for i in range(len(args.defaults)):
290 args.defaults[i] = parser.parse_expression('None')
291 for i, d in enumerate(args.kw_defaults):
292 if d is not None:
293 args.kw_defaults[i] = parser.parse_expression('None')
294 return node
296 def transform_module(self, mod, user_context):
297 """Transforms a module.
299 Subclasses may override this method. The return value is opaque.
301 The method receives the original AST. The result is passed as-is to the
302 output of `transform`.
304 Args:
305 mod: A Python module.
306 user_context: An opaque object (may be None) that is forwarded to
307 transform_ast, through the ctx.user attribute.
308 Returns:
309 List[Tuple[Any, Any]]. By default it returns the output of transform_ast,
310 evaluated on each supported member, other than modules, together with a
311 `transformer.Context` containing information about the transformation
312 process.
313 """
314 result = []
315 for member in mod.__dict__.values():
316 if inspect.ismodule(member):
317 continue # Not transforming modules recursively.
318 try:
319 result.append(self.transform(member, user_context))
320 except NotImplementedError:
321 pass # Skip unsupported elements.
322 return result
324 def transform_function(self, fn, user_context):
325 """Transforms a function.
327 Subclasses may override this method. The return value is opaque.
329 The method receives the original AST. The result is passed as-is to the
330 output of `transform`.
332 Args:
333 fn: A function or lambda.
334 user_context: An opaque object (may be None) that is forwarded to
335 transform_ast, through the ctx.user attribute.
336 Returns:
337 Tuple[Any, Any]. By default it returns the output of transform_ast,
338 together with a `transformer.Context` containing information about the
339 transformation process.
340 """
341 future_features = inspect_utils.getfutureimports(fn)
342 node, source = parser.parse_entity(fn, future_features=future_features)
343 logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
345 origin_info.resolve_entity(node, source, fn)
347 namespace = inspect_utils.getnamespace(fn)
348 namer = naming.Namer(namespace)
349 new_name = namer.new_symbol(self.get_transformed_name(node), ())
350 entity_info = transformer.EntityInfo(
351 name=new_name,
352 source_code=source,
353 source_file='<fragment>',
354 future_features=future_features,
355 namespace=namespace)
356 context = transformer.Context(entity_info, namer, user_context)
358 node = self._erase_arg_defaults(node)
359 result = self.transform_ast(node, context)
361 return result, context
364class PyToPy(GenericTranspiler):
365 """A generic Python-to-Python transpiler.
367 Its `transform` method offers a function-in, function-out interface.
368 Internally, it takes care of parsing, caching and loading of the translated
369 code.
371 Users typically subclass this, overriding `transform_ast`.
373 Usually, instances of this class are singletons, since each instance manages
374 its own cache. The caching can be controlled by overriding `get_caching_key`.
376 Example:
378 class MyTransformer(PyToPy):
380 def transform_ast(self, node, ctx):
381 node = <<transform node, usually using ast.NodeTransformer classes>>
382 return node
384 transformer = MyTransfomer()
386 new_f, module, source_map = transformer.transform_function(f, ...)
387 # new_f is a function with signature identical to f
389 The transformed function has access to the same namespace as the original
390 function. To allow access to internal APIs, users may inject additional
391 symbols by overriding `get_extra_locals`.
392 """
394 def __init__(self):
395 self._cache_lock = threading.RLock()
396 self._cache = cache.CodeObjectCache()
398 def get_extra_locals(self):
399 """Returns extra static local variables to be made to transformed code.
401 Subclasses must override this.
403 Returns:
404 extra_locals: A Dict[Text, Any] containing additional variables to make
405 available to the transformed code.
406 """
407 raise NotImplementedError('subclasses must override this')
409 def get_caching_key(self, user_context):
410 """Returns a unique key to use for caching.
412 Subclasses must override this.
414 Calls made to `transform_function` with functions that have the same code
415 object and caching key will return a cached instance on subsequent
416 invocations.
418 Args:
419 user_context: The context object which was passed to `transform`.
421 Returns:
422 extra_locals: A hashable.
423 """
424 raise NotImplementedError('subclasses must override this')
426 def _cached_factory(self, fn, cache_subkey):
427 cached_factory = self._cache[fn][cache_subkey]
428 logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
429 cached_factory)
430 return cached_factory
432 def transform_function(self, fn, user_context):
433 """Transforms a function. See GenericTranspiler.trasnform_function.
435 This overload wraps the parent's `transform_function`, adding caching and
436 facilities to instantiate the output as a Python object. It also
437 adds facilities to make new symbols available to the generated Python code,
438 visible as local variables - see `get_extra_locals`.
440 Args:
441 fn: A function or lambda.
442 user_context: An opaque object (may be None) that is forwarded to
443 transform_ast, through the ctx.user attribute.
444 Returns:
445 A tuple:
446 * A function or lambda with the same signature and closure as `fn`
447 * The temporary module into which the transformed function was loaded
448 * The source map as a
449 Dict[origin_info.LineLocation, origin_info.OriginInfo]
450 """
451 cache_subkey = self.get_caching_key(user_context)
453 if self._cache.has(fn, cache_subkey):
454 # Fast path: use a lock-free check.
455 factory = self._cached_factory(fn, cache_subkey)
457 else:
458 with self._cache_lock:
459 # Check again under lock.
460 if self._cache.has(fn, cache_subkey):
461 factory = self._cached_factory(fn, cache_subkey)
463 else:
464 logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
465 # TODO(mdan): Confusing overloading pattern. Fix.
466 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
468 if isinstance(nodes, gast.Lambda):
469 nodes = gast.Assign(
470 targets=[
471 gast.Name(
472 ctx.info.name,
473 ctx=gast.Store(),
474 annotation=None,
475 type_comment=None)
476 ],
477 value=nodes)
478 else:
479 nodes.name = ctx.info.name
481 if logging.has_verbosity(2):
482 logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
484 factory = _PythonFnFactory(
485 ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals())
486 factory.create(
487 nodes, ctx.namer, future_features=ctx.info.future_features)
488 self._cache[fn][cache_subkey] = factory
490 transformed_fn = factory.instantiate(
491 globals_=fn.__globals__,
492 closure=fn.__closure__ or (),
493 defaults=fn.__defaults__,
494 kwdefaults=getattr(fn, '__kwdefaults__', None))
495 return transformed_fn, factory.module, factory.source_map