Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util_v2.py: 26%

141 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utilities for V2 control flow.""" 

17 

18from tensorflow.core.framework import attr_value_pb2 

19from tensorflow.python.data.util import structure # pylint: disable=unused-import 

20from tensorflow.python.eager import context 

21from tensorflow.python.eager.polymorphic_function import atomic_function 

22from tensorflow.python.eager.polymorphic_function import monomorphic_function 

23from tensorflow.python.eager.polymorphic_function import tracing_compiler 

24from tensorflow.python.framework import function_def_to_graph 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework.func_graph import FuncGraph 

27from tensorflow.python.ops import control_flow_util 

28from tensorflow.python.ops import control_flow_v2_func_graphs 

29from tensorflow.python.ops import gradients_util 

30from tensorflow.python.util import keras_deps 

31from tensorflow.python.util import tf_contextlib 

32 

33_EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = None 

34_DISABLE_LOWER_USING_SWITCH_MERGE = False 

35 

36 

37CondBranchFuncGraph = control_flow_v2_func_graphs.CondBranchFuncGraph 

38WhileCondFuncGraph = control_flow_v2_func_graphs.WhileCondFuncGraph 

39WhileBodyFuncGraph = control_flow_v2_func_graphs.WhileBodyFuncGraph 

40 

41 

42def in_defun(): 

43 """Returns if the current graph is, or is nested in, a defun.""" 

44 if context.executing_eagerly(): return False 

45 

46 graph = ops.get_default_graph() 

47 while (isinstance(graph, CondBranchFuncGraph) or 

48 isinstance(graph, WhileBodyFuncGraph) or 

49 isinstance(graph, WhileCondFuncGraph)): 

50 graph = graph.outer_graph 

51 return isinstance(graph, FuncGraph) 

52 

53 

54def in_while_loop_defun(graph): 

55 """Returns if the graph is a while loop FuncGraph.""" 

56 if context.executing_eagerly(): return False 

57 return (isinstance(graph, WhileCondFuncGraph) or 

58 isinstance(graph, WhileBodyFuncGraph)) 

59 

60 

61def create_new_tf_function(func_graph): 

62 """Converts func_graph to a TF_Function and adds it to the current graph. 

63 

64 Args: 

65 func_graph: FuncGraph 

66 

67 Returns: 

68 The name of the new TF_Function. 

69 """ 

70 func = atomic_function.from_func_graph( 

71 func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) 

72 

73 func_graph.outer_graph._add_function_recursive(func) # pylint: disable=protected-access 

74 return func_graph.name 

75 

76 

77def unique_fn_name(scope, name): 

78 """Returns a unique name to use for a control flow function. 

79 

80 Args: 

81 scope: A name scope string. 

82 name: An identifier for this function (e.g. "true", "body"). 

83 

84 Returns: 

85 A string, the name to use for the function. 

86 """ 

87 return ("%s%s_%s" % (scope, name, ops.uid())).replace("/", "_") 

88 

89 

90def unique_grad_fn_name(forward_name): 

91 return "%s_grad_%s" % (forward_name, ops.uid()) 

92 

93 

94def maybe_set_lowering_attr(op, lower_using_switch_merge=None): 

95 """Sets the flag to enable lowering on `op` if necessary. 

96 

97 Lowering allows cond_v2 and while_v2 to avoid some of the limitations of 

98 Functions, allowing users to specify devices & colocation inside of cond_v2 

99 and while_v2 input functions, and enabling non-strict evaluation & partial 

100 pruning. This brings v2 control flow closer to feature parity with v1 control 

101 flow. 

102 

103 However, we do not lower in the following cases: 

104 - When the `If` or `While` ops are in the XLA context. Because it is easier 

105 for XLA to apply its own optimizations when dealing with un-lowered 

106 control flow operators than with low-level control flow primitives. 

107 - When the eager execution context specifies the executor of functions to 

108 be the single threaded executor (see context.function_executor_type()). 

109 Because the single threaded executor does not support v1 control flow ops. 

110 - When 'lower_using_switch_merge' is explicitly set to False. 

111 

112 Args: 

113 op: An `If` or `While` Operation. 

