Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/cond.py: 22%

133 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Cond function for Control Flow Operations.""" 

16 

17from tensorflow.python.eager import context 

18from tensorflow.python.eager.polymorphic_function import eager_function_run 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import indexed_slices 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import control_flow_util as util 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.platform import tf_logging as logging 

27from tensorflow.python.types import core 

28from tensorflow.python.util import deprecation 

29from tensorflow.python.util import dispatch 

30from tensorflow.python.util import nest 

31from tensorflow.python.util.lazy_loader import LazyLoader 

32from tensorflow.python.util.tf_export import tf_export 

33 

34# TODO(b/269483538): below lazy loads 

35# needed for references while refactors are in progress 

36control_flow_ops = LazyLoader( 

37 "control_flow_ops", globals(), 

38 "tensorflow.python.ops.control_flow_ops") 

39# This is to avoid a circular dependency: 

40# cond_v2 -> gradients_util -> control_flow_ops 

41cond_v2 = LazyLoader("cond_v2", globals(), 

42 "tensorflow.python.ops.cond_v2") 

43 

44 

45# pylint: disable=redefined-outer-name 

46# pylint: disable=g-doc-args 

47@tf_export(v1=["cond"]) 

48@dispatch.add_dispatch_support 

49@deprecation.deprecated_args( 

50 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 

51 "fn1", "fn2") 

52def cond(pred, 

53 true_fn=None, 

54 false_fn=None, 

55 strict=False, 

56 name=None, 

57 fn1=None, 

58 fn2=None): 

59 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 

60 

61 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 

62 `false_fn` must have the same non-zero number and type of outputs. 

63 

64 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 

65 `false_fn` will be executed regardless of which branch is selected at runtime. 

66 

67 Although this behavior is consistent with the dataflow model of TensorFlow, 

68 it has frequently surprised users who expected a lazier semantics. 

69 Consider the following simple program: 

70 

71 ```python 

72 z = tf.multiply(a, b) 

73 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 

74 ``` 

75 

76 If `x < y`, the `tf.add` operation will be executed and `tf.square` 

77 operation will not be executed. Since `z` is needed for at least one 

78 branch of the `cond`, the `tf.multiply` operation is always executed, 

79 unconditionally. 

80 

81 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 

82 call to `cond`, and not at all during `Session.run()`). `cond` 

83 stitches together the graph fragments created during the `true_fn` and 

84 `false_fn` calls with some additional graph nodes to ensure that the right 

85 branch gets executed depending on the value of `pred`. 

86 

87 `tf.cond` supports nested structures as implemented in 

88 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 

89 same (possibly nested) value structure of lists, tuples, and/or named tuples. 

90 Singleton lists and tuples form the only exceptions to this: when returned by 

91 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 

92 This behavior is disabled by passing `strict=True`. 

93 

94 Args: 

95 pred: A scalar determining whether to return the result of `true_fn` or 

96 `false_fn`. 

97 true_fn: The callable to be performed if pred is true. 

98 false_fn: The callable to be performed if pred is false. 

99 strict: A boolean that enables/disables 'strict' mode; see above. 

100 name: Optional name prefix for the returned tensors. 

101 

102 Returns: 

103 Tensors returned by the call to either `true_fn` or `false_fn`. If the 

104 callables return a singleton list, the element is extracted from the list. 

105 

106 Raises: 

107 TypeError: if `true_fn` or `false_fn` is not callable. 

108 ValueError: if `true_fn` and `false_fn` do not return the same number of 

109 tensors, or return tensors of different types. 

110 

111 Example: 

112 

113 ```python 

114 x = tf.constant(2) 

115 y = tf.constant(5) 

116 def f1(): return tf.multiply(x, 17) 

117 def f2(): return tf.add(y, 23) 

118 r = tf.cond(tf.less(x, y), f1, f2) 

119 # r is set to f1(). 

120 # Operations in f2 (e.g., tf.add) are not executed. 

121 ``` 

122 

123 """ 

124 # We needed to make true_fn/false_fn keyword arguments for 

