Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_case.py: 30%

79 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Case functions for Control Flow Operations.""" 

16 

17import collections 

18import functools 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.ops import array_ops_stack 

24from tensorflow.python.ops import cond 

25from tensorflow.python.ops import control_flow_assert 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.platform import tf_logging as logging 

28from tensorflow.python.util import dispatch 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32@tf_export("case", v1=[]) 

33@dispatch.add_dispatch_support 

34def case_v2(pred_fn_pairs, 

35 default=None, 

36 exclusive=False, 

37 strict=False, 

38 name="case"): 

39 """Create a case operation. 

40 

41 See also `tf.switch_case`. 

42 

43 The `pred_fn_pairs` parameter is a list of pairs of size N. 

44 Each pair contains a boolean scalar tensor and a python callable that 

45 creates the tensors to be returned if the boolean evaluates to True. 

46 `default` is a callable generating a list of tensors. All the callables 

47 in `pred_fn_pairs` as well as `default` (if provided) should return the same 

48 number and types of tensors. 

49 

50 If `exclusive==True`, all predicates are evaluated, and an exception is 

51 thrown if more than one of the predicates evaluates to `True`. 

52 If `exclusive==False`, execution stops at the first predicate which 

53 evaluates to True, and the tensors generated by the corresponding function 

54 are returned immediately. If none of the predicates evaluate to True, this 

55 operation returns the tensors generated by `default`. 

56 

57 `tf.case` supports nested structures as implemented in 

58 `tf.nest`. All of the callables must return the same (possibly nested) value 

59 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 

60 form the only exceptions to this: when returned by a callable, they are 

61 implicitly unpacked to single values. This behavior is disabled by passing 

62 `strict=True`. 

63 

64 @compatibility(v2) 

65 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and 

66 tf.Variable are no longer hashable in v2, so cannot be used as a key for a 

67 dictionary. Please use a list or a tuple instead. 

68 @end_compatibility 

69 

70 

71 **Example 1:** 

72 

73 Pseudocode: 

74 

75 ``` 

76 if (x < y) return 17; 

77 else return 23; 

78 ``` 

79 

80 Expressions: 

81 

82 ```python 

83 f1 = lambda: tf.constant(17) 

84 f2 = lambda: tf.constant(23) 

85 r = tf.case([(tf.less(x, y), f1)], default=f2) 

86 ``` 

87 

88 **Example 2:** 

89 

90 Pseudocode: 

91 

92 ``` 

93 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 

94 if (x < y) return 17; 

95 else if (x > z) return 23; 

96 else return -1; 

97 ``` 

98 

99 Expressions: 

100 

101 ```python 

102 def f1(): return tf.constant(17) 

103 def f2(): return tf.constant(23) 

104 def f3(): return tf.constant(-1) 

105 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], 

106 default=f3, exclusive=True) 

107 ``` 

108 

109 Args: 

110 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which 

111 returns a list of tensors. 

112 default: Optional callable that returns a list of tensors. 

113 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 

114 strict: A boolean that enables/disables 'strict' mode; see above. 

115 name: A name for this operation (optional). 

116 

117 Returns: 

118 The tensors returned by the first pair whose predicate evaluated to True, or 

119 those returned by `default` if none does. 

120 

121 Raises: 

122 TypeError: If `pred_fn_pairs` is not a list/tuple. 

123 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 

124 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

125 callable. 

126 """ 

127 return _case_helper( 

128 cond.cond, 

129 pred_fn_pairs, 

130 default, 

131 exclusive, 

132 name, 

133 allow_python_preds=False, 

134 strict=strict) 

135 

136 

137@tf_export(v1=["case"]) 

138@dispatch.add_dispatch_support 

139def case(pred_fn_pairs, 

140 default=None, 

141 exclusive=False, 

142 strict=False, 

143 name="case"): 

144 """Create a case operation. 

145 

146 See also `tf.switch_case`. 

147 

148 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 

149 Each pair contains a boolean scalar tensor and a python callable that 

150 creates the tensors to be returned if the boolean evaluates to True. 

151 `default` is a callable generating a list of tensors. All the callables 

152 in `pred_fn_pairs` as well as `default` (if provided) should return the same 

153 number and types of tensors. 

154 

155 If `exclusive==True`, all predicates are evaluated, and an exception is 

156 thrown if more than one of the predicates evaluates to `True`. 