114 lower_using_switch_merge: Explicit value to lower or not (optional). 

115 """ 

116 if lower_using_switch_merge is not None: 

117 # pylint: disable=protected-access 

118 op._set_attr("_lower_using_switch_merge", 

119 attr_value_pb2.AttrValue(b=lower_using_switch_merge)) 

120 # pylint: enable=protected-access 

121 elif (not _DISABLE_LOWER_USING_SWITCH_MERGE and 

122 not control_flow_util.GraphOrParentsInXlaContext(op.graph) and 

123 context.context().function_call_options.executor_type != 

124 "SINGLE_THREADED_EXECUTOR"): 

125 # pylint: disable=protected-access 

126 op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) 

127 # pylint: enable=protected-access 

128 

129 

130def maybe_propagate_compile_time_consts_in_xla(op): 

131 """Tells XLA whether to propagate compile-time consts in the loop body. 

132 

133 This is needed to make compile time constants available to ops, for example 

134 `max_num_elements` in `EmptyTensorList`, inside the loop body. Ideally this 

135 would always be turned on, but that doesn't work with legacy functionalized 

136 while_loops. 

137 

138 Args: 

139 op: A `While` Operation. 

140 """ 

141 if control_flow_util.GraphOrParentsInXlaContext(op.graph): 

142 # pylint: disable=protected-access 

143 op._set_attr("_xla_propagate_compile_time_consts", 

144 attr_value_pb2.AttrValue(b=True)) 

145 # pylint: enable=protected-access 

146 

147 

148def resource_input_index(tensor_name, input_names, node_defs, functions): 

149 """Returns the index of the input corresponding to `tensor_name`. 

150 

151 This method is used to find the corresponding index of an arbitrary resource 

152 tensor in a function (the function could be a loop body). We assume that 

153 resource handles are never created in functions, so that every resource 

154 tensor can be traced back to a function input. 

155 

156 The awkward signature of this method is to make it work with both FuncGraphs 

157 and FunctionDefs. This is so we can recurse on function call ops without 

158 building the corresponding FuncGraph (note that even if a FuncGraph for a 

159 FunctionDef already exists, the input/output/node names may have been 

160 changed when the FuncGraph was serialized to the FunctionDef, which makes it 

161 unusable with this algorithm). 

162 

163 Args: 

164 tensor_name: the name of the resource tensor to be resolved to an input. 

165 input_names: a list of the names of all inputs to the function. 

166 node_defs: a dict mapping op name -> NodeDef for every op in the function. 

167 functions: a dict mapping function name -> AtomicFunction. 

168 

169 Returns: 

170 The index into input_names corresponding to `tensor_name`. 

