Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_switch_case.py: 28%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Switch case for Control Flow Operations.""" 

16 

17from tensorflow.python.eager import context 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import control_flow_util as util 

21from tensorflow.python.ops import gen_functional_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.util.lazy_loader import LazyLoader 

24from tensorflow.python.util.tf_export import tf_export 

25 

26# TODO(b/269483538): needed for references while refactors are in progress 

27# This is to avoid a circular dependency: 

28# cond_v2 -> gradients_util -> control_flow_ops 

29cond_v2 = LazyLoader("cond_v2", globals(), 

30 "tensorflow.python.ops.cond_v2") 

31 

32 

33def _indexed_case_verify_and_canonicalize_args(branch_fns, default, 

34 branch_index): 

35 """Verifies input arguments for the case function. 

36 

37 Args: 

38 branch_fns: Dict or list of pairs of an `int` and a callable which returns a 

39 list of tensors. 

40 default: Optional callable that returns a list of tensors. 

41 branch_index: Optional int `Tensor`, which selects for the corresponding 

42 pred_fn_pair. 

43 

44 Raises: 

45 TypeError: If `branch_fns` is not a list/dictionary. 

46 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 

47 callables. 

48 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

49 callable. 

50 

51 Returns: 

52 branch_fns: validated list of callables for each branch (default last). 

53 """ 

54 if not isinstance(branch_index, ops.Tensor): 

55 raise TypeError("'branch_index' must be a Tensor, got {}".format( 

56 type(branch_index))) 

57 if not branch_index.dtype.is_integer: 

58 raise TypeError("'branch_index' must be an integer Tensor, got {}".format( 

59 branch_index.dtype)) 

60 

61 if not branch_fns: 

62 raise ValueError("Must provide at least one item in 'branch_fns'") 

63 if not isinstance(branch_fns, (list, tuple, dict)): 

64 raise TypeError("'branch_fns' must be a list, tuple, or dict") 

65 

66 if isinstance(branch_fns, dict): 

67 branch_fns = branch_fns.items() 

68 

69 if all(callable(fn) for fn in branch_fns): 

70 branch_fns = list(enumerate(branch_fns)) 

71 

72 for key_fn_pair in branch_fns: 

73 if not isinstance(key_fn_pair, tuple) or len(key_fn_pair) != 2: 

74 raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. " 

75 f"Received {key_fn_pair}.") 

76 key, branch_fn = key_fn_pair 

77 

78 if not isinstance(key, int): 

79 raise TypeError("key must be a Python `int`, got {}".format(type(key))) 

80 

81 if not callable(branch_fn): 

82 raise TypeError("fn for key {} must be callable.".format(key)) 

83 

84 keys = [p[0] for p in branch_fns] 

85 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): 

86 raise ValueError( 

87 "branch indices (keys) must form contiguous range of [0 to {}) but " 

88 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) 

89 actions = [p[1] for p in sorted(branch_fns)] 

90 if default is not None: 

91 actions.append(default) 

92 return actions 

93 

94 

95def _indexed_case_helper(branch_fns, 

96 default, 

97 branch_index, 

98 name, 

99 lower_using_switch_merge=None): 

100 """Implementation of case that emits the n-way indexed Case op. 

101 

102 Args: 

103 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a callable 

104 which returns a list of tensors. 

105 default: Optional callable that returns a list of tensors. 

106 branch_index: Optional int `Tensor`, which selects for the corresponding 

107 pred_fn_pair. 

108 name: A name for this operation (optional). 

109 lower_using_switch_merge: Lower this op using switch merge ops (optional). 

110 

111 Returns: 

112 The tensors returned by the pair whose key matched branch_index, or 

113 those returned by `default` if none does. 

114 

115 Raises: 

116 TypeError: If `branch_fns` is not a list/dictionary. 

117 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 

118 callables. 

119 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

120 callable. 

121 """ 

122 branch_fns = _indexed_case_verify_and_canonicalize_args( 

123 branch_fns, default, branch_index) 

124 with ops.name_scope(name, "case", [branch_index]): 

125 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 

126 branch_index = array_ops.where( 

127 math_ops.less(branch_index, 0) 

128 | math_ops.greater_equal(branch_index, len(branch_fns)), 

129 len(branch_fns) - 1, branch_index) 

130 return branch_fns[int(branch_index)]() 

131 return cond_v2.indexed_case( 

132 branch_index, 

133 branch_fns, 

134 lower_using_switch_merge=lower_using_switch_merge) 

135 

136 

137@tf_export("__internal__.execute_fn_for_device", v1=[]) 

138def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): 

139 """Executes one of the provided callables based on the device placement. 