125 # backwards-compatibility. This check exists so that we can convert back to 

126 # having them be positional arguments. 

127 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 

128 # `fn1` and `fn2` are deleted. 

129 if fn1 is not None: 

130 if true_fn is not None: 

131 raise TypeError( 

132 "cond(): 'true_fn' and 'fn1' may not be set simultaneously.") 

133 true_fn = fn1 

134 elif true_fn is None: 

135 raise TypeError("cond(): 'true_fn' argument required") 

136 if fn2 is not None: 

137 if false_fn is not None: 

138 raise TypeError( 

139 "cond(): 'false_fn' and 'fn2' may not be set simultaneously.") 

140 false_fn = fn2 

141 elif false_fn is None: 

142 raise TypeError("cond(): 'false_fn' argument required") 

143 

144 if not callable(true_fn): 

145 raise TypeError("'true_fn' must be callable.") 

146 if not callable(false_fn): 

147 raise TypeError("'false_fn' must be callable.") 

148 

149 if context.executing_eagerly(): 

150 return _eager_cond_implementation(pred, true_fn, false_fn, strict, name) 

151 

152 # Always enable control flow v2 if building a function, regardless of toggle. 

153 if util.EnableControlFlowV2(ops.get_default_graph()): 

154 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 

155 

156 with ops.name_scope(name, "cond", [pred]): 

157 # Add the Switch to the graph. 

158 if isinstance(pred, bool): 

159 raise TypeError("'pred' must not be a Python bool.") 

160 p_2, p_1 = control_flow_ops.switch(pred, pred) 

161 pivot_1 = array_ops.identity(p_1, name="switch_t") 

162 pivot_2 = array_ops.identity(p_2, name="switch_f") 

163 pred = array_ops.identity(pred, name="pred_id") 

164 # Disable the fetching of tensors that are only on one branch of cond. 

165 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 

166 tensor.op.graph.prevent_fetching(tensor.op) 

167 

168 # Build the graph for the true branch in a new context. 

169 context_t = control_flow_ops.CondContext(pred, pivot_1, branch=1) 

170 try: 

171 context_t.Enter() 

172 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 

173 if orig_res_t is None: 

174 raise ValueError("'true_fn' must have a return value.") 

175 context_t.ExitResult(res_t) 

176 finally: 

177 context_t.Exit() 

178 

179 # Build the graph for the false branch in a new context. 

180 context_f = control_flow_ops.CondContext(pred, pivot_2, branch=0) 

181 try: 

182 context_f.Enter() 

183 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 

184 if orig_res_f is None: 

185 raise ValueError("'false_fn' must have a return value.") 

186 context_f.ExitResult(res_f) 

187 finally: 

188 context_f.Exit() 

189 

190 if not strict: 

191 orig_res_t = _UnpackIfSingleton(orig_res_t) 

192 orig_res_f = _UnpackIfSingleton(orig_res_f) 

193 

194 # Check that the return values of the two branches have the same structure. 

195 try: 

196 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True) 

197 except (TypeError, ValueError): 

198 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f) 

199 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f) 

200 try: 

201 nest.assert_same_structure(orig_res_t, orig_res_f, 

202 expand_composites=True) 

203 except TypeError as e: 

204 raise TypeError( 

205 f"Incompatible return types of 'true_fn' and 'false_fn': {e}") 

206 except ValueError as e: 

207 raise ValueError( 

208 f"Incompatible return values of 'true_fn' and 'false_fn': {e}") 

209 

210 # Add the final merge to the graph. 

211 if not res_t: 

212 raise ValueError( 

213 "'true_fn' and 'false_fn' must return at least one result.") 

214 

215 res_t_flat = nest.flatten(res_t, expand_composites=True) 

216 res_f_flat = nest.flatten(res_f, expand_composites=True) 

217 

218 for (x, y) in zip(res_t_flat, res_f_flat): 

219 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 

220 if x.dtype.base_dtype != y.dtype.base_dtype: 