171 """ 

172 while tensor_name not in input_names: 

173 # FunctionDefs and graphs use different tensor naming conventions. 

174 parts = tensor_name.split(":") 

175 if len(parts) == 3: 

176 op_name, _, output_idx = parts 

177 elif len(parts) == 2: 

178 op_name, output_idx = parts 

179 else: 

180 assert len(parts) == 1 

181 op_name = parts[0] 

182 output_idx = 0 

183 tensor_name = "%s:%d" % (tensor_name, output_idx) 

184 # Check again for cases where the tensor suffix (":0") is stripped out. 

185 if tensor_name in input_names: 

186 break 

187 output_idx = int(output_idx) 

188 node_def = node_defs[op_name] 

189 

190 def _extract_input_index(function_attribute_name): 

191 func_name = node_def.attr[function_attribute_name].func.name 

192 fdef = functions[func_name].cached_definition 

193 output_arg_name = fdef.signature.output_arg[output_idx].name 

194 output_tensor_name = fdef.ret[output_arg_name] 

195 return resource_input_index( 

196 output_tensor_name, [arg.name for arg in fdef.signature.input_arg], 

197 {ndef.name: ndef for ndef in fdef.node_def}, functions) 

198 

199 if node_def.op in ("Identity", "While"): 

200 # Captured resources occur at the same index in the lists of inputs and 

201 # outputs of a while or identity op. So we lookup the input of `tensor.op` 

202 # at the same index as the index of `tensor` in the `tensor.op.outputs`. 

203 tensor_name = node_def.input[output_idx] 

204 elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): 

205 # Functions output any captured resource tensors used by their 

206 # gradients. `tensor_name` is one of these outputs from a nested 

207 # function call, so recursively find the corresponding input in the 

208 # nested FunctionDef. 

209 tensor_name = node_def.input[_extract_input_index("f")] 

210 elif node_def.op in ("If", "StatelessIf"): 

211 input_index = _extract_input_index("then_branch") 

212 if input_index != _extract_input_index("else_branch"): 

213 raise AssertionError( 

214 ("Expected cond branches ({} op) to each have the same " 

215 "input->output mapping of resources.").format(node_def.op)) 

216 tensor_name = node_def.input[ 

217 # Ignore the `cond` input; the function inputs come after. 

218 input_index + 1] 

219 else: 

220 # We assume there are no other ops types that will "forward" resource 

221 # handles like this, so all other handles must have been created by the 

222 # op. (Note that cond_v2 wraps resource handle outputs in optionals, 

223 # which we'll end up accumulating). 

224 raise ValueError("Taking gradient of a while loop which creates " 

225 "a resource in its body is not supported: %s (%s)" 

226 % (op_name, node_def.op)) 

227 

228 return input_names.index(tensor_name) 

229 

230 

231@tf_contextlib.contextmanager 

232def clear_control_inputs(): 

233 """Clears the control inputs but preserves the ControlFlowContext. 

234 

235 This is needed to preserve the XLAControlFlowControl when clearing 

236 control inputs for the gradient accumulators in while_v2. 

237 `ops.control_dependencies` does not allow that. 

238 

239 Yields: 

240 A context manager in which the ops created will not have any control inputs 

241 by default but the control flow context is the same. 

242 """ 

243 # pylint: disable=protected-access 

244 control_flow_context = ops.get_default_graph()._get_control_flow_context() 

245 with ops.control_dependencies(None): 

246 ops.get_default_graph()._set_control_flow_context(control_flow_context) 

247 yield 

248 # pylint: enable=protected-access 

249 

250 

251def _is_tpu_strategy(strategy): 

252 return (strategy is not None and 

253 strategy.__class__.__name__.startswith("TPUStrategy")) 

254 

255 

256def _is_building_keras_layer(): 

257 # TODO(srbs): Remove this function when we no long support session with Keras. 

258 keras_call_context_function = keras_deps.get_call_context_function() 

259 if keras_call_context_function: 

260 return keras_call_context_function().layer is not None 

261 else: 

262 return False 

263 

264 

265def output_all_intermediates(): 

266 """Whether to output all intermediates of a functional control flow op. 

267 

268 The default behavior is to output intermediates only when building a Keras 

269 Layer in graph mode and that too when certain other conditions are met: 

270 1. We do not output intermediates if the functional control flow op 

271 is being built inside a FuncGraph which is not a If/While graph. This 

272 guards against outputting intermediates in eager mode since keras adds 

273 tensors to a FuncGraph named "keras_graph" in that case. Also because we 

274 do not output intermediates of tf.function (since this feature is only for 

275 backwards compatibility) outputting intermediates of functional control 

276 flow ops built inside tf.function is of no value. 

277 2. We do not output intermediates when the compilation is using XLA or for a 

278 TPU. 

279 3. We do not output intermediates when a single threaded executor is used 

280 since that does not perform inlining and pruning. 

281 

282 Returns: 

283 A bool telling whether to output all intermediates. 

