Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/training_loop.py: 16%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16"""Library for constructing a training loop, suitable for TPUs.""" 

17 

18from typing import Any, Callable, Iterable, List, Optional, Union 

19 

20from tensorflow.python.compiler.xla import xla 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import control_flow_ops 

24from tensorflow.python.ops import while_loop as while_loop_tf 

25from tensorflow.python.tpu import tensor_tracer 

26from tensorflow.python.tpu import tpu_feed 

27from tensorflow.python.tpu import tpu_function 

28from tensorflow.python.types import core as core_types 

29 

30 

31def while_loop(condition: Callable[..., Any], 

32 body: Callable[..., Any], 

33 inputs: Optional[List[Any]] = None, 

34 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

35 name: Any = None) -> Any: 

36 """Builds a training loop for TPUs. 

37 

38 The set of loop-carried tensors corresponds to `inputs`. Both 

39 `condition` and `body` take the current value of the loop-carried 

40 tensors. 'body' additionally takes a tuple of infeed from 

41 infeed_queue if infeed_queue is not None. `condition` must return a 

42 single boolean value that determines whether iteration 

43 continues. `body` must return an updated list of values for the 

44 loop-carried tensors. 

45 

46 Args: 

47 condition: a Python function that builds the loop condition. 

48 body: a Python function that builds the loop body. 

49 inputs: a list of initial values passed into the training loop, or None 

50 (equivalent to an empty list). 

51 infeed_queue: if not None, the infeed queue from which to append a tuple of 

52 arguments as inputs to condition. 

53 name: (Deprecated) Does nothing. 

54 

55 Returns: 

56 The final values of the loop-carried tensors. 

57 

58 Raises: 

59 TypeError: if body or condition has the wrong signature. 

60 """ 

61 del name 

62 # Converts inputs to Tensors. 

63 inputs = [] if inputs is None else [ops.convert_to_tensor(x) for 

64 x in inputs] 

65 input_types = [x.dtype for x in inputs] 

66 input_arity = len(inputs) 

67 

68 body_arg_error = xla.check_function_argument_count( 

69 body, input_arity, infeed_queue) 

70 if body_arg_error is not None: 

71 if infeed_queue is None: 

72 raise TypeError( 

73 f"Supplied loop body function cannot be called with the specified " 

74 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop body needs {body_arg_error}" 

75 ) 

76 else: 

77 raise TypeError( 

78 f"Supplied loop body function cannot be called with the specified " 

79 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]} and {infeed_queue.number_of_tuple_elements} additional inputs from " 

80 f"infeed, but the computation needs {body_arg_error}") 

81 condition_arg_error = xla.check_function_argument_count( 

82 condition, input_arity, None) 

83 if condition_arg_error is not None: 

84 if infeed_queue is None: 

85 raise TypeError( 

86 f"Supplied loop condition function cannot be called with the " 

87 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop " 

88 f"condition needs {condition_arg_error}") 

89 else: 

90 raise TypeError( 

91 f"Supplied loop condition function cannot be called with the " 

92 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop " 

93 f"condition needs {condition_arg_error}. Note that infeed is not passed to the loop condition." 

94 ) 

95 

96 def condition_wrapper(*inputs): 

97 # Discards the dummy output added for arity-0 loops. 

98 if input_arity == 0: 

99 inputs = [] 

100 return condition(*inputs) 

101 

102 def body_wrapper(*inputs): 

103 """Wrapper around `body` that handles infeed queues and control deps.""" 

104 inputs = list(inputs) 

105 

106 # Discards the dummy output added for arity-0 loops. 

107 if input_arity == 0: 

108 inputs = [] 

109 

110 # Runs `body` with the dequeue_ops appended. 

111 if infeed_queue: 

112 number_of_shards = tpu_function.get_tpu_context().number_of_shards 

113 if number_of_shards is None: 

114 raise ValueError("Can't build training loop with infeed when there is " 

115 "no tpu_shard_context. Are you building a loop or " 

116 "graph directly rather than from inside tpu.rewrite, " 

117 "tpu.batch_parallel, tpu.shard, or tpu.replicate?") 