140 

141 This API is used when the implementations for high level function depend on 

142 the underlying device placement. It takes a dictionary of device type to 

143 callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of 

144 the device where to run this op matches the key in 'device_branch_fns', 

145 the corresponding callable is executed, falling back to 'default_fn' if none 

146 matches. 

147 

148 **Example:** 

149 ```python 

150 def f1(): return tf.constant(1) 

151 def f2(): return tf.constant(2) 

152 r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1) 

153 ``` 

154 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on 

155 any other device types. 

156 

157 

158 Args: 

159 device_branch_fns: a dictionary of device types to the callables. Each 

160 callable must return a matching structure of tensors. 

161 default_fn: fallback callable when the underlying device does not match any 

162 key in the 'device_branch_fns'. 

163 name: A name for this operation (optional). 

164 

165 Returns: 

166 The tensors returned by the callable identified by device type during 

167 execution, or those returned by 'default_fn' if no key matches. 

168 """ 

169 # Always execute the default fn for XLA to avoid complicated graph by case op. 

170 # see more discussions in b/167276293. 

171 is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph()) 

172 if is_in_xla: 

173 return default_fn() 

174 device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()} 

175 branch_fns = list(device_branch_fns_upper.values()) 

176 devices = list(device_branch_fns_upper.keys()) 

177 device_index = gen_functional_ops.device_index(device_names=devices) 

178 return _indexed_case_helper( 

179 branch_fns, 

180 default_fn, 

181 device_index, 

182 name, 

183 lower_using_switch_merge=False) 

184 

185 

186@tf_export("switch_case") 

187def switch_case(branch_index, branch_fns, default=None, name="switch_case"): 

188 """Create a switch/case operation, i.e. 

189 

190 an integer-indexed conditional. 

191 

192 See also `tf.case`. 

193 

194 This op can be substantially more efficient than `tf.case` when exactly one 

195 branch will be selected. `tf.switch_case` is more like a C++ switch/case 

196 statement than `tf.case`, which is more like an if/elif/elif/else chain. 

197 

198 The `branch_fns` parameter is either a dict from `int` to callables, or list 

199 of (`int`, callable) pairs, or simply a list of callables (in which case the 

200 index is implicitly the key). The `branch_index` `Tensor` is used to select an 

201 element in `branch_fns` with matching `int` key, falling back to `default` 

202 if none match, or `max(keys)` if no `default` is provided. The keys must form 

203 a contiguous set from `0` to `len(branch_fns) - 1`. 

204 

205 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All 

206 callables must return the same (possibly nested) value structure of lists, 

207 tuples, and/or named tuples. 

208 

209 **Example:** 

210 

211 Pseudocode: 

212 

213 ```c++ 

214 switch (branch_index) { // c-style switch 

215 case 0: return 17; 

216 case 1: return 31; 

217 default: return -1; 

218 } 

219 ``` 

220 or 

221 ```python 

222 branches = {0: lambda: 17, 1: lambda: 31} 

223 branches.get(branch_index, lambda: -1)() 

224 ``` 

225 

226 Expressions: 

227 

228 ```python 

229 def f1(): return tf.constant(17) 

230 def f2(): return tf.constant(31) 

231 def f3(): return tf.constant(-1) 

232 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) 

233 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) 

234 ``` 

235 

236 Args: 

237 branch_index: An int Tensor specifying which of `branch_fns` should be 

238 executed. 

239 branch_fns: A `dict` mapping `int`s to callables, or a `list` of (`int`, 

240 callable) pairs, or simply a list of callables (in which case the index 

241 serves as the key). Each callable must return a matching structure of 

242 tensors. 

243 default: Optional callable that returns a structure of tensors. 

244 name: A name for this operation (optional). 

245 

246 Returns: 

247 The tensors returned by the callable identified by `branch_index`, or those 

248 returned by `default` if no key matches and `default` was provided, or those 

249 returned by the max-keyed `branch_fn` if no `default` is provided. 

250 

251 Raises: 

252 TypeError: If `branch_fns` is not a list/dictionary. 

253 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 

254 callables. 

255 TypeError: If `fns[i]` is not callable for any i, or `default` is not 

256 callable. 

257 """ 

258 return _indexed_case_helper(branch_fns, default, branch_index, name)