Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py: 20%
416 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Code for backpropagation using the tape utilities."""
17# TODO(b/159343581): Properly support CompositeTensor in all functions in this
18# file.
20import functools
21import operator
23from tensorflow.python import pywrap_tfe
24from tensorflow.python.eager import backprop_util
25from tensorflow.python.eager import context
26from tensorflow.python.eager import execute
27from tensorflow.python.eager import imperative_grad
28from tensorflow.python.eager import tape
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import composite_tensor_gradient
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import indexed_slices
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import check_ops
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import default_gradient
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_math_ops
44from tensorflow.python.ops import resource_variable_ops
45from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_ops
46from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import _pywrap_utils
49from tensorflow.python.util import nest
50from tensorflow.python.util import tf_contextlib
51from tensorflow.python.util import tf_inspect
52from tensorflow.python.util import variable_utils
53from tensorflow.python.util.tf_export import tf_export
56_op_attr_type_cache = {}
59def op_attr_type(op_type, attr_name):
60 try:
61 return _op_attr_type_cache[(op_type, attr_name)]
62 except KeyError:
63 context.ensure_initialized()
64 h = context.context()._handle # pylint: disable=protected-access
65 attr_type = pywrap_tfe.TFE_OpNameGetAttrType(h, op_type, attr_name)
66 _op_attr_type_cache[(op_type, attr_name)] = attr_type
67 return attr_type
70def make_attr(attr_type, value):
71 # pybind11 enums do not return the raw value like SWIG enums do. They are
72 # useful when comparing amongst each other but not direct integers as we are
73 # doing in most tests.
74 # https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
75 # TODO(amitpatankar): After all SWIG transitions, convert the enum comparisons
76 # from integer value to class.
77 if attr_type == int(pywrap_tfe.TF_ATTR_TYPE):
78 return dtypes.as_dtype(value)
79 if attr_type == [int(pywrap_tfe.TF_ATTR_TYPE)]:
80 return [dtypes.as_dtype(v) for v in value]
81 if attr_type == int(pywrap_tfe.TF_ATTR_SHAPE):
82 return tensor_shape.as_shape(value).as_proto()
83 if attr_type == [int(pywrap_tfe.TF_ATTR_SHAPE)]:
84 return [tensor_shape.as_shape(v).as_proto() for v in value]
85 return nest.map_structure(
86 lambda v: v.encode() if isinstance(v, str) else v,
87 value)
90class _MockOp(object):
91 """Pretends to be a tf.Operation for the gradient functions."""
93 def __init__(self, attrs, inputs, outputs, typ, skip_input_indices):
94 self.attrs = attrs
95 self.inputs = inputs
96 self.outputs = outputs
97 self.type = typ
98 self.skip_input_indices = skip_input_indices
100 def get_attr(self, attr):
101 typ = op_attr_type(self.type, attr)
102 for i in range(0, len(self.attrs), 2):
103 if self.attrs[i] == attr:
104 return make_attr(typ, self.attrs[i + 1])
105 raise KeyError(attr)
107 def _get_control_flow_context(self):
108 raise NotImplementedError(
109 "tf.GradientTape.gradients() does not support graph control flow "
110 "operations like tf.cond or tf.while at this time. Use tf.gradients() "
111 "instead. If you need this feature, please file a feature request at "
112 "https://github.com/tensorflow/tensorflow/issues/new"
113 )
116def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
117 out_grads, skip_input_indices, forward_pass_name_scope):
118 """Calls the gradient function of the op.
120 Args:
121 op_name: the name of the op to be differentiated.
122 attr_tuple: the attrs, as a tuple.
123 num_inputs: the number of inputs to the op.
124 inputs: inputs to the original operation.
125 outputs: outputs to the original operation.
126 out_grads: gradients of the operation wrt its outputs.
127 skip_input_indices: a tuple that is passed to the gradient function,
128 indicating which inputs to skip calculating the gradient for
129 forward_pass_name_scope: the namescope of the op in the forward pass.
131 Returns:
132 The gradients with respect to the inputs of the function, as a list.
133 """
134 mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
135 grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
136 if grad_fn is None:
137 return [None] * num_inputs
139 # This does not work with v1 TensorArrays.
140 if ops.executing_eagerly_outside_functions(
141 ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
142 gradient_name_scope = "gradient_tape/"
143 if forward_pass_name_scope:
144 gradient_name_scope += forward_pass_name_scope + "/"
145 with ops.name_scope(gradient_name_scope):
146 return grad_fn(mock_op, *out_grads)
147 else:
148 return grad_fn(mock_op, *out_grads)
151pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
154def _must_record_gradient():
155 return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
158@tf_export("__internal__.record_gradient", v1=[])
159def record_gradient(op_name, inputs, attrs, outputs):
160 """Explicitly record the gradient for a given op.
162 Args:
163 op_name: The op name as listed in the `OpDef` for the op.
164 inputs: A list of tensor inputs to the op.
165 attrs: The op attributes as a flattened list of alternating attribute names
166 and attribute values.
167 outputs: A list of tensor outputs from the op.
168 """
169 pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, outputs,
170 ops.get_name_scope())
173execute.must_record_gradient = _must_record_gradient
174execute.record_gradient = record_gradient
177def implicit_val_and_grad(f):
178 """Returns a function which differentiates f with respect to variables.
180 The wrapped function returns the value and the gradient of f when called with
181 the same arguments. The gradient is with respect to all trainable TFE
182 variables accessed by `f`.
184 This function is useful when the exact set of variables to differentiate with
185 is not known ahead of time.
187 Example:
189 ```python
190 dense_layer = tf.compat.v1.layers.Dense(1)
191 def loss(x, y):
192 return tf.reduce_sum(tf.square(dense_layer(x) - y))
194 # Obtain the gradient function.
195 val_grad_fn = tfe.implicit_value_and_gradients(loss)
197 # Invoke the gradient function with concrete values of x and y.
198 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
199 y = tf.constant([[10.0], [20.0]])
200 value, grads_and_vars = val_grad_fn(x, y)
201 print('Value of loss: %s' % value)
203 # Apply the gradients to Variables.
204 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
205 optimizer.apply_gradients(grads_and_vars)
206 ```
208 Args:
209 f: function to be differentiated. If `f` returns a scalar, this scalar will
210 be differentiated. If `f` returns a tensor or list of tensors, by default
211 a scalar will be computed by adding all their values to produce a single
212 scalar.
214 Returns:
215 A function which, when called, returns a tuple pair.
216 Its first element is the value to which the function evaluates.
217 Its second element is list of (gradient, variable) pairs.
219 Raises:
220 ValueError: if `f` returns None.
221 """
222 # TODO(cais): Remove calls to tf.constant() once the gradients functions
223 # accept lists and np.ndarrays.
225 def grad_fn(*args, **kwds):
226 """Computes the gradient of the wrapped function."""
227 this_tape = tape.push_new_tape()
228 try:
229 end_node = f(*args, **kwds)
230 if end_node is None:
231 raise ValueError("Cannot differentiate a function that returns None; "
232 "did you forget to return a value from {}?".format(
233 f.__name__))
234 finally:
235 tape.pop_tape(this_tape)
236 # Note: variables are returned in construction order. This ensures unique
237 # order across executions.
238 variables = this_tape.watched_variables()
239 if not variables:
240 raise ValueError("No trainable variables were accessed while the "
241 "function was being computed.")
243 sources = [v.handle for v in variables]
244 for s in sources:
245 if getattr(s, "is_packed", False):
246 raise ValueError(
247 "GradientTape.gradient is not supported on packed EagerTensors yet."
248 )
249 grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
250 sources)
251 return end_node, list(zip(grad, variables))
253 return grad_fn
256def implicit_grad(f):
257 """Returns a function which differentiates f with respect to variables.
259 The wrapped function returns the gradient of f when called with the same
260 arguments. The gradient is with respect to all trainable TFE variables
261 accessed by `f`.
263 This function is useful when the exact set of variables to differentiate with
264 is not known ahead of time.
266 Example:
268 ```python
269 dense_layer = tf.compat.v1.layers.Dense(1)
270 def loss(x, y):
271 return tf.reduce_sum(tf.square(dense_layer(x) - y))
273 # Obtain the gradient function.
274 grad_fn = tfe.implicit_gradients(loss)
276 # Invoke the gradient function with concrete values of x and y.
277 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
278 y = tf.constant([[10.0], [20.0]])
279 grads_and_vars = grad_fn(x, y)
281 # Apply the gradients to Variables.
282 optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
283 optimizer.apply_gradients(grads_and_vars)
284 ```
286 Args:
287 f: function to be differentiated. If `f` returns a scalar, this scalar will
288 be differentiated. If `f` returns a tensor or list of tensors, by default
289 a scalar will be computed by adding all their values to produce a single
290 scalar.
292 Returns:
293 A function which, when called, returns a list of (gradient, variable) pairs.
294 """
295 # TODO(cais): Remove calls to tf.constant() once the gradients functions
296 # accept lists and np.ndarrays.
298 def grad_fn(*args, **kwds):
299 """Computes the gradient of the wrapped function."""
300 return implicit_val_and_grad(f)(*args, **kwds)[1]
302 return grad_fn
305def _get_arg_spec(f, params, param_args):
306 """The positions of the parameters of f to be differentiated in param_args."""
307 try:
308 args = tf_inspect.getfullargspec(f).args
309 except TypeError as e:
310 # TypeError can happen when f is a callable object.
311 if params is None:
312 return range(len(param_args))
313 elif all(isinstance(x, int) for x in params):
314 return params
315 raise ValueError("Either callable provided is not a function or could not "
316 "inspect its arguments by name: %s. Original error: %s"
317 % (f, e))
318 if params is None:
319 if not args:
320 return range(len(param_args))
321 if args[0] == "self":
322 return range(len(args) - 1)
323 else:
324 return range(len(args))
325 elif all(isinstance(x, str) for x in params):
326 return [args.index(n) for n in params]
327 elif all(isinstance(x, int) for x in params):
328 return params
329 else:
330 raise ValueError(
331 "params must be all strings or all integers; got %s." % params)
334def gradients_function(f, params=None):
335 """Returns a function which differentiates f with respect to params.
337 Example:
338 ```python
339 # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
340 # Therefore, the 1st order derivatives are:
341 # df / dx = 3 * (x ^ 2) * y - y ^ 2
342 # df / dy = x ^ 3 - 2 * x * y
343 # The 2nd order derivatives with respect to x is:
344 # d^2 f / (dx)^2 = 6 * x * y
345 def f(x, y):
346 return x * x * x * y - x * y * y
348 # Obtain a function that returns 1st order gradients.
349 grad_fn = tfe.gradients_function(f)
351 x = 2.0
352 y = 3.0
354 # Invoke the 1st order gradient function.
355 x_grad, y_grad = grad_fn(x, y)
356 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
357 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
359 # Obtain a function that returns the 2nd order gradient with respect to x.
360 gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0])
362 # Invoke the 2nd order gradient function.
363 x_gradgrad = gradgrad_fn(x, y)[0]
364 assert x_gradgrad.numpy() == 6 * 2 * 3
366 # To obtain a callable that returns the gradient(s) of `f` with respect to a
367 # subset of its inputs, use the `params` keyword argument with
368 # `gradients_function()`.
369 ygrad_fn = tfe.gradients_function(f, params=[1])
371 (y_grad,) = ygrad_fn(x, y)
372 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
373 ```
375 Note that only tensors with real or complex dtypes are differentiable.
377 Args:
378 f: function to be differentiated. If `f` returns a scalar, this scalar will
379 be differentiated. If `f` returns a tensor or list of tensors, by default
380 a scalar will be computed by adding all their values to produce a single
381 scalar. If desired, the tensors can be elementwise multiplied by the
382 tensors passed as the `dy` keyword argument to the returned gradient
383 function.
384 params: list of parameter names of f or list of integers indexing the
385 parameters with respect to which we'll differentiate. Passing None
386 differentiates with respect to all parameters.
388 Returns:
389 function which, when called, returns the value of f and the gradient
390 of `f` with respect to all of `params`. The function takes an extra optional
391 keyword argument `dy`. Setting it allows computation of vector jacobian
392 products for vectors other than the vector of ones.
394 Raises:
395 ValueError: if the params are not all strings or all integers.
396 """
398 def decorated(*args, **kwds):
399 """Computes the gradient of the decorated function."""
401 _, grad = val_and_grad_function(f, params=params)(*args, **kwds)
402 return grad
404 return decorated
407def _ensure_unique_tensor_objects(parameter_positions, args):
408 """Make each of the parameter_positions in args a unique ops.Tensor object.
410 Ensure that each parameter is treated independently.
411 For example:
413 def f(x, y): return x * y
414 g = gradients_function(f)
415 one = tf.constant(1.)
417 g(one, one) should return [1., 1.]
418 (even though the two arguments are the same Tensor object).
420 Args:
421 parameter_positions: List of indices into args defining the arguments to
422 differentiate against.
423 args: A list of arguments to the function to be differentiated.
425 Returns:
426 args, possibly edited in-place.
427 """
428 s = set()
429 for (i, t) in enumerate(args):
430 if i in parameter_positions:
431 tid = ops.tensor_id(t)
432 if tid in s:
433 args[i] = gen_array_ops.identity(args[i])
434 else:
435 s.add(tid)
436 return args
439def val_and_grad_function(f, params=None):
440 """Returns a function that computes f and its derivative w.r.t. params.
442 Example:
443 ```python
444 # f(x, y) = (x ^ 3) * y - x * (y ^ 2)
445 # Therefore, the 1st order derivatives are:
446 # df / dx = 3 * (x ^ 2) * y - y ^ 2
447 # df / dy = x ^ 3 - 2 * x * y
448 def f(x, y):
449 return x * x * x * y - x * y * y
451 # Obtain a function that returns the function value and the 1st order
452 # gradients.
453 val_grads_fn = tfe.value_and_gradients_function(f)
455 x = 2.0
456 y = 3.0
458 # Invoke the value-and-gradients function.
459 f_val, (x_grad, y_grad) = val_grads_fn(x, y)
460 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
461 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2
462 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
464 # To obtain a callable that returns the value of `f` and the gradient(s) of
465 # `f` with respect to a subset of its inputs, use the `params` keyword
466 # argument with `value_and_gradients_function()`.
467 val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1])
469 f_val, (y_grad,) = val_ygrad_fn(x, y)
470 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2)
471 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3
472 ```
474 Args:
475 f: function to be differentiated. If `f` returns a scalar, this scalar will
476 be differentiated. If `f` returns a tensor or list of tensors, by default
477 a scalar will be computed by adding all their values to produce a single
478 scalar. If desired, the tensors can be elementwise multiplied by the
479 tensors passed as the `dy` keyword argument to the returned gradient
480 function.
481 params: list of parameter names of f or list of integers indexing the
482 parameters with respect to which we'll differentiate. Passing `None`
483 differentiates with respect to all parameters.
485 Returns:
486 function which, when called, returns the value of f and the gradient
487 of f with respect to all of `params`. The function takes an extra optional
488 keyword argument "dy". Setting it allows computation of vector jacobian
489 products for vectors other than the vector of ones.
491 Raises:
492 ValueError: if the params are not all strings or all integers.
493 """
495 def decorated(*args, **kwds):
496 """Computes the value and gradient of the decorated function."""
497 dy = kwds.pop("dy", None)
498 if kwds:
499 raise ValueError("Functions to be differentiated cannot "
500 "receive keyword arguments.")
501 val, vjp = make_vjp(f, params)(*args, **kwds)
502 return val, vjp(dy=dy)
504 return decorated
507def make_vjp(f, params=None, persistent=True):
508 """Returns a function that computes f and its vjp w.r.t.
510 params.
512 The term "vjp" here is an abbreviation for vector-jacobian product.
514 Args:
515 f: the function to be differentiated.
516 params: the parameters (numbers or names) to differentiate with respect to.
517 A value of None will differentiate with respect to all parameters.
518 persistent: Boolean controlling whether the VJP function can be re-used.
519 Must be True or False.
521 Returns:
522 A function, which when called, returns a tuple (value, vjp), where:
523 - value is the result of calling f.
524 - vjp is a function, which takes a vector as an argument and
525 returns the product of that vector with the Jacobian of f.
526 Providing no argument to vjp is equivalent to providing a
527 vector of ones.
529 For example,
530 ```python
531 def f(x):
532 return x * x
534 wrapped_fn = tfe.make_vjp(f)
535 result, vjp = wrapped_fn(tf.constant(3.0))
536 # result is 9.0
537 vjp() # the vjp function returns 6.0
539 Raises:
540 ValueError: if `f` returns None.
541 """
543 def decorated(*args, **kwds):
544 """Computes the value and gradient of the decorated function."""
545 parameter_positions = _get_arg_spec(f, params, args)
546 assert not kwds, "The gradient function can't take keyword arguments."
547 this_tape = tape.push_new_tape(persistent=persistent)
548 try:
549 sources = []
550 args = [
551 ops.convert_to_tensor(arg) if i in parameter_positions else arg
552 for i, arg in enumerate(args)
553 ]
554 args = _ensure_unique_tensor_objects(parameter_positions, args)
555 for i in parameter_positions:
556 if getattr(args[i], "is_packed", False):
557 raise ValueError(
558 "GradientTape.gradient is not supported on packed EagerTensors"
559 "yet.")
560 sources.append(args[i])
561 tape.watch(this_tape, args[i])
562 result = f(*args)
563 if result is None:
564 raise ValueError("Cannot differentiate a function that returns None; "
565 "did you forget to return a value from {}?".format(
566 f.__name__))
567 flat_result = nest.flatten(result)
568 flat_result = [gen_array_ops.identity(x) for x in flat_result]
569 result = nest.pack_sequence_as(result, flat_result)
570 finally:
571 tape.pop_tape(this_tape)
572 def vjp(dy=None):
573 if dy is not None:
574 dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
575 return imperative_grad.imperative_grad(
576 this_tape, nest.flatten(result), sources, output_gradients=dy)
578 return result, vjp
580 return decorated
583def _aggregate_grads(gradients):
584 """Aggregate gradients from multiple sources.
586 Args:
587 gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
589 Returns:
590 If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
591 Otherwise returns an aggregated 'IndexedSlices'.
592 """
593 assert gradients, "No gradients to aggregate"
595 if len(gradients) == 1:
596 return gradients[0]
597 if all(isinstance(g, ops.Tensor) for g in gradients):
598 return gen_math_ops.add_n(gradients)
599 else:
600 assert all(
601 isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices))
602 for g in gradients)
603 return backprop_util.AggregateIndexedSlicesGradients(gradients)
606def _num_elements(grad):
607 """The number of elements in the `grad` tensor."""
608 if isinstance(grad, ops.Tensor):
609 shape_tuple = grad._shape_tuple() # pylint: disable=protected-access
610 elif isinstance(grad, indexed_slices.IndexedSlices):
611 shape_tuple = grad.values._shape_tuple() # pylint: disable=protected-access
612 else:
613 raise ValueError("`grad` not a Tensor or IndexedSlices.")
614 if shape_tuple is None or None in shape_tuple:
615 return 0
616 return functools.reduce(operator.mul, shape_tuple, 1)
619def _fast_fill(value, shape, dtype):
620 return array_ops.fill(
621 constant_op.constant(shape, dtype=dtypes.int32),
622 constant_op.constant(value, dtype=dtype))
625def _zeros(shape, dtype):
626 """Helper to return (possibly cached) zero tensors in eager mode."""
627 # Note: variants will use _zeros_like
628 if dtype == dtypes.string or dtype == dtypes.resource:
629 return None
631 ctx = context.context()
632 if not ctx.executing_eagerly():
633 return array_ops.zeros(shape, dtype)
635 device = ctx.device_name
637 if tensor_util.is_tf_type(shape):
638 shape_key = shape.ref()
639 else:
640 shape_key = shape
641 cache_key = shape_key, dtype, device
642 cached = ctx.zeros_cache().get(cache_key)
643 if cached is None:
644 if dtypes.as_dtype(dtype).is_bool:
645 value = False
646 else:
647 value = 0
648 cached = _fast_fill(value, shape, dtype)
649 ctx.zeros_cache().put(cache_key, cached)
650 return cached
653def _ones(shape, dtype):
654 as_dtype = dtypes.as_dtype(dtype)
655 if as_dtype == dtypes.string:
656 return None
658 if not context.executing_eagerly():
659 return array_ops.ones(shape, dtype)
661 if as_dtype.is_bool:
662 value = True
663 else:
664 value = 1
666 if shape == (): # pylint: disable=g-explicit-bool-comparison
667 return constant_op.constant(value, dtype=dtype)
668 return _fast_fill(value, shape, dtype)
671_default_vspace = imperative_grad.VSpace(
672 num_elements_fn=_num_elements,
673 aggregate_fn=_aggregate_grads,
674 zeros_fn=_zeros,
675 ones_fn=_ones,
676 zeros_like_fn=default_gradient.zeros_like,
677 ones_like_fn=default_gradient.ones_like,
678 graph_shape_fn=gen_array_ops.shape)
679pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
682def _handle_or_self(x):
683 """Unwrap resource variable/ndarray to return tensors."""
684 if resource_variable_ops.is_resource_variable(x):
685 return x.handle
686 return x
689def _extract_tensors_and_variables(tensor):
690 """Extracts tensors and variables from the input object."""
691 for obj in nest.flatten(tensor):
692 if _pywrap_utils.IsTensor(obj) or _pywrap_utils.IsVariable(obj):
693 yield obj
694 elif isinstance(obj, composite_tensor.CompositeTensor):
695 components = type_spec.type_spec_from_value(obj)._to_components(obj) # pylint: disable=protected-access
696 yield from _extract_tensors_and_variables(components)
697 else:
698 raise ValueError(f"Passed in object {obj} of type {type(obj).__name__!r}"
699 f", not tf.Tensor or tf.Variable or ExtensionType.")
702@tf_export("GradientTape", "autodiff.GradientTape", v1=["GradientTape"])
703class GradientTape:
704 """Record operations for automatic differentiation.
706 Operations are recorded if they are executed within this context manager and
707 at least one of their inputs is being "watched".
709 Trainable variables (created by `tf.Variable` or `tf.compat.v1.get_variable`,
710 where `trainable=True` is default in both cases) are automatically watched.
711 Tensors can be manually watched by invoking the `watch` method on this context
712 manager.
714 For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
715 be computed as:
717 >>> x = tf.constant(3.0)
718 >>> with tf.GradientTape() as g:
719 ... g.watch(x)
720 ... y = x * x
721 >>> dy_dx = g.gradient(y, x)
722 >>> print(dy_dx)
723 tf.Tensor(6.0, shape=(), dtype=float32)
725 GradientTapes can be nested to compute higher-order derivatives. For example,
727 >>> x = tf.constant(5.0)
728 >>> with tf.GradientTape() as g:
729 ... g.watch(x)
730 ... with tf.GradientTape() as gg:
731 ... gg.watch(x)
732 ... y = x * x
733 ... dy_dx = gg.gradient(y, x) # dy_dx = 2 * x
734 >>> d2y_dx2 = g.gradient(dy_dx, x) # d2y_dx2 = 2
735 >>> print(dy_dx)
736 tf.Tensor(10.0, shape=(), dtype=float32)
737 >>> print(d2y_dx2)
738 tf.Tensor(2.0, shape=(), dtype=float32)
740 By default, the resources held by a GradientTape are released as soon as
741 GradientTape.gradient() method is called. To compute multiple gradients over
742 the same computation, create a persistent gradient tape. This allows multiple
743 calls to the gradient() method as resources are released when the tape object
744 is garbage collected. For example:
746 >>> x = tf.constant(3.0)
747 >>> with tf.GradientTape(persistent=True) as g:
748 ... g.watch(x)
749 ... y = x * x
750 ... z = y * y
751 >>> dz_dx = g.gradient(z, x) # (4*x^3 at x = 3)
752 >>> print(dz_dx)
753 tf.Tensor(108.0, shape=(), dtype=float32)
754 >>> dy_dx = g.gradient(y, x)
755 >>> print(dy_dx)
756 tf.Tensor(6.0, shape=(), dtype=float32)
758 By default GradientTape will automatically watch any trainable variables that
759 are accessed inside the context. If you want fine grained control over which
760 variables are watched you can disable automatic tracking by passing
761 `watch_accessed_variables=False` to the tape constructor:
763 >>> x = tf.Variable(2.0)
764 >>> w = tf.Variable(5.0)
765 >>> with tf.GradientTape(
766 ... watch_accessed_variables=False, persistent=True) as tape:
767 ... tape.watch(x)
768 ... y = x ** 2 # Gradients will be available for `x`.
769 ... z = w ** 3 # No gradients will be available as `w` isn't being watched.
770 >>> dy_dx = tape.gradient(y, x)
771 >>> print(dy_dx)
772 tf.Tensor(4.0, shape=(), dtype=float32)
773 >>> # No gradients will be available as `w` isn't being watched.
774 >>> dz_dw = tape.gradient(z, w)
775 >>> print(dz_dw)
776 None
778 Note that when using models you should ensure that your variables exist when
779 using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
780 first iteration not have any gradients:
782 ```python
783 a = tf.keras.layers.Dense(32)
784 b = tf.keras.layers.Dense(32)
786 with tf.GradientTape(watch_accessed_variables=False) as tape:
787 tape.watch(a.variables) # Since `a.build` has not been called at this point
788 # `a.variables` will return an empty list and the
789 # tape will not be watching anything.
790 result = b(a(inputs))
791 tape.gradient(result, a.variables) # The result of this computation will be
792 # a list of `None`s since a's variables
793 # are not being watched.
794 ```
796 Note that only tensors with real or complex dtypes are differentiable.
797 """
799 def __init__(self, persistent=False, watch_accessed_variables=True):
800 """Creates a new GradientTape.
802 Args:
803 persistent: Boolean controlling whether a persistent gradient tape
804 is created. False by default, which means at most one call can
805 be made to the gradient() method on this object.
806 watch_accessed_variables: Boolean controlling whether the tape will
807 automatically `watch` any (trainable) variables accessed while the tape
808 is active. Defaults to True meaning gradients can be requested from any
809 result computed in the tape derived from reading a trainable `Variable`.
810 If False users must explicitly `watch` any `Variable`s they want to
811 request gradients from.
812 """
813 self._tape = None
814 self._persistent = persistent
815 self._watch_accessed_variables = watch_accessed_variables
816 self._watched_variables = ()
817 self._recording = False
819 def __enter__(self):
820 """Enters a context inside which operations are recorded on this tape."""
821 self._push_tape()
822 return self
824 def __exit__(self, typ, value, traceback):
825 """Exits the recording context, no further operations are traced."""
826 if self._recording:
827 self._pop_tape()
829 def _push_tape(self):
830 """Pushes a new tape onto the tape stack."""
831 if self._recording:
832 raise ValueError("Tape is still recording, This can happen if you try to "
833 "re-enter an already-active tape.")
834 if self._tape is None:
835 self._tape = tape.push_new_tape(
836 persistent=self._persistent,
837 watch_accessed_variables=self._watch_accessed_variables)
838 else:
839 tape.push_tape(self._tape)
840 self._recording = True
842 def _pop_tape(self):
843 if not self._recording:
844 raise ValueError("Tape is not recording.")
845 tape.pop_tape(self._tape)
846 self._recording = False
848 @tf_contextlib.contextmanager
849 def _ensure_recording(self):
850 """Ensures that this tape is recording."""
851 if not self._recording:
852 try:
853 self._push_tape()
854 yield
855 finally:
856 self._pop_tape()
857 else:
858 yield
860 # TODO(b/209081027): Add a variable in composite tensor test case after
861 # variables become composite tensors.
862 def watch(self, tensor):
863 """Ensures that `tensor` is being traced by this tape.
865 Args:
866 tensor: a Tensor/Variable or list of Tensors/Variables.
868 Raises:
869 ValueError: if it encounters something that is not a tensor.
870 """
871 for t in _extract_tensors_and_variables(tensor):
872 if not backprop_util.IsTrainable(t):
873 logging.log_first_n(
874 logging.WARN, "The dtype of the watched tensor must be "
875 "floating (e.g. tf.float32), got %r", 5, t.dtype)
876 if hasattr(t, "handle"):
877 # There are many variable-like objects, all of them currently have
878 # `handle` attribute that points to a tensor. If this changes,
879 # internals of watch_variable need to change as well.
880 tape.watch_variable(self._tape, t)
881 else:
882 tape.watch(self._tape, t)
884 @tf_contextlib.contextmanager
885 def stop_recording(self):
886 """Temporarily stops recording operations on this tape.
888 Operations executed while this context manager is active will not be
889 recorded on the tape. This is useful for reducing the memory used by tracing
890 all computations.
892 For example:
894 >>> x = tf.constant(4.0)
895 >>> with tf.GradientTape() as tape:
896 ... with tape.stop_recording():
897 ... y = x ** 2
898 >>> dy_dx = tape.gradient(y, x)
899 >>> print(dy_dx)
900 None
902 Yields:
903 None
904 Raises:
905 RuntimeError: if the tape is not currently recording.
906 """
907 if self._tape is None:
908 raise RuntimeError(
909 "Trying to stop recording a tape which is not recording.")
910 self._pop_tape()
911 try:
912 yield
913 finally:
914 self._push_tape()
916 def reset(self):
917 """Clears all information stored in this tape.
919 Equivalent to exiting and reentering the tape context manager with a new
920 tape. For example, the two following code blocks are equivalent:
922 ```
923 with tf.GradientTape() as t:
924 loss = loss_fn()
925 with tf.GradientTape() as t:
926 loss += other_loss_fn()
927 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
930 # The following is equivalent to the above
931 with tf.GradientTape() as t:
932 loss = loss_fn()
933 t.reset()
934 loss += other_loss_fn()
935 t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
936 ```
938 This is useful if you don't want to exit the context manager for the tape,
939 or can't because the desired reset point is inside a control flow construct:
941 ```
942 with tf.GradientTape() as t:
943 loss = ...
944 if loss > k:
945 t.reset()
946 ```
947 """
948 self._pop_tape()
949 self._tape = None
950 self._push_tape()
952 def watched_variables(self):
953 """Returns variables watched by this tape in order of construction."""
954 if self._tape is not None:
955 self._watched_variables = self._tape.watched_variables()
956 return self._watched_variables
958 def gradient(self,
959 target,
960 sources,
961 output_gradients=None,
962 unconnected_gradients=UnconnectedGradients.NONE):
963 """Computes the gradient using operations recorded in context of this tape.
965 Note: Unless you set `persistent=True` a GradientTape can only be used to
966 compute one set of gradients (or jacobians).
968 In addition to Tensors, gradient also supports RaggedTensors. For example,
970 >>> x = tf.ragged.constant([[1.0, 2.0], [3.0]])
971 >>> with tf.GradientTape() as g:
972 ... g.watch(x)
973 ... y = x * x
974 >>> g.gradient(y, x)
975 <tf.RaggedTensor [[2.0, 4.0], [6.0]]>
977 Args:
978 target: a list or nested structure of Tensors or Variables or
979 CompositeTensors to be differentiated.
980 sources: a list or nested structure of Tensors or Variables or
981 CompositeTensors. `target` will be differentiated against elements in
982 `sources`.
983 output_gradients: a list of gradients, one for each differentiable
984 element of target. Defaults to None.
985 unconnected_gradients: a value which can either hold 'none' or 'zero' and
986 alters the value which will be returned if the target and sources are
987 unconnected. The possible values and effects are detailed in
988 'UnconnectedGradients' and it defaults to 'none'.
990 Returns:
991 a list or nested structure of Tensors (or IndexedSlices, or None, or
992 CompositeTensor), one for each element in `sources`. Returned structure
993 is the same as the structure of `sources`.
995 Raises:
996 RuntimeError: If called on a used, non-persistent tape.
997 RuntimeError: If called inside the context of the tape.
998 TypeError: If the target is a None object.
999 ValueError: If the target is a variable or if unconnected gradients is
1000 called with an unknown value.
1001 """
1002 if self._tape is None:
1003 raise RuntimeError("A non-persistent GradientTape can only be used to "
1004 "compute one set of gradients (or jacobians)")
1005 if self._recording:
1006 if not self._persistent:
1007 self._pop_tape()
1008 else:
1009 logging.log_first_n(
1010 logging.WARN, "Calling GradientTape.gradient on a persistent "
1011 "tape inside its context is significantly less "
1012 "efficient than calling it outside the context (it "
1013 "causes the gradient ops to be recorded on the "
1014 "tape, leading to increased CPU and memory usage). "
1015 "Only call GradientTape.gradient inside the "
1016 "context if you actually want to trace the "
1017 "gradient in order to compute higher order "
1018 "derivatives.", 1)
1020 if target is None:
1021 raise TypeError("Argument `target` should be a list or nested structure"
1022 " of Tensors, Variables or CompositeTensors to be "
1023 "differentiated, but received None.")
1025 flat_targets = []
1026 for t in nest.flatten(target):
1027 flat_targets.append(_handle_or_self(t))
1028 flat_targets = composite_tensor_gradient.get_flat_tensors_for_gradients(
1029 flat_targets)
1030 for t in flat_targets:
1031 if not backprop_util.IsTrainable(t):
1032 logging.vlog(
1033 1, "The dtype of the target tensor must be "
1034 "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1035 "got %r", t.dtype)
1037 flat_sources_raw = nest.flatten(sources)
1038 flat_sources = []
1039 for t in flat_sources_raw:
1040 flat_sources.append(_handle_or_self(t))
1041 flat_sources = composite_tensor_gradient.get_flat_tensors_for_gradients(
1042 flat_sources)
1043 for t in flat_sources:
1044 if not backprop_util.IsTrainable(t):
1045 logging.vlog(
1046 1, "The dtype of the source tensor must be "
1047 "floating (e.g. tf.float32) when calling GradientTape.gradient, "
1048 "got %r", t.dtype)
1049 if getattr(t, "is_packed", False):
1050 raise ValueError(
1051 "GradientTape.gradient is not supported on packed EagerTensors yet."
1052 )
1054 if output_gradients is not None:
1055 output_gradients = nest.flatten(
1056 variable_utils.convert_variables_to_tensors(output_gradients))
1057 output_gradients = (
1058 composite_tensor_gradient.get_flat_tensors_for_gradients(
1059 output_gradients))
1060 output_gradients = [None if x is None else ops.convert_to_tensor(x)
1061 for x in output_gradients]
1063 flat_grad = imperative_grad.imperative_grad(
1064 self._tape,
1065 flat_targets,
1066 flat_sources,
1067 output_gradients=output_gradients,
1068 sources_raw=flat_sources_raw,
1069 unconnected_gradients=unconnected_gradients)
1071 if not self._persistent:
1072 # Keep track of watched variables before setting tape to None
1073 self._watched_variables = self._tape.watched_variables()
1074 self._tape = None
1076 flat_sources_raw = nest.map_structure(_handle_or_self, flat_sources_raw)
1077 flat_grad = composite_tensor_gradient.replace_flat_tensors_for_gradients(
1078 flat_sources_raw, flat_grad)
1079 grad = nest.pack_sequence_as(sources, flat_grad)
1080 return grad
1082 def jacobian(self,
1083 target,
1084 sources,
1085 unconnected_gradients=UnconnectedGradients.NONE,
1086 parallel_iterations=None,
1087 experimental_use_pfor=True):
1088 """Computes the jacobian using operations recorded in context of this tape.
1090 Note: Unless you set `persistent=True` a GradientTape can only be used to
1091 compute one set of gradients (or jacobians).
1093 Note: By default the jacobian implementation uses parallel for (pfor), which
1094 creates a tf.function under the hood for each jacobian call. For better
1095 performance, and to avoid recompilation and vectorization rewrites on each
1096 call, enclose GradientTape code in @tf.function.
1098 See[wikipedia
1099 article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1100 for the definition of a Jacobian.
1102 Example usage:
1104 ```python
1105 with tf.GradientTape() as g:
1106 x = tf.constant([1.0, 2.0])
1107 g.watch(x)
1108 y = x * x
1109 jacobian = g.jacobian(y, x)
1110 # jacobian value is [[2., 0.], [0., 4.]]
1111 ```
1113 Args:
1114 target: Tensor to be differentiated.
1115 sources: a list or nested structure of Tensors or Variables. `target`
1116 will be differentiated against elements in `sources`.
1117 unconnected_gradients: a value which can either hold 'none' or 'zero' and
1118 alters the value which will be returned if the target and sources are
1119 unconnected. The possible values and effects are detailed in
1120 'UnconnectedGradients' and it defaults to 'none'.
1121 parallel_iterations: A knob to control how many iterations are dispatched
1122 in parallel. This knob can be used to control the total memory usage.
1123 experimental_use_pfor: If true, vectorizes the jacobian computation. Else
1124 falls back to a sequential while_loop. Vectorization can sometimes fail
1125 or lead to excessive memory usage. This option can be used to disable
1126 vectorization in such cases.
1128 Returns:
1129 A list or nested structure of Tensors (or None), one for each element in
1130 `sources`. Returned structure is the same as the structure of `sources`.
1131 Note if any gradient is sparse (IndexedSlices), jacobian function
1132 currently makes it dense and returns a Tensor instead. This may change in
1133 the future.
1136 Raises:
1137 RuntimeError: If called on a used, non-persistent tape.
1138 RuntimeError: If called on a non-persistent tape with eager execution
1139 enabled and without enabling experimental_use_pfor.
1140 ValueError: If vectorization of jacobian computation fails.
1141 """
1142 if self._tape is None:
1143 raise RuntimeError("A non-persistent GradientTape can only be used to "
1144 "compute one set of gradients (or jacobians)")
1146 flat_sources = nest.flatten(sources)
1147 target_static_shape = target.shape
1148 target_shape = array_ops.shape(target)
1149 # Note that we push and pop the tape here and below. This is needed since we
1150 # need gradients through the enclosed operations.
1151 with self._ensure_recording():
1152 target = array_ops.reshape(target, [-1])
1154 def loop_fn(i):
1155 with self._ensure_recording():
1156 y = array_ops.gather(target, i)
1157 return self.gradient(y, flat_sources,
1158 unconnected_gradients=unconnected_gradients)
1160 try:
1161 target_size = int(target.shape[0])
1162 except TypeError:
1163 target_size = array_ops.shape(target)[0]
1165 if experimental_use_pfor:
1166 try:
1167 output = pfor_ops.pfor(loop_fn, target_size,
1168 parallel_iterations=parallel_iterations)
1169 except ValueError as err:
1170 raise ValueError(
1171 "Encountered an exception while vectorizing the "
1172 "jacobian computation. Vectorization can be disabled by setting"
1173 " experimental_use_pfor to False.") from err
1174 else:
1175 if context.executing_eagerly() and not self._persistent:
1176 raise RuntimeError(
1177 "GradientTape must be created with persistent=True"
1178 " to compute the jacobian with eager execution enabled and with "
1179 " experimental_use_pfor set to False.")
1180 output = pfor_ops.for_loop(
1181 loop_fn, [target.dtype] * len(flat_sources), target_size,
1182 parallel_iterations=parallel_iterations)
1184 for i, out in enumerate(output):
1185 if out is not None:
1186 new_shape = array_ops.concat(
1187 [target_shape, array_ops.shape(out)[1:]], axis=0)
1188 out = array_ops.reshape(out, new_shape)
1189 if context.executing_eagerly():
1190 out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
1191 output[i] = out
1193 return nest.pack_sequence_as(sources, output)
1195 def batch_jacobian(self,
1196 target,
1197 source,
1198 unconnected_gradients=UnconnectedGradients.NONE,
1199 parallel_iterations=None,
1200 experimental_use_pfor=True):
1201 """Computes and stacks per-example jacobians.
1203 See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant)
1204 for the definition of a Jacobian. This function is essentially an efficient
1205 implementation of the following:
1207 `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
1209 Note that compared to `GradientTape.jacobian` which computes gradient of
1210 each output value w.r.t each input value, this function is useful when
1211 `target[i,...]` is independent of `source[j,...]` for `j != i`. This
1212 assumption allows more efficient computation as compared to
1213 `GradientTape.jacobian`. The output, as well as intermediate activations,
1214 are lower dimensional and avoid a bunch of redundant zeros which would
1215 result in the jacobian computation given the independence assumption.
1217 Note: Unless you set `persistent=True` a GradientTape can only be used to
1218 compute one set of gradients (or jacobians).
1220 Note: By default the batch_jacobian implementation uses parallel for (pfor),
1221 which creates a tf.function under the hood for each batch_jacobian call.
1222 For better performance, and to avoid recompilation and vectorization
1223 rewrites on each call, enclose GradientTape code in @tf.function.
1226 Example usage:
1228 ```python
1229 with tf.GradientTape() as g:
1230 x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
1231 g.watch(x)
1232 y = x * x
1233 batch_jacobian = g.batch_jacobian(y, x)
1234 # batch_jacobian is [[[2, 0], [0, 4]], [[6, 0], [0, 8]]]
1235 ```
1237 Args:
1238 target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
1239 `target[i,...]` should only depend on `source[i,...]`.
1240 source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
1241 unconnected_gradients: a value which can either hold 'none' or 'zero' and
1242 alters the value which will be returned if the target and sources are
1243 unconnected. The possible values and effects are detailed in
1244 'UnconnectedGradients' and it defaults to 'none'.
1245 parallel_iterations: A knob to control how many iterations are dispatched
1246 in parallel. This knob can be used to control the total memory usage.
1247 experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
1248 uses a tf.while_loop.
1250 Returns:
1251 A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
1252 is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
1253 per-example jacobians.
1255 Raises:
1256 RuntimeError: If called on a used, non-persistent tape.
1257 RuntimeError: If called on a non-persistent tape with eager execution
1258 enabled and without enabling experimental_use_pfor.
1259 ValueError: If vectorization of jacobian computation fails or if first
1260 dimension of `target` and `source` do not match.
1261 """
1262 if self._tape is None:
1263 raise RuntimeError("A non-persistent GradientTape can only be used to"
1264 "compute one set of gradients (or jacobians)")
1265 target_shape = target.shape
1266 if target_shape.rank is None:
1267 dim = tensor_shape.Dimension(None)
1268 else:
1269 dim = target_shape.dims[0]
1270 if not (target_shape.with_rank_at_least(2) and
1271 source.shape.with_rank_at_least(2) and
1272 dim.is_compatible_with(source.shape[0])):
1273 raise ValueError(
1274 "Need first dimension of target shape (%s) and "
1275 "source shape (%s) to match." % (target.shape, source.shape))
1276 if target_shape.is_fully_defined():
1277 batch_size = int(target_shape[0])
1278 target_row_size = target_shape.num_elements() // batch_size
1279 else:
1280 target_shape = array_ops.shape(target)
1281 batch_size = target_shape[0]
1282 target_row_size = array_ops.size(target) // batch_size
1283 source_shape = array_ops.shape(source)
1284 # Flatten target to 2-D.
1285 # Note that we push and pop the tape here and below. This is needed since we
1286 # need gradients through the enclosed operations.
1287 with self._ensure_recording():
1288 with ops.control_dependencies(
1289 [check_ops.assert_equal(batch_size, source_shape[0])]):
1290 target = array_ops.reshape(target, [batch_size, target_row_size])
1292 run_once = False
1294 def loop_fn(i):
1295 nonlocal run_once
1296 if run_once and not self._persistent:
1297 if parallel_iterations is not None:
1298 raise RuntimeError(
1299 "GradientTape must be created with persistent=True"
1300 " to compute the batch_jacobian with parallel_iterations.")
1301 else:
1302 raise RuntimeError(
1303 "GradientTape must be created with persistent=True"
1304 " to compute the batch_jacobian.")
1305 run_once = True
1307 with self._ensure_recording():
1308 y = array_ops.gather(target, i, axis=1)
1309 return self.gradient(y, source,
1310 unconnected_gradients=unconnected_gradients)
1312 if experimental_use_pfor:
1313 try:
1314 output = pfor_ops.pfor(loop_fn, target_row_size,
1315 parallel_iterations=parallel_iterations)
1316 except ValueError as err:
1317 raise ValueError(
1318 "Encountered an exception while vectorizing the "
1319 "batch_jacobian computation. Vectorization can be disabled by "
1320 "setting experimental_use_pfor to False.") from err
1321 else:
1322 if context.executing_eagerly() and not self._persistent:
1323 raise RuntimeError(
1324 "GradientTape must be created with persistent=True"
1325 " to compute the batch_jacobian with eager execution enabled and "
1326 " with experimental_use_pfor set to False.")
1327 output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
1328 parallel_iterations=parallel_iterations)
1329 new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
1330 if output is None:
1331 # Note that this block is returning zeros when it could use `None` to
1332 # represent unconnected gradients. This is to maintain compatibility with
1333 # the previous behavior, which ignored `unconnected_gradients`.
1334 output = array_ops.zeros(new_shape, target.dtype)
1335 return output
1336 else:
1337 output = array_ops.reshape(output,
1338 [target_row_size, batch_size, -1])
1339 output = array_ops.transpose(output, [1, 0, 2])
1341 output = array_ops.reshape(output, new_shape)
1342 return output