118 infeed_queue.set_number_of_shards(number_of_shards) 

119 dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] 

120 else: 

121 dequeue_ops = [] 

122 outputs = body(*(inputs + dequeue_ops)) 

123 

124 # If the computation only returned one value, make it a tuple. 

125 if not isinstance(outputs, (list, tuple)): 

126 outputs = (outputs,) 

127 

128 outputs = [ 

129 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 

130 for o in outputs 

131 ] 

132 

133 # Separates the returned Operations and Tensors. 

134 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 

135 output_tensors = [o for o in outputs 

136 if not isinstance(o, ops.Operation)] 

137 

138 if outputs != output_tensors + output_operations: 

139 raise ValueError( 

140 "TPU training loop body must return zero or more Tensor values " 

141 "followed by zero or more Operations.") 

142 

143 output_types = [op.dtype for op in output_tensors] 

144 if input_types != output_types: 

145 raise TypeError( 

146 "Mismatch between input types and output types for training loop " 

147 "body: {} vs {}".format(input_types, output_types)) 

148 

149 # Add the dequeue operations to output_operations to ensure they are run 

150 # by the loop, even if the programmer's loop body does not use them. 

151 output_operations += dequeue_ops 

152 

153 # Add a dummy output, if needed. 

154 if not output_tensors: 

155 output_tensors = array_ops.constant(0) 

156 

157 if output_operations: 

158 # TODO(phawkins): in principle this is too restrictive since it serializes 

159 # the training loop steps. In practice it does not matter since this loop 

160 # will be compiled by XLA. 

161 output_tensors = control_flow_ops.tuple(output_tensors, 

162 control_inputs=output_operations) 

163 

164 if tensor_tracer.TensorTracer.is_enabled(): 

165 num_replicas = tpu_function.get_tpu_context().number_of_shards 

166 if num_replicas is None: 

167 num_replicas = 1 

168 tt = tensor_tracer.TensorTracer() 

169 output_tensors = tt.trace_tpu(ops.get_default_graph(), 

170 output_tensors, None, 

171 num_replicas) 

172 return output_tensors 

173 

174 # If the body has arity 0, add a dummy loop-carried value to which we can add 

175 # control dependencies from any side-effecting operations. 

176 if input_arity == 0: 

177 inputs = [array_ops.constant(0)] 

178 return while_loop_tf.while_loop( 

179 condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) 

180 

181 

182def repeat( 

183 n: int, 

184 body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic 

185 inputs: Optional[List[core_types.TensorLike]] = None, 

186 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 

187 name: Any = None) -> List[core_types.TensorLike]: 

188 """Builds a training loop that executes a fixed number of iterations. 

189 

190 The set of loop-carried tensors correspond to `inputs`. 

191 `body` must be a function that takes and returns the values of the 

192 loop-carried tensors. 

193 

194 Args: 

195 n: the number of loop iterations 

196 body: a Python function that builds the loop body. 

197 inputs: a list of initial values passed into the training loop or None 

198 (equivalent to an empty list). 

199 infeed_queue: if not None, the infeed queue from which to append a tuple of 

200 arguments as inputs to condition. 

201 name: (Deprecated) Does nothing. 

202 

203 Returns: 

204 The final values of the loop-carried tensors. 

205 Raises: 

206 ValueError: if there is a type error. 

207 """ 

208 def _convert_to_list(xs): 

209 if not isinstance(xs, (list, tuple)): 

210 return [xs] 

211 else: 

212 return list(xs) 

213 

214 def cond(i, *args): 

215 del args 

216 return i < n 

217 

218 def body_wrapper(i, *args): 

219 return [i + 1] + _convert_to_list(body(*args)) 

220 

221 inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) 

222 outputs = while_loop( 

223 cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) 

224 outputs = _convert_to_list(outputs) 

225 if len(outputs) == 1: 

226 # Returns the Op rather than an empty list. 

227 return outputs[0].op 

228 else: 

229 return outputs[1:]