Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/control_flow.py: 19%

173 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Handles control flow statements: while, for, if.""" 

16 

17import gast 

18 

19from tensorflow.python.autograph.core import converter 

20from tensorflow.python.autograph.lang import directives 

21from tensorflow.python.autograph.pyct import anno 

22from tensorflow.python.autograph.pyct import cfg 

23from tensorflow.python.autograph.pyct import origin_info 

24from tensorflow.python.autograph.pyct import parser 

25from tensorflow.python.autograph.pyct import qual_names 

26from tensorflow.python.autograph.pyct import templates 

27from tensorflow.python.autograph.pyct.static_analysis import activity 

28from tensorflow.python.autograph.pyct.static_analysis import annos 

29from tensorflow.python.autograph.pyct.static_analysis import liveness 

30from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 

31from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 

32 

33 

34class _Function(object): 

35 

36 scope = None 

37 

38 

39class ControlFlowTransformer(converter.Base): 

40 """Transforms control flow structures like loops an conditionals.""" 

41 

42 def visit_Lambda(self, node): 

43 with self.state[_Function] as fn: 

44 fn.scope = anno.getanno(node, anno.Static.SCOPE) 

45 return self.generic_visit(node) 

46 

47 def visit_FunctionDef(self, node): 

48 with self.state[_Function] as fn: 

49 fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 

50 return self.generic_visit(node) 

51 

52 def _create_nonlocal_declarations(self, vars_): 

53 vars_ = set(vars_) 

54 results = [] 

55 global_vars = self.state[_Function].scope.globals & vars_ 

56 

57 if global_vars: 

58 results.append(gast.Global([str(v) for v in global_vars])) 

59 

60 nonlocal_vars = [ 

61 v for v in vars_ if not v.is_composite() and v not in global_vars] 

62 if nonlocal_vars: 

63 results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) 

64 

65 return results 

66 

67 def _create_state_functions( 

68 self, block_vars, nonlocal_declarations, getter_name, setter_name): 

69 if not block_vars: 

70 template = """ 

71 def getter_name(): 

72 return () 

73 def setter_name(block_vars): 

74 pass 

75 """ 

76 return templates.replace( 

77 template, getter_name=getter_name, setter_name=setter_name) 

78 

79 guarded_block_vars = [] 

80 for v in block_vars: 

81 if v.is_simple(): 

82 guarded_block_vars.append(v) 

83 else: 

84 guarded_block_vars.append( 

85 templates.replace_as_expression( 

86 'ag__.ldu(lambda: var_, name)', 

87 var_=v, 

88 name=gast.Constant(str(v), kind=None))) 

89 

90 template = """ 

91 def getter_name(): 

92 return guarded_state_vars, 

93 def setter_name(vars_): 

94 nonlocal_declarations 

95 state_vars, = vars_ 