157 If `exclusive==False`, execution stops at the first predicate which 

158 evaluates to True, and the tensors generated by the corresponding function 

159 are returned immediately. If none of the predicates evaluate to True, this 

160 operation returns the tensors generated by `default`. 

161 

162 `tf.case` supports nested structures as implemented in 

163 `tf.nest`. All of the callables must return the same (possibly nested) value 

164 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 

165 form the only exceptions to this: when returned by a callable, they are 

166 implicitly unpacked to single values. This behavior is disabled by passing 

167 `strict=True`. 

168 

169 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 

170 conditional tests is not guaranteed. However, the order is guaranteed to be 

171 deterministic, so that variables created in conditional branches are created 

172 in fixed order across runs. 

173 

174 @compatibility(eager) 

175 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 

176 Use a list of tuples instead. 

177 @end_compatibility 

178 

179 

180 **Example 1:** 

181 

182 Pseudocode: 

183 

184 ``` 

185 if (x < y) return 17; 

186 else return 23; 

187 ``` 

188 

189 Expressions: 

190 

191 ```python 

192 f1 = lambda: tf.constant(17) 

193 f2 = lambda: tf.constant(23) 

194 r = tf.case([(tf.less(x, y), f1)], default=f2) 

195 ``` 

196 

197 **Example 2:** 

198 

199 Pseudocode: 

200 

201 ``` 

202 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 

203 if (x < y) return 17; 

204 else if (x > z) return 23; 

205 else return -1; 

206 ``` 

207 

208 Expressions: 

209 

210 ```python 

211 def f1(): return tf.constant(17) 

212 def f2(): return tf.constant(23) 

213 def f3(): return tf.constant(-1) 

214 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 

215 default=f3, exclusive=True) 

216 ``` 

217 

218 Args: 

219 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 

220 callable which returns a list of tensors. 

221 default: Optional callable that returns a list of tensors. 

222 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 

223 strict: A boolean that enables/disables 'strict' mode; see above. 

224 name: A name for this operation (optional). 

225 

226 Returns: 

227 The tensors returned by the first pair whose predicate evaluated to True, or 

228 those returned by `default` if none does. 

229 

230 Raises: 

231 TypeError: If `pred_fn_pairs` is not a list/dictionary. 

232 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 

233 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

234 callable. 

235 """ 

236 return _case_helper( 

237 cond.cond, 

238 pred_fn_pairs, 

239 default, 

240 exclusive, 

241 name, 

242 allow_python_preds=False, 

243 strict=strict) 

244 

245 

246def _assert_at_most_n_true(predicates, n, msg): 

247 """Returns an Assert op that checks that at most n predicates are True. 

248 

249 Args: 

250 predicates: list of bool scalar tensors. 

251 n: maximum number of true predicates allowed. 

252 msg: Error message. 

253 """ 

254 preds_c = array_ops_stack.stack(predicates, name="preds_c") 

255 num_true_conditions = math_ops.reduce_sum( 

256 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 

257 condition = math_ops.less_equal(num_true_conditions, 

258 constant_op.constant(n, name="n_true_conds")) 

259 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 

260 error_msg = [ 

261 "%s: more than %d conditions (%s) evaluated as True:" % 

262 (msg, n, preds_names), preds_c 

263 ] 

264 return control_flow_assert.Assert( 

265 condition, data=error_msg, summarize=len(predicates)) 

266 

267 

268def _case_create_default_action(predicates, actions): 

269 """Creates default action for a list of actions and their predicates. 

270 

271 It uses the input actions to select an arbitrary as default and makes sure 

272 that corresponding predicates have valid values. 

273 

274 Args: 

275 predicates: a list of bool scalar tensors 

276 actions: a list of callable objects which return tensors. 

277 

278 Returns: 

279 a callable 

280 """ 

281 k = len(predicates) - 1 # could pick any 

282 predicate, action = predicates[k], actions[k] 

283 other_predicates, other_actions = predicates[:k], actions[:k] 

284 

285 def default_action(): 

286 others_msg = ("Implementation error: " 

287 "selected default action #%d was called, but some of other " 

288 "predicates are True: " % k) 

289 default_msg = ("Input error: " 

290 "None of conditions evaluated as True:", 

291 array_ops_stack.stack(predicates, name="preds_c")) 

292 with ops.control_dependencies([ 

293 _assert_at_most_n_true( # pylint: disable=protected-access 

294 other_predicates, n=0, msg=others_msg), 

295 control_flow_assert.Assert(predicate, data=default_msg) 

296 ]): 

297 return action() 

298 

299 return default_action, other_predicates, other_actions 

300 

301 

302def _case_helper(cond_fn, 

303 pred_fn_pairs, 

304 default, 

305 exclusive, 

306 name, 

307 allow_python_preds=False, 

308 **cond_kwargs): 

309 """Implementation of case that allows for different cond functions. 

