Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py: 25%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Generic source code transformation infrastructure.""" 

16 

17import inspect 

18import threading 

19import types 

20 

21import gast 

22 

23from tensorflow.python.autograph.pyct import cache 

24from tensorflow.python.autograph.pyct import inspect_utils 

25from tensorflow.python.autograph.pyct import loader 

26from tensorflow.python.autograph.pyct import naming 

27from tensorflow.python.autograph.pyct import origin_info 

28from tensorflow.python.autograph.pyct import parser 

29from tensorflow.python.autograph.pyct import templates 

30from tensorflow.python.autograph.pyct import transformer 

31from tensorflow.python.autograph.utils import ag_logging as logging 

32 

33 

34def _wrap_into_factory(nodes, entity_name, inner_factory_name, 

35 outer_factory_name, closure_vars, factory_args, 

36 future_features): 

37 """Wraps an AST into the body of a factory with consistent lexical context. 

38 

39 The AST is expected to define some symbol with a name given by `entity_name`. 

40 

41 This mechanism ensures that the resulting transformed entity has lexical 

42 scoping identical to that of the source entity, while allowing extra 

43 parametrization. 

44 

45 Two nested factories achieve the following: 

46 

47 1. The inner factory dynamically creates the entity represented by `nodes`. 

48 2. The inner factory is parametrized by a custom set of arguments. 

49 3. The inner factory has a closure identical to that of the transformed 

50 entity. 

51 4. The inner factory has local variables named like `args`, which `nodes` may 

52 use as additional parameters. 

53 5. The inner factory returns the variables given by `entity_name`. 

54 6. The outer factory is niladic. 

55 7. The outer factory has no closure. 

56 8. The outer factory creates the necessary lexical scope for the inner 

57 factory, so that the loaded code has the given configuration for 

58 closure/globals. 

59 9. The outer factory returns the inner factory. 

60 

61 Roughly speaking, the following code is generated: 

62 

63 from __future__ import future_feature_1 

64 from __future__ import future_feature_2 

65 ... 

66 

67 def outer_factory(): 

68 closure_var_1 = None 

69 closure_var_2 = None 

70 ... 

71 

72 def inner_factory(arg_1, arg_2, ...): 

73 <<nodes>> 

74 return entity 

75 

76 return inner_factory 

77 

78 The lexical scoping is created using dummy symbol declarations which create 

79 local variables in the body of the outer factory, so that the Python parser 

