Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util.py: 21%

153 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utility functions for control flow. 

17 

18This file is necessary to avoid cyclic dependencies between ops.py and 

19control_flow_ops.py. 

20""" 

21 

22import os 

23import traceback 

24 

25from tensorflow.python import tf2 

26from tensorflow.python.platform import tf_logging as logging 

27 

28ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and 

29 os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or 

30 os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or 

31 os.getenv("TF_ENABLE_COND_V2", "0") != "0" or 

32 os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or 

33 os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") 

34 

35 

36# TODO(b/137793122): Remove this. 

37def enable_control_flow_v2(): # pylint: disable=invalid-name 

38 """Use control flow v2. 

39 

40 Do not use this symbol. This will be removed. 

41 """ 

42 global ENABLE_CONTROL_FLOW_V2 

43 ENABLE_CONTROL_FLOW_V2 = True 

44 

45 

46def EnableControlFlowV2(graph): 

47 """Returns whether control flow v2 should be used in `graph`.""" 

48 # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs). 

49 # TODO(skyewm): do something better than hasattr without messing up imports. 

50 return ENABLE_CONTROL_FLOW_V2 or ( 

51 graph.building_function and not hasattr(graph, "_captured")) 

52 

53 

54def IsInXLAContext(op): 

55 try: 

56 xla_compile = op.get_attr("_XlaCompile") 

57 if xla_compile: return True 

58 except ValueError: 

59 pass 

60 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

61 return GetContainingXLAContext(ctxt) is not None 

62 

63 

64def InXlaContext(graph): 

65 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access 

66 return GetContainingXLAContext(ctxt) is not None 

67 

68 

69def GraphOrParentsInXlaContext(graph): 

70 while True: 

71 if InXlaContext(graph): return True 

72 try: 

73 graph = graph.outer_graph 

74 except AttributeError: 

75 return False 

76 

77 

78def IsInWhileLoop(op): 

79 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

80 return GetContainingWhileContext(ctxt) is not None 

81 

82 

83def IsInCond(op): 

84 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

85 return GetContainingCondContext(ctxt) is not None 

86 

87 

88def IsSwitch(op): 

89 """Return true if `op` is a Switch.""" 

90 return op.type == "Switch" or op.type == "RefSwitch" 

91 

92 

93def IsMerge(op): 

94 """Return true if `op` is a Merge.""" 

95 return op.type == "Merge" or op.type == "RefMerge" 

96 

97 

98def IsLoopEnter(op): 

99 """Returns true if `op` is an Enter.""" 

100 return op.type == "Enter" or op.type == "RefEnter" 

101 

102 

103def IsLoopExit(op): 

104 """Return true if `op` is an Exit.""" 

105 return op.type == "Exit" or op.type == "RefExit" 

106 

107 

108def IsCondSwitch(op): 

109 """Return true if `op` is the Switch for a conditional.""" 

110 if not IsSwitch(op): 

111 return False 

112 if not op.outputs: 

113 return False 

114 # Switch nodes are not part of the cond control flow context that they 

115 # represent, so consider the consumers of its outputs to determine if it is 

116 # cond switch or not. A switch is a cond switch iff all its consumers are in 

117 # cond contexts. 

118 is_cond_switch = True 

119 for o in op.outputs: 

120 for c in o.consumers(): 

121 ctxt = c._get_control_flow_context() # pylint: disable=protected-access 

122 if IsLoopEnter(c): 

123 ctxt = ctxt.outer_context 

124 is_cond_switch = is_cond_switch and (ctxt is not None and 

125 ctxt.IsCondContext()) 

126 return is_cond_switch 

127 

128 

129def IsCondMerge(op): 

130 """Return true if `op` is the Merge for a conditional.""" 

131 if not IsMerge(op): 

132 return False 

133 if not op.inputs: 

134 return False 

135 # Merge nodes are not part of the cond control flow context that they 

136 # represent, so consider the inputs to the merge of to determine if it is 

137 # cond merge or not: A merge is a cond merge iff all its inputs are in 

138 # cond contexts. 

139 is_cond_merge = True 

140 for i in op.inputs: 

141 ctxt = GetOutputContext(i.op) 

142 is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext() 

143 return is_cond_merge 

144 

145 

146def IsLoopSwitch(op): 

147 """Return true if `op` is the Switch for a while loop.""" 

148 if IsSwitch(op): 

149 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

150 return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op) 

151 return False 

152 

153 

154def IsLoopMerge(op): 

155 """Return true if `op` is the Merge for a while loop.""" 

156 if IsMerge(op): 

157 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

158 return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op) 

159 return False 

160 

161 

162def IsLoopConstantEnter(op): 

163 """Return true iff op is a loop invariant.""" 

164 return IsLoopEnter(op) and op.get_attr("is_constant") 

165 

166 

167def GetLoopConstantEnter(value): 

168 """Return the enter op if we can infer `value` to be a loop invariant.""" 

169 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"} 

170 op = value.op 

171 while op.type in id_ops: 

172 op = op.inputs[0].op 

173 return op if IsLoopConstantEnter(op) else None 

174 

175 

176def GetOutputContext(op): 

177 """Return the control flow context for the output of an op.""" 

178 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

179 # Exit nodes usually have a control flow context, except in the case where the 

180 # exit node was imported via import_graph_def (in which case no nodes have 

181 # control flow contexts). 

182 if ctxt is not None and IsLoopExit(op): 

183 ctxt = ctxt.outer_context 

184 return ctxt 

185 

186 

187def GetContainingWhileContext(ctxt, stop_ctxt=None): 

188 """Returns the first ancestor WhileContext of `ctxt`. 