284 """ 

285 if _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE is not None: 

286 return _EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 

287 if in_defun(): 

288 return False 

289 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 

290 return False 

291 if (context.context().function_call_options.executor_type == 

292 "SINGLE_THREADED_EXECUTOR"): 

293 return False 

294 return _is_building_keras_layer() 

295 

296 

297def get_func_graph(op, input_shapes, func_name): 

298 """Generates and returns a FuncGraph for the given op and input_shapes.""" 

299 fdef = None 

300 graph = op.graph 

301 # Recursively search the func in graphs. 

302 while graph is not None: 

303 func = graph._get_function(func_name) # pylint: disable=protected-access 

304 if func is not None: 

305 fdef = func.cached_definition 

306 break 

307 if hasattr(graph, "outer_graph"): 

308 graph = graph.outer_graph 

309 else: 

310 break 

311 

312 if fdef is None: 

313 raise KeyError("%s cannot be found in the graph" % func_name) 

314 

315 # `op.graph` may not be the same as `ops.get_default_graph()` e.g. 

316 # in the case of nested if ops or when the gradient is being computed 

317 # from inside a Defun. We build the `func_graph` with `op.graph` as its 

318 # `outer_graph`. This resembles how the `FuncGraph` was built in the 

319 # forward pass. We need this so that we can resolve references to tensors 

320 # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. 

321 with op.graph.as_default(): 

322 func_graph = function_def_to_graph.function_def_to_graph( 

323 fdef, input_shapes=input_shapes) 

324 

325 # TODO(xjun): Ideally we want to retrieve the gradient functions instead of 

326 # re-create them. But the lifetime of gradient functions of PartitionedCall 

327 # ops is attached to ParitionedCall ops in the original func_graph and 

328 # when we are inside this function we don't have access to the original 

329 # func_graph or PartitionedCall ops. See cl/499362867 and cl/273858076 for 

330 # more context. 

331 for operation in func_graph.get_operations(): 

332 if operation.type in ["PartitionedCall", "StatefulPartitionedCall"]: 

333 f = graph._get_function(operation.get_attr("f").name) # pylint: disable=protected-access 

334 try: 

335 cf = monomorphic_function.ConcreteFunction( 

336 f.graph, attrs=f.cached_definition.attr 

337 ) 

338 except AttributeError: 

339 # f is not found or f is a _DefinedFunction that doesn't have a graph. 

340 continue 

341 operation._gradient_function = cf._get_gradient_function() # pylint: disable=protected-access 

342 

343 return func_graph 

344 

345 

346def get_op_and_outputs(op_or_outputs): 

347 if isinstance(op_or_outputs, ops.Operation): 

348 return op_or_outputs, [] 

349 elif not op_or_outputs: # Empty list. 

350 return None, [] 

351 else: 

352 return op_or_outputs[0].op, op_or_outputs 

353 

354 

355def graph_wrapped_for_higher_order_tape_gradients(graph): 

356 """Check if `graph` is wrapped by `run_as_function_for_tape_gradients`.""" 

357 while graph is not None: 

358 if "cflow_gradient_wrapper" in getattr(graph, "name", ""): 

359 return True 

360 graph = getattr(graph, "outer_graph", None) 

361 return False 

362 

363 

364def run_as_function_for_tape_gradients(make_op, inputs): 

365 """Fix higher-order tape gradients by wrapping `make_op` in a function. 

366 

367 Args: 

368 make_op: A function that takes a list of inputs and returns a list of output 

369 tensors. This function should set any handle data relevant to its outputs 

370 before returning. 

371 inputs: A list of tensors to check for tape gradients and pass to 

372 `make_op`. These should include all tensors used in `make_op`. 

373 

374 Returns: 

375 Tensors corresponding to `make_op`'s output. 

376 """ 

377 # GradientTapes created inside a function currently don't work well with 

378 # un-wrapped control flow ops in that same function. Wrapping in an extra 

379 # layer of intermediate function means we run extra logic in the function 

380 # gradient code to record the correct intermediates on the tape. 

381 # 

382 # The function attribute inputs to control flow ops are not hashable, so we 

383 # pass everything as a capture to bypass defun's caching. 

384 if (gradients_util.PossibleTapeGradientTypes(inputs) 

385 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER 

386 # We only need one function between the tape and the op; if we've already 

387 # wrapped once, we stop wrapping to avoid infinite recursion. 

388 and not (ops.get_default_graph().building_function 

389 and "cflow_gradient_wrapper" in ops.get_default_graph().name)): 

390 results = tracing_compiler.TracingCompiler( 

391 make_op, 

392 "cflow_gradient_wrapper", 

393 autograph=False)(inputs) 

394 return results 

395 else: 

396 return make_op(inputs)