80 correctly marks them as free non-global variables upon load (that is, it 

81 creates cell slots for each symbol. These symbols are initialized with None, 

82 but their values are not expected to be used; instead, the caller is expected 

83 to replace them with the cells of the source entity. For more details, see: 

84 https://docs.python.org/3/reference/executionmodel.html#binding-of-names 

85 

86 Args: 

87 nodes: Tuple[ast.AST], the source code to wrap. 

88 entity_name: Union[Text, ast.AST], the name of the principal entity that 

89 `nodes` define. 

90 inner_factory_name: Text, the name of the inner factory. 

91 outer_factory_name: Text, the name of the outer factory. 

92 closure_vars: Iterable[Text], names of the closure variables for the inner 

93 factory. 

94 factory_args: Iterable[Text], names of additional arguments for the 

95 inner factory. Useful to configure variables that the converted code can 

96 use. Typically, these are modules. 

97 future_features: Iterable[Text], names of future statements to associate the 

98 code with. 

99 

100 Returns: 

101 ast.AST 

102 """ 

103 dummy_closure_defs = [] 

104 for var_name in closure_vars: 

105 template = """ 

106 var_name = None 

107 """ 

108 dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) 

109 

110 if future_features: 

111 future_imports = gast.ImportFrom( 

112 module='__future__', 

113 names=[gast.alias(name=name, asname=None) for name in future_features], 

114 level=0) 

115 else: 

116 future_imports = [] 

117 

118 factory_args = [ 

119 gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) 

120 for name in factory_args 

121 ] 

122 

123 template = """ 

124 future_imports 

125 def outer_factory_name(): 

126 dummy_closure_defs 

127 def inner_factory_name(factory_args): 

128 entity_defs 

129 return entity_name 

130 return inner_factory_name 

131 """ 

132 return templates.replace( 

133 template, 

134 dummy_closure_defs=dummy_closure_defs, 

135 entity_defs=nodes, 

136 entity_name=entity_name, 

137 factory_args=factory_args, 

138 future_imports=future_imports, 

139 inner_factory_name=inner_factory_name, 

140 outer_factory_name=outer_factory_name) 

141 

142 

143class _PythonFnFactory(object): 

144 """Helper object that wraps a Python function factory.""" 

145 

146 def __init__(self, name, freevars, extra_locals): 

147 """Creates a new factory for a Python function. 

148 

149 Args: 

150 name: The function name. 

151 freevars: The list of non-global free variables for the function. 

152 extra_locals: Dict[Text, Any], names and values for custom variables that 

153 are accessible to the generated code as local variables. 

154 """ 

155 self._name = name 

156 self._freevars = freevars 

157 self._extra_locals = extra_locals 

158 

159 self._unbound_factory = None 

160 self.module = None 

161 self.source_map = None 

162 

163 def create(self, 

164 nodes, 

165 namer, 

166 inner_factory_name='inner_factory', 

167 outer_factory_name='outer_factory', 

168 future_features=()): 

169 """Initializes a function.""" 

170 if self._unbound_factory is not None: 

171 raise ValueError('double initialization; create a new object instead') 

172 

173 inner_factory_name = namer.new_symbol(inner_factory_name, ()) 

174 outer_factory_name = namer.new_symbol(outer_factory_name, ()) 

175 nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, 

176 outer_factory_name, self._freevars, 

177 self._extra_locals.keys(), future_features) 

178 

179 module, _, source_map = loader.load_ast( 

180 nodes, include_source_map=True) 

181 outer_factory = getattr(module, outer_factory_name) 

182 self._unbound_factory = outer_factory() 

183 self.module = module 

184 self.source_map = source_map 

185 

186 def instantiate(self, 

187 globals_, 

188 closure, 

189 defaults=None, 

190 kwdefaults=None): 

191 """Creates a new function instance.""" 

192 if self._unbound_factory is None: 

193 raise ValueError('call create first') 

194 

195 factory_code = self._unbound_factory.__code__ 

196 factory_freevars = factory_code.co_freevars 

197 closure_map = dict(zip(self._freevars, closure)) 

198 factory_closure = tuple( 

199 closure_map[name] for name in factory_code.co_freevars) 

200 if len(factory_closure) != len(closure): 

201 raise ValueError( 

202 'closure mismatch, requested {}, but source function had {}'.format( 

203 self._freevars, factory_freevars)) 

204 

205 bound_factory = types.FunctionType( 

206 code=factory_code, 

207 globals=globals_, 

208 name=self._name, 

209 argdefs=(), 

210 closure=factory_closure) 

211 

212 # The lint override is a false positive. 

213 new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable 

214 

215 if defaults: 

216 new_fn.__defaults__ = defaults 

217 if kwdefaults: 

218 new_fn.__kwdefaults__ = kwdefaults 

219 

220 return new_fn 

221 

222 

223class GenericTranspiler(object): 

224 """A generic transpiler for Python functions. 

225 

226 Its interface is the `transform` API, which can process Python function 

227 objects. Internally, it handles parsing. 

228 

229 Users typically subclass this, customizing the `transform_ast` method. The 

230 output of transformed_ast is returned directly by `transform`. Existing 

231 methods like `transform_function` may also be overloaded. 

232 

233 Example: 

234 

235 class MyTransformer(GenericTranspiler): 

236 

237 def transform_ast(self, node, ctx): 

238 result = <<transform node>> 

239 return result 

240 

241 transformer = MyTransfomer() 

242 

243 result = transformer.transform(f, ...) 

244 # result is the output 

245 """ 

246 

247 def get_transformed_name(self, node): 

248 """Returns a name for the output function. Subclasses may override this.""" 

249 if isinstance(node, gast.Lambda): 

250 return 'lam' 

251 elif isinstance(node, gast.FunctionDef): 

252 return node.name 

253 raise ValueError('Unknown node type {}'.format(node)) 

254 

255 def transform_ast(self, node, ctx): 

256 """Performs an actual transformation of a function's AST. 

257 

258 Subclasses must implement this method, and do not usually call it. 

259 

260 Args: 

261 node: One or more ast.AST nodes representing the AST to be transformed. 

262 ctx: transformer.Context. 

263 """ 

264 raise NotImplementedError('subclasses must override this') 

265 

266 def transform(self, obj, user_context): 

267 """Transforms a Python object. 

268 

269 Users typically call this method. 

270 

271 Args: 

272 obj: A Python object, function, type, etc. 

273 user_context: An opaque object (may be None) that is forwarded to 

274 transform_ast, through the ctx.user attribute. 

275 Returns: 

276 The result of calling transform_function. 

277 

278 Raises: 

279 NotImplementedError: if the type of obj is not handled. 

280 """ 

281 if inspect.isfunction(obj) or inspect.ismethod(obj): 

282 return self.transform_function(obj, user_context) 

283 

284 raise NotImplementedError('Non-function: {}'.format(type(obj))) 

285 

286 def _erase_arg_defaults(self, node): 

287 """Erase arg default expressions, which would otherwise be unbound.""" 

288 args = node.args 

289 for i in range(len(args.defaults)): 

290 args.defaults[i] = parser.parse_expression('None') 

291 for i, d in enumerate(args.kw_defaults): 

292 if d is not None: 

293 args.kw_defaults[i] = parser.parse_expression('None') 

294 return node 

295 

296 def transform_module(self, mod, user_context): 

297 """Transforms a module. 

298 

299 Subclasses may override this method. The return value is opaque. 

300 

301 The method receives the original AST. The result is passed as-is to the 

302 output of `transform`. 

303 

304 Args: 

305 mod: A Python module. 

306 user_context: An opaque object (may be None) that is forwarded to 

307 transform_ast, through the ctx.user attribute. 

308 Returns: 

309 List[Tuple[Any, Any]]. By default it returns the output of transform_ast, 

310 evaluated on each supported member, other than modules, together with a 

311 `transformer.Context` containing information about the transformation 

312 process. 

313 """ 

314 result = [] 

315 for member in mod.__dict__.values(): 

316 if inspect.ismodule(member): 

317 continue # Not transforming modules recursively. 

318 try: 

319 result.append(self.transform(member, user_context)) 

320 except NotImplementedError: 

321 pass # Skip unsupported elements. 

322 return result 

323 

324 def transform_function(self, fn, user_context): 

325 """Transforms a function. 

326 

327 Subclasses may override this method. The return value is opaque. 

328 

329 The method receives the original AST. The result is passed as-is to the 

330 output of `transform`. 

331 

332 Args: 

333 fn: A function or lambda. 

334 user_context: An opaque object (may be None) that is forwarded to 

335 transform_ast, through the ctx.user attribute. 

336 Returns: 

337 Tuple[Any, Any]. By default it returns the output of transform_ast, 

338 together with a `transformer.Context` containing information about the 

339 transformation process. 

340 """ 

341 future_features = inspect_utils.getfutureimports(fn) 

342 node, source = parser.parse_entity(fn, future_features=future_features) 

343 logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) 

344 

345 origin_info.resolve_entity(node, source, fn) 

346 

347 namespace = inspect_utils.getnamespace(fn) 

348 namer = naming.Namer(namespace) 

349 new_name = namer.new_symbol(self.get_transformed_name(node), ()) 

350 entity_info = transformer.EntityInfo( 

351 name=new_name, 

352 source_code=source, 

353 source_file='<fragment>', 

354 future_features=future_features, 

355 namespace=namespace) 

356 context = transformer.Context(entity_info, namer, user_context) 

357 

358 node = self._erase_arg_defaults(node) 

359 result = self.transform_ast(node, context) 

360 

361 return result, context 

362 

363 

364class PyToPy(GenericTranspiler): 

365 """A generic Python-to-Python transpiler. 

366 

367 Its `transform` method offers a function-in, function-out interface. 

368 Internally, it takes care of parsing, caching and loading of the translated 

369 code. 

370 

371 Users typically subclass this, overriding `transform_ast`. 

372 

373 Usually, instances of this class are singletons, since each instance manages 

374 its own cache. The caching can be controlled by overriding `get_caching_key`. 

375 

376 Example: 

377 

378 class MyTransformer(PyToPy): 

379 

380 def transform_ast(self, node, ctx): 

381 node = <<transform node, usually using ast.NodeTransformer classes>> 

382 return node 

383 

384 transformer = MyTransfomer() 

385 

386 new_f, module, source_map = transformer.transform_function(f, ...) 

387 # new_f is a function with signature identical to f 

388 

389 The transformed function has access to the same namespace as the original 

390 function. To allow access to internal APIs, users may inject additional 

391 symbols by overriding `get_extra_locals`. 

392 """ 

393 

394 def __init__(self): 

395 self._cache_lock = threading.RLock() 

396 self._cache = cache.CodeObjectCache() 

397 

398 def get_extra_locals(self): 

399 """Returns extra static local variables to be made to transformed code. 

400 

401 Subclasses must override this. 

402 

403 Returns: 

404 extra_locals: A Dict[Text, Any] containing additional variables to make 

405 available to the transformed code. 

406 """ 

407 raise NotImplementedError('subclasses must override this') 

408 

409 def get_caching_key(self, user_context): 

410 """Returns a unique key to use for caching. 

411 

412 Subclasses must override this. 

413 

414 Calls made to `transform_function` with functions that have the same code 

415 object and caching key will return a cached instance on subsequent 

416 invocations. 

417 

418 Args: 

419 user_context: The context object which was passed to `transform`. 

420 

421 Returns: 

422 extra_locals: A hashable. 

423 """ 

424 raise NotImplementedError('subclasses must override this') 

425 

426 def _cached_factory(self, fn, cache_subkey): 

427 cached_factory = self._cache[fn][cache_subkey] 

428 logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, 

429 cached_factory) 

430 return cached_factory 

431 

432 def transform_function(self, fn, user_context): 

433 """Transforms a function. See GenericTranspiler.trasnform_function. 

434 

435 This overload wraps the parent's `transform_function`, adding caching and 

436 facilities to instantiate the output as a Python object. It also 

437 adds facilities to make new symbols available to the generated Python code, 

438 visible as local variables - see `get_extra_locals`. 

439 

440 Args: 

441 fn: A function or lambda. 

442 user_context: An opaque object (may be None) that is forwarded to 

443 transform_ast, through the ctx.user attribute. 

444 Returns: 

445 A tuple: 

446 * A function or lambda with the same signature and closure as `fn` 

447 * The temporary module into which the transformed function was loaded 

448 * The source map as a 

449 Dict[origin_info.LineLocation, origin_info.OriginInfo] 

450 """ 

451 cache_subkey = self.get_caching_key(user_context) 

452 

453 if self._cache.has(fn, cache_subkey): 

454 # Fast path: use a lock-free check. 

455 factory = self._cached_factory(fn, cache_subkey) 

456 

457 else: 

458 with self._cache_lock: 

459 # Check again under lock. 

460 if self._cache.has(fn, cache_subkey): 

461 factory = self._cached_factory(fn, cache_subkey) 

462 

463 else: 

464 logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) 

465 # TODO(mdan): Confusing overloading pattern. Fix. 

466 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context) 

467 

468 if isinstance(nodes, gast.Lambda): 

469 nodes = gast.Assign( 

470 targets=[ 

471 gast.Name( 

472 ctx.info.name, 

473 ctx=gast.Store(), 

474 annotation=None, 

475 type_comment=None) 

476 ], 

477 value=nodes) 

478 else: 

479 nodes.name = ctx.info.name 

480 

481 if logging.has_verbosity(2): 

482 logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) 

483 

484 factory = _PythonFnFactory( 

485 ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals()) 

486 factory.create( 

487 nodes, ctx.namer, future_features=ctx.info.future_features) 

488 self._cache[fn][cache_subkey] = factory 

489 

490 transformed_fn = factory.instantiate( 

491 globals_=fn.__globals__, 

492 closure=fn.__closure__ or (), 

493 defaults=fn.__defaults__, 

494 kwdefaults=getattr(fn, '__kwdefaults__', None)) 

495 return transformed_fn, factory.module, factory.source_map