189 

190 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a 

191 while loop. 

192 

193 Args: 

194 ctxt: ControlFlowContext 

195 stop_ctxt: ControlFlowContext, optional. If provided, the search will end 

196 if it sees stop_ctxt. 

197 

198 Returns: 

199 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing 

200 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not 

201 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. 

202 """ 

203 while ctxt: 

204 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt 

205 ctxt = ctxt.outer_context 

206 return None 

207 

208 

209def GetContainingXLAContext(ctxt): 

210 """Returns the first ancestor XLAContext of `ctxt`. 

211 

212 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a 

213 while loop. 

214 

215 Args: 

216 ctxt: ControlFlowContext 

217 

218 Returns: 

219 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing 

220 `ctxt`, or None if `ctxt` is not in a while loop. 

221 """ 

222 while ctxt: 

223 if ctxt.IsXLAContext(): return ctxt 

224 ctxt = ctxt.outer_context 

225 return None 

226 

227 

228def GetContainingCondContext(ctxt): 

229 """Returns the first ancestor CondContext of `ctxt`. 

230 

231 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond. 

232 

233 Args: 

234 ctxt: ControlFlowContext 

235 

236 Returns: 

237 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing 

238 `ctxt`, or None if `ctxt` is not in a cond. 

239 """ 

240 while ctxt: 

241 if ctxt.IsCondContext(): return ctxt 

242 ctxt = ctxt.outer_context 

243 return None 

244 

245 

246def IsContainingContext(ctxt, maybe_containing_ctxt): 

247 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`.""" 

248 while ctxt is not maybe_containing_ctxt: 

249 if ctxt is None: return False 

250 ctxt = ctxt.outer_context 

251 return True 

252 

253 

254def OpInContext(op, ctxt): 

255 return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access 

256 

257 

258def TensorInContext(tensor, ctxt): 

259 return OpInContext(tensor.op, ctxt) 

260 

261 

262def CheckInputFromValidContext(op, input_op): 

263 """Returns whether `input_op` can be used from `op`s context. 

264 

265 Conceptually, only inputs from op's while context or any ancestor while 

266 context (including outside of any context) are valid. In practice, there are 

267 many other edge cases as well. 

268 

269 Args: 

270 op: Operation 

271 input_op: Operation 

272 

273 Raises: 

274 ValueError: if input_op is from an invalid context. 

275 """ 

276 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