221 raise ValueError( 

222 "Outputs of 'true_fn' and 'false_fn' must have the same type(s). " 

223 f"Received {x.dtype.name} from 'true_fn' " 

224 f"and {y.dtype.name} from 'false_fn'.") 

225 

226 merges = [ 

227 control_flow_ops.merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 

228 merges = nest.map_structure( 

229 control_flow_ops._convert_flow_to_tensorarray, # pylint: disable=protected-access 

230 nest.flatten(orig_res_t, expand_composites=True), 

231 merges) 

232 

233 # Only add non-nested conds to the collection. Any nested control flow will 

234 # be encapsulated in the root context. 

235 assert context_t.outer_context == context_f.outer_context 

236 if context_t.outer_context is None: 

237 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 

238 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 

239 

240 merges = nest.pack_sequence_as( 

241 structure=orig_res_t, flat_sequence=merges, expand_composites=True) 

242 

243 # Singleton lists and tuples are automatically unpacked if strict == False. 

244 if not strict: 

245 merges = _UnpackIfSingleton(merges) 

246 return merges 

247 

248 

249@tf_export("cond", v1=[]) 

250@dispatch.add_dispatch_support 

251def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): 

252 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 

253 

254 Note: This op is automatically used in a `tf.function` to convert Python 

255 if-statements when the predicate is a `tf.Tensor`, unless `autograph=False` is 

256 explicitly specified in `tf.function` args. For example, the following are 

257 equivalent: 

258 

259 >>> @tf.function 

260 ... def fun1(x,y): 

261 ... if x > 0: # AutoGraph converts if-statement to tf.cond(). 

262 ... z = y+1 

263 ... else: 

264 ... z = y-1 

265 ... return z 

266 >>> fun1(tf.constant(7), tf.constant(3)).numpy() 

267 4 

268 

269 >>> @tf.function 

270 ... def fun2(x,y): 

271 ... pred = x > 0 

272 ... true_fn = lambda: y+1 

273 ... false_fn = lambda: y-1 

274 ... return tf.cond(pred, true_fn, false_fn) # Use tf.cond() explicitly. 

275 >>> fun1(tf.constant(7), tf.constant(3)).numpy() 

276 4 

277 

278 For more information, see [tf.function and AutoGraph guide]( 

279 https://www.tensorflow.org/guide/function#autograph_transformations). 

280 

281 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 

282 `false_fn` must have the same non-zero number and type of outputs. 

283 

284 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 

285 `false_fn` will be executed regardless of which branch is selected at runtime. 

286 

287 Although this behavior is consistent with the dataflow model of TensorFlow, 

288 it has frequently surprised users who expected a lazier semantics. 

289 Consider the following simple program: 

290 

291 >>> x, y = tf.constant(2, dtype=tf.int32), tf.constant(4, dtype=tf.int32) 

292 >>> z = tf.multiply(x, y) 

293 >>> r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 

294 >>> r.numpy() 

295 10 

296 

297 If `x < y`, the `tf.add` operation will be executed and `tf.square` 

298 operation will not be executed. Since `z` is needed for at least one 

299 branch of the `cond`, the `tf.multiply` operation is always executed, 

300 unconditionally. 

301 

302 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 

303 call to `cond`, and not at all during `Session.run()`). `cond` 

304 stitches together the graph fragments created during the `true_fn` and 

305 `false_fn` calls with some additional graph nodes to ensure that the right 

306 branch gets executed depending on the value of `pred`. 

307 

308 `tf.cond` supports nested structures as implemented in 

309 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 

310 same (possibly nested) value structure of lists, tuples, and/or named tuples. 

311 Singleton lists and tuples form the only exceptions to this: when returned by 

312 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 

313 

314 Note: It is illegal to "directly" use tensors created inside a cond branch 

315 outside it, e.g. by storing a reference to a branch tensor in the python 

316 state. If you need to use a tensor created in a branch function you should 

317 return it as an output of the branch function and use the output from 

318 `tf.cond` instead. 

319 

320 Args: 

321 pred: A scalar determining whether to return the result of `true_fn` or 