310 

311 Args: 

312 cond_fn: method that has signature and semantics of `cond` above. 

313 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 

314 callable which returns a list of tensors. 

315 default: Optional callable that returns a list of tensors. 

316 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 

317 name: A name for this operation (optional). 

318 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 

319 addition to boolean Tensors 

320 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 

321 

322 Returns: 

323 The tensors returned by the first pair whose predicate evaluated to True, or 

324 those returned by `default` if none does. 

325 

326 Raises: 

327 TypeError: If `pred_fn_pairs` is not a list/dictionary. 

328 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 

329 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

330 callable. 

331 """ 

332 predicates, actions = _case_verify_and_canonicalize_args( 

333 pred_fn_pairs, exclusive, name, allow_python_preds) 

334 with ops.name_scope(name, "case", [predicates]): 

335 if default is None: 

336 default, predicates, actions = _case_create_default_action( 

337 predicates, actions) 

338 fn = default 

339 # To eval conditions in direct order we create nested conditions in reverse: 

340 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 

341 for predicate, action in reversed(list(zip(predicates, actions))): 

342 fn = functools.partial( 

343 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 

344 if exclusive: 

345 with ops.control_dependencies([ 

346 _assert_at_most_n_true( # pylint: disable=protected-access 

347 predicates, n=1, msg="Input error: exclusive=True") 

348 ]): 

349 return fn() 

350 else: 

351 return fn() 

352 

353 

354def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 

355 allow_python_preds): 

356 """Verifies input arguments for the case function. 

357 

358 Args: 

359 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 

360 callable which returns a list of tensors. 

361 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 

362 name: A name for the case operation. 

363 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 

364 addition to boolean Tensors 

365 

366 Raises: 

367 TypeError: If `pred_fn_pairs` is not a list/dictionary. 

368 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 

369 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

370 callable. 

371 

372 Returns: 

373 a tuple <list of scalar bool tensors, list of callables>. 

374 """ 

375 if not isinstance(pred_fn_pairs, (list, tuple, dict)): 

376 raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. " 

377 f"Received: {type(pred_fn_pairs)}") 

378 

379 if isinstance(pred_fn_pairs, collections.OrderedDict): 

380 pred_fn_pairs = pred_fn_pairs.items() 

381 elif isinstance(pred_fn_pairs, dict): 

382 if context.executing_eagerly(): 

383 # No name to sort on in eager mode. Use dictionary traversal order, 

384 # which is nondeterministic in versions of Python < 3.6 

385 if not exclusive: 

386 raise ValueError("Unordered dictionaries are not supported for the " 

387 "'pred_fn_pairs' argument when `exclusive=False` and " 

388 "eager mode is enabled.") 

389 pred_fn_pairs = list(pred_fn_pairs.items()) 

390 else: 

391 pred_fn_pairs = sorted( 

392 pred_fn_pairs.items(), key=lambda item: item[0].name) 

393 if not exclusive: 

394 logging.warn( 

395 "%s: An unordered dictionary of predicate/fn pairs was " 

396 "provided, but exclusive=False. The order of conditional " 

397 "tests is deterministic but not guaranteed.", name) 

398 for pred_fn_pair in pred_fn_pairs: 

399 if not isinstance(pred_fn_pair, tuple) or len(pred_fn_pair) != 2: 

400 raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. " 

401 f"Received {pred_fn_pair}.") 

402 pred, fn = pred_fn_pair 

403 

404 if isinstance(pred, ops.Tensor): 

405 if pred.dtype != dtypes.bool: 

406 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 

407 elif not allow_python_preds: 

408 raise TypeError("pred must be a Tensor, got: %s" % pred) 

409 elif not isinstance(pred, bool): 

410 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 

411 

412 if not callable(fn): 

413 raise TypeError("fn for pred %s must be callable." % pred.name) 

414 

415 predicates, actions = zip(*pred_fn_pairs) 

416 return predicates, actions