277 input_ctxt = GetOutputContext(input_op) 

278 valid = False 

279 

280 if not input_ctxt: 

281 # input_op isn't in a control flow context. 

282 valid = True 

283 elif op_ctxt is input_ctxt: 

284 # input_op is in the same context as op. 

285 valid = True 

286 else: 

287 while_ctxt = GetContainingWhileContext(op_ctxt) 

288 input_while_ctxt = GetContainingWhileContext(input_ctxt) 

289 

290 if while_ctxt is None: 

291 if input_while_ctxt is None: 

292 # Neither op nor input_op is in a while loop, but one or both are in 

293 # conds. We allow this, although execution will fail if the branch 

294 # corresponding to input_op's cond context isn't taken. 

295 valid = True 

296 # Invalid if op isn't in a while loop and input_op is. Unless... 

297 if IsLoopEnter(op): 

298 # WhileContext._BuildLoop clears context for Enter nodes. 

299 valid = True 

300 if IsSwitch(op): 

301 # CondContext.AddValue clears context for Switch nodes. 

302 valid = True 

303 elif IsContainingContext(while_ctxt, input_while_ctxt): 

304 # input_op is in a while loop which contains op's while loop (or not in a 

305 # while loop at all). 

306 valid = True 

307 elif (while_ctxt.grad_state and 

308 IsContainingContext(while_ctxt.grad_state.forward_context, 

309 input_while_ctxt)): 

310 # op is in a gradient context and input_op is in the associated forward 

311 # pass context or an ancestor thereof. This case is need to build while 

312 # loop gradients. 

313 # NOTE(skyewm): we theoretically also need this case for custom gradient 

314 # functions that close over tensors from ancestor contexts, but I haven't 

315 # verified this. 

316 valid = True 

317 elif (while_ctxt.grad_state and 

318 while_ctxt.grad_state.forward_context is 

319 input_while_ctxt._outer_context): # pylint: disable=protected-access 

320 # op is in a gradient context and input_op is in a child of the associated 

321 # forward pass context. This case is needed for the gradients of while 

322 # loops with conds. 

323 valid = True 

324 elif (input_while_ctxt.grad_state and 

325 input_while_ctxt.grad_state.forward_context is while_ctxt): 

326 # input_op is in the gradient context of op's context. This case is needed 

327 # when the gradient of a while loop gradient is requested (this will 

328 # eventually fail unless there is a stop_gradient() or similar). 

329 valid = True 

330 elif (input_while_ctxt.grad_state and 

331 input_ctxt.grad_state.forward_context.grad_state and 

332 input_ctxt.grad_state.forward_context.grad_state.forward_context is 

333 while_ctxt): 

334 # input_op is in the grad grad context of op's context. This case is 

335 # needed when the gradient of a while loop gradient is requested (this 

336 # will eventually fail unless there is a stop_gradient() or similar). 

337 valid = True 

338 

339 if not valid: 

340 if while_ctxt: 

341 error_msg = ( 

342 f"Cannot use '{input_op.name}' as input to '{op.name}' because they " 

343 "are in different while loops.") 

344 else: 

345 error_msg = ( 

346 f"Cannot use '{input_op.name}' as input to '{op.name}' because " 

347 f"'{input_op.name}' is in a while loop.") 

348 

349 # Log the error message plus the relevant stack traces. The stacks may be 

350 # useful for debugging this error, but we don't want to raise an 

351 # unreadable exception. 

352 log_msg = error_msg 

353 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt) 

354 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt) 

355 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % ( 

356 op.name, "".join(traceback.format_list(op.traceback)), 

357 input_op.name, "".join(traceback.format_list(input_op.traceback))) 

358 logging.info(log_msg) 

359 raise ValueError(error_msg + " See info log for more details.") 

360 

361 

362def GetWhileContext(op): 

363 """Get the WhileContext to which this op belongs.""" 

364 ctxt = op._get_control_flow_context() # pylint: disable=protected-access 

365 if ctxt: 

366 ctxt = ctxt.GetWhileContext() 

367 return ctxt