Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/training_loop.py: 16%
83 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16"""Library for constructing a training loop, suitable for TPUs."""
18from typing import Any, Callable, Iterable, List, Optional, Union
20from tensorflow.python.compiler.xla import xla
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import while_loop as while_loop_tf
25from tensorflow.python.tpu import tensor_tracer
26from tensorflow.python.tpu import tpu_feed
27from tensorflow.python.tpu import tpu_function
28from tensorflow.python.types import core as core_types
31def while_loop(condition: Callable[..., Any],
32 body: Callable[..., Any],
33 inputs: Optional[List[Any]] = None,
34 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
35 name: Any = None) -> Any:
36 """Builds a training loop for TPUs.
38 The set of loop-carried tensors corresponds to `inputs`. Both
39 `condition` and `body` take the current value of the loop-carried
40 tensors. 'body' additionally takes a tuple of infeed from
41 infeed_queue if infeed_queue is not None. `condition` must return a
42 single boolean value that determines whether iteration
43 continues. `body` must return an updated list of values for the
44 loop-carried tensors.
46 Args:
47 condition: a Python function that builds the loop condition.
48 body: a Python function that builds the loop body.
49 inputs: a list of initial values passed into the training loop, or None
50 (equivalent to an empty list).
51 infeed_queue: if not None, the infeed queue from which to append a tuple of
52 arguments as inputs to condition.
53 name: (Deprecated) Does nothing.
55 Returns:
56 The final values of the loop-carried tensors.
58 Raises:
59 TypeError: if body or condition has the wrong signature.
60 """
61 del name
62 # Converts inputs to Tensors.
63 inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
64 x in inputs]
65 input_types = [x.dtype for x in inputs]
66 input_arity = len(inputs)
68 body_arg_error = xla.check_function_argument_count(
69 body, input_arity, infeed_queue)
70 if body_arg_error is not None:
71 if infeed_queue is None:
72 raise TypeError(
73 f"Supplied loop body function cannot be called with the specified "
74 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop body needs {body_arg_error}"
75 )
76 else:
77 raise TypeError(
78 f"Supplied loop body function cannot be called with the specified "
79 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]} and {infeed_queue.number_of_tuple_elements} additional inputs from "
80 f"infeed, but the computation needs {body_arg_error}")
81 condition_arg_error = xla.check_function_argument_count(
82 condition, input_arity, None)
83 if condition_arg_error is not None:
84 if infeed_queue is None:
85 raise TypeError(
86 f"Supplied loop condition function cannot be called with the "
87 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
88 f"condition needs {condition_arg_error}")
89 else:
90 raise TypeError(
91 f"Supplied loop condition function cannot be called with the "
92 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
93 f"condition needs {condition_arg_error}. Note that infeed is not passed to the loop condition."
94 )
96 def condition_wrapper(*inputs):
97 # Discards the dummy output added for arity-0 loops.
98 if input_arity == 0:
99 inputs = []
100 return condition(*inputs)
102 def body_wrapper(*inputs):
103 """Wrapper around `body` that handles infeed queues and control deps."""
104 inputs = list(inputs)
106 # Discards the dummy output added for arity-0 loops.
107 if input_arity == 0:
108 inputs = []
110 # Runs `body` with the dequeue_ops appended.
111 if infeed_queue:
112 number_of_shards = tpu_function.get_tpu_context().number_of_shards
113 if number_of_shards is None:
114 raise ValueError("Can't build training loop with infeed when there is "
115 "no tpu_shard_context. Are you building a loop or "
116 "graph directly rather than from inside tpu.rewrite, "
117 "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
118 infeed_queue.set_number_of_shards(number_of_shards)
119 dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
120 else:
121 dequeue_ops = []
122 outputs = body(*(inputs + dequeue_ops))
124 # If the computation only returned one value, make it a tuple.
125 if not isinstance(outputs, (list, tuple)):
126 outputs = (outputs,)
128 outputs = [
129 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
130 for o in outputs
131 ]
133 # Separates the returned Operations and Tensors.
134 output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
135 output_tensors = [o for o in outputs
136 if not isinstance(o, ops.Operation)]
138 if outputs != output_tensors + output_operations:
139 raise ValueError(
140 "TPU training loop body must return zero or more Tensor values "
141 "followed by zero or more Operations.")
143 output_types = [op.dtype for op in output_tensors]
144 if input_types != output_types:
145 raise TypeError(
146 "Mismatch between input types and output types for training loop "
147 "body: {} vs {}".format(input_types, output_types))
149 # Add the dequeue operations to output_operations to ensure they are run
150 # by the loop, even if the programmer's loop body does not use them.
151 output_operations += dequeue_ops
153 # Add a dummy output, if needed.
154 if not output_tensors:
155 output_tensors = array_ops.constant(0)
157 if output_operations:
158 # TODO(phawkins): in principle this is too restrictive since it serializes
159 # the training loop steps. In practice it does not matter since this loop
160 # will be compiled by XLA.
161 output_tensors = control_flow_ops.tuple(output_tensors,
162 control_inputs=output_operations)
164 if tensor_tracer.TensorTracer.is_enabled():
165 num_replicas = tpu_function.get_tpu_context().number_of_shards
166 if num_replicas is None:
167 num_replicas = 1
168 tt = tensor_tracer.TensorTracer()
169 output_tensors = tt.trace_tpu(ops.get_default_graph(),
170 output_tensors, None,
171 num_replicas)
172 return output_tensors
174 # If the body has arity 0, add a dummy loop-carried value to which we can add
175 # control dependencies from any side-effecting operations.
176 if input_arity == 0:
177 inputs = [array_ops.constant(0)]
178 return while_loop_tf.while_loop(
179 condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
182def repeat(
183 n: int,
184 body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic
185 inputs: Optional[List[core_types.TensorLike]] = None,
186 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
187 name: Any = None) -> List[core_types.TensorLike]:
188 """Builds a training loop that executes a fixed number of iterations.
190 The set of loop-carried tensors correspond to `inputs`.
191 `body` must be a function that takes and returns the values of the
192 loop-carried tensors.
194 Args:
195 n: the number of loop iterations
196 body: a Python function that builds the loop body.
197 inputs: a list of initial values passed into the training loop or None
198 (equivalent to an empty list).
199 infeed_queue: if not None, the infeed queue from which to append a tuple of
200 arguments as inputs to condition.
201 name: (Deprecated) Does nothing.
203 Returns:
204 The final values of the loop-carried tensors.
205 Raises:
206 ValueError: if there is a type error.
207 """
208 def _convert_to_list(xs):
209 if not isinstance(xs, (list, tuple)):
210 return [xs]
211 else:
212 return list(xs)
214 def cond(i, *args):
215 del args
216 return i < n
218 def body_wrapper(i, *args):
219 return [i + 1] + _convert_to_list(body(*args))
221 inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
222 outputs = while_loop(
223 cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
224 outputs = _convert_to_list(outputs)
225 if len(outputs) == 1:
226 # Returns the Op rather than an empty list.
227 return outputs[0].op
228 else:
229 return outputs[1:]