96 """ 

97 return templates.replace( 

98 template, 

99 nonlocal_declarations=nonlocal_declarations, 

100 getter_name=getter_name, 

101 guarded_state_vars=guarded_block_vars, 

102 setter_name=setter_name, 

103 state_vars=tuple(block_vars)) 

104 

105 def _create_loop_options(self, node): 

106 if not anno.hasanno(node, anno.Basic.DIRECTIVES): 

107 return gast.Dict([], []) 

108 

109 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) 

110 if directives.set_loop_options not in loop_directives: 

111 return gast.Dict([], []) 

112 

113 opts_dict = loop_directives[directives.set_loop_options] 

114 str_keys, values = zip(*opts_dict.items()) 

115 keys = [gast.Constant(s, kind=None) for s in str_keys] 

116 values = list(values) # ast and gast don't play well with tuples. 

117 return gast.Dict(keys, values) 

118 

119 def _create_undefined_assigns(self, undefined_symbols): 

120 assignments = [] 

121 for s in undefined_symbols: 

122 template = ''' 

123 var = ag__.Undefined(symbol_name) 

124 ''' 

125 assignments += templates.replace( 

126 template, 

127 var=s, 

128 symbol_name=gast.Constant(s.ssf(), kind=None)) 

129 return assignments 

130 

131 def _get_block_basic_vars(self, modified, live_in, live_out): 

132 nonlocals = self.state[_Function].scope.nonlocals 

133 basic_scope_vars = [] 

134 for s in modified: 

135 if s.is_composite(): 

136 # TODO(mdan): Raise an error when this happens for a TF scope. 

137 continue 

138 # Variables not live into or out of the scope are considered local to the 

139 # scope. 

140 if s in live_in or s in live_out or s in nonlocals: 

141 basic_scope_vars.append(s) 

142 continue 

143 return frozenset(basic_scope_vars) 

144 

145 def _get_block_composite_vars(self, modified, live_in): 

146 # The scope variables corresponding to composite symbols (e.g. `self.x`). 

147 composite_scope_vars = [] 

148 for s in modified: 

149 if not s.is_composite(): 

150 continue 

151 # Mutations made to objects created inside the scope will appear as writes 

152 # to composite symbols. Because these mutations appear as modifications 

153 # made to composite symbols, we check whether the composite's parent is 

154 # actually live into the scope. 

155 # Example: 

156 # while cond: 

157 # x = Foo() 

158 # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not. 

159 # 

160 # Note that some parents might not be symbols - for example, in x['foo'], 

161 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the 

162 # liveness of literals. 

163 support_set_symbols = tuple( 

164 sss for sss in s.support_set if sss.is_symbol()) 

165 if not all(sss in live_in for sss in support_set_symbols): 

166 continue 

167 composite_scope_vars.append(s) 

168 return frozenset(composite_scope_vars) 

169 

170 def _get_block_vars(self, node, modified): 

171 """Determines the variables affected inside a control flow statement.""" 

172 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 

173 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 

174 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 

175 fn_scope = self.state[_Function].scope 

176 

177 basic_scope_vars = self._get_block_basic_vars( 

178 modified, 

179 live_in, 

180 live_out) 

181 composite_scope_vars = self._get_block_composite_vars(modified, live_in) 

182 scope_vars = tuple(basic_scope_vars | composite_scope_vars) 

183 

184 # Variables that are modified inside the scope, but not defined 

185 # before entering it. Only simple variables must be defined. The 

186 # composite ones will be implicitly checked at runtime. 

187 possibly_undefined = ( 

188 modified - defined_in - fn_scope.globals - fn_scope.nonlocals) 

189 undefined = tuple(v for v in possibly_undefined if not v.is_composite()) 

190 

191 # Variables that are modified inside the scope, and depend on values outside 

192 # it. 

193 input_only = basic_scope_vars & live_in - live_out 

194 

195 # Place the outputs first, then sort lexicographically. 

196 scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v)) 

197 nouts = len(scope_vars) - len(input_only) 

198 

199 return scope_vars, undefined, nouts 

200 

201 def visit_If(self, node): 

202 node = self.generic_visit(node) 

203 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 

204 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 

205 

206 cond_vars, undefined, nouts = self._get_block_vars( 

207 node, body_scope.bound | orelse_scope.bound) 

208 

209 undefined_assigns = self._create_undefined_assigns(undefined) 

210 

211 nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) 

212 

213 reserved = body_scope.referenced | orelse_scope.referenced 

214 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 

215 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 

216 state_functions = self._create_state_functions( 

217 cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) 

218 

219 orelse_body = node.orelse 

220 if not orelse_body: 

221 orelse_body = [gast.Pass()] 

222 

223 template = """ 

224 state_functions 

225 def body_name(): 

226 nonlocal_declarations 

227 body 

228 def orelse_name(): 

229 nonlocal_declarations 

230 orelse 

231 undefined_assigns 

232 ag__.if_stmt( 

233 test, 

234 body_name, 

235 orelse_name, 

236 state_getter_name, 

237 state_setter_name, 

238 (symbol_names,), 

239 nouts) 

240 """ 

241 new_nodes = templates.replace( 

242 template, 

243 body=node.body, 

244 body_name=self.ctx.namer.new_symbol('if_body', reserved), 

245 orelse=orelse_body, 

246 orelse_name=self.ctx.namer.new_symbol('else_body', reserved), 

247 nonlocal_declarations=nonlocal_declarations, 

248 nouts=gast.Constant(nouts, kind=None), 

249 state_functions=state_functions, 

250 state_getter_name=state_getter_name, 

251 state_setter_name=state_setter_name, 

252 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), 

253 test=node.test, 

254 undefined_assigns=undefined_assigns) 

255 origin_info.copy_origin(node, new_nodes[-1]) 

256 return new_nodes 

257 

258 def visit_While(self, node): 

259 node = self.generic_visit(node) 

260 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 

261 

262 loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound) 

263 

264 undefined_assigns = self._create_undefined_assigns(undefined) 

265 

266 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 

267 

268 reserved = body_scope.referenced 

269 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 

270 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 

271 state_functions = self._create_state_functions( 

272 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 

273 

274 opts = self._create_loop_options(node) 

275 

276 template = """ 