322 `false_fn`. 

323 true_fn: The callable to be performed if pred is true. 

324 false_fn: The callable to be performed if pred is false. 

325 name: Optional name prefix for the returned tensors. 

326 

327 Returns: 

328 Tensors returned by the call to either `true_fn` or `false_fn`. If the 

329 callables return a singleton list, the element is extracted from the list. 

330 

331 Raises: 

332 TypeError: if `true_fn` or `false_fn` is not callable. 

333 ValueError: if `true_fn` and `false_fn` do not return the same number of 

334 tensors, or return tensors of different types. 

335 

336 Example: 

337 

338 >>> x = tf.constant(2) 

339 >>> y = tf.constant(5) 

340 >>> def f1(): return tf.multiply(x, 7) 

341 >>> def f2(): return tf.add(y, 3) 

342 >>> r = tf.cond(tf.less(x, y), f1, f2) 

343 >>> # r is set to f1(). 

344 >>> # Operations in f2 (e.g., tf.add) are not executed. 

345 >>> r.numpy() 

346 14 

347 

348 """ 

349 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 

350 

351 

352def _UnpackIfSingleton(res): 

353 if isinstance(res, (list, tuple)) and len(res) == 1: 

354 return res[0] 

355 else: 

356 return res 

357 

358 

359def _eager_cond_implementation(pred, true_fn, false_fn, strict, name): 

360 """Special cases for `cond` when executing eagerly.""" 

361 pred = ops.convert_to_tensor(pred) 

362 pred_constant_value = tensor_util.constant_value(pred) 

363 if pred_constant_value is None: 

364 # Eager tensors from a parallel device may not have a constant 

365 # value. Running the cond op itself would work, but we don't have logic to 

366 # build cond ops without wrapping in a function first. 

367 if (not isinstance(true_fn, core.GenericFunction) 

368 or not isinstance(false_fn, core.GenericFunction)): 

369 raise TypeError("When running tf.cond on a parallel device, 'true_fn' " 

370 "and 'false_fn' must be decorated with `tf.function`.") 

371 functions_run_eagerly = eager_function_run.functions_run_eagerly() 

372 if functions_run_eagerly: 

373 # We need to use tf.function to deal with variable creation inside the 

374 # cond, and skipping it because of run_functions_eagerly would just 

375 # crash immediately. 

376 logging.warning( 

377 "It looks like tf.function behavior was disabled, perhaps using " 

378 "tf.config.run_functions_eagerly. Parallelized tf.cond requires " 

379 "tf.function to work. This primitive will override the disable.") 

380 eager_function_run.run_functions_eagerly(False) 

381 try: 

382 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 

383 finally: 

384 if functions_run_eagerly is not None: 

385 eager_function_run.run_functions_eagerly(functions_run_eagerly) 

386 else: 

387 # For conditions which are eager tensors with a constant value (most of 

388 # them), we only call the relevant branch function and execute it eagerly. 

389 with ops.name_scope(name, "cond", [pred]): 

390 if pred_constant_value: 

391 result = true_fn() 

392 else: 

393 result = false_fn() 

394 if not strict: 

395 result = _UnpackIfSingleton(result) 

396 return result 

397 

398 

399def _cast_indexed_slice_indices(a, b): 

400 """Cast IndexedSlice.indices from int32 to int64 where necessary. 

401 

402 If `a` and `b` are both IndexedSlices, and their indices have different 

403 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b` 

404 in-place). Otherwise, does nothing. 

405 

406 Args: 

407 a: A value, which may be an IndexedSlices. 

408 b: A value, which may be an IndexedSlices. 

409 """ 

410 if (isinstance(a, indexed_slices.IndexedSlices) and 

411 isinstance(b, indexed_slices.IndexedSlices) and 

412 a.indices.dtype != b.indices.dtype): 

413 # pylint: disable=protected-access 

414 a._indices = math_ops.cast(a.indices, dtypes.int64) 

415 b._indices = math_ops.cast(b.indices, dtypes.int64)