277 state_functions 

278 def body_name(): 

279 nonlocal_declarations 

280 body 

281 def test_name(): 

282 return test 

283 undefined_assigns 

284 ag__.while_stmt( 

285 test_name, 

286 body_name, 

287 state_getter_name, 

288 state_setter_name, 

289 (symbol_names,), 

290 opts) 

291 """ 

292 new_nodes = templates.replace( 

293 template, 

294 body=node.body, 

295 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 

296 nonlocal_declarations=nonlocal_declarations, 

297 opts=opts, 

298 state_functions=state_functions, 

299 state_getter_name=state_getter_name, 

300 state_setter_name=state_setter_name, 

301 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 

302 test=node.test, 

303 test_name=self.ctx.namer.new_symbol('loop_test', reserved), 

304 undefined_assigns=undefined_assigns) 

305 origin_info.copy_origin(node, new_nodes[-1]) 

306 return new_nodes 

307 

308 def visit_For(self, node): 

309 node = self.generic_visit(node) 

310 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 

311 iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) 

312 

313 loop_vars, undefined, _ = self._get_block_vars( 

314 node, body_scope.bound | iter_scope.bound) 

315 

316 undefined_assigns = self._create_undefined_assigns(undefined) 

317 

318 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) 

319 

320 reserved = body_scope.referenced | iter_scope.referenced 

321 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) 

322 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) 

323 state_functions = self._create_state_functions( 

324 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) 

325 

326 opts = self._create_loop_options(node) 

327 opts.keys.append(gast.Constant('iterate_names', kind=None)) 

328 opts.values.append(gast.Constant( 

329 parser.unparse(node.target, include_encoding_marker=False), kind=None)) 

330 

331 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): 

332 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) 

333 extra_test_name = self.ctx.namer.new_symbol( 

334 'extra_test', reserved) 

335 template = """ 

336 def extra_test_name(): 

337 nonlocal_declarations 

338 return extra_test_expr 

339 """ 

340 extra_test_function = templates.replace( 

341 template, 

342 extra_test_expr=extra_test, 

343 extra_test_name=extra_test_name, 

344 loop_vars=loop_vars, 

345 nonlocal_declarations=nonlocal_declarations) 

346 else: 

347 extra_test_name = parser.parse_expression('None') 

348 extra_test_function = [] 

349 

350 # iterate_arg_name holds a single arg with the iterates, which may be a 

351 # tuple. 

352 iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) 

353 template = """ 

354 iterates = iterate_arg_name 

355 """ 

356 iterate_expansion = templates.replace( 

357 template, iterate_arg_name=iterate_arg_name, iterates=node.target) 

358 origin_info.copy_origin(node, iterate_expansion) 

359 

360 template = """ 

361 state_functions 

362 def body_name(iterate_arg_name): 

363 nonlocal_declarations 

364 iterate_expansion 

365 body 

366 extra_test_function 

367 undefined_assigns 

368 ag__.for_stmt( 

369 iterated, 

370 extra_test_name, 

371 body_name, 

372 state_getter_name, 

373 state_setter_name, 

374 (symbol_names,), 

375 opts) 

376 """ 

377 new_nodes = templates.replace( 

378 template, 

379 body=node.body, 

380 body_name=self.ctx.namer.new_symbol('loop_body', reserved), 

381 extra_test_function=extra_test_function, 

382 extra_test_name=extra_test_name, 

383 iterate_arg_name=iterate_arg_name, 

384 iterate_expansion=iterate_expansion, 

385 iterated=node.iter, 

386 nonlocal_declarations=nonlocal_declarations, 

387 opts=opts, 

388 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), 

389 state_functions=state_functions, 

390 state_getter_name=state_getter_name, 

391 state_setter_name=state_setter_name, 

392 undefined_assigns=undefined_assigns) 

393 origin_info.copy_origin(node, new_nodes[-1]) 

394 return new_nodes 

395 

396 

397class AnnotatedDef(reaching_definitions.Definition): 

398 

399 def __init__(self): 

400 super(AnnotatedDef, self).__init__() 

401 self.directives = {} 

402 

403 

404def transform(node, ctx): 

405 graphs = cfg.build(node) 

406 node = qual_names.resolve(node) 

407 node = activity.resolve(node, ctx, None) 

408 node = reaching_definitions.resolve(node, ctx, graphs) 

409 node = reaching_fndefs.resolve(node, ctx, graphs) 

410 node = liveness.resolve(node, ctx, graphs) 

411 

412 node = ControlFlowTransformer(ctx).visit(node) 

413 return node