Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py: 29%
1254 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Tests for this file live in python/kernel_tests/array_ops_test.py
16"""Support for manipulating tensors."""
18import numbers
19import numpy as np
21from tensorflow.python.eager import context
22from tensorflow.python.eager import record
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import indexed_slices
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_conversion_registry
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_util
34# 'Constant' gets imported in the module 'array_ops'.
35from tensorflow.python.framework.constant_op import constant
36from tensorflow.python.ops import array_ops_stack
37from tensorflow.python.ops import gen_array_ops
38from tensorflow.python.ops import gen_math_ops
39from tensorflow.python.ops import shape_util
40# go/tf-wildcard-import
41# pylint: disable=wildcard-import
42from tensorflow.python.ops.gen_array_ops import *
43from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
44from tensorflow.python.types import core
45from tensorflow.python.util import _pywrap_utils
46from tensorflow.python.util import deprecation
47from tensorflow.python.util import dispatch
48from tensorflow.python.util import nest
49from tensorflow.python.util import tf_decorator
50from tensorflow.python.util.tf_export import tf_export
51# pylint: enable=wildcard-import
53# Used for slicing to specify a new 1 size dimension
54newaxis = None
55tf_export("newaxis").export_constant(__name__, "newaxis")
57# We override the 'slice' for the "slice" op, so we keep Python's
58# existing 'slice' for later use in this module.
59_BaseSlice = slice
62@tf_export("reshape", v1=["reshape", "manip.reshape"])
63@dispatch.add_dispatch_support
64def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
65 r"""Reshapes a tensor.
67 Given `tensor`, this operation returns a new `tf.Tensor` that has the same
68 values as `tensor` in the same order, except with a new shape given by
69 `shape`.
71 >>> t1 = [[1, 2, 3],
72 ... [4, 5, 6]]
73 >>> print(tf.shape(t1).numpy())
74 [2 3]
75 >>> t2 = tf.reshape(t1, [6])
76 >>> t2
77 <tf.Tensor: shape=(6,), dtype=int32,
78 numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
79 >>> tf.reshape(t2, [3, 2])
80 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
81 array([[1, 2],
82 [3, 4],
83 [5, 6]], dtype=int32)>
85 The `tf.reshape` does not change the order of or the total number of elements
86 in the tensor, and so it can reuse the underlying data buffer. This makes it
87 a fast operation independent of how big of a tensor it is operating on.
89 >>> tf.reshape([1, 2, 3], [2, 2])
90 Traceback (most recent call last):
91 ...
92 InvalidArgumentError: Input to reshape is a tensor with 3 values, but the
93 requested shape has 4
95 To instead reorder the data to rearrange the dimensions of a tensor, see
96 `tf.transpose`.
98 >>> t = [[1, 2, 3],
99 ... [4, 5, 6]]
100 >>> tf.reshape(t, [3, 2]).numpy()
101 array([[1, 2],
102 [3, 4],
103 [5, 6]], dtype=int32)
104 >>> tf.transpose(t, perm=[1, 0]).numpy()
105 array([[1, 4],
106 [2, 5],
107 [3, 6]], dtype=int32)
109 If one component of `shape` is the special value -1, the size of that
110 dimension is computed so that the total size remains constant. In particular,
111 a `shape` of `[-1]` flattens into 1-D. At most one component of `shape` can
112 be -1.
114 >>> t = [[1, 2, 3],
115 ... [4, 5, 6]]
116 >>> tf.reshape(t, [-1])
117 <tf.Tensor: shape=(6,), dtype=int32,
118 numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
119 >>> tf.reshape(t, [3, -1])
120 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
121 array([[1, 2],
122 [3, 4],
123 [5, 6]], dtype=int32)>
124 >>> tf.reshape(t, [-1, 2])
125 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
126 array([[1, 2],
127 [3, 4],
128 [5, 6]], dtype=int32)>
130 `tf.reshape(t, [])` reshapes a tensor `t` with one element to a scalar.
132 >>> tf.reshape([7], []).numpy()
133 7
135 More examples:
137 >>> t = [1, 2, 3, 4, 5, 6, 7, 8, 9]
138 >>> print(tf.shape(t).numpy())
139 [9]
140 >>> tf.reshape(t, [3, 3])
141 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
142 array([[1, 2, 3],
143 [4, 5, 6],
144 [7, 8, 9]], dtype=int32)>
146 >>> t = [[[1, 1], [2, 2]],
147 ... [[3, 3], [4, 4]]]
148 >>> print(tf.shape(t).numpy())
149 [2 2 2]
150 >>> tf.reshape(t, [2, 4])
151 <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
152 array([[1, 1, 2, 2],
153 [3, 3, 4, 4]], dtype=int32)>
155 >>> t = [[[1, 1, 1],
156 ... [2, 2, 2]],
157 ... [[3, 3, 3],
158 ... [4, 4, 4]],
159 ... [[5, 5, 5],
160 ... [6, 6, 6]]]
161 >>> print(tf.shape(t).numpy())
162 [3 2 3]
163 >>> # Pass '[-1]' to flatten 't'.
164 >>> tf.reshape(t, [-1])
165 <tf.Tensor: shape=(18,), dtype=int32,
166 numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
167 dtype=int32)>
168 >>> # -- Using -1 to infer the shape --
169 >>> # Here -1 is inferred to be 9:
170 >>> tf.reshape(t, [2, -1])
171 <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
172 array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
173 [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
174 >>> # -1 is inferred to be 2:
175 >>> tf.reshape(t, [-1, 9])
176 <tf.Tensor: shape=(2, 9), dtype=int32, numpy=
177 array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
178 [4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
179 >>> # -1 is inferred to be 3:
180 >>> tf.reshape(t, [ 2, -1, 3])
181 <tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
182 array([[[1, 1, 1],
183 [2, 2, 2],
184 [3, 3, 3]],
185 [[4, 4, 4],
186 [5, 5, 5],
187 [6, 6, 6]]], dtype=int32)>
189 Args:
190 tensor: A `Tensor`.
191 shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
192 Defines the shape of the output tensor.
193 name: Optional string. A name for the operation.
195 Returns:
196 A `Tensor`. Has the same type as `tensor`.
197 """
198 result = gen_array_ops.reshape(tensor, shape, name)
199 shape_util.maybe_set_static_shape(result, shape)
200 return result
203@tf_export("fill")
204@dispatch.add_dispatch_support
205def fill(dims, value, name=None):
206 r"""Creates a tensor filled with a scalar value.
208 See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`.
210 This operation creates a tensor of shape `dims` and fills it with `value`.
212 For example:
214 >>> tf.fill([2, 3], 9)
215 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
216 array([[9, 9, 9],
217 [9, 9, 9]], dtype=int32)>
219 `tf.fill` evaluates at graph runtime and supports dynamic shapes based on
220 other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
221 embeds the value as a `Const` node.
223 Args:
224 dims: A 1-D sequence of non-negative numbers. Represents the shape of the
225 output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
226 value: A value to fill the returned `tf.Tensor`.
227 name: Optional string. The name of the output `tf.Tensor`.
229 Returns:
230 A `tf.Tensor` with shape `dims` and the same dtype as `value`.
232 Raises:
233 InvalidArgumentError: `dims` contains negative entries.
234 NotFoundError: `dims` contains non-integer entries.
236 @compatibility(numpy)
237 Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
238 number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
239 specifying a 1-D shaped result, while TensorFlow does not support this syntax.
240 @end_compatibility
241 """
242 result = gen_array_ops.fill(dims, value, name=name)
243 shape_util.maybe_set_static_shape(result, dims)
244 return result
247@tf_export("identity")
248@dispatch.add_dispatch_support
249def identity(input, name=None): # pylint: disable=redefined-builtin
250 r"""Return a Tensor with the same shape and contents as input.
252 The return value is not the same Tensor as the original, but contains the same
253 values. This operation is fast when used on the same device.
255 For example:
257 >>> a = tf.constant([0.78])
258 >>> a_identity = tf.identity(a)
259 >>> a.numpy()
260 array([0.78], dtype=float32)
261 >>> a_identity.numpy()
262 array([0.78], dtype=float32)
264 Calling `tf.identity` on a variable will make a Tensor that represents the
265 value of that variable at the time it is called. This is equivalent to calling
266 `<variable>.read_value()`.
268 >>> a = tf.Variable(5)
269 >>> a_identity = tf.identity(a)
270 >>> a.assign_add(1)
271 <tf.Variable ... shape=() dtype=int32, numpy=6>
272 >>> a.numpy()
273 6
274 >>> a_identity.numpy()
275 5
277 This function can also be used to explicitly transfer tensors between devices.
278 For example, to transfer a tensor in GPU memory back to host memory, one can
279 use:
281 >>> with tf.device("/gpu:0"):
282 ... x_on_gpu = tf.constant(1)
283 >>> with tf.device("/cpu:0"):
284 ... x_on_cpu = tf.identity(x_on_gpu)
285 >>> x_on_cpu.device
286 '/job:localhost/replica:0/task:0/device:CPU:0'
288 Args:
289 input: A `Tensor`, a `Variable`, a `CompositeTensor` or anything that can be
290 converted to a tensor using `tf.convert_to_tensor`.
291 name: A name for the operation (optional).
293 Returns:
294 A `Tensor` or CompositeTensor. Has the same type and contents as `input`.
295 """
296 # Don't expand ResourceVariables, so identity(variable) will return a Tensor.
297 if (isinstance(input, composite_tensor.CompositeTensor) and
298 not _pywrap_utils.IsResourceVariable(input)):
299 return nest.map_structure(identity, input, expand_composites=True)
300 if context.executing_eagerly() and not hasattr(input, "graph"):
301 # Make sure we get an input with handle data attached from resource
302 # variables. Variables have correct handle data when graph building.
303 input = ops.convert_to_tensor(input)
304 ret = gen_array_ops.identity(input, name=name)
305 # Propagate handle data for happier shape inference for resource variables.
306 if hasattr(input, "_handle_data"):
307 ret._handle_data = input._handle_data # pylint: disable=protected-access
308 return ret
311# pylint: disable=redefined-builtin,protected-access
312@tf_export(v1=["expand_dims"])
313@dispatch.add_dispatch_support
314@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
315def expand_dims(input, axis=None, name=None, dim=None):
316 """Returns a tensor with a length 1 axis inserted at index `axis`.
318 Given a tensor `input`, this operation inserts a dimension of length 1 at the
319 dimension index `axis` of `input`'s shape. The dimension index follows Python
320 indexing rules: It's zero-based, a negative index it is counted backward
321 from the end.
323 This operation is useful to:
325 * Add an outer "batch" dimension to a single element.
326 * Align axes for broadcasting.
327 * To add an inner vector length axis to a tensor of scalars.
329 For example:
331 If you have a single image of shape `[height, width, channels]`:
333 >>> image = tf.zeros([10,10,3])
335 You can add an outer `batch` axis by passing `axis=0`:
337 >>> tf.expand_dims(image, axis=0).shape.as_list()
338 [1, 10, 10, 3]
340 The new axis location matches Python `list.insert(axis, 1)`:
342 >>> tf.expand_dims(image, axis=1).shape.as_list()
343 [10, 1, 10, 3]
345 Following standard Python indexing rules, a negative `axis` counts from the
346 end so `axis=-1` adds an inner most dimension:
348 >>> tf.expand_dims(image, -1).shape.as_list()
349 [10, 10, 3, 1]
351 This operation requires that `axis` is a valid index for `input.shape`,
352 following Python indexing rules:
354 ```
355 -1-tf.rank(input) <= axis <= tf.rank(input)
356 ```
358 This operation is related to:
360 * `tf.squeeze`, which removes dimensions of size 1.
361 * `tf.reshape`, which provides more flexible reshaping capability.
362 * `tf.sparse.expand_dims`, which provides this functionality for
363 `tf.SparseTensor`
365 Args:
366 input: A `Tensor`.
367 axis: 0-D (scalar). Specifies the dimension index at which to expand the
368 shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`.
369 name: The name of the output `Tensor` (optional).
370 dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
372 Returns:
373 A `Tensor` with the same data as `input`, but its shape has an additional
374 dimension of size 1 added.
376 Raises:
377 ValueError: if either both or neither of `dim` and `axis` are specified.
378 """
379 axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
380 if axis is None:
381 raise ValueError("Must specify an axis argument to tf.expand_dims()")
382 return expand_dims_v2(input, axis, name)
385@tf_export("expand_dims", v1=[])
386@dispatch.add_dispatch_support
387def expand_dims_v2(input, axis, name=None):
388 """Returns a tensor with a length 1 axis inserted at index `axis`.
390 Given a tensor `input`, this operation inserts a dimension of length 1 at the
391 dimension index `axis` of `input`'s shape. The dimension index follows Python
392 indexing rules: It's zero-based, a negative index it is counted backward
393 from the end.
395 This operation is useful to:
397 * Add an outer "batch" dimension to a single element.
398 * Align axes for broadcasting.
399 * To add an inner vector length axis to a tensor of scalars.
401 For example:
403 If you have a single image of shape `[height, width, channels]`:
405 >>> image = tf.zeros([10,10,3])
407 You can add an outer `batch` axis by passing `axis=0`:
409 >>> tf.expand_dims(image, axis=0).shape.as_list()
410 [1, 10, 10, 3]
412 The new axis location matches Python `list.insert(axis, 1)`:
414 >>> tf.expand_dims(image, axis=1).shape.as_list()
415 [10, 1, 10, 3]
417 Following standard Python indexing rules, a negative `axis` counts from the
418 end so `axis=-1` adds an inner most dimension:
420 >>> tf.expand_dims(image, -1).shape.as_list()
421 [10, 10, 3, 1]
423 This operation requires that `axis` is a valid index for `input.shape`,
424 following Python indexing rules:
426 ```
427 -1-tf.rank(input) <= axis <= tf.rank(input)
428 ```
430 This operation is related to:
432 * `tf.squeeze`, which removes dimensions of size 1.
433 * `tf.reshape`, which provides more flexible reshaping capability.
434 * `tf.sparse.expand_dims`, which provides this functionality for
435 `tf.SparseTensor`
437 Args:
438 input: A `Tensor`.
439 axis: Integer specifying the dimension index at which to expand the
440 shape of `input`. Given an input of D dimensions, `axis` must be in range
441 `[-(D+1), D]` (inclusive).
442 name: Optional string. The name of the output `Tensor`.
444 Returns:
445 A tensor with the same data as `input`, with an additional dimension
446 inserted at the index specified by `axis`.
448 Raises:
449 TypeError: If `axis` is not specified.
450 InvalidArgumentError: If `axis` is out of range `[-(D+1), D]`.
451 """
452 return gen_array_ops.expand_dims(input, axis, name)
455# pylint: enable=redefined-builtin,protected-access
458# Aliases for some automatically-generated names.
459# pylint: disable=protected-access
460@deprecation.deprecated("2016-11-30",
461 "This op will be removed after the deprecation date. "
462 "Please switch to tf.setdiff1d().")
463def listdiff(x, y, out_idx=None, name=None):
464 return gen_array_ops.list_diff(x, y, out_idx, name)
467listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
469# pylint: enable=protected-access
472# pylint: disable=undefined-variable
473@deprecation.deprecated("2018-11-30",
474 "This op will be removed after the deprecation date. "
475 "Please switch to tf.sets.difference().")
476@tf_export(v1=["setdiff1d"])
477@dispatch.add_dispatch_support
478def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
479 """Computes the difference between two lists of numbers or strings.
481 Given a list x and a list y, this operation returns a list out that
482 represents all values that are in x but not in y. The returned list
483 out is sorted in the same order that the numbers appear in x
484 (duplicates are preserved). This operation also returns a list idx
485 that represents the position of each out element in x.
487 In other words:
489 ```python
490 out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]
491 ```
493 Example usage:
495 >>> x = [1, 2, 3, 4, 5, 6]
496 >>> y = [1, 3, 5]
497 >>> setdiff1d(x,y)
498 ListDiff(out=<tf.Tensor: id=2, shape=(3,), dtype=int32,
499 numpy=array([2, 4, 6], dtype=int32)>, idx=<tf.Tensor: id=3,
500 shape=(3,), dtype=int32, numpy=array([1, 3, 5], dtype=int32)>)
502 Args:
503 x: A Tensor. 1-D. Values to keep.
504 y: A Tensor. Must have the same type as x. 1-D. Values to remove.
505 out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
506 tf.int32.
507 name: A name for the operation (optional).
509 Returns:
510 A tuple of Tensor objects (out, idx).
511 out: A Tensor. Has the same type as x.
512 idx: A Tensor of type out_idx.
513 """
514 return gen_array_ops.list_diff(x, y, index_dtype, name)
517setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
520@tf_export("broadcast_dynamic_shape")
521@dispatch.add_dispatch_support
522def broadcast_dynamic_shape(shape_x, shape_y):
523 """Computes the shape of a broadcast given symbolic shapes.
525 When `shape_x` and `shape_y` are Tensors representing shapes (i.e. the result
526 of calling tf.shape on another Tensor) this computes a Tensor which is the
527 shape of the result of a broadcasting op applied in tensors of shapes
528 `shape_x` and `shape_y`.
530 This is useful when validating the result of a broadcasting operation when the
531 tensors do not have statically known shapes.
533 Example:
535 >>> shape_x = (1, 2, 3)
536 >>> shape_y = (5, 1, 3)
537 >>> tf.broadcast_dynamic_shape(shape_x, shape_y)
538 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([5, 2, 3], ...>
540 Args:
541 shape_x: A rank 1 integer `Tensor`, representing the shape of x.
542 shape_y: A rank 1 integer `Tensor`, representing the shape of y.
544 Returns:
545 A rank 1 integer `Tensor` representing the broadcasted shape.
547 Raises:
548 InvalidArgumentError: If the two shapes are incompatible for
549 broadcasting.
550 """
551 return gen_array_ops.broadcast_args(shape_x, shape_y)
554@tf_export("broadcast_static_shape")
555@dispatch.add_dispatch_support
556def broadcast_static_shape(shape_x, shape_y):
557 """Computes the shape of a broadcast given known shapes.
559 When `shape_x` and `shape_y` are fully known `TensorShape`s this computes a
560 `TensorShape` which is the shape of the result of a broadcasting op applied in
561 tensors of shapes `shape_x` and `shape_y`.
563 For example, if shape_x is `TensorShape([1, 2, 3])` and shape_y is
564 `TensorShape([5, 1, 3])`, the result is a TensorShape whose value is
565 `TensorShape([5, 2, 3])`.
567 This is useful when validating the result of a broadcasting operation when the
568 tensors have statically known shapes.
570 Example:
572 >>> shape_x = tf.TensorShape([1, 2, 3])
573 >>> shape_y = tf.TensorShape([5, 1 ,3])
574 >>> tf.broadcast_static_shape(shape_x, shape_y)
575 TensorShape([5, 2, 3])
577 Args:
578 shape_x: A `TensorShape`
579 shape_y: A `TensorShape`
581 Returns:
582 A `TensorShape` representing the broadcasted shape.
584 Raises:
585 ValueError: If the two shapes can not be broadcasted.
586 """
587 return common_shapes.broadcast_shape(shape_x, shape_y)
590@tf_export("shape", v1=[])
591@dispatch.add_dispatch_support
592def shape_v2(input, out_type=dtypes.int32, name=None):
593 # pylint: disable=redefined-builtin
594 """Returns a tensor containing the shape of the input tensor.
596 See also `tf.size`, `tf.rank`.
598 `tf.shape` returns a 1-D integer tensor representing the shape of `input`.
599 For a scalar input, the tensor returned has a shape of (0,) and its value is
600 the empty vector (i.e. []).
602 For example:
604 >>> tf.shape(1.)
605 <tf.Tensor: shape=(0,), dtype=int32, numpy=array([], dtype=int32)>
607 >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
608 >>> tf.shape(t)
609 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)>
611 Note: When using symbolic tensors, such as when using the Keras API,
612 tf.shape() will return the shape of the symbolic tensor.
614 >>> a = tf.keras.layers.Input((None, 10))
615 >>> tf.shape(a)
616 <... shape=(3,) dtype=int32...>
618 In these cases, using `tf.Tensor.shape` will return more informative results.
620 >>> a.shape
621 TensorShape([None, None, 10])
623 (The first `None` represents the as yet unknown batch size.)
625 `tf.shape` and `Tensor.shape` should be identical in eager mode. Within
626 `tf.function` or within a `compat.v1` context, not all dimensions may be
627 known until execution time. Hence, when defining custom layers and models
628 for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`.
630 Args:
631 input: A `Tensor` or `SparseTensor`.
632 out_type: (Optional) The specified output type of the operation (`int32` or
633 `int64`). Defaults to `tf.int32`.
634 name: A name for the operation (optional).
636 Returns:
637 A `Tensor` of type `out_type`.
638 """
639 return shape(input, name, out_type)
642@tf_export(v1=["shape"])
643@dispatch.add_dispatch_support
644def shape(input, name=None, out_type=dtypes.int32):
645 # pylint: disable=redefined-builtin
646 """Returns the shape of a tensor.
648 This operation returns a 1-D integer tensor representing the shape of `input`.
650 For example:
652 ```python
653 t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
654 tf.shape(t) # [2, 2, 3]
655 ```
657 Args:
658 input: A `Tensor` or `SparseTensor`.
659 name: A name for the operation (optional).
660 out_type: (Optional) The specified output type of the operation (`int32`
661 or `int64`). Defaults to `tf.int32`.
663 Returns:
664 A `Tensor` of type `out_type`.
665 """
666 return shape_internal(input, name, optimize=True, out_type=out_type)
669def shape_internal(input, name=None, optimize=True, out_type=None):
670 # pylint: disable=redefined-builtin
671 """Returns the shape of a tensor.
673 If `out_type` is not specified and the shape is fully known, then we look at
674 the dimension values to determine whether to return an int32 or int64 tensor.
675 If the shape is not fully known, we default to int32.
677 Args:
678 input: A `Tensor` or `SparseTensor`.
679 name: A name for the operation (optional).
680 optimize: if true, encode the shape as a constant when possible.
681 out_type: (Optional) The specified output type of the operation (`int32` or
682 `int64`). Defaults to tf.int32.
684 Returns:
685 A `Tensor` of type `out_type`.
687 """
688 with ops.name_scope(name, "Shape", [input]) as name:
689 if isinstance(
690 input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
691 if not out_type:
692 out_type = dtypes.int32
693 return gen_math_ops.cast(input.dense_shape, out_type)
694 else:
695 if not context.executing_eagerly():
696 input = ops.convert_to_tensor(input)
697 input_shape = input.get_shape()
698 if optimize and input_shape.is_fully_defined():
699 # For fully defined shapes, if the out_type is not specified, we pick
700 # int32 / int64 based on the actual values.
701 if not out_type:
702 return constant_op._tensor_shape_tensor_conversion_function( # pylint: disable=protected-access
703 input_shape)
704 return constant(input_shape.as_list(), out_type, name=name)
705 if not out_type:
706 out_type = dtypes.int32
707 return gen_array_ops.shape(input, name=name, out_type=out_type)
710@tf_export("shape_n")
711@dispatch.add_dispatch_support
712def shape_n(input, out_type=dtypes.int32, name=None):
713 # pylint: disable=redefined-builtin
714 """Returns shape of a list of tensors.
716 Given a list of tensors, `tf.shape_n` is much faster than applying `tf.shape`
717 to each tensor individually.
718 >>> a = tf.ones([1, 2])
719 >>> b = tf.ones([2, 3])
720 >>> c = tf.ones([3, 4])
721 >>> tf.shape_n([a, b, c])
722 [<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>,
723 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 3], dtype=int32)>,
724 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([3, 4], dtype=int32)>]
726 Args:
727 input: A list of at least 1 `Tensor` object with the same dtype.
728 out_type: The specified output type of the operation (`int32` or `int64`).
729 Defaults to `tf.int32`(optional).
730 name: A name for the operation (optional).
732 Returns:
733 A list of `Tensor` specifying the shape of each input tensor with type of
734 `out_type`.
735 """
737 return gen_array_ops.shape_n(input, out_type=out_type, name=name)
740@tf_export("size", v1=[])
741@dispatch.add_dispatch_support
742def size_v2(input, out_type=dtypes.int32, name=None):
743 # pylint: disable=redefined-builtin
744 """Returns the size of a tensor.
746 See also `tf.shape`.
748 Returns a 0-D `Tensor` representing the number of elements in `input`
749 of type `out_type`. Defaults to tf.int32.
751 For example:
753 >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
754 >>> tf.size(t)
755 <tf.Tensor: shape=(), dtype=int32, numpy=12>
757 Args:
758 input: A `Tensor` or `SparseTensor`.
759 name: A name for the operation (optional).
760 out_type: (Optional) The specified non-quantized numeric output type of the
761 operation. Defaults to `tf.int32`.
763 Returns:
764 A `Tensor` of type `out_type`. Defaults to `tf.int32`.
766 @compatibility(numpy)
767 Equivalent to np.size()
768 @end_compatibility
769 """
771 return size(input, name, out_type)
774@tf_export(v1=["size"])
775@dispatch.add_dispatch_support
776def size(input, name=None, out_type=dtypes.int32):
777 # pylint: disable=redefined-builtin
778 """Returns the size of a tensor.
780 Returns a 0-D `Tensor` representing the number of elements in `input`
781 of type `out_type`. Defaults to tf.int32.
783 For example:
785 ```python
786 t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
787 tf.size(t) # 12
788 ```
790 Args:
791 input: A `Tensor` or `SparseTensor`.
792 name: A name for the operation (optional).
793 out_type: (Optional) The specified non-quantized numeric output type of the
794 operation. Defaults to `tf.int32`.
796 Returns:
797 A `Tensor` of type `out_type`. Defaults to `tf.int32`.
799 @compatibility(numpy)
800 Equivalent to np.size()
801 @end_compatibility
802 """
803 return size_internal(input, name, optimize=True, out_type=out_type)
806def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
807 # pylint: disable=redefined-builtin,protected-access
808 """Returns the size of a tensor.
810 Args:
811 input: A `Tensor` or `SparseTensor`.
812 name: A name for the operation (optional).
813 optimize: if true, encode the size as a constant when possible.
814 out_type: (Optional) The specified non-quantized numeric output type of the
815 operation. Defaults to `tf.int32`.
817 Returns:
818 A `Tensor` of type `out_type`. Defaults to `tf.int32`.
819 """
820 if (context.executing_eagerly() and not hasattr(input, "graph") and
821 not isinstance(
822 input,
823 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))):
824 input = ops.convert_to_tensor(input)
825 np_out_type = out_type.as_numpy_dtype
826 num_elements = np.prod(input._shape_tuple(), dtype=np_out_type) # pylint: disable=protected-access
827 return ops.convert_to_tensor(num_elements, dtype=out_type)
828 with ops.name_scope(name, "Size", [input]) as name:
829 if isinstance(
830 input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
831 return gen_math_ops.prod(
832 gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
833 else:
834 input = ops.convert_to_tensor(input)
835 input_shape = input.get_shape()
836 if optimize:
837 if input_shape.is_fully_defined():
838 return constant(input_shape.num_elements(), out_type, name=name)
839 if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
840 return constant(0, out_type, name=name)
841 return gen_array_ops.size(input, name=name, out_type=out_type)
844@tf_export("rank")
845@dispatch.add_dispatch_support
846def rank(input, name=None):
847 # pylint: disable=redefined-builtin
848 """Returns the rank of a tensor.
850 See also `tf.shape`.
852 Returns a 0-D `int32` `Tensor` representing the rank of `input`.
854 For example:
856 ```python
857 # shape of tensor 't' is [2, 2, 3]
858 t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
859 tf.rank(t) # 3
860 ```
862 **Note**: The rank of a tensor is not the same as the rank of a matrix. The
863 rank of a tensor is the number of indices required to uniquely select each
864 element of the tensor. Rank is also known as "order", "degree", or "ndims."
866 Args:
867 input: A `Tensor` or `SparseTensor`.
868 name: A name for the operation (optional).
870 Returns:
871 A `Tensor` of type `int32`.
873 @compatibility(numpy)
874 Equivalent to np.ndim
875 @end_compatibility
876 """
877 return rank_internal(input, name, optimize=True)
880def rank_internal(input, name=None, optimize=True):
881 # pylint: disable=redefined-builtin
882 """Returns the rank of a tensor.
884 Args:
885 input: A `Tensor` or `SparseTensor`.
886 name: A name for the operation (optional).
887 optimize: if true, encode the rank as a constant when possible.
889 Returns:
890 A `Tensor` of type `int32`.
891 """
892 with ops.name_scope(name, "Rank", [input]) as name:
893 if isinstance(
894 input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
895 return gen_array_ops.size(input.dense_shape, name=name)
896 else:
897 input = ops.convert_to_tensor(input)
898 input_shape = input.get_shape()
899 if optimize and input_shape.ndims is not None:
900 return constant(input_shape.ndims, dtypes.int32, name=name)
901 return gen_array_ops.rank(input, name=name)
904_SLICE_TYPE_ERROR = (
905 "Only integers, slices (`:`), ellipsis (`...`), "
906 "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
907 "indices")
909_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref,
910 dtypes.int64, dtypes.int64_ref)
913def _check_index(idx):
914 """Check if a given value is a valid index into a tensor."""
915 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
916 return
918 # Optimistic check. Assumptions:
919 # * any object with a dtype is supported
920 # * any object with a dtype has a sizeable shape attribute.
921 dtype = getattr(idx, "dtype", None)
922 if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
923 idx.shape and len(idx.shape) == 1):
924 # TODO(slebedev): IndexError seems more appropriate here, but it
925 # will break `_slice_helper` contract.
926 raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
929def _is_undefined_dimension(d):
930 return isinstance(d, tensor_shape.Dimension) and d.value is None
933@tf_export("__operators__.getitem", v1=[])
934@dispatch.add_dispatch_support
935def _slice_helper(tensor, slice_spec, var=None):
936 """Overload for Tensor.__getitem__.
938 This operation extracts the specified region from the tensor.
939 The notation is similar to NumPy with the restriction that
940 currently only support basic indexing. That means that
941 using a non-scalar tensor as input is not currently allowed.
943 Some useful examples:
945 ```python
946 # Strip leading and trailing 2 elements
947 foo = tf.constant([1,2,3,4,5,6])
948 print(foo[2:-2]) # => [3,4]
950 # Skip every other row and reverse the order of the columns
951 foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
952 print(foo[::2,::-1]) # => [[3,2,1], [9,8,7]]
954 # Use scalar tensors as indices on both dimensions
955 print(foo[tf.constant(0), tf.constant(2)]) # => 3
957 # Insert another dimension
958 foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
959 print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
960 print(foo[:, tf.newaxis, :]) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
961 print(foo[:, :, tf.newaxis]) # => [[[1],[2],[3]], [[4],[5],[6]],
962 [[7],[8],[9]]]
964 # Ellipses (3 equivalent operations)
965 foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
966 print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
967 print(foo[tf.newaxis, ...]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
968 print(foo[tf.newaxis]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
970 # Masks
971 foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
972 print(foo[foo > 2]) # => [3, 4, 5, 6, 7, 8, 9]
973 ```
975 Notes:
976 - `tf.newaxis` is `None` as in NumPy.
977 - An implicit ellipsis is placed at the end of the `slice_spec`
978 - NumPy advanced indexing is currently not supported.
980 Purpose in the API:
982 This method is exposed in TensorFlow's API so that library developers
983 can register dispatching for `Tensor.__getitem__` to allow it to handle
984 custom composite tensors & other custom objects.
986 The API symbol is not intended to be called by users directly and does
987 appear in TensorFlow's generated documentation.
989 Args:
990 tensor: An ops.Tensor object.
991 slice_spec: The arguments to Tensor.__getitem__.
992 var: In the case of variable slice assignment, the Variable object to slice
993 (i.e. tensor is the read-only view of this variable).
995 Returns:
996 The appropriate slice of "tensor", based on "slice_spec".
998 Raises:
999 ValueError: If a slice range is negative size.
1000 TypeError: If the slice indices aren't int, slice, ellipsis,
1001 tf.newaxis or scalar int32/int64 tensors.
1002 """
1003 tensor = ops.convert_to_tensor(tensor)
1004 # TODO(wangpeng): Consider supporting var
1005 if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access
1006 return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access
1008 if isinstance(slice_spec, bool) or \
1009 (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
1010 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
1011 return boolean_mask(tensor=tensor, mask=slice_spec)
1013 if not isinstance(slice_spec, (list, tuple)):
1014 slice_spec = [slice_spec]
1016 begin, end, strides = [], [], []
1017 index = 0
1019 new_axis_mask, shrink_axis_mask = 0, 0
1020 begin_mask, end_mask = 0, 0
1021 ellipsis_mask = 0
1022 for s in slice_spec:
1023 if isinstance(s, _BaseSlice):
1024 # Finds the best dtype for begin, end, and strides.
1025 dtype = None
1026 for t in [s.start, s.stop, s.step]:
1027 if t is None or not isinstance(t, ops.Tensor):
1028 continue
1029 if t.dtype == dtypes.int64:
1030 dtype = dtypes.int64
1031 elif t.dtype == dtypes.int32 and dtype != dtypes.int64:
1032 dtype = dtypes.int32
1033 elif t.dtype == dtypes.int16 and dtype is None:
1034 dtype = dtypes.int16
1036 if s.start is not None and not _is_undefined_dimension(s.start):
1037 _check_index(s.start)
1038 begin.append(s.start)
1039 else:
1040 if dtype is not None:
1041 begin.append(constant_op.constant(0, dtype=dtype))
1042 else:
1043 begin.append(0)
1044 begin_mask |= (1 << index)
1045 if s.stop is not None and not _is_undefined_dimension(s.stop):
1046 _check_index(s.stop)
1047 end.append(s.stop)
1048 else:
1049 if dtype is not None:
1050 end.append(constant_op.constant(0, dtype=dtype))
1051 else:
1052 end.append(0)
1053 end_mask |= (1 << index)
1054 if s.step is not None and not _is_undefined_dimension(s.step):
1055 _check_index(s.step)
1056 strides.append(s.step)
1057 else:
1058 if dtype is not None:
1059 strides.append(constant_op.constant(1, dtype=dtype))
1060 else:
1061 strides.append(1)
1062 elif s is Ellipsis:
1063 begin.append(0)
1064 end.append(0)
1065 strides.append(1)
1066 ellipsis_mask |= (1 << index)
1067 elif s is newaxis:
1068 begin.append(0)
1069 end.append(0)
1070 strides.append(1)
1071 new_axis_mask |= (1 << index)
1072 else:
1073 _check_index(s)
1074 begin.append(s)
1075 end.append(s + 1)
1076 # TODO(mdan): Investigate why we can't set int32 here.
1077 if isinstance(s, ops.Tensor) and (s.dtype == dtypes.int16 or
1078 s.dtype == dtypes.int64):
1079 strides.append(constant_op.constant(1, dtype=s.dtype))
1080 else:
1081 strides.append(1)
1082 shrink_axis_mask |= (1 << index)
1083 index += 1
1085 # stack possibly involves no tensors, so we must use op_scope correct graph.
1086 with ops.name_scope(
1087 None,
1088 "strided_slice", [tensor] + begin + end + strides,
1089 skip_on_eager=False) as name:
1090 if begin:
1091 packed_begin, packed_end, packed_strides = (
1092 array_ops_stack.stack(begin),
1093 array_ops_stack.stack(end),
1094 array_ops_stack.stack(strides))
1095 # TODO(mdan): Instead of implicitly casting, it's better to enforce the
1096 # same dtypes.
1097 if (packed_begin.dtype == dtypes.int64 or
1098 packed_end.dtype == dtypes.int64 or
1099 packed_strides.dtype == dtypes.int64):
1100 if packed_begin.dtype != dtypes.int64:
1101 packed_begin = gen_math_ops.cast(packed_begin, dtypes.int64)
1102 if packed_end.dtype != dtypes.int64:
1103 packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
1104 if packed_strides.dtype != dtypes.int64:
1105 packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
1106 elif (packed_begin.dtype == dtypes.int16 and
1107 packed_end.dtype == dtypes.int16 and
1108 packed_strides.dtype == dtypes.int16):
1109 if packed_begin.dtype != dtypes.int16:
1110 packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16)
1111 if packed_end.dtype != dtypes.int16:
1112 packed_end = gen_math_ops.cast(packed_end, dtypes.int16)
1113 if packed_strides.dtype != dtypes.int16:
1114 packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16)
1115 else:
1116 var_empty = constant([], dtype=dtypes.int32)
1117 packed_begin = packed_end = packed_strides = var_empty
1118 return strided_slice(
1119 tensor,
1120 packed_begin,
1121 packed_end,
1122 packed_strides,
1123 begin_mask=begin_mask,
1124 end_mask=end_mask,
1125 shrink_axis_mask=shrink_axis_mask,
1126 new_axis_mask=new_axis_mask,
1127 ellipsis_mask=ellipsis_mask,
1128 var=var,
1129 name=name)
1132# pylint: disable=undefined-variable,protected-access,redefined-outer-name
1133@tf_export("slice")
1134@dispatch.add_dispatch_support
1135def slice(input_, begin, size, name=None):
1136 # pylint: disable=redefined-builtin
1137 """Extracts a slice from a tensor.
1139 See also `tf.strided_slice`.
1141 This operation extracts a slice of size `size` from a tensor `input_` starting
1142 at the location specified by `begin`. The slice `size` is represented as a
1143 tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
1144 of `input_` that you want to slice. The starting location (`begin`) for the
1145 slice is represented as an offset in each dimension of `input_`. In other
1146 words, `begin[i]` is the offset into the i'th dimension of `input_` that you
1147 want to slice from.
1149 Note that `tf.Tensor.__getitem__` is typically a more pythonic way to
1150 perform slices, as it allows you to write `foo[3:7, :-2]` instead of
1151 `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`.
1153 `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
1154 all remaining elements in dimension i are included in the
1155 slice. In other words, this is equivalent to setting:
1157 `size[i] = input_.dim_size(i) - begin[i]`
1159 This operation requires that:
1161 `0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]`
1163 For example:
1165 ```python
1166 t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1167 [[3, 3, 3], [4, 4, 4]],
1168 [[5, 5, 5], [6, 6, 6]]])
1169 tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]]
1170 tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
1171 # [4, 4, 4]]]
1172 tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
1173 # [[5, 5, 5]]]
1174 ```
1176 Args:
1177 input_: A `Tensor`.
1178 begin: An `int32` or `int64` `Tensor`.
1179 size: An `int32` or `int64` `Tensor`.
1180 name: A name for the operation (optional).
1182 Returns:
1183 A `Tensor` the same type as `input_`.
1184 """
1185 return gen_array_ops._slice(input_, begin, size, name=name)
1188# pylint: disable=invalid-name
1189@tf_export("strided_slice")
1190@dispatch.add_dispatch_support
1191def strided_slice(input_,
1192 begin,
1193 end,
1194 strides=None,
1195 begin_mask=0,
1196 end_mask=0,
1197 ellipsis_mask=0,
1198 new_axis_mask=0,
1199 shrink_axis_mask=0,
1200 var=None,
1201 name=None):
1202 """Extracts a strided slice of a tensor (generalized Python array indexing).
1204 See also `tf.slice`.
1206 **Instead of calling this op directly most users will want to use the
1207 NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which
1208 is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.**
1209 The interface of this op is a low-level encoding of the slicing syntax.
1211 Roughly speaking, this op extracts a slice of size `(end-begin)/stride`
1212 from the given `input_` tensor. Starting at the location specified by `begin`
1213 the slice continues by adding `stride` to the index until all dimensions are
1214 not less than `end`.
1215 Note that a stride can be negative, which causes a reverse slice.
1217 Given a Python slice `input[spec0, spec1, ..., specn]`,
1218 this function will be called as follows.
1220 `begin`, `end`, and `strides` will be vectors of length n.
1221 n in general is not equal to the rank of the `input_` tensor.
1223 In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`,
1224 `new_axis_mask`, `shrink_axis_mask`) the ith bit will correspond to
1225 the ith spec.
1227 If the ith bit of `begin_mask` is set, `begin[i]` is ignored and
1228 the fullest possible range in that dimension is used instead.
1229 `end_mask` works analogously, except with the end range.
1231 `foo[5:,:,:3]` on a 7x8x9 tensor is equivalent to `foo[5:7,0:8,0:3]`.
1232 `foo[::-1]` reverses a tensor with shape 8.
1234 If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions
1235 as needed will be inserted between other dimensions. Only one
1236 non-zero bit is allowed in `ellipsis_mask`.
1238 For example `foo[3:5,...,4:5]` on a shape 10x3x3x10 tensor is
1239 equivalent to `foo[3:5,:,:,4:5]` and
1240 `foo[3:5,...]` is equivalent to `foo[3:5,:,:,:]`.
1242 If the ith bit of `new_axis_mask` is set, then `begin`,
1243 `end`, and `stride` are ignored and a new length 1 dimension is
1244 added at this point in the output tensor.
1246 For example,
1247 `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
1249 If the ith bit of `shrink_axis_mask` is set, it implies that the ith
1250 specification shrinks the dimensionality by 1, taking on the value at index
1251 `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in
1252 Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask`
1253 equal to 2.
1256 NOTE: `begin` and `end` are zero-indexed.
1257 `strides` entries must be non-zero.
1260 ```python
1261 t = tf.constant([[[1, 1, 1], [2, 2, 2]],
1262 [[3, 3, 3], [4, 4, 4]],
1263 [[5, 5, 5], [6, 6, 6]]])
1264 tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1]) # [[[3, 3, 3]]]
1265 tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1]) # [[[3, 3, 3],
1266 # [4, 4, 4]]]
1267 tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1]) # [[[4, 4, 4],
1268 # [3, 3, 3]]]
1269 ```
1271 Args:
1272 input_: A `Tensor`.
1273 begin: An `int32` or `int64` `Tensor`.
1274 end: An `int32` or `int64` `Tensor`.
1275 strides: An `int32` or `int64` `Tensor`.
1276 begin_mask: An `int32` mask.
1277 end_mask: An `int32` mask.
1278 ellipsis_mask: An `int32` mask.
1279 new_axis_mask: An `int32` mask.
1280 shrink_axis_mask: An `int32` mask.
1281 var: The variable corresponding to `input_` or None
1282 name: A name for the operation (optional).
1284 Returns:
1285 A `Tensor` the same type as `input`.
1286 """
1288 if strides is None:
1289 strides = ones_like(begin)
1291 op = gen_array_ops.strided_slice(
1292 input=input_,
1293 begin=begin,
1294 end=end,
1295 strides=strides,
1296 name=name,
1297 begin_mask=begin_mask,
1298 end_mask=end_mask,
1299 ellipsis_mask=ellipsis_mask,
1300 new_axis_mask=new_axis_mask,
1301 shrink_axis_mask=shrink_axis_mask)
1303 parent_name = name
1305 if var is not None:
1306 def assign(val, name=None):
1307 """Closure that holds all the arguments to create an assignment."""
1309 if name is None:
1310 name = parent_name + "_assign"
1312 return var._strided_slice_assign(
1313 begin=begin,
1314 end=end,
1315 strides=strides,
1316 value=val,
1317 name=name,
1318 begin_mask=begin_mask,
1319 end_mask=end_mask,
1320 ellipsis_mask=ellipsis_mask,
1321 new_axis_mask=new_axis_mask,
1322 shrink_axis_mask=shrink_axis_mask)
1324 op.assign = assign
1326 return op
1329def _SliceHelperVar(var, slice_spec):
1330 """Creates a slice helper object given a variable.
1332 This allows creating a sub-tensor from part of the current contents
1333 of a variable. See `tf.Tensor.__getitem__` for detailed examples
1334 of slicing.
1336 This function in addition also allows assignment to a sliced range.
1337 This is similar to `__setitem__` functionality in Python. However,
1338 the syntax is different so that the user can capture the assignment
1339 operation for grouping or passing to `sess.run()` in TF1.
1340 For example,
1342 ```python
1343 import tensorflow as tf
1344 A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
1345 print(A[:2, :2]) # => [[1,2], [4,5]]
1347 A[:2,:2].assign(22. * tf.ones((2, 2))))
1348 print(A) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
1349 ```
1351 Note that assignments currently do not support NumPy broadcasting
1352 semantics.
1354 Args:
1355 var: An `ops.Variable` object.
1356 slice_spec: The arguments to `Tensor.__getitem__`.
1358 Returns:
1359 The appropriate slice of "tensor", based on "slice_spec".
1360 As an operator. The operator also has a `assign()` method
1361 that can be used to generate an assignment operator.
1363 Raises:
1364 ValueError: If a slice range is negative size.
1365 TypeError: TypeError: If the slice indices aren't int, slice,
1366 ellipsis, tf.newaxis or int32/int64 tensors.
1368 """
1370 return _slice_helper(var.value(), slice_spec, var)
1373ops.Tensor._override_operator("__getitem__", _slice_helper)
1376@tf_export("parallel_stack")
1377@dispatch.add_dispatch_support
1378def parallel_stack(values, name="parallel_stack"):
1379 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
1381 Requires that the shape of inputs be known at graph construction time.
1383 Packs the list of tensors in `values` into a tensor with rank one higher than
1384 each tensor in `values`, by packing them along the first dimension.
1385 Given a list of length `N` of tensors of shape `(A, B, C)`; the `output`
1386 tensor will have the shape `(N, A, B, C)`.
1388 For example:
1390 ```python
1391 x = tf.constant([1, 4])
1392 y = tf.constant([2, 5])
1393 z = tf.constant([3, 6])
1394 tf.parallel_stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]]
1395 ```
1397 The difference between `stack` and `parallel_stack` is that `stack` requires
1398 all the inputs be computed before the operation will begin but doesn't require
1399 that the input shapes be known during graph construction.
1401 `parallel_stack` will copy pieces of the input into the output as they become
1402 available, in some situations this can provide a performance benefit.
1404 Unlike `stack`, `parallel_stack` does NOT support backpropagation.
1406 This is the opposite of unstack. The numpy equivalent is
1408 tf.parallel_stack([x, y, z]) = np.asarray([x, y, z])
1410 @compatibility(eager)
1411 parallel_stack is not compatible with eager execution.
1412 @end_compatibility
1414 Args:
1415 values: A list of `Tensor` objects with the same shape and type.
1416 name: A name for this operation (optional).
1418 Returns:
1419 output: A stacked `Tensor` with the same type as `values`.
1421 Raises:
1422 RuntimeError: if executed in eager mode.
1423 """
1424 if context.executing_eagerly():
1425 raise RuntimeError("tf.parallel_stack() is not compatible with "
1426 "eager execution.")
1427 with ops.name_scope(name):
1428 value_t = ops.convert_to_tensor(values[0])
1429 value_shape = ops.convert_to_tensor(value_t).get_shape()
1431 output_shape = tensor_shape.TensorShape([len(values)])
1432 output_shape = output_shape.concatenate(value_shape)
1433 # expand_dims converts concat to stack.
1434 return gen_array_ops.parallel_concat(
1435 [expand_dims(value, 0) for value in values], shape=output_shape)
1438# pylint: disable=invalid-name
1439def _autopacking_helper(list_or_tuple, dtype, name):
1440 """Converts the given list or tuple to a tensor by packing.
1442 Args:
1443 list_or_tuple: A (possibly nested) list or tuple containing a tensor.
1444 dtype: The element type of the returned tensor.
1445 name: A name for the returned tensor.
1447 Returns:
1448 A `tf.Tensor` with value equivalent to `list_or_tuple`.
1449 """
1450 if context.executing_eagerly():
1451 # NOTE: Fast path when all the items are tensors, this doesn't do any type
1452 # checking.
1453 if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
1454 return gen_array_ops.pack(list_or_tuple, name=name)
1455 must_pack = False
1456 converted_elems = []
1457 with ops.name_scope(name) as scope:
1458 for i, elem in enumerate(list_or_tuple):
1459 if isinstance(elem, core.Tensor):
1460 if dtype is not None and elem.dtype.base_dtype != dtype:
1461 raise TypeError(f"Cannot convert a list containing a tensor of dtype "
1462 f"{elem.dtype} to {dtype} (Tensor is: {elem!r})")
1463 converted_elems.append(elem)
1464 must_pack = True
1465 elif isinstance(elem, (list, tuple)):
1466 converted_elem = _autopacking_helper(elem, dtype, str(i))
1467 if isinstance(converted_elem, core.Tensor):
1468 must_pack = True
1469 converted_elems.append(converted_elem)
1470 else:
1471 converted_elems.append(elem)
1472 if must_pack:
1473 elems_as_tensors = []
1474 for i, elem in enumerate(converted_elems):
1475 if isinstance(elem, core.Tensor):
1476 elems_as_tensors.append(elem)
1477 else:
1478 # NOTE(mrry): This is inefficient, but it enables us to
1479 # handle the case where the list arguments are other
1480 # convertible-to-tensor types, such as numpy arrays.
1481 elems_as_tensors.append(
1482 constant_op.constant(elem, dtype=dtype, name=str(i)))
1483 return gen_array_ops.pack(elems_as_tensors, name=scope)
1484 else:
1485 return converted_elems
1488def _get_dtype_from_nested_lists(list_or_tuple):
1489 """Returns the dtype of any tensor-like object in `list_or_tuple`, if found.
1491 Args:
1492 list_or_tuple: A list or tuple representing an object that can be converted
1493 to a `tf.Tensor`.
1495 Returns:
1496 The dtype of any tensor-like object in `list_or_tuple`, or `None` if no
1497 such object exists.
1498 """
1499 for elem in list_or_tuple:
1500 if isinstance(elem, core.Tensor):
1501 return elem.dtype.base_dtype
1502 elif isinstance(elem, (list, tuple)):
1503 maybe_dtype = _get_dtype_from_nested_lists(elem)
1504 if maybe_dtype is not None:
1505 return maybe_dtype
1506 return None
1509def _cast_nested_seqs_to_dtype(dtype):
1511 def _maybe_cast(elem):
1512 if isinstance(elem, core.Tensor):
1513 if dtype != elem.dtype.base_dtype:
1514 elem = gen_math_ops.cast(elem, dtype)
1515 return elem
1517 return _maybe_cast
1520_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType)
1521_NON_AUTOPACKABLE_TYPES.add(np.ndarray)
1524def _should_not_autopack(v):
1525 # The condition we really want is
1526 # any(isinstance(elem, core.Tensor))
1527 # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
1528 # pylint: disable=unidiomatic-typecheck
1529 # TODO(slebedev): add nest.all?
1530 return all(type(elem) in _NON_AUTOPACKABLE_TYPES for elem in nest.flatten(v))
1531 # pylint: enable=unidiomatic-typecheck
1534def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
1535 """Tensor conversion function that automatically packs arguments."""
1536 if as_ref or _should_not_autopack(v):
1537 return NotImplemented
1538 inferred_dtype = _get_dtype_from_nested_lists(v)
1539 if inferred_dtype is None:
1540 # We did not find any tensor-like objects in the nested lists, so defer to
1541 # other conversion functions.
1542 return NotImplemented
1543 if dtype is None:
1544 dtype = inferred_dtype
1545 elif dtype != inferred_dtype:
1546 v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
1547 return _autopacking_helper(v, dtype, name or "packed")
1550# pylint: enable=invalid-name
1552# NOTE: Register this conversion function to run *before* one that
1553# assumes every element is a value.
1554tensor_conversion_registry.register_tensor_conversion_function(
1555 (list, tuple), _autopacking_conversion_function, 99)
1558@tf_export("concat")
1559@dispatch.add_dispatch_support
1560def concat(values, axis, name="concat"):
1561 """Concatenates tensors along one dimension.
1563 See also `tf.tile`, `tf.stack`, `tf.repeat`.
1565 Concatenates the list of tensors `values` along dimension `axis`. If
1566 `values[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, the concatenated
1567 result has shape
1569 [D0, D1, ... Raxis, ...Dn]
1571 where
1573 Raxis = sum(Daxis(i))
1575 That is, the data from the input tensors is joined along the `axis`
1576 dimension.
1578 The number of dimensions of the input tensors must match, and all dimensions
1579 except `axis` must be equal.
1581 For example:
1583 >>> t1 = [[1, 2, 3], [4, 5, 6]]
1584 >>> t2 = [[7, 8, 9], [10, 11, 12]]
1585 >>> tf.concat([t1, t2], 0)
1586 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
1587 array([[ 1, 2, 3],
1588 [ 4, 5, 6],
1589 [ 7, 8, 9],
1590 [10, 11, 12]], dtype=int32)>
1592 >>> tf.concat([t1, t2], 1)
1593 <tf.Tensor: shape=(2, 6), dtype=int32, numpy=
1594 array([[ 1, 2, 3, 7, 8, 9],
1595 [ 4, 5, 6, 10, 11, 12]], dtype=int32)>
1597 As in Python, the `axis` could also be negative numbers. Negative `axis`
1598 are interpreted as counting from the end of the rank, i.e.,
1599 `axis + rank(values)`-th dimension.
1601 For example:
1603 >>> t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
1604 >>> t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
1605 >>> tf.concat([t1, t2], -1)
1606 <tf.Tensor: shape=(2, 2, 4), dtype=int32, numpy=
1607 array([[[ 1, 2, 7, 4],
1608 [ 2, 3, 8, 4]],
1609 [[ 4, 4, 2, 10],
1610 [ 5, 3, 15, 11]]], dtype=int32)>
1612 Note: If you are concatenating along a new axis consider using stack.
1613 E.g.
1615 ```python
1616 tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)
1617 ```
1619 can be rewritten as
1621 ```python
1622 tf.stack(tensors, axis=axis)
1623 ```
1625 Args:
1626 values: A list of `Tensor` objects or a single `Tensor`.
1627 axis: 0-D `int32` `Tensor`. Dimension along which to concatenate. Must be
1628 in the range `[-rank(values), rank(values))`. As in Python, indexing for
1629 axis is 0-based. Positive axis in the rage of `[0, rank(values))` refers
1630 to `axis`-th dimension. And negative axis refers to `axis +
1631 rank(values)`-th dimension.
1632 name: A name for the operation (optional).
1634 Returns:
1635 A `Tensor` resulting from concatenation of the input tensors.
1636 """
1637 if not isinstance(values, (list, tuple)):
1638 values = [values]
1639 # TODO(mrry): Change to return values?
1640 if len(values) == 1: # Degenerate case of one tensor.
1641 # Make a throwaway call to convert_to_tensor to make sure
1642 # that axis is of the correct type, and make sure that
1643 # the returned tensor is a scalar.
1644 # TODO(keveman): Implement a standalone type and shape checker.
1645 with ops.name_scope(name) as scope:
1646 ops.convert_to_tensor(
1647 axis, name="concat_dim",
1648 dtype=dtypes.int32).get_shape().assert_has_rank(0)
1649 return identity(values[0], name=name)
1650 return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
1653@tf_export(v1=["boolean_mask"])
1654@dispatch.add_dispatch_support
1655def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
1656 """Apply boolean mask to tensor.
1658 Numpy equivalent is `tensor[mask]`.
1660 In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1661 the first K dimensions of `tensor`'s shape. We then have:
1662 `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1663 where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1664 The `axis` could be used with `mask` to indicate the axis to mask from.
1665 In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1666 the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1668 See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1669 ragged tensors, and can be used if you need to preserve the masked dimensions
1670 of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1672 Examples:
1674 ```python
1675 # 1-D example
1676 tensor = [0, 1, 2, 3]
1677 mask = np.array([True, False, True, False])
1678 tf.boolean_mask(tensor, mask) # [0, 2]
1680 # 2-D example
1681 tensor = [[1, 2], [3, 4], [5, 6]]
1682 mask = np.array([True, False, True])
1683 tf.boolean_mask(tensor, mask) # [[1, 2], [5, 6]]
1684 ```
1686 Args:
1687 tensor: N-D Tensor.
1688 mask: K-D boolean Tensor, K <= N and K must be known statically.
1689 name: A name for this operation (optional).
1690 axis: A 0-D int Tensor representing the axis in `tensor` to mask from. By
1691 default, axis is 0 which will mask from the first dimension. Otherwise K +
1692 axis <= N.
1694 Returns:
1695 (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1696 to `True` values in `mask`.
1698 Raises:
1699 ValueError: If shapes do not conform.
1700 """
1702 def _apply_mask_1d(reshaped_tensor, mask, axis=None):
1703 """Mask tensor along dimension 0 with a 1-D mask."""
1704 indices = squeeze(where_v2(mask), axis=[1])
1705 return gather(reshaped_tensor, indices, axis=axis)
1707 with ops.name_scope(name, values=[tensor, mask]):
1708 tensor = ops.convert_to_tensor(tensor, name="tensor")
1709 mask = ops.convert_to_tensor(mask, name="mask")
1711 shape_mask = mask.get_shape()
1712 ndims_mask = shape_mask.ndims
1713 shape_tensor = tensor.get_shape()
1714 if ndims_mask == 0:
1715 raise ValueError("mask cannot be scalar.")
1716 if ndims_mask is None:
1717 raise ValueError(
1718 "Number of mask dimensions must be specified, even if some dimensions"
1719 " are None. E.g. shape=[None] is ok, but shape=None is not.")
1720 axis = 0 if axis is None else axis
1721 axis_value = tensor_util.constant_value(axis)
1722 if axis_value is not None:
1723 axis = axis_value
1724 shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
1726 leading_size = gen_math_ops.prod(shape(tensor)[axis:axis + ndims_mask], [0])
1727 tensor = reshape(
1728 tensor,
1729 concat([
1730 shape(tensor)[:axis], [leading_size],
1731 shape(tensor)[axis + ndims_mask:]
1732 ], 0))
1733 # TODO(yongtang): tf.reshape in C++ kernel might have set the shape
1734 # correctly, so the following may not be needed? It still might be possible
1735 # that there are some edge case where tensor_util.constant_value resolves
1736 # more cases than ShapeInference of tf.reshape in C++ kernel.
1737 if axis_value is not None:
1738 first_dim = shape_tensor[axis:axis + ndims_mask].num_elements()
1739 tensor.set_shape(
1740 tensor_shape.as_shape(shape_tensor[:axis]).concatenate(
1741 [first_dim]).concatenate(shape_tensor[axis + ndims_mask:]))
1743 mask = reshape(mask, [-1])
1744 return _apply_mask_1d(tensor, mask, axis)
1747@tf_export("boolean_mask", v1=[])
1748@dispatch.add_dispatch_support
1749def boolean_mask_v2(tensor, mask, axis=None, name="boolean_mask"):
1750 """Apply boolean mask to tensor.
1752 Numpy equivalent is `tensor[mask]`.
1754 In general, `0 < dim(mask) = K <= dim(tensor)`, and `mask`'s shape must match
1755 the first K dimensions of `tensor`'s shape. We then have:
1756 `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
1757 where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
1758 The `axis` could be used with `mask` to indicate the axis to mask from.
1759 In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
1760 the first `axis + dim(mask)` dimensions of `tensor`'s shape.
1762 See also: `tf.ragged.boolean_mask`, which can be applied to both dense and
1763 ragged tensors, and can be used if you need to preserve the masked dimensions
1764 of `tensor` (rather than flattening them, as `tf.boolean_mask` does).
1766 Examples:
1768 >>> tensor = [0, 1, 2, 3] # 1-D example
1769 >>> mask = np.array([True, False, True, False])
1770 >>> tf.boolean_mask(tensor, mask)
1771 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2], dtype=int32)>
1773 >>> tensor = [[1, 2], [3, 4], [5, 6]] # 2-D example
1774 >>> mask = np.array([True, False, True])
1775 >>> tf.boolean_mask(tensor, mask)
1776 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
1777 array([[1, 2],
1778 [5, 6]], dtype=int32)>
1780 Args:
1781 tensor: N-D Tensor.
1782 mask: K-D boolean Tensor, K <= N and K must be known statically.
1783 axis: A 0-D int Tensor representing the axis in `tensor` to mask from. By
1784 default, axis is 0 which will mask from the first dimension. Otherwise K +
1785 axis <= N.
1786 name: A name for this operation (optional).
1788 Returns:
1789 (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
1790 to `True` values in `mask`.
1792 Raises:
1793 ValueError: If shapes do not conform.
1795 Examples:
1797 ```python
1798 # 2-D example
1799 tensor = [[1, 2], [3, 4], [5, 6]]
1800 mask = np.array([True, False, True])
1801 boolean_mask(tensor, mask) # [[1, 2], [5, 6]]
1802 ```
1803 """
1804 return boolean_mask(tensor, mask, name, axis)
1807@tf_export("sparse.mask", v1=["sparse.mask", "sparse_mask"])
1808@deprecation.deprecated_endpoints("sparse_mask")
1809def sparse_mask(a, mask_indices, name=None):
1810 """Masks elements of `IndexedSlices`.
1812 Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
1813 contains a subset of the slices of `a`. Only the slices at indices not
1814 specified in `mask_indices` are returned.
1816 This is useful when you need to extract a subset of slices in an
1817 `IndexedSlices` object.
1819 For example:
1821 ```python
1822 # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
1823 # with shape [1000, 10]
1824 a.indices # [12, 26, 37, 45]
1825 tf.shape(a.values) # [4, 10]
1827 # `b` will be the subset of `a` slices at its second and third indices, so
1828 # we want to mask its first and last indices (which are at absolute
1829 # indices 12, 45)
1830 b = tf.sparse.mask(a, [12, 45])
1832 b.indices # [26, 37]
1833 tf.shape(b.values) # [2, 10]
1834 ```
1836 Args:
1837 a: An `IndexedSlices` instance.
1838 mask_indices: Indices of elements to mask.
1839 name: A name for the operation (optional).
1841 Returns:
1842 The masked `IndexedSlices` instance.
1843 """
1844 with ops.name_scope(name, "sparse_mask", [a, mask_indices]) as name:
1845 indices = a.indices
1846 out_indices, to_gather = gen_array_ops.list_diff(indices, mask_indices)
1847 out_values = gather(a.values, to_gather, name=name)
1848 return indexed_slices.IndexedSlices(out_values, out_indices, a.dense_shape)
1851@tf_export("unique")
1852@dispatch.add_dispatch_support
1853def unique(x, out_idx=dtypes.int32, name=None):
1854 """Finds unique elements in a 1-D tensor.
1856 See also `tf.unique_with_counts`.
1858 This operation returns a tensor `y` containing all the unique elements
1859 of `x` sorted in the same order that they occur in `x`. This operation
1860 also returns a tensor `idx` the same size as `x` that contains the index
1861 of each value of `x` in the unique output `y`. In other words:
1864 y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
1866 Example usage:
1868 >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
1869 >>> y, idx = unique(x)
1870 >>> y
1871 <tf.Tensor: id=5, shape=(5,), dtype=int32,
1872 numpy=array([1, 2, 4, 7, 8], dtype=int32)>
1873 >>> idx
1874 <tf.Tensor: id=6, shape=(9,), dtype=int32,
1875 numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
1877 Args:
1878 x: A Tensor. 1-D.
1879 out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
1880 tf.int32.
1881 name: A name for the operation (optional).
1883 Returns:
1884 A tuple of Tensor objects (y, idx).
1885 y: A Tensor. Has the same type as x.
1886 idx: A Tensor of type out_idx.
1888 """
1889 # TODO(yongtang): switch to v2 once API deprecation
1890 # period (3 weeks) pass.
1891 # TODO(yongtang): The documentation should also
1892 # be updated when switch to v2.
1893 return gen_array_ops.unique(x, out_idx, name)
1896unique.__doc__ = gen_array_ops.unique.__doc__
1899@tf_export("unique_with_counts")
1900@dispatch.add_dispatch_support
1901def unique_with_counts(x, out_idx=dtypes.int32, name=None):
1902 """Finds unique elements in a 1-D tensor.
1904 See also `tf.unique`.
1906 This operation returns a tensor `y` containing all the unique elements
1907 of `x` sorted in the same order that they occur in `x`. This operation
1908 also returns a tensor `idx` the same size as `x` that contains the index
1909 of each value of `x` in the unique output `y`. Finally, it returns a
1910 third tensor `count` that contains the count of each element of `y`
1911 in `x`. In other words:
1913 y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]
1915 Example usage:
1917 >>> x = tf.constant([1, 1, 2, 4, 4, 4, 7, 8, 8])
1918 >>> y, idx, count = unique_with_counts(x)
1919 >>> y
1920 <tf.Tensor: id=8, shape=(5,), dtype=int32,
1921 numpy=array([1, 2, 4, 7, 8], dtype=int32)>
1922 >>> idx
1923 <tf.Tensor: id=9, shape=(9,), dtype=int32,
1924 numpy=array([0, 0, 1, 2, 2, 2, 3, 4, 4], dtype=int32)>
1925 >>> count
1926 <tf.Tensor: id=10, shape=(5,), dtype=int32,
1927 numpy=array([2, 1, 3, 1, 2], dtype=int32)>
1929 Args:
1930 x: A Tensor. 1-D.
1931 out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
1932 tf.int32.
1933 name: A name for the operation (optional).
1935 Returns:
1936 A tuple of Tensor objects (y, idx, count).
1937 y: A Tensor. Has the same type as x.
1938 idx: A Tensor of type out_idx.
1939 count: A Tensor of type out_idx.
1941 """
1942 # TODO(yongtang): switch to v2 once API deprecation
1943 # period (3 weeks) pass.
1944 # TODO(yongtang): The documentation should also
1945 # be updated when switch to v2.
1946 return gen_array_ops.unique_with_counts(x, out_idx, name)
1949unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
1952@tf_export("split")
1953@dispatch.add_dispatch_support
1954def split(value, num_or_size_splits, axis=0, num=None, name="split"):
1955 """Splits a tensor `value` into a list of sub tensors.
1957 See also `tf.unstack`.
1959 If `num_or_size_splits` is an `int`, then it splits `value` along the
1960 dimension `axis` into `num_or_size_splits` smaller tensors. This requires that
1961 `value.shape[axis]` is divisible by `num_or_size_splits`.
1963 If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into
1964 `len(num_or_size_splits)` elements. The shape of the `i`-th
1965 element has the same size as the `value` except along dimension `axis` where
1966 the size is `num_or_size_splits[i]`.
1968 For example:
1970 >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
1971 >>>
1972 >>> # Split `x` into 3 tensors along dimension 1
1973 >>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
1974 >>> tf.shape(s0).numpy()
1975 array([ 5, 10], dtype=int32)
1976 >>>
1977 >>> # Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
1978 >>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
1979 >>> tf.shape(split0).numpy()
1980 array([5, 4], dtype=int32)
1981 >>> tf.shape(split1).numpy()
1982 array([ 5, 15], dtype=int32)
1983 >>> tf.shape(split2).numpy()
1984 array([ 5, 11], dtype=int32)
1986 Args:
1987 value: The `Tensor` to split.
1988 num_or_size_splits: Either an `int` indicating the number of splits
1989 along `axis` or a 1-D integer `Tensor` or Python list containing the sizes
1990 of each output tensor along `axis`. If an `int`, then it must evenly
1991 divide `value.shape[axis]`; otherwise the sum of sizes along the split
1992 axis must match that of the `value`.
1993 axis: An `int` or scalar `int32` `Tensor`. The dimension along which
1994 to split. Must be in the range `[-rank(value), rank(value))`. Defaults to
1995 0.
1996 num: Optional, an `int`, used to specify the number of outputs when it
1997 cannot be inferred from the shape of `size_splits`.
1998 name: A name for the operation (optional).
2000 Returns:
2001 if `num_or_size_splits` is an `int` returns a list of
2002 `num_or_size_splits` `Tensor` objects; if `num_or_size_splits` is a 1-D
2003 list or 1-D `Tensor` returns `num_or_size_splits.get_shape[0]`
2004 `Tensor` objects resulting from splitting `value`.
2006 Raises:
2007 ValueError: If `num` is unspecified and cannot be inferred.
2008 ValueError: If `num_or_size_splits` is a scalar `Tensor`.
2009 """
2010 if isinstance(num_or_size_splits,
2011 (numbers.Integral, tensor_shape.Dimension)):
2012 return gen_array_ops.split(
2013 axis=axis, num_split=num_or_size_splits, value=value, name=name)
2015 size_splits = ops.convert_to_tensor(num_or_size_splits)
2017 if size_splits._rank() == 0:
2018 raise ValueError(
2019 "Rank-0 tensors are not supported as the num_or_size_splits argument "
2020 "to split. Argument provided: %s" % (num_or_size_splits,))
2022 if num is None:
2023 size_splits_shape = size_splits._shape_tuple()
2024 if size_splits_shape:
2025 num = size_splits_shape[0]
2026 if num is None:
2027 raise ValueError(
2028 f"Cannot infer argument `num` from shape {num_or_size_splits}")
2030 return gen_array_ops.split_v(
2031 value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
2034@tf_export("transpose", v1=[])
2035@dispatch.add_dispatch_support
2036def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
2037 """Transposes `a`, where `a` is a Tensor.
2039 Permutes the dimensions according to the value of `perm`.
2041 The returned tensor's dimension `i` will correspond to the input dimension
2042 `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is the rank
2043 of the input tensor. Hence, by default, this operation performs a regular
2044 matrix transpose on 2-D input Tensors.
2046 If conjugate is `True` and `a.dtype` is either `complex64` or `complex128`
2047 then the values of `a` are conjugated and transposed.
2049 @compatibility(numpy)
2050 In `numpy` transposes are memory-efficient constant time operations as they
2051 simply return a new view of the same data with adjusted `strides`.
2053 TensorFlow does not support strides, so `transpose` returns a new tensor with
2054 the items permuted.
2055 @end_compatibility
2057 For example:
2059 >>> x = tf.constant([[1, 2, 3], [4, 5, 6]])
2060 >>> tf.transpose(x)
2061 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
2062 array([[1, 4],
2063 [2, 5],
2064 [3, 6]], dtype=int32)>
2066 Equivalently, you could call `tf.transpose(x, perm=[1, 0])`.
2068 If `x` is complex, setting conjugate=True gives the conjugate transpose:
2070 >>> x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2071 ... [4 + 4j, 5 + 5j, 6 + 6j]])
2072 >>> tf.transpose(x, conjugate=True)
2073 <tf.Tensor: shape=(3, 2), dtype=complex128, numpy=
2074 array([[1.-1.j, 4.-4.j],
2075 [2.-2.j, 5.-5.j],
2076 [3.-3.j, 6.-6.j]])>
2078 'perm' is more useful for n-dimensional tensors where n > 2:
2080 >>> x = tf.constant([[[ 1, 2, 3],
2081 ... [ 4, 5, 6]],
2082 ... [[ 7, 8, 9],
2083 ... [10, 11, 12]]])
2085 As above, simply calling `tf.transpose` will default to `perm=[2,1,0]`.
2087 To take the transpose of the matrices in dimension-0 (such as when you are
2088 transposing matrices where 0 is the batch dimension), you would set
2089 `perm=[0,2,1]`.
2091 >>> tf.transpose(x, perm=[0, 2, 1])
2092 <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
2093 array([[[ 1, 4],
2094 [ 2, 5],
2095 [ 3, 6]],
2096 [[ 7, 10],
2097 [ 8, 11],
2098 [ 9, 12]]], dtype=int32)>
2100 Note: This has a shorthand `linalg.matrix_transpose`):
2102 Args:
2103 a: A `Tensor`.
2104 perm: A permutation of the dimensions of `a`. This should be a vector.
2105 conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2106 to tf.math.conj(tf.transpose(input)).
2107 name: A name for the operation (optional).
2109 Returns:
2110 A transposed `Tensor`.
2111 """
2112 return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
2115@tf_export(v1=["transpose"])
2116@dispatch.add_dispatch_support
2117def transpose(a, perm=None, name="transpose", conjugate=False):
2118 """Transposes `a`.
2120 Permutes the dimensions according to `perm`.
2122 The returned tensor's dimension i will correspond to the input dimension
2123 `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
2124 the rank of the input tensor. Hence, by default, this operation performs a
2125 regular matrix transpose on 2-D input Tensors. If conjugate is True and
2126 `a.dtype` is either `complex64` or `complex128` then the values of `a`
2127 are conjugated and transposed.
2129 @compatibility(numpy)
2130 In `numpy` transposes are memory-efficient constant time operations as they
2131 simply return a new view of the same data with adjusted `strides`.
2133 TensorFlow does not support strides, so `transpose` returns a new tensor with
2134 the items permuted.
2135 @end_compatibility
2137 For example:
2139 ```python
2140 x = tf.constant([[1, 2, 3], [4, 5, 6]])
2141 tf.transpose(x) # [[1, 4]
2142 # [2, 5]
2143 # [3, 6]]
2145 # Equivalently
2146 tf.transpose(x, perm=[1, 0]) # [[1, 4]
2147 # [2, 5]
2148 # [3, 6]]
2150 # If x is complex, setting conjugate=True gives the conjugate transpose
2151 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2152 [4 + 4j, 5 + 5j, 6 + 6j]])
2153 tf.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
2154 # [2 - 2j, 5 - 5j],
2155 # [3 - 3j, 6 - 6j]]
2157 # 'perm' is more useful for n-dimensional tensors, for n > 2
2158 x = tf.constant([[[ 1, 2, 3],
2159 [ 4, 5, 6]],
2160 [[ 7, 8, 9],
2161 [10, 11, 12]]])
2163 # Take the transpose of the matrices in dimension-0
2164 # (this common operation has a shorthand `linalg.matrix_transpose`)
2165 tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
2166 # [2, 5],
2167 # [3, 6]],
2168 # [[7, 10],
2169 # [8, 11],
2170 # [9, 12]]]
2171 ```
2173 Args:
2174 a: A `Tensor`.
2175 perm: A permutation of the dimensions of `a`.
2176 name: A name for the operation (optional).
2177 conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2178 to tf.math.conj(tf.transpose(input)).
2180 Returns:
2181 A transposed `Tensor`.
2182 """
2183 with ops.name_scope(name, "transpose", [a]) as name:
2184 if not tensor_util.is_tf_type(a):
2185 a = ops.convert_to_tensor(a, name="a")
2187 if conjugate and a.dtype.is_complex:
2188 transpose_fn = gen_array_ops.conjugate_transpose
2189 else:
2190 transpose_fn = gen_array_ops.transpose
2192 if perm is not None:
2193 return transpose_fn(a, perm, name=name)
2195 rank = a.shape.rank
2196 if rank is None:
2197 perm = gen_math_ops._range(gen_array_ops.rank(a) - 1, -1, -1)
2198 else:
2199 perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
2200 return transpose_fn(a, perm, name=name)
2203# pylint: disable=invalid-name
2204@tf_export(
2205 "linalg.matrix_transpose",
2206 v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"])
2207@dispatch.add_dispatch_support
2208@deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose")
2209def matrix_transpose(a, name="matrix_transpose", conjugate=False):
2210 """Transposes last two dimensions of tensor `a`.
2212 For example:
2214 ```python
2215 x = tf.constant([[1, 2, 3], [4, 5, 6]])
2216 tf.linalg.matrix_transpose(x) # [[1, 4],
2217 # [2, 5],
2218 # [3, 6]]
2220 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
2221 [4 + 4j, 5 + 5j, 6 + 6j]])
2222 tf.linalg.matrix_transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
2223 # [2 - 2j, 5 - 5j],
2224 # [3 - 3j, 6 - 6j]]
2226 # Matrix with two batch dimensions.
2227 # x.shape is [1, 2, 3, 4]
2228 # tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
2229 ```
2231 Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
2232 This is done with minimal cost, and is preferable to using this function. E.g.
2234 ```python
2235 # Good! Transpose is taken at minimal additional cost.
2236 tf.matmul(matrix, b, transpose_b=True)
2238 # Inefficient!
2239 tf.matmul(matrix, tf.linalg.matrix_transpose(b))
2240 ```
2242 @compatibility(numpy)
2243 In `numpy` transposes are memory-efficient constant time operations as they
2244 simply return a new view of the same data with adjusted `strides`.
2246 TensorFlow does not support strides, `linalg.matrix_transpose` returns a new
2247 tensor with the items permuted.
2248 @end_compatibility
2250 Args:
2251 a: A `Tensor` with `rank >= 2`.
2252 name: A name for the operation (optional).
2253 conjugate: Optional bool. Setting it to `True` is mathematically equivalent
2254 to tf.math.conj(tf.linalg.matrix_transpose(input)).
2256 Returns:
2257 A transposed batch matrix `Tensor`.
2259 Raises:
2260 ValueError: If `a` is determined statically to have `rank < 2`.
2261 """
2262 with ops.name_scope(name, values=[a]):
2263 a = ops.convert_to_tensor(a, name="a")
2265 # If we know the number of dimensions (statically), we can do two things:
2266 # 1. Check that `a` is a (batch) matrix.
2267 # 2. Use a Python list for perm. This preserves static shape information
2268 # and avoids extra computations.
2269 a_shape = a.get_shape()
2270 ndims = a_shape.ndims
2271 if ndims is not None:
2272 if ndims < 2:
2273 raise ValueError("Argument `a` should be a (batch) matrix with rank "
2274 f">= 2. Received `a` = {a} with shape: {a_shape}")
2275 perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2]
2276 else:
2277 a_rank = rank(a)
2278 perm = concat(
2279 (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0)
2281 return transpose(a, perm=perm, conjugate=conjugate)
2284@tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"])
2285@dispatch.add_dispatch_support
2286@deprecation.deprecated_endpoints("matrix_diag")
2287def matrix_diag(diagonal,
2288 name="diag",
2289 k=0,
2290 num_rows=-1,
2291 num_cols=-1,
2292 padding_value=0,
2293 align="RIGHT_LEFT"):
2294 """Returns a batched diagonal tensor with given batched diagonal values.
2296 Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
2297 diagonals of a matrix, with everything else padded with `padding`. `num_rows`
2298 and `num_cols` specify the dimension of the innermost matrix of the output. If
2299 both are not specified, the op assumes the innermost matrix is square and
2300 infers its size from `k` and the innermost dimension of `diagonal`. If only
2301 one of them is specified, the op assumes the unspecified value is the smallest
2302 possible based on other criteria.
2304 Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor
2305 has rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only
2306 one diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has
2307 rank `r` with shape `[I, J, ..., L, num_rows, num_cols]`.
2309 The second innermost dimension of `diagonal` has double meaning. When `k` is
2310 scalar or `k[0] == k[1]`, `M` is part of the batch size [I, J, ..., M], and
2311 the output tensor is:
2313 ```
2314 output[i, j, ..., l, m, n]
2315 = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
2316 padding_value ; otherwise
2317 ```
2319 Otherwise, `M` is treated as the number of diagonals for the matrix in the
2320 same batch (`M = k[1]-k[0]+1`), and the output tensor is:
2322 ```
2323 output[i, j, ..., l, m, n]
2324 = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2325 padding_value ; otherwise
2326 ```
2327 where `d = n - m`, `diag_index = k[1] - d`, and
2328 `index_in_diag = n - max(d, 0) + offset`.
2330 `offset` is zero except when the alignment of the diagonal is to the right.
2331 ```
2332 offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2333 and `d >= 0`) or
2334 (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2335 and `d <= 0`)
2336 0 ; otherwise
2337 ```
2338 where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2340 For example:
2342 ```
2343 # The main diagonal.
2344 diagonal = np.array([[1, 2, 3, 4], # Input shape: (2, 4)
2345 [5, 6, 7, 8]])
2346 tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0], # Output shape: (2, 4, 4)
2347 [0, 2, 0, 0],
2348 [0, 0, 3, 0],
2349 [0, 0, 0, 4]],
2350 [[5, 0, 0, 0],
2351 [0, 6, 0, 0],
2352 [0, 0, 7, 0],
2353 [0, 0, 0, 8]]]
2355 # A superdiagonal (per batch).
2356 diagonal = np.array([[1, 2, 3], # Input shape: (2, 3)
2357 [4, 5, 6]])
2358 tf.matrix_diag(diagonal, k = 1)
2359 ==> [[[0, 1, 0, 0], # Output shape: (2, 4, 4)
2360 [0, 0, 2, 0],
2361 [0, 0, 0, 3],
2362 [0, 0, 0, 0]],
2363 [[0, 4, 0, 0],
2364 [0, 0, 5, 0],
2365 [0, 0, 0, 6],
2366 [0, 0, 0, 0]]]
2368 # A tridiagonal band (per batch).
2369 diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3)
2370 [1, 2, 3],
2371 [0, 4, 5]],
2372 [[2, 3, 0],
2373 [6, 7, 9],
2374 [0, 9, 1]]])
2375 tf.matrix_diag(diagonals, k = (-1, 1))
2376 ==> [[[1, 8, 0], # Output shape: (2, 3, 3)
2377 [4, 2, 9],
2378 [0, 5, 3]],
2379 [[6, 2, 0],
2380 [9, 7, 3],
2381 [0, 1, 9]]]
2383 # RIGHT_LEFT alignment.
2384 diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3)
2385 [1, 2, 3],
2386 [4, 5, 0]],
2387 [[0, 2, 3],
2388 [6, 7, 9],
2389 [9, 1, 0]]])
2390 tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
2391 ==> [[[1, 8, 0], # Output shape: (2, 3, 3)
2392 [4, 2, 9],
2393 [0, 5, 3]],
2394 [[6, 2, 0],
2395 [9, 7, 3],
2396 [0, 1, 9]]]
2398 # Rectangular matrix.
2399 diagonal = np.array([1, 2]) # Input shape: (2)
2400 tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
2401 ==> [[0, 0, 0, 0], # Output shape: (3, 4)
2402 [1, 0, 0, 0],
2403 [0, 2, 0, 0]]
2405 # Rectangular matrix with inferred num_cols and padding_value = 9.
2406 tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
2407 ==> [[9, 9], # Output shape: (3, 2)
2408 [1, 9],
2409 [9, 2]]
2410 ```
2412 Args:
2413 diagonal: A `Tensor` with `rank k >= 1`.
2414 name: A name for the operation (optional).
2415 k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2416 main diagonal, and negative value means subdiagonals. `k` can be a single
2417 integer (for a single diagonal) or a pair of integers specifying the low
2418 and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2419 num_rows: The number of rows of the output matrix. If it is not provided,
2420 the op assumes the output matrix is a square matrix and infers the matrix
2421 size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2422 num_cols: The number of columns of the output matrix. If it is not provided,
2423 the op assumes the output matrix is a square matrix and infers the matrix
2424 size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
2425 padding_value: The value to fill the area outside the specified diagonal
2426 band with. Default is 0.
2427 align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2428 `align` is a string specifying how superdiagonals and subdiagonals should
2429 be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2430 (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2431 aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2432 the left (right-pads the row). It is the packing format LAPACK uses.
2433 cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2435 Returns:
2436 A Tensor. Has the same type as `diagonal`.
2437 """
2438 # Special case to sidestep the tf.constant conversion error:
2439 # TypeError: Expected bool, got 0 of type 'int' instead.
2440 if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
2441 padding_value = bool(padding_value)
2443 return gen_array_ops.matrix_diag_v3(
2444 diagonal=diagonal,
2445 k=k,
2446 num_rows=num_rows,
2447 num_cols=num_cols,
2448 padding_value=padding_value,
2449 align=align,
2450 name=name)
2453@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
2454@dispatch.add_dispatch_support
2455@deprecation.deprecated_endpoints("matrix_diag_part")
2456def matrix_diag_part(
2457 input, # pylint:disable=redefined-builtin
2458 name="diag_part",
2459 k=0,
2460 padding_value=0,
2461 align="RIGHT_LEFT"):
2462 """Returns the batched diagonal part of a batched tensor.
2464 Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
2465 `input`.
2467 Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
2468 Let `max_diag_len` be the maximum length among all diagonals to be extracted,
2469 `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2470 Let `num_diags` be the number of diagonals to extract,
2471 `num_diags = k[1] - k[0] + 1`.
2473 If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
2474 `[I, J, ..., L, max_diag_len]` and values:
2476 ```
2477 diagonal[i, j, ..., l, n]
2478 = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2479 padding_value ; otherwise.
2480 ```
2481 where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
2483 Otherwise, the output tensor has rank `r` with dimensions
2484 `[I, J, ..., L, num_diags, max_diag_len]` with values:
2486 ```
2487 diagonal[i, j, ..., l, m, n]
2488 = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
2489 padding_value ; otherwise.
2490 ```
2491 where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
2493 `offset` is zero except when the alignment of the diagonal is to the right.
2494 ```
2495 offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2496 and `d >= 0`) or
2497 (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2498 and `d <= 0`)
2499 0 ; otherwise
2500 ```
2501 where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2503 The input must be at least a matrix.
2505 For example:
2507 ```
2508 input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4)
2509 [5, 6, 7, 8],
2510 [9, 8, 7, 6]],
2511 [[5, 4, 3, 2],
2512 [1, 2, 3, 4],
2513 [5, 6, 7, 8]]])
2515 # A main diagonal from each batch.
2516 tf.linalg.diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3)
2517 [5, 2, 7]]
2519 # A superdiagonal from each batch.
2520 tf.linalg.diag_part(input, k = 1)
2521 ==> [[2, 7, 6], # Output shape: (2, 3)
2522 [4, 3, 8]]
2524 # A band from each batch.
2525 tf.linalg.diag_part(input, k = (-1, 2))
2526 ==> [[[3, 8, 0], # Output shape: (2, 4, 3)
2527 [2, 7, 6],
2528 [1, 6, 7],
2529 [0, 5, 8]],
2530 [[3, 4, 0],
2531 [4, 3, 8],
2532 [5, 2, 7],
2533 [0, 1, 6]]]
2535 # RIGHT_LEFT alignment.
2536 tf.linalg.diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
2537 ==> [[[0, 3, 8], # Output shape: (2, 4, 3)
2538 [2, 7, 6],
2539 [1, 6, 7],
2540 [5, 8, 0]],
2541 [[0, 3, 4],
2542 [4, 3, 8],
2543 [5, 2, 7],
2544 [1, 6, 0]]]
2546 # max_diag_len can be shorter than the main diagonal.
2547 tf.linalg.diag_part(input, k = (-2, -1))
2548 ==> [[[5, 8],
2549 [0, 9]],
2550 [[1, 6],
2551 [0, 5]]]
2553 # padding_value = 9
2554 tf.linalg.diag_part(input, k = (1, 3), padding_value = 9)
2555 ==> [[[4, 9, 9], # Output shape: (2, 3, 3)
2556 [3, 8, 9],
2557 [2, 7, 6]],
2558 [[2, 9, 9],
2559 [3, 4, 9],
2560 [4, 3, 8]]]
2562 ```
2564 Args:
2565 input: A `Tensor` with `rank k >= 2`.
2566 name: A name for the operation (optional).
2567 k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2568 main diagonal, and negative value means subdiagonals. `k` can be a single
2569 integer (for a single diagonal) or a pair of integers specifying the low
2570 and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2571 padding_value: The value to fill the area outside the specified diagonal
2572 band with. Default is 0.
2573 align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2574 `align` is a string specifying how superdiagonals and subdiagonals should
2575 be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2576 (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2577 aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2578 the left (right-pads the row). It is the packing format LAPACK uses.
2579 cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2581 Returns:
2582 A Tensor containing diagonals of `input`. Has the same type as `input`.
2584 Raises:
2585 InvalidArgumentError: When `k` is out of bound or when `k[0]>k[1:]`.
2586 """
2587 # Special case to sidestep the tf.constant conversion error:
2588 # TypeError: Expected bool, got 0 of type 'int' instead.
2589 if hasattr(input, "dtype") and input.dtype == "bool":
2590 padding_value = bool(padding_value)
2592 return gen_array_ops.matrix_diag_part_v3(
2593 input=input, k=k, padding_value=padding_value, align=align, name=name)
2596@tf_export(
2597 "linalg.tensor_diag_part", v1=["linalg.tensor_diag_part", "diag_part"])
2598@dispatch.add_dispatch_support
2599@deprecation.deprecated_endpoints("diag_part")
2600def tensor_diag_part(
2601 input, # pylint:disable=redefined-builtin
2602 name=None):
2603 """Returns the diagonal part of the tensor.
2605 This operation returns a tensor with the `diagonal` part
2606 of the `input`. The `diagonal` part is computed as follows:
2608 Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
2609 tensor of rank `k` with dimensions `[D1,..., Dk]` where:
2611 `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
2613 For a rank 2 tensor, `linalg.diag_part` and `linalg.tensor_diag_part`
2614 produce the same result. For rank 3 and higher, linalg.diag_part extracts
2615 the diagonal of each inner-most matrix in the tensor. An example where
2616 they differ is given below.
2618 >>> x = [[[[1111,1112],[1121,1122]],
2619 ... [[1211,1212],[1221,1222]]],
2620 ... [[[2111, 2112], [2121, 2122]],
2621 ... [[2211, 2212], [2221, 2222]]]
2622 ... ]
2623 >>> tf.linalg.tensor_diag_part(x)
2624 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
2625 array([[1111, 1212],
2626 [2121, 2222]], dtype=int32)>
2627 >>> tf.linalg.diag_part(x).shape
2628 TensorShape([2, 2, 2])
2630 Args:
2631 input: A `Tensor` with rank `2k`.
2632 name: A name for the operation (optional).
2634 Returns:
2635 A Tensor containing diagonals of `input`. Has the same type as `input`, and
2636 rank `k`.
2637 """
2638 return gen_array_ops.diag_part(input=input, name=name)
2641@tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"])
2642@dispatch.add_dispatch_support
2643@deprecation.deprecated_endpoints("matrix_set_diag")
2644def matrix_set_diag(
2645 input, # pylint:disable=redefined-builtin
2646 diagonal,
2647 name="set_diag",
2648 k=0,
2649 align="RIGHT_LEFT"):
2650 """Returns a batched matrix tensor with new batched diagonal values.
2652 Given `input` and `diagonal`, this operation returns a tensor with the
2653 same shape and values as `input`, except for the specified diagonals of the
2654 innermost matrices. These will be overwritten by the values in `diagonal`.
2656 `input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
2657 `k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
2658 Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
2659 `num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
2660 `max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
2661 `max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
2663 The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
2664 If `k` is scalar or `k[0] == k[1]`:
2666 ```
2667 output[i, j, ..., l, m, n]
2668 = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
2669 input[i, j, ..., l, m, n] ; otherwise
2670 ```
2672 Otherwise,
2674 ```
2675 output[i, j, ..., l, m, n]
2676 = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
2677 input[i, j, ..., l, m, n] ; otherwise
2678 ```
2679 where `d = n - m`, `diag_index = k[1] - d`, and
2680 `index_in_diag = n - max(d, 0) + offset`.
2682 `offset` is zero except when the alignment of the diagonal is to the right.
2683 ```
2684 offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
2685 and `d >= 0`) or
2686 (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
2687 and `d <= 0`)
2688 0 ; otherwise
2689 ```
2690 where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
2692 For example:
2694 ```
2695 # The main diagonal.
2696 input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4)
2697 [7, 7, 7, 7],
2698 [7, 7, 7, 7]],
2699 [[7, 7, 7, 7],
2700 [7, 7, 7, 7],
2701 [7, 7, 7, 7]]])
2702 diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3)
2703 [4, 5, 6]])
2704 tf.matrix_set_diag(input, diagonal)
2705 ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
2706 [7, 2, 7, 7],
2707 [7, 7, 3, 7]],
2708 [[4, 7, 7, 7],
2709 [7, 5, 7, 7],
2710 [7, 7, 6, 7]]]
2712 # A superdiagonal (per batch).
2713 tf.matrix_set_diag(input, diagonal, k = 1)
2714 ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4)
2715 [7, 7, 2, 7],
2716 [7, 7, 7, 3]],
2717 [[7, 4, 7, 7],
2718 [7, 7, 5, 7],
2719 [7, 7, 7, 6]]]
2721 # A band of diagonals.
2722 diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3)
2723 [6, 5, 8],
2724 [1, 2, 3],
2725 [0, 4, 5]],
2726 [[1, 2, 0],
2727 [5, 6, 4],
2728 [6, 1, 2],
2729 [0, 3, 4]]])
2730 tf.matrix_set_diag(input, diagonals, k = (-1, 2))
2731 ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
2732 [4, 2, 5, 1],
2733 [7, 5, 3, 8]],
2734 [[6, 5, 1, 7],
2735 [3, 1, 6, 2],
2736 [7, 4, 2, 4]]]
2738 # RIGHT_LEFT alignment.
2739 diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3)
2740 [6, 5, 8],
2741 [1, 2, 3],
2742 [4, 5, 0]],
2743 [[0, 1, 2],
2744 [5, 6, 4],
2745 [6, 1, 2],
2746 [3, 4, 0]]])
2747 tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
2748 ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
2749 [4, 2, 5, 1],
2750 [7, 5, 3, 8]],
2751 [[6, 5, 1, 7],
2752 [3, 1, 6, 2],
2753 [7, 4, 2, 4]]]
2755 ```
2757 Args:
2758 input: A `Tensor` with rank `k + 1`, where `k >= 1`.
2759 diagonal: A `Tensor` with rank `k`, when `d_lower == d_upper`, or `k + 1`,
2760 otherwise. `k >= 1`.
2761 name: A name for the operation (optional).
2762 k: Diagonal offset(s). Positive value means superdiagonal, 0 refers to the
2763 main diagonal, and negative value means subdiagonals. `k` can be a single
2764 integer (for a single diagonal) or a pair of integers specifying the low
2765 and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
2766 align: Some diagonals are shorter than `max_diag_len` and need to be padded.
2767 `align` is a string specifying how superdiagonals and subdiagonals should
2768 be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
2769 (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
2770 aligns superdiagonals to the right (left-pads the row) and subdiagonals to
2771 the left (right-pads the row). It is the packing format LAPACK uses.
2772 cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
2773 """
2774 return gen_array_ops.matrix_set_diag_v3(
2775 input=input, diagonal=diagonal, k=k, align=align, name=name)
2778# pylint: enable=invalid-name
2781def _constant_if_small(value, shape, dtype, name):
2782 try:
2783 if np.prod(shape) < 1000:
2784 return constant(value, shape=shape, dtype=dtype, name=name)
2785 except (NotImplementedError, TypeError):
2786 # Happens when shape is a Tensor, list with Tensor elements, etc.
2787 pass
2788 return None
2791def _tag_zeros_tensor(fun):
2792 """ Tags the result of function by setting _is_zeros_tensor attribute.
2794 This is useful to compute Hessians of fused ops such as cross_entropy.
2795 """
2797 def wrapped(*args, **kwargs):
2798 tensor = fun(*args, **kwargs)
2799 tensor._is_zeros_tensor = True
2800 return tensor
2802 return tf_decorator.make_decorator(fun, wrapped)
2805@tf_export("zeros")
2806@dispatch.add_dispatch_support
2807@_tag_zeros_tensor
2808def zeros(shape, dtype=dtypes.float32, name=None):
2809 """Creates a tensor with all elements set to zero.
2811 See also `tf.zeros_like`, `tf.ones`, `tf.fill`, `tf.eye`.
2813 This operation returns a tensor of type `dtype` with shape `shape` and
2814 all elements set to zero.
2816 >>> tf.zeros([3, 4], tf.int32)
2817 <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
2818 array([[0, 0, 0, 0],
2819 [0, 0, 0, 0],
2820 [0, 0, 0, 0]], dtype=int32)>
2822 Args:
2823 shape: A `list` of integers, a `tuple` of integers, or
2824 a 1-D `Tensor` of type `int32`.
2825 dtype: The DType of an element in the resulting `Tensor`.
2826 name: Optional string. A name for the operation.
2828 Returns:
2829 A `Tensor` with all elements set to zero.
2830 """
2831 dtype = dtypes.as_dtype(dtype).base_dtype
2832 with ops.name_scope(name, "zeros", [shape]) as name:
2833 if dtype == dtypes.bool:
2834 zero = False
2835 elif dtype == dtypes.string:
2836 zero = ""
2837 elif dtype.is_quantized:
2838 zero = np.zeros([]).astype(dtype.as_numpy_dtype)
2839 else:
2840 zero = 0
2842 if not isinstance(shape, ops.Tensor):
2843 try:
2844 if not context.executing_eagerly():
2845 # Create a constant if it won't be very big. Otherwise, create a fill
2846 # op to prevent serialized GraphDefs from becoming too large.
2847 output = _constant_if_small(zero, shape, dtype, name)
2848 if output is not None:
2849 return output
2851 # Go through tensor shapes to get int64-if-needed semantics
2852 shape = constant_op._tensor_shape_tensor_conversion_function(
2853 tensor_shape.TensorShape(shape))
2854 except (TypeError, ValueError, errors.UnimplementedError):
2855 # Happens when shape is a list with tensor elements
2856 shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
2857 if not shape._shape_tuple():
2858 shape = reshape(shape, [-1]) # Ensure it's a vector
2859 output = fill(shape, constant(zero, dtype=dtype), name=name)
2860 assert output.dtype.base_dtype == dtype
2861 return output
2864@tf_export(v1=["zeros_like"])
2865@dispatch.register_unary_elementwise_api
2866@dispatch.add_dispatch_support
2867def zeros_like(tensor, dtype=None, name=None, optimize=True):
2868 """Creates a tensor with all elements set to zero.
2870 See also `tf.zeros`.
2872 Given a single tensor (`tensor`), this operation returns a tensor of the
2873 same type and shape as `tensor` with all elements set to zero. Optionally,
2874 you can use `dtype` to specify a new type for the returned tensor.
2876 Examples:
2878 >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
2879 >>> tf.zeros_like(tensor)
2880 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
2881 array([[0, 0, 0],
2882 [0, 0, 0]], dtype=int32)>
2884 >>> tf.zeros_like(tensor, dtype=tf.float32)
2885 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
2886 array([[0., 0., 0.],
2887 [0., 0., 0.]], dtype=float32)>
2889 Args:
2890 tensor: A `Tensor`.
2891 dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
2892 `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
2893 `complex64`, `complex128`, `bool` or `string`. (optional)
2894 name: A name for the operation (optional).
2895 optimize: if `True`, attempt to statically determine the shape of `tensor`
2896 and encode it as a constant. (optional, defaults to `True`)
2898 Returns:
2899 A `Tensor` with all elements set to zero.
2900 """
2901 return zeros_like_impl(tensor, dtype, name, optimize)
2904@tf_export("zeros_like", v1=[])
2905@dispatch.register_unary_elementwise_api
2906@dispatch.add_dispatch_support
2907def zeros_like_v2(
2908 input, # pylint: disable=redefined-builtin
2909 dtype=None,
2910 name=None):
2911 """Creates a tensor with all elements set to zero.
2913 See also `tf.zeros`.
2915 Given a single tensor or array-like object (`input`), this operation returns
2916 a tensor of the same type and shape as `input` with all elements set to zero.
2917 Optionally, you can use `dtype` to specify a new type for the returned tensor.
2919 Examples:
2921 >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
2922 >>> tf.zeros_like(tensor)
2923 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
2924 array([[0, 0, 0],
2925 [0, 0, 0]], dtype=int32)>
2927 >>> tf.zeros_like(tensor, dtype=tf.float32)
2928 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
2929 array([[0., 0., 0.],
2930 [0., 0., 0.]], dtype=float32)>
2932 >>> tf.zeros_like([[1, 2, 3], [4, 5, 6]])
2933 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
2934 array([[0, 0, 0],
2935 [0, 0, 0]], dtype=int32)>
2937 Args:
2938 input: A `Tensor` or array-like object.
2939 dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
2940 `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
2941 `complex64`, `complex128`, `bool` or `string` (optional).
2942 name: A name for the operation (optional).
2944 Returns:
2945 A `Tensor` with all elements set to zero.
2946 """
2947 return zeros_like_impl(input, dtype, name, optimize=True)
2950@_tag_zeros_tensor
2951def zeros_like_impl(tensor, dtype, name, optimize=True):
2952 """Internal implementation for the v1/v2 zeros_like API calls."""
2953 with ops.name_scope(name, "zeros_like", [tensor]) as name:
2954 if not tensor_util.is_tf_type(tensor):
2955 tensor = ops.convert_to_tensor(tensor, name="tensor")
2956 tensor_shape = tensor.shape
2957 tensor_dtype = tensor.dtype
2959 if context.executing_eagerly():
2960 if dtype is not None and dtype != tensor_dtype:
2961 return zeros(
2962 shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
2963 return gen_array_ops.zeros_like(tensor, name=name)
2965 # For now, variant types must be created via zeros_like; as we need to
2966 # pass the input variant object to the proper zeros callback.
2968 if (optimize and tensor_shape.is_fully_defined() and
2969 tensor_dtype != dtypes.variant):
2970 # We can produce a zeros tensor independent of the value of 'tensor',
2971 # since the shape is known statically.
2972 return zeros(tensor_shape, dtype=dtype or tensor_dtype, name=name)
2974 if dtype is not None and dtype != tensor_dtype and dtype != dtypes.variant:
2975 return zeros(
2976 shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
2977 else:
2978 return gen_array_ops.zeros_like(tensor, name=name)
2981@tf_export(v1=["ones_like"])
2982@dispatch.register_unary_elementwise_api
2983@dispatch.add_dispatch_support
2984def ones_like(tensor, dtype=None, name=None, optimize=True):
2985 """Creates a tensor with all elements set to 1.
2987 See also `tf.ones`.
2989 Given a single tensor (`tensor`), this operation returns a tensor of the same
2990 type and shape as `tensor` with all elements set to 1. Optionally, you can
2991 specify a new type (`dtype`) for the returned tensor.
2993 For example:
2995 ```python
2996 tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
2997 tf.ones_like(tensor) # [[1, 1, 1], [1, 1, 1]]
2998 ```
3000 Args:
3001 tensor: A `Tensor`.
3002 dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
3003 `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`,
3004 `complex128` or `bool`.
3005 name: A name for the operation (optional).
3006 optimize: if true, attempt to statically determine the shape of 'tensor' and
3007 encode it as a constant.
3009 Returns:
3010 A `Tensor` with all elements set to 1.
3011 """
3012 return ones_like_impl(tensor, dtype, name, optimize)
3015@tf_export("ones_like", v1=[])
3016@dispatch.register_unary_elementwise_api
3017@dispatch.add_dispatch_support
3018def ones_like_v2(
3019 input, # pylint: disable=redefined-builtin
3020 dtype=None,
3021 name=None):
3022 """Creates a tensor of all ones that has the same shape as the input.
3024 See also `tf.ones`.
3026 Given a single tensor (`tensor`), this operation returns a tensor of the
3027 same type and shape as `tensor` with all elements set to 1. Optionally,
3028 you can use `dtype` to specify a new type for the returned tensor.
3030 For example:
3032 >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
3033 >>> tf.ones_like(tensor)
3034 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3035 array([[1, 1, 1],
3036 [1, 1, 1]], dtype=int32)>
3038 Args:
3039 input: A `Tensor`.
3040 dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
3041 `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
3042 `complex64`, `complex128`, `bool` or `string`.
3043 name: A name for the operation (optional).
3045 Returns:
3046 A `Tensor` with all elements set to one.
3047 """
3048 return ones_like_impl(input, dtype, name, optimize=True)
3051def ones_like_impl(tensor, dtype, name, optimize=True):
3052 """Internal implementation for the v1/v2 ones_like API calls."""
3053 with ops.name_scope(name, "ones_like", [tensor]) as name:
3054 tensor = ops.convert_to_tensor(tensor, name="tensor")
3055 ones_shape = shape_internal(tensor, optimize=optimize)
3056 if dtype is None:
3057 dtype = tensor.dtype
3058 ret = ones(ones_shape, dtype=dtype, name=name)
3059 if not context.executing_eagerly():
3060 ret.set_shape(tensor.get_shape())
3061 return ret
3064@tf_export("ones")
3065@dispatch.add_dispatch_support
3066def ones(shape, dtype=dtypes.float32, name=None):
3067 """Creates a tensor with all elements set to one (1).
3069 See also `tf.ones_like`, `tf.zeros`, `tf.fill`, `tf.eye`.
3071 This operation returns a tensor of type `dtype` with shape `shape` and
3072 all elements set to one.
3074 >>> tf.ones([3, 4], tf.int32)
3075 <tf.Tensor: shape=(3, 4), dtype=int32, numpy=
3076 array([[1, 1, 1, 1],
3077 [1, 1, 1, 1],
3078 [1, 1, 1, 1]], dtype=int32)>
3080 Args:
3081 shape: A `list` of integers, a `tuple` of integers, or
3082 a 1-D `Tensor` of type `int32`.
3083 dtype: Optional DType of an element in the resulting `Tensor`. Default is
3084 `tf.float32`.
3085 name: Optional string. A name for the operation.
3087 Returns:
3088 A `Tensor` with all elements set to one (1).
3089 """
3090 dtype = dtypes.as_dtype(dtype).base_dtype
3091 with ops.name_scope(name, "ones", [shape]) as name:
3092 if dtype == dtypes.bool:
3093 one = True
3094 elif dtype.is_quantized:
3095 one = np.ones([]).astype(dtype.as_numpy_dtype)
3096 else:
3097 one = 1
3098 if not isinstance(shape, ops.Tensor):
3099 try:
3100 if not context.executing_eagerly():
3101 # Create a constant if it won't be very big. Otherwise, create a fill
3102 # op to prevent serialized GraphDefs from becoming too large.
3103 output = _constant_if_small(one, shape, dtype, name)
3104 if output is not None:
3105 return output
3107 # Go through tensor shapes to get int64-if-needed semantics
3108 shape = constant_op._tensor_shape_tensor_conversion_function(
3109 tensor_shape.TensorShape(shape))
3110 except (TypeError, ValueError):
3111 # Happens when shape is a list with tensor elements
3112 shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
3113 if not shape._shape_tuple():
3114 shape = reshape(shape, [-1]) # Ensure it's a vector
3115 output = fill(shape, constant(one, dtype=dtype), name=name)
3116 assert output.dtype.base_dtype == dtype
3117 return output
3120@tf_export(v1=["placeholder"])
3121def placeholder(dtype, shape=None, name=None):
3122 """Inserts a placeholder for a tensor that will be always fed.
3124 **Important**: This tensor will produce an error if evaluated. Its value must
3125 be fed using the `feed_dict` optional argument to `Session.run()`,
3126 `Tensor.eval()`, or `Operation.run()`.
3128 For example:
3130 ```python
3131 x = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))
3132 y = tf.matmul(x, x)
3134 with tf.compat.v1.Session() as sess:
3135 print(sess.run(y)) # ERROR: will fail because x was not fed.
3137 rand_array = np.random.rand(1024, 1024)
3138 print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
3139 ```
3141 Args:
3142 dtype: The type of elements in the tensor to be fed.
3143 shape: The shape of the tensor to be fed (optional). If the shape is not
3144 specified, you can feed a tensor of any shape.
3145 name: A name for the operation (optional).
3147 Returns:
3148 A `Tensor` that may be used as a handle for feeding a value, but not
3149 evaluated directly.
3151 Raises:
3152 RuntimeError: if eager execution is enabled
3154 @compatibility(TF2)
3155 This API is not compatible with eager execution and `tf.function`. To migrate
3156 to TF2, rewrite the code to be compatible with eager execution. Check the
3157 [migration
3158 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
3159 on replacing `Session.run` calls. In TF2, you can just pass tensors directly
3160 into ops and layers. If you want to explicitly set up your inputs, also see
3161 [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
3162 how to use `tf.keras.Input` to replace `tf.compat.v1.placeholder`.
3163 `tf.function` arguments also do the job of `tf.compat.v1.placeholder`.
3164 For more details please read [Better
3165 performance with tf.function](https://www.tensorflow.org/guide/function).
3166 @end_compatibility
3167 """
3168 if context.executing_eagerly():
3169 raise RuntimeError("tf.placeholder() is not compatible with "
3170 "eager execution.")
3172 return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
3175@tf_export(v1=["placeholder_with_default"])
3176def placeholder_with_default(input, shape, name=None): # pylint: disable=redefined-builtin
3177 """A placeholder op that passes through `input` when its output is not fed.
3179 @compatibility(TF2)
3180 This API is strongly discouraged for use with eager execution and
3181 `tf.function`. The primary use of this API is for testing computation wrapped
3182 within a `tf.function` where the input tensors might not have statically known
3183 fully-defined shapes. The same can be achieved by creating a
3184 [concrete function](
3185 https://www.tensorflow.org/guide/function#obtaining_concrete_functions)
3186 from the `tf.function` with a `tf.TensorSpec` input which has partially
3187 defined shapes. For example, the code
3189 >>> @tf.function
3190 ... def f():
3191 ... x = tf.compat.v1.placeholder_with_default(
3192 ... tf.constant([[1., 2., 3.], [4., 5., 6.]]), [None, 3])
3193 ... y = tf.constant([[1.],[2.], [3.]])
3194 ... z = tf.matmul(x, y)
3195 ... assert z.shape[0] == None
3196 ... assert z.shape[1] == 1
3198 >>> f()
3200 can easily be replaced by
3202 >>> @tf.function
3203 ... def f(x):
3204 ... y = tf.constant([[1.],[2.], [3.]])
3205 ... z = tf.matmul(x, y)
3206 ... assert z.shape[0] == None
3207 ... assert z.shape[1] == 1
3209 >>> g = f.get_concrete_function(tf.TensorSpec([None, 3]))
3211 You can learn more about `tf.function` at [Better
3212 performance with tf.function](https://www.tensorflow.org/guide/function).
3213 @end_compatibility
3215 Args:
3216 input: A `Tensor`. The default value to produce when output is not fed.
3217 shape: A `tf.TensorShape` or list of `int`s. The (possibly partial) shape of
3218 the tensor.
3219 name: A name for the operation (optional).
3221 Returns:
3222 A `Tensor`. Has the same type as `input`.
3223 """
3224 return gen_array_ops.placeholder_with_default(input, shape, name)
3227@tf_export(v1=["sparse.placeholder", "sparse_placeholder"])
3228@deprecation.deprecated_endpoints("sparse_placeholder")
3229def sparse_placeholder(dtype, shape=None, name=None):
3230 """Inserts a placeholder for a sparse tensor that will be always fed.
3232 **Important**: This sparse tensor will produce an error if evaluated.
3233 Its value must be fed using the `feed_dict` optional argument to
3234 `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
3236 For example:
3238 ```python
3239 x = tf.compat.v1.sparse.placeholder(tf.float32)
3240 y = tf.sparse.reduce_sum(x)
3242 with tf.compat.v1.Session() as sess:
3243 print(sess.run(y)) # ERROR: will fail because x was not fed.
3245 indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
3246 values = np.array([1.0, 2.0], dtype=np.float32)
3247 shape = np.array([7, 9, 2], dtype=np.int64)
3248 print(sess.run(y, feed_dict={
3249 x: tf.compat.v1.SparseTensorValue(indices, values, shape)})) # Will
3250 succeed.
3251 print(sess.run(y, feed_dict={
3252 x: (indices, values, shape)})) # Will succeed.
3254 sp = tf.sparse.SparseTensor(indices=indices, values=values,
3255 dense_shape=shape)
3256 sp_value = sp.eval(session=sess)
3257 print(sess.run(y, feed_dict={x: sp_value})) # Will succeed.
3258 ```
3261 Args:
3262 dtype: The type of `values` elements in the tensor to be fed.
3263 shape: The shape of the tensor to be fed (optional). If the shape is not
3264 specified, you can feed a sparse tensor of any shape.
3265 name: A name for prefixing the operations (optional).
3267 Returns:
3268 A `SparseTensor` that may be used as a handle for feeding a value, but not
3269 evaluated directly.
3271 Raises:
3272 RuntimeError: if eager execution is enabled
3274 @compatibility(TF2)
3275 This API is not compatible with eager execution and `tf.function`. To migrate
3276 to TF2, rewrite the code to be compatible with eager execution. Check the
3277 [migration
3278 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
3279 on replacing `Session.run` calls. In TF2, you can just pass tensors directly
3280 into ops and layers. If you want to explicitly set up your inputs, also see
3281 [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
3282 how to use `tf.keras.Input` to replace `tf.compat.v1.sparse_placeholder`.
3283 `tf.function` arguments also do the job of `tf.compat.v1.sparse_placeholder`.
3284 For more details please read [Better
3285 performance with tf.function](https://www.tensorflow.org/guide/function).
3286 @end_compatibility
3287 """
3288 if context.executing_eagerly():
3289 raise RuntimeError("`sparse_placeholder` is not compatible with "
3290 "eager execution.")
3292 shape_name = (name + "/shape") if name is not None else None
3293 default_shape_name = (name + "/shape_default") if name is not None else None
3294 if shape is None:
3295 rank = None
3296 dense_shape = placeholder(dtypes.int64, shape=[rank], name=shape_name)
3297 dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
3298 else:
3299 if isinstance(shape, ops.Tensor):
3300 rank = shape.get_shape()[0]
3301 dense_shape_default = tensor_util.constant_value_as_shape(shape)
3302 else:
3303 rank = len(shape)
3304 # determine the shape, to override the `.shape` property of the
3305 # `SparseTensor`
3306 dense_shape_default = tensor_shape.TensorShape(
3307 tuple(None if dim == -1 else dim for dim in shape))
3308 shape = tuple(tensor_shape.dimension_value(dim) for dim in shape)
3309 shape = tuple(-1 if dim is None else dim for dim in shape)
3310 shape = ops.convert_to_tensor(
3311 shape, dtype=dtypes.int64, name=default_shape_name)
3313 # `dense_shape` needs to be feedable (for users that treat this as an
3314 # actual placeholder). `constant_value_as_shape` sets constants to
3315 # not-feedable. placeholder_with_default works, but blocks `SparseTensor`
3316 # from reading the default value back out.
3317 dense_shape = placeholder_with_default(
3318 shape, shape=shape.shape, name=shape_name)
3320 result = sparse_tensor.SparseTensor(
3321 values=placeholder(
3322 dtype,
3323 shape=[None],
3324 name=(name + "/values") if name is not None else None),
3325 indices=placeholder(
3326 dtypes.int64,
3327 shape=[None, rank],
3328 name=(name + "/indices") if name is not None else None),
3329 dense_shape=dense_shape)
3331 # Now the SparseTensor.shape is a list of `None`s, since it couldn't read the
3332 # default shape out of the placeholder. Override that
3333 # shape to be the value determined here, so partial shapes can be
3334 # propagated.
3335 result.set_shape(dense_shape_default)
3336 return result
3338# pylint: enable=redefined-outer-name
3341@tf_export("pad", v1=[])
3342@dispatch.add_dispatch_support
3343def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None):
3344 """Pads a tensor.
3346 This operation pads a `tensor` according to the `paddings` you specify.
3347 `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3348 `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3349 many values to add before the contents of `tensor` in that dimension, and
3350 `paddings[D, 1]` indicates how many values to add after the contents of
3351 `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3352 and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3353 `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3354 no greater than `tensor.dim_size(D)`.
3356 The padded size of each dimension D of the output is:
3358 `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3360 For example:
3362 ```python
3363 t = tf.constant([[1, 2, 3], [4, 5, 6]])
3364 paddings = tf.constant([[1, 1,], [2, 2]])
3365 # 'constant_values' is 0.
3366 # rank of 't' is 2.
3367 tf.pad(t, paddings, "CONSTANT") # [[0, 0, 0, 0, 0, 0, 0],
3368 # [0, 0, 1, 2, 3, 0, 0],
3369 # [0, 0, 4, 5, 6, 0, 0],
3370 # [0, 0, 0, 0, 0, 0, 0]]
3372 tf.pad(t, paddings, "REFLECT") # [[6, 5, 4, 5, 6, 5, 4],
3373 # [3, 2, 1, 2, 3, 2, 1],
3374 # [6, 5, 4, 5, 6, 5, 4],
3375 # [3, 2, 1, 2, 3, 2, 1]]
3377 tf.pad(t, paddings, "SYMMETRIC") # [[2, 1, 1, 2, 3, 3, 2],
3378 # [2, 1, 1, 2, 3, 3, 2],
3379 # [5, 4, 4, 5, 6, 6, 5],
3380 # [5, 4, 4, 5, 6, 6, 5]]
3381 ```
3383 Args:
3384 tensor: A `Tensor`.
3385 paddings: A `Tensor` of type `int32`.
3386 mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3387 constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3388 same type as `tensor`.
3389 name: A name for the operation (optional).
3391 Returns:
3392 A `Tensor`. Has the same type as `tensor`.
3394 Raises:
3395 ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3396 """
3397 return pad(tensor, paddings, mode, name, constant_values)
3400@tf_export(v1=["pad"])
3401@dispatch.add_dispatch_support
3402def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name
3403 """Pads a tensor.
3405 This operation pads a `tensor` according to the `paddings` you specify.
3406 `paddings` is an integer tensor with shape `[n, 2]`, where n is the rank of
3407 `tensor`. For each dimension D of `input`, `paddings[D, 0]` indicates how
3408 many values to add before the contents of `tensor` in that dimension, and
3409 `paddings[D, 1]` indicates how many values to add after the contents of
3410 `tensor` in that dimension. If `mode` is "REFLECT" then both `paddings[D, 0]`
3411 and `paddings[D, 1]` must be no greater than `tensor.dim_size(D) - 1`. If
3412 `mode` is "SYMMETRIC" then both `paddings[D, 0]` and `paddings[D, 1]` must be
3413 no greater than `tensor.dim_size(D)`.
3415 The padded size of each dimension D of the output is:
3417 `paddings[D, 0] + tensor.dim_size(D) + paddings[D, 1]`
3419 For example:
3421 ```python
3422 t = tf.constant([[1, 2, 3], [4, 5, 6]])
3423 paddings = tf.constant([[1, 1,], [2, 2]])
3424 # 'constant_values' is 0.
3425 # rank of 't' is 2.
3426 tf.pad(t, paddings, "CONSTANT") # [[0, 0, 0, 0, 0, 0, 0],
3427 # [0, 0, 1, 2, 3, 0, 0],
3428 # [0, 0, 4, 5, 6, 0, 0],
3429 # [0, 0, 0, 0, 0, 0, 0]]
3431 tf.pad(t, paddings, "REFLECT") # [[6, 5, 4, 5, 6, 5, 4],
3432 # [3, 2, 1, 2, 3, 2, 1],
3433 # [6, 5, 4, 5, 6, 5, 4],
3434 # [3, 2, 1, 2, 3, 2, 1]]
3436 tf.pad(t, paddings, "SYMMETRIC") # [[2, 1, 1, 2, 3, 3, 2],
3437 # [2, 1, 1, 2, 3, 3, 2],
3438 # [5, 4, 4, 5, 6, 6, 5],
3439 # [5, 4, 4, 5, 6, 6, 5]]
3440 ```
3442 Args:
3443 tensor: A `Tensor`.
3444 paddings: A `Tensor` of type `int32`.
3445 mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
3446 name: A name for the operation (optional).
3447 constant_values: In "CONSTANT" mode, the scalar pad value to use. Must be
3448 same type as `tensor`.
3450 Returns:
3451 A `Tensor`. Has the same type as `tensor`.
3453 Raises:
3454 ValueError: When mode is not one of "CONSTANT", "REFLECT", or "SYMMETRIC".
3455 """
3457 # Convert lower/mixed case to upper for NumPy compatibility
3458 # NumPy uses all lower-case modes.
3459 mode = mode.upper()
3460 if mode == "CONSTANT":
3461 # TODO(rjryan): Once the forward compatibility period (3 weeks) have passed
3462 # remove the "Pad" fallback here.
3463 if (not tensor_util.is_tf_type(constant_values) and
3464 np.ndim(constant_values) == 0 and
3465 constant_values == np.zeros_like(constant_values)):
3466 result = gen_array_ops.pad(tensor, paddings, name=name)
3467 else:
3468 result = gen_array_ops.pad_v2(
3469 tensor, paddings, constant_values, name=name)
3470 elif mode == "REFLECT":
3471 result = gen_array_ops.mirror_pad(
3472 tensor, paddings, mode="REFLECT", name=name)
3473 elif mode == "SYMMETRIC":
3474 result = gen_array_ops.mirror_pad(
3475 tensor, paddings, mode="SYMMETRIC", name=name)
3476 else:
3477 raise ValueError("Value of argument `mode` expected to be "
3478 """one of "CONSTANT", "REFLECT", or "SYMMETRIC". """
3479 f"Received `mode` = {mode}")
3481 # Restore shape information where possible.
3482 if not context.executing_eagerly():
3483 paddings_constant = _get_paddings_constant(paddings)
3484 input_shape = (
3485 tensor_shape.TensorShape(tensor.shape)
3486 if isinstance(tensor, ops.Tensor) else result.op.inputs[0].shape)
3487 if (input_shape.ndims is not None and
3488 not result.shape.is_fully_defined() and paddings_constant is not None):
3489 new_shape = []
3490 for padding, dim in zip(paddings_constant, input_shape.as_list()):
3491 if padding is None or dim is None or any((x is None for x in padding)):
3492 new_shape.append(None)
3493 else:
3494 new_shape.append(sum(padding) + dim)
3495 result.set_shape(new_shape)
3497 return result
3500def _get_paddings_constant(paddings):
3501 """Helper to get the constant values of the paddings arg to pad().
3503 Used under V1 graph mode to facilitate computation of the shape of the output
3504 tensor of `pad()`.
3506 Args:
3507 paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
3508 a nested list or tuple of Tensor and/or numbers.
3510 Returns:
3511 A nested list or numbers or `None`, in which `None` indicates unknown
3512 padding size.
3513 """
3514 if isinstance(paddings, ops.Tensor):
3515 return tensor_util.constant_value(paddings, partial=True)
3516 elif isinstance(paddings, (list, tuple)):
3517 return [_get_paddings_constant(x) for x in paddings]
3518 else:
3519 return paddings
3522@tf_export("meshgrid")
3523@dispatch.add_dispatch_support
3524def meshgrid(*args, **kwargs):
3525 """Broadcasts parameters for evaluation on an N-D grid.
3527 Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
3528 of N-D coordinate arrays for evaluating expressions on an N-D grid.
3530 Notes:
3532 `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
3533 When the `indexing` argument is set to 'xy' (the default), the broadcasting
3534 instructions for the first two dimensions are swapped.
3536 Examples:
3538 Calling `X, Y = meshgrid(x, y)` with the tensors
3540 ```python
3541 x = [1, 2, 3]
3542 y = [4, 5, 6]
3543 X, Y = tf.meshgrid(x, y)
3544 # X = [[1, 2, 3],
3545 # [1, 2, 3],
3546 # [1, 2, 3]]
3547 # Y = [[4, 4, 4],
3548 # [5, 5, 5],
3549 # [6, 6, 6]]
3550 ```
3552 Args:
3553 *args: `Tensor`s with rank 1.
3554 **kwargs:
3555 - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
3556 - name: A name for the operation (optional).
3558 Returns:
3559 outputs: A list of N `Tensor`s with rank N.
3561 Raises:
3562 TypeError: When no keyword arguments (kwargs) are passed.
3563 ValueError: When indexing keyword argument is not one of `xy` or `ij`.
3564 """
3566 indexing = kwargs.pop("indexing", "xy")
3567 name = kwargs.pop("name", "meshgrid")
3568 if kwargs:
3569 key = list(kwargs.keys())[0]
3570 raise TypeError("'{}' is an invalid keyword argument "
3571 "for this function".format(key))
3573 if indexing not in ("xy", "ij"):
3574 raise ValueError("Argument `indexing` parameter must be either "
3575 f"'xy' or 'ij', got '{indexing}'")
3577 with ops.name_scope(name, "meshgrid", args) as name:
3578 ndim = len(args)
3579 s0 = (1,) * ndim
3581 if not ndim:
3582 return []
3584 # Prepare reshape by inserting dimensions with size 1 where needed
3585 output = []
3586 for i, x in enumerate(args):
3587 output.append(
3588 reshape(array_ops_stack.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
3589 # Create parameters for broadcasting each tensor to the full size
3590 shapes = [size(x) for x in args]
3592 output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype
3594 if indexing == "xy" and ndim > 1:
3595 output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2))
3596 output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
3597 shapes[0], shapes[1] = shapes[1], shapes[0]
3599 # TODO(nolivia): improve performance with a broadcast
3600 mult_fact = ones(shapes, output_dtype)
3601 return [x * mult_fact for x in output]
3604NEW_AXIS = -1
3605SHRINK_AXIS = -2
3608# PEP-8 naming
3609# pylint: disable=invalid-name,redefined-outer-name
3610def _compute_size_of_strided_dim(shrink, spec, size):
3611 """Computes the size of a single strided slice dimension."""
3613 unknown = None # Document what None means here.
3614 use_full_range = None # Document other use of None.
3615 # if this is a shrink axis (i.e. a non-range index)
3616 # it either will produce an error or return 1
3617 if shrink:
3618 return 1
3619 if size is unknown or size.value is unknown:
3620 return unknown
3621 size = size.value
3622 stride = spec.step
3623 if stride is not unknown:
3624 if stride == 0:
3625 return unknown
3626 stride = spec.step
3627 valid_range = [0, size] if stride > 0 else [-1, size - 1]
3629 # PEP-8 naming
3630 # pylint: disable=invalid-name
3631 def canonical(x, c):
3632 if x is use_full_range:
3633 return valid_range[c] if stride > 0 else valid_range[(c + 1) & 1]
3634 else:
3635 x_fwd = size + x if x < 0 else x # make negative indices positive
3636 return max(valid_range[0], min(valid_range[1], x_fwd))
3638 begin = canonical(spec.start, 0)
3639 end = canonical(spec.stop, 1)
3640 interval_length = end - begin
3641 if interval_length == 0 or ((interval_length < 0) != (stride < 0)):
3642 return 0
3643 else:
3644 remainder = 1 if interval_length % stride != 0 else 0
3645 return interval_length // stride + remainder
3646 else:
3647 return unknown # unknown because stride is unknown
3650def _TileGradShape(op):
3651 """Shape function for the TileGrad op."""
3652 multiples_shape = op.inputs[1].get_shape().with_rank(1)
3653 input_shape = op.inputs[0].get_shape().with_rank(multiples_shape[0])
3654 # NOTE(mrry): Represent `multiples` as a `TensorShape` because (i)
3655 # it is a vector of non-negative integers, and (ii) doing so allows
3656 # us to handle partially-known multiples.
3657 multiples = tensor_util.constant_value_as_shape(op.inputs[1]).with_rank(
3658 input_shape.ndims)
3659 if multiples.ndims is None:
3660 return [tensor_shape.unknown_shape()]
3661 else:
3662 output_dims = []
3663 for dim, multiple in zip(input_shape.dims, multiples.dims):
3664 output_dims.append(dim // multiple)
3665 return [tensor_shape.TensorShape(output_dims)]
3668@tf_export("edit_distance")
3669@dispatch.add_dispatch_support
3670def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
3671 """Computes the Levenshtein distance between sequences.
3673 This operation takes variable-length sequences (`hypothesis` and `truth`),
3674 each provided as a `SparseTensor`, and computes the Levenshtein distance.
3675 You can normalize the edit distance by length of `truth` by setting
3676 `normalize` to true.
3678 For example:
3680 Given the following input,
3681 * `hypothesis` is a `tf.SparseTensor` of shape `[2, 1, 1]`
3682 * `truth` is a `tf.SparseTensor` of shape `[2, 2, 2]`
3684 >>> hypothesis = tf.SparseTensor(
3685 ... [[0, 0, 0],
3686 ... [1, 0, 0]],
3687 ... ["a", "b"],
3688 ... (2, 1, 1))
3689 >>> truth = tf.SparseTensor(
3690 ... [[0, 1, 0],
3691 ... [1, 0, 0],
3692 ... [1, 0, 1],
3693 ... [1, 1, 0]],
3694 ... ["a", "b", "c", "a"],
3695 ... (2, 2, 2))
3696 >>> tf.edit_distance(hypothesis, truth, normalize=True)
3697 <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
3698 array([[inf, 1. ],
3699 [0.5, 1. ]], dtype=float32)>
3701 The operation returns a dense Tensor of shape `[2, 2]` with
3702 edit distances normalized by `truth` lengths.
3704 **Note**: It is possible to calculate edit distance between two
3705 sparse tensors with variable-length values. However, attempting to create
3706 them while eager execution is enabled will result in a `ValueError`.
3708 For the following inputs,
3710 ```python
3711 # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
3712 # (0,0) = ["a"]
3713 # (1,0) = ["b"]
3714 hypothesis = tf.sparse.SparseTensor(
3715 [[0, 0, 0],
3716 [1, 0, 0]],
3717 ["a", "b"],
3718 (2, 1, 1))
3720 # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
3721 # (0,0) = []
3722 # (0,1) = ["a"]
3723 # (1,0) = ["b", "c"]
3724 # (1,1) = ["a"]
3725 truth = tf.sparse.SparseTensor(
3726 [[0, 1, 0],
3727 [1, 0, 0],
3728 [1, 0, 1],
3729 [1, 1, 0]],
3730 ["a", "b", "c", "a"],
3731 (2, 2, 2))
3733 normalize = True
3735 # The output would be a dense Tensor of shape `(2,)`, with edit distances
3736 normalized by 'truth' lengths.
3737 # output => array([0., 0.5], dtype=float32)
3738 ```
3740 Args:
3741 hypothesis: A `SparseTensor` containing hypothesis sequences.
3742 truth: A `SparseTensor` containing truth sequences.
3743 normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
3744 length of `truth.`
3745 name: A name for the operation (optional).
3747 Returns:
3748 A dense `Tensor` with rank `R - 1`, where R is the rank of the
3749 `SparseTensor` inputs `hypothesis` and `truth`.
3751 Raises:
3752 TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
3753 """
3754 if not isinstance(
3755 hypothesis,
3756 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3757 raise TypeError("Hypothesis must be a SparseTensor.")
3758 if not isinstance(
3759 truth, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
3760 raise TypeError("Truth must be a SparseTensor.")
3762 return gen_array_ops.edit_distance(
3763 hypothesis.indices,
3764 hypothesis.values,
3765 hypothesis.dense_shape,
3766 truth.indices,
3767 truth.values,
3768 truth.dense_shape,
3769 normalize=normalize,
3770 name=name)
3773@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
3774def _FakeQuantWithMinMaxArgsGradient(op, grad):
3775 """Gradient for FakeQuantWithMinMaxArgs op."""
3776 return fake_quant_with_min_max_args_gradient(
3777 grad,
3778 op.inputs[0],
3779 min=op.get_attr("min"),
3780 max=op.get_attr("max"),
3781 num_bits=op.get_attr("num_bits"),
3782 narrow_range=op.get_attr("narrow_range"))
3785@ops.RegisterGradient("FakeQuantWithMinMaxVars")
3786def _FakeQuantWithMinMaxVarsGradient(op, grad):
3787 """Gradient for FakeQuantWithMinMaxVars op."""
3788 return fake_quant_with_min_max_vars_gradient(
3789 grad,
3790 op.inputs[0],
3791 op.inputs[1],
3792 op.inputs[2],
3793 num_bits=op.get_attr("num_bits"),
3794 narrow_range=op.get_attr("narrow_range"))
3797@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
3798def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
3799 """Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
3800 return fake_quant_with_min_max_vars_per_channel_gradient(
3801 grad,
3802 op.inputs[0],
3803 op.inputs[1],
3804 op.inputs[2],
3805 num_bits=op.get_attr("num_bits"),
3806 narrow_range=op.get_attr("narrow_range"))
3809@ops.RegisterGradient("QuantizeAndDequantizeV4")
3810def _QuantizeAndDequantizeV4Grad(op, grad):
3811 """Gradient for QuantizeAndDequantizeV4 op."""
3812 return quantize_and_dequantize_v4_grad(
3813 grad,
3814 op.inputs[0],
3815 op.inputs[1],
3816 op.inputs[2],
3817 axis=op.get_attr("axis"))
3820@ops.RegisterGradient("QuantizeAndDequantizeV4Grad")
3821def _QuantizeAndDequantizeV4GradGrad(op, grad):
3822 """Gradient for QuantizeAndDequantizeV4Grad op."""
3823 return _QuantizeAndDequantizeV4Grad(op, grad)
3826@tf_export("required_space_to_batch_paddings")
3827def required_space_to_batch_paddings(input_shape,
3828 block_shape,
3829 base_paddings=None,
3830 name=None):
3831 """Calculate padding required to make block_shape divide input_shape.
3833 This function can be used to calculate a suitable paddings argument for use
3834 with space_to_batch_nd and batch_to_space_nd.
3836 Args:
3837 input_shape: int32 Tensor of shape [N].
3838 block_shape: int32 Tensor of shape [N].
3839 base_paddings: Optional int32 Tensor of shape [N, 2]. Specifies the minimum
3840 amount of padding to use. All elements must be >= 0. If not specified,
3841 defaults to 0.
3842 name: string. Optional name prefix.
3844 Returns:
3845 (paddings, crops), where:
3847 `paddings` and `crops` are int32 Tensors of rank 2 and shape [N, 2]
3848 satisfying:
3850 paddings[i, 0] = base_paddings[i, 0].
3851 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
3852 (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0
3854 crops[i, 0] = 0
3855 crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]
3857 Raises: ValueError if called with incompatible shapes.
3858 """
3859 with ops.name_scope(name, "required_space_to_batch_paddings",
3860 [input_shape, block_shape]):
3861 input_shape = ops.convert_to_tensor(
3862 input_shape, dtype=dtypes.int32, name="input_shape")
3863 block_shape = ops.convert_to_tensor(
3864 block_shape, dtype=dtypes.int32, name="block_shape")
3866 block_shape.get_shape().assert_is_fully_defined()
3867 block_shape.get_shape().assert_has_rank(1)
3868 num_block_dims = block_shape.get_shape().dims[0].value
3869 if num_block_dims == 0:
3870 return zeros([0, 2], dtypes.int32), zeros([0, 2], dtypes.int32)
3872 input_shape.get_shape().assert_is_compatible_with([num_block_dims])
3874 if base_paddings is not None:
3875 base_paddings = ops.convert_to_tensor(
3876 base_paddings, dtype=dtypes.int32, name="base_paddings")
3877 base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2])
3878 else:
3879 base_paddings = zeros([num_block_dims, 2], dtypes.int32)
3881 const_block_shape = tensor_util.constant_value(block_shape)
3882 const_input_shape = tensor_util.constant_value(input_shape)
3883 const_base_paddings = tensor_util.constant_value(base_paddings)
3884 if (const_block_shape is not None and const_input_shape is not None and
3885 const_base_paddings is not None):
3886 block_shape = const_block_shape
3887 input_shape = const_input_shape
3888 base_paddings = const_base_paddings
3890 # Use same expression for both constant and non-constant case.
3891 pad_start = base_paddings[:, 0]
3892 orig_pad_end = base_paddings[:, 1]
3893 full_input_shape = input_shape + pad_start + orig_pad_end
3894 pad_end_extra = (block_shape - full_input_shape % block_shape) % block_shape
3895 pad_end = orig_pad_end + pad_end_extra
3897 result_paddings = array_ops_stack.stack(
3898 [[pad_start[i], pad_end[i]] for i in range(num_block_dims)],
3899 name="paddings")
3900 result_crops = array_ops_stack.stack(
3901 [[0, pad_end_extra[i]] for i in range(num_block_dims)], name="crops")
3902 return result_paddings, result_crops
3905@tf_export(v1=["nn.space_to_batch", "space_to_batch"])
3906@dispatch.add_dispatch_support
3907@deprecation.deprecated_endpoints("space_to_batch")
3908def space_to_batch( # pylint: disable=missing-docstring
3909 input, # pylint: disable=redefined-builtin
3910 paddings,
3911 block_size=None,
3912 name=None,
3913 block_shape=None): # pylint: disable=redefined-builtin
3914 block_size = deprecation.deprecated_argument_lookup("block_shape",
3915 block_shape, "block_size",
3916 block_size)
3917 result = space_to_batch_nd(
3918 input,
3919 paddings=paddings,
3920 block_shape=np.array([block_size, block_size], dtype=np.int64),
3921 name=name)
3922 result.set_shape(result.get_shape().with_rank(4))
3923 return result
3926space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
3929@tf_export("space_to_batch", "nn.space_to_batch", v1=[])
3930@dispatch.add_dispatch_support
3931def space_to_batch_v2(input, block_shape, paddings, name=None): # pylint: disable=redefined-builtin
3932 return space_to_batch_nd(input, block_shape, paddings, name)
3935space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__
3938@tf_export(v1=["nn.space_to_depth", "space_to_depth"])
3939@dispatch.add_dispatch_support
3940@deprecation.deprecated_endpoints("space_to_depth")
3941def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
3942 return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
3945space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
3948@tf_export("nn.space_to_depth", v1=[])
3949@dispatch.add_dispatch_support
3950def space_to_depth_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
3951 return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
3954space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__
3957@tf_export(v1=["nn.depth_to_space", "depth_to_space"])
3958@dispatch.add_dispatch_support
3959@deprecation.deprecated_endpoints("depth_to_space")
3960def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
3961 return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
3964depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
3967@tf_export("nn.depth_to_space", v1=[])
3968@dispatch.add_dispatch_support
3969def depth_to_space_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin
3970 return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
3973depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__
3976@tf_export(v1=["batch_to_space"])
3977@dispatch.add_dispatch_support
3978def batch_to_space(input, crops, block_size, name=None, block_shape=None): # pylint: disable=redefined-builtin,missing-docstring
3979 block_size = deprecation.deprecated_argument_lookup("block_shape",
3980 block_shape, "block_size",
3981 block_size)
3982 result = batch_to_space_nd(
3983 input,
3984 crops=crops,
3985 block_shape=np.array([block_size, block_size], dtype=np.int64),
3986 name=name)
3987 result.set_shape(result.get_shape().with_rank(4))
3988 return result
3991batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
3994@tf_export("batch_to_space", v1=[])
3995@dispatch.add_dispatch_support
3996def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin
3997 """BatchToSpace for N-D tensors of type T.
3999 This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
4000 shape `block_shape + [batch]`, interleaves these blocks back into the grid
4001 defined by the spatial dimensions `[1, ..., M]`, to obtain a result with the
4002 same rank as the input. The spatial dimensions of this intermediate result
4003 are then optionally cropped according to `crops` to produce the output. This
4004 is the reverse of SpaceToBatch (see `tf.space_to_batch`).
4006 Args:
4007 input: A N-D `Tensor` with shape `input_shape = [batch] + spatial_shape +
4008 remaining_shape`, where `spatial_shape` has M dimensions.
4009 block_shape: A 1-D `Tensor` with shape [M]. Must be one of the following
4010 types: `int32`, `int64`. All values must be >= 1. For backwards
4011 compatibility with TF 1.0, this parameter may be an int, in which case it
4012 is converted to
4013 `numpy.array([block_shape, block_shape],
4014 dtype=numpy.int64)`.
4015 crops: A 2-D `Tensor` with shape `[M, 2]`. Must be one of the
4016 following types: `int32`, `int64`. All values must be >= 0.
4017 `crops[i] = [crop_start, crop_end]` specifies the amount to crop from
4018 input dimension `i + 1`, which corresponds to spatial dimension `i`.
4019 It is required that
4020 `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`.
4021 This operation is equivalent to the following steps:
4022 1. Reshape `input` to `reshaped` of shape: [block_shape[0], ...,
4023 block_shape[M-1], batch / prod(block_shape), input_shape[1], ...,
4024 input_shape[N-1]]
4025 2. Permute dimensions of `reshaped` to produce `permuted` of shape
4026 [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
4027 input_shape[M], block_shape[M-1], input_shape[M+1],
4028 ..., input_shape[N-1]]
4029 3. Reshape `permuted` to produce `reshaped_permuted` of shape
4030 [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
4031 input_shape[M] * block_shape[M-1], input_shape[M+1], ...,
4032 input_shape[N-1]]
4033 4. Crop the start and end of dimensions `[1, ..., M]` of
4034 `reshaped_permuted` according to `crops` to produce the output
4035 of shape:
4036 [batch / prod(block_shape), input_shape[1] *
4037 block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] *
4038 block_shape[M-1] - crops[M-1,0] - crops[M-1,1], input_shape[M+1],
4039 ..., input_shape[N-1]]
4040 name: A name for the operation (optional).
4042 Examples:
4044 1. For the following input of shape `[4, 1, 1, 1]`,
4045 `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4047 ```python
4048 [[[[1]]],
4049 [[[2]]],
4050 [[[3]]],
4051 [[[4]]]]
4052 ```
4054 The output tensor has shape `[1, 2, 2, 1]` and value:
4056 ```
4057 x = [[[[1], [2]],
4058 [[3], [4]]]]
4059 ```
4061 2. For the following input of shape `[4, 1, 1, 3]`,
4062 `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4064 ```python
4065 [[[1, 2, 3]],
4066 [[4, 5, 6]],
4067 [[7, 8, 9]],
4068 [[10, 11, 12]]]
4069 ```
4071 The output tensor has shape `[1, 2, 2, 3]` and value:
4073 ```python
4074 x = [[[[1, 2, 3], [4, 5, 6 ]],
4075 [[7, 8, 9], [10, 11, 12]]]]
4076 ```
4078 3. For the following
4079 input of shape `[4, 2, 2, 1]`,
4080 `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
4082 ```python
4083 x = [[[[1], [3]], [[ 9], [11]]],
4084 [[[2], [4]], [[10], [12]]],
4085 [[[5], [7]], [[13], [15]]],
4086 [[[6], [8]], [[14], [16]]]]
4087 ```
4089 The output tensor has shape `[1, 4, 4, 1]` and value:
4091 ```python
4092 x = [[[1], [2], [ 3], [ 4]],
4093 [[5], [6], [ 7], [ 8]],
4094 [[9], [10], [11], [12]],
4095 [[13], [14], [15], [16]]]
4096 ```
4098 4. For the following input of shape
4099 `[8, 1, 3, 1]`,
4100 `block_shape = [2, 2]`, and `crops = [[0, 0], [2, 0]]`:
4102 ```python
4103 x = [[[[0], [ 1], [ 3]]],
4104 [[[0], [ 9], [11]]],
4105 [[[0], [ 2], [ 4]]],
4106 [[[0], [10], [12]]],
4107 [[[0], [ 5], [ 7]]],
4108 [[[0], [13], [15]]],
4109 [[[0], [ 6], [ 8]]],
4110 [[[0], [14], [16]]]]
4111 ```
4113 The output tensor has shape `[2, 2, 4, 1]` and value:
4115 ```python
4116 x = [[[[ 1], [ 2], [ 3], [ 4]],
4117 [[ 5], [ 6], [ 7], [ 8]]],
4118 [[[ 9], [10], [11], [12]],
4119 [[13], [14], [15], [16]]]]
4120 ```
4122 Returns:
4123 A `Tensor`. Has the same type as `input`.
4124 """
4125 if isinstance(block_shape, int):
4126 block_shape = np.array([block_shape, block_shape], dtype=np.int64)
4128 return batch_to_space_nd(
4129 input=input, block_shape=block_shape, crops=crops, name=name)
4132@tf_export("one_hot")
4133@dispatch.add_dispatch_support
4134def one_hot(indices,
4135 depth,
4136 on_value=None,
4137 off_value=None,
4138 axis=None,
4139 dtype=None,
4140 name=None):
4141 """Returns a one-hot tensor.
4143 See also `tf.fill`, `tf.eye`.
4145 The locations represented by indices in `indices` take value `on_value`,
4146 while all other locations take value `off_value`.
4148 `on_value` and `off_value` must have matching data types. If `dtype` is also
4149 provided, they must be the same data type as specified by `dtype`.
4151 If `on_value` is not provided, it will default to the value `1` with type
4152 `dtype`
4154 If `off_value` is not provided, it will default to the value `0` with type
4155 `dtype`
4157 If the input `indices` is rank `N`, the output will have rank `N+1`. The
4158 new axis is created at dimension `axis` (default: the new axis is appended
4159 at the end).
4161 If `indices` is a scalar the output shape will be a vector of length `depth`
4163 If `indices` is a vector of length `features`, the output shape will be:
4165 ```
4166 features x depth if axis == -1
4167 depth x features if axis == 0
4168 ```
4170 If `indices` is a matrix (batch) with shape `[batch, features]`, the output
4171 shape will be:
4173 ```
4174 batch x features x depth if axis == -1
4175 batch x depth x features if axis == 1
4176 depth x batch x features if axis == 0
4177 ```
4179 If `indices` is a RaggedTensor, the 'axis' argument must be positive and refer
4180 to a non-ragged axis. The output will be equivalent to applying 'one_hot' on
4181 the values of the RaggedTensor, and creating a new RaggedTensor from the
4182 result.
4184 If `dtype` is not provided, it will attempt to assume the data type of
4185 `on_value` or `off_value`, if one or both are passed in. If none of
4186 `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
4187 value `tf.float32`.
4189 Note: If a non-numeric data type output is desired (`tf.string`, `tf.bool`,
4190 etc.), both `on_value` and `off_value` _must_ be provided to `one_hot`.
4192 For example:
4194 ```python
4195 indices = [0, 1, 2]
4196 depth = 3
4197 tf.one_hot(indices, depth) # output: [3 x 3]
4198 # [[1., 0., 0.],
4199 # [0., 1., 0.],
4200 # [0., 0., 1.]]
4202 indices = [0, 2, -1, 1]
4203 depth = 3
4204 tf.one_hot(indices, depth,
4205 on_value=5.0, off_value=0.0,
4206 axis=-1) # output: [4 x 3]
4207 # [[5.0, 0.0, 0.0], # one_hot(0)
4208 # [0.0, 0.0, 5.0], # one_hot(2)
4209 # [0.0, 0.0, 0.0], # one_hot(-1)
4210 # [0.0, 5.0, 0.0]] # one_hot(1)
4212 indices = [[0, 2], [1, -1]]
4213 depth = 3
4214 tf.one_hot(indices, depth,
4215 on_value=1.0, off_value=0.0,
4216 axis=-1) # output: [2 x 2 x 3]
4217 # [[[1.0, 0.0, 0.0], # one_hot(0)
4218 # [0.0, 0.0, 1.0]], # one_hot(2)
4219 # [[0.0, 1.0, 0.0], # one_hot(1)
4220 # [0.0, 0.0, 0.0]]] # one_hot(-1)
4222 indices = tf.ragged.constant([[0, 1], [2]])
4223 depth = 3
4224 tf.one_hot(indices, depth) # output: [2 x None x 3]
4225 # [[[1., 0., 0.],
4226 # [0., 1., 0.]],
4227 # [[0., 0., 1.]]]
4228 ```
4230 Args:
4231 indices: A `Tensor` of indices.
4232 depth: A scalar defining the depth of the one hot dimension.
4233 on_value: A scalar defining the value to fill in output when `indices[j]
4234 = i`. (default: 1)
4235 off_value: A scalar defining the value to fill in output when `indices[j]
4236 != i`. (default: 0)
4237 axis: The axis to fill (default: -1, a new inner-most axis).
4238 dtype: The data type of the output tensor.
4239 name: A name for the operation (optional).
4241 Returns:
4242 output: The one-hot tensor.
4244 Raises:
4245 TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
4246 TypeError: If dtype of `on_value` and `off_value` don't match one another
4247 """
4248 with ops.name_scope(
4249 name, "one_hot",
4250 [indices, depth, on_value, off_value, axis, dtype]) as name:
4251 on_exists = on_value is not None
4252 off_exists = off_value is not None
4254 if on_exists:
4255 on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
4256 if off_exists:
4257 off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
4259 on_dtype = on_value.dtype.base_dtype if on_exists else None
4260 off_dtype = off_value.dtype.base_dtype if off_exists else None
4262 if on_exists or off_exists:
4263 if dtype is not None:
4264 # Ensure provided on_value and/or off_value match dtype
4265 if on_exists and on_dtype != dtype:
4266 raise TypeError("dtype {0} of on_value does not match "
4267 "dtype parameter {1}".format(on_dtype, dtype))
4268 if off_exists and off_dtype != dtype:
4269 raise TypeError("dtype {0} of off_value does not match "
4270 "dtype parameter {1}".format(off_dtype, dtype))
4271 else:
4272 # dtype not provided: automatically assign it
4273 dtype = on_dtype if on_exists else off_dtype
4274 elif dtype is None:
4275 # None of on_value, off_value, or dtype provided. Default dtype to float32
4276 dtype = dtypes.float32
4278 if not on_exists:
4279 # on_value not provided: assign to value 1 of type dtype
4280 on_value = ops.convert_to_tensor(1, dtype, name="on_value")
4281 on_dtype = dtype
4282 if not off_exists:
4283 # off_value not provided: assign to value 0 of type dtype
4284 off_value = ops.convert_to_tensor(0, dtype, name="off_value")
4285 off_dtype = dtype
4287 if on_dtype != off_dtype:
4288 raise TypeError("dtype {0} of on_value does not match "
4289 "dtype {1} of off_value".format(on_dtype, off_dtype))
4291 return gen_array_ops.one_hot(indices, depth, on_value, off_value, axis,
4292 name)
4295def _all_dimensions(x):
4296 """Returns a 1D-tensor listing all dimensions in x."""
4297 # Fast path: avoid creating Rank and Range ops if ndims is known.
4298 if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
4299 return constant_op.constant(
4300 np.arange(x.get_shape().ndims), dtype=dtypes.int32)
4301 if (isinstance(x, sparse_tensor.SparseTensor) and
4302 x.dense_shape.get_shape().is_fully_defined()):
4303 r = x.dense_shape.get_shape().dims[0].value # sparse.dense_shape is 1-D.
4304 return constant_op.constant(np.arange(r), dtype=dtypes.int32)
4306 # Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
4307 return gen_math_ops._range(0, rank(x), 1)
4310@tf_export("sequence_mask")
4311@dispatch.add_dispatch_support
4312def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
4313 """Returns a mask tensor representing the first N positions of each cell.
4315 If `lengths` has shape `[d_1, d_2, ..., d_n]` the resulting tensor `mask` has
4316 dtype `dtype` and shape `[d_1, d_2, ..., d_n, maxlen]`, with
4318 ```
4319 mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
4320 ```
4322 Examples:
4324 ```python
4325 tf.sequence_mask([1, 3, 2], 5) # [[True, False, False, False, False],
4326 # [True, True, True, False, False],
4327 # [True, True, False, False, False]]
4329 tf.sequence_mask([[1, 3],[2,0]]) # [[[True, False, False],
4330 # [True, True, True]],
4331 # [[True, True, False],
4332 # [False, False, False]]]
4333 ```
4335 Args:
4336 lengths: integer tensor, all its values <= maxlen.
4337 maxlen: scalar integer tensor, size of last dimension of returned tensor.
4338 Default is the maximum value in `lengths`.
4339 dtype: output type of the resulting tensor.
4340 name: name of the op.
4342 Returns:
4343 A mask tensor of shape `lengths.shape + (maxlen,)`, cast to specified dtype.
4344 Raises:
4345 ValueError: if `maxlen` is not a scalar.
4346 """
4347 with ops.name_scope(name, "SequenceMask", [lengths, maxlen]):
4348 lengths = ops.convert_to_tensor(lengths)
4350 if maxlen is None:
4351 maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths))
4352 maxlen = gen_math_ops.maximum(constant(0, maxlen.dtype), maxlen)
4353 else:
4354 maxlen = ops.convert_to_tensor(maxlen)
4355 if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0:
4356 raise ValueError("Argument `maxlen` must be scalar for sequence_mask, "
4357 f"received `maxlen` = {maxlen} "
4358 f"with shape '{maxlen.get_shape()}' instead")
4360 # The basic idea is to compare a range row vector of size maxlen:
4361 # [0, 1, 2, 3, 4]
4362 # to length as a matrix with 1 column: [[1], [3], [2]].
4363 # Because of broadcasting on both arguments this comparison results
4364 # in a matrix of size (len(lengths), maxlen)
4365 row_vector = gen_math_ops._range(
4366 constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype))
4367 # Since maxlen >= max(lengths), it is safe to use maxlen as a cast
4368 # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths.
4369 matrix = gen_math_ops.cast(expand_dims(lengths, -1), maxlen.dtype)
4370 result = row_vector < matrix
4371 if dtype is None or result.dtype.is_compatible_with(dtype):
4372 return result
4373 else:
4374 return gen_math_ops.cast(result, dtype)
4377@tf_export(v1=["squeeze"])
4378@dispatch.add_dispatch_support
4379@deprecation.deprecated_args(None, "Use the `axis` argument instead",
4380 "squeeze_dims")
4381def squeeze(input, axis=None, name=None, squeeze_dims=None):
4382 # pylint: disable=redefined-builtin
4383 """Removes dimensions of size 1 from the shape of a tensor.
4385 Given a tensor `input`, this operation returns a tensor of the same type with
4386 all dimensions of size 1 removed. If you don't want to remove all size 1
4387 dimensions, you can remove specific size 1 dimensions by specifying
4388 `axis`.
4390 For example:
4392 >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4393 >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4394 >>> print(tf.shape(tf.squeeze(t)).numpy())
4395 [2 3]
4397 Or, to remove specific size 1 dimensions:
4399 >>> # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4400 >>> t = tf.ones([1, 2, 1, 3, 1, 1])
4401 >>> print(tf.shape(tf.squeeze(t, [2, 4])).numpy())
4402 [1 2 3 1]
4404 Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4405 time, where `N` is the number of elements in the squeezed dimensions.
4407 Args:
4408 input: A `Tensor`. The `input` to squeeze.
4409 axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4410 squeezes the dimensions listed. The dimension index starts at 0. It is an
4411 error to squeeze a dimension that is not 1. Must be in the range
4412 `[-rank(input), rank(input))`. Must be specified if `input` is a
4413 `RaggedTensor`.
4414 name: A name for the operation (optional).
4415 squeeze_dims: Deprecated keyword argument that is now axis.
4417 Returns:
4418 A `Tensor`. Has the same type as `input`.
4419 Contains the same data as `input`, but has one or more dimensions of
4420 size 1 removed.
4422 Raises:
4423 ValueError: When both `squeeze_dims` and `axis` are specified.
4424 """
4425 axis = deprecation.deprecated_argument_lookup("axis", axis, "squeeze_dims",
4426 squeeze_dims)
4427 if np.isscalar(axis):
4428 axis = [axis]
4429 return gen_array_ops.squeeze(input, axis, name)
4432@tf_export("squeeze", v1=[])
4433@dispatch.add_dispatch_support
4434def squeeze_v2(input, axis=None, name=None):
4435 """Removes dimensions of size 1 from the shape of a tensor.
4437 Given a tensor `input`, this operation returns a tensor of the same type with
4438 all dimensions of size 1 removed. If you don't want to remove all size 1
4439 dimensions, you can remove specific size 1 dimensions by specifying
4440 `axis`.
4442 For example:
4444 ```python
4445 # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4446 tf.shape(tf.squeeze(t)) # [2, 3]
4447 ```
4449 Or, to remove specific size 1 dimensions:
4451 ```python
4452 # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
4453 tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
4454 ```
4456 Unlike the older op `tf.compat.v1.squeeze`, this op does not accept a
4457 deprecated `squeeze_dims` argument.
4459 Note: if `input` is a `tf.RaggedTensor`, then this operation takes `O(N)`
4460 time, where `N` is the number of elements in the squeezed dimensions.
4462 Note: If squeeze is performed on dimensions of unknown sizes, then the
4463 returned Tensor will be of unknown shape. A common situation is when the
4464 first (batch) dimension is of size `None`, `tf.squeeze` returns
4465 `<unknown>` shape which may be a surprise. Specify the `axis=` argument
4466 to get the expected result, as illustrated in the following example:
4468 ```python
4469 @tf.function
4470 def func(x):
4471 print('x.shape:', x.shape)
4472 known_axes = [i for i, size in enumerate(x.shape) if size == 1]
4473 y = tf.squeeze(x, axis=known_axes)
4474 print('shape of tf.squeeze(x, axis=known_axes):', y.shape)
4475 y = tf.squeeze(x)
4476 print('shape of tf.squeeze(x):', y.shape)
4477 return 0
4479 _ = func.get_concrete_function(tf.TensorSpec([None, 1, 2], dtype=tf.int32))
4480 # Output is.
4481 # x.shape: (None, 1, 2)
4482 # shape of tf.squeeze(x, axis=known_axes): (None, 2)
4483 # shape of tf.squeeze(x): <unknown>
4484 ```
4486 Args:
4487 input: A `Tensor`. The `input` to squeeze.
4488 axis: An optional list of `ints`. Defaults to `[]`. If specified, only
4489 squeezes the dimensions listed. The dimension index starts at 0. It is an
4490 error to squeeze a dimension that is not 1. Must be in the range
4491 `[-rank(input), rank(input))`. Must be specified if `input` is a
4492 `RaggedTensor`.
4493 name: A name for the operation (optional).
4495 Returns:
4496 A `Tensor`. Has the same type as `input`.
4497 Contains the same data as `input`, but has one or more dimensions of
4498 size 1 removed.
4500 Raises:
4501 ValueError: The input cannot be converted to a tensor, or the specified
4502 axis cannot be squeezed.
4503 """
4504 # pylint: disable=redefined-builtin
4505 return squeeze(input, axis, name)
4508@tf_export(v1=["where"])
4509@dispatch.add_dispatch_support
4510def where(condition, x=None, y=None, name=None):
4511 """Return the elements, either from `x` or `y`, depending on the `condition`.
4513 If both `x` and `y` are None, then this operation returns the coordinates of
4514 true elements of `condition`. The coordinates are returned in a 2-D tensor
4515 where the first dimension (rows) represents the number of true elements, and
4516 the second dimension (columns) represents the coordinates of the true
4517 elements. Keep in mind, the shape of the output tensor can vary depending on
4518 how many true values there are in input. Indices are output in row-major
4519 order.
4521 If both non-None, `x` and `y` must have the same shape.
4522 The `condition` tensor must be a scalar if `x` and `y` are scalar.
4523 If `x` and `y` are tensors of higher rank, then `condition` must be either a
4524 vector with size matching the first dimension of `x`, or must have the same
4525 shape as `x`.
4527 The `condition` tensor acts as a mask that chooses, based on the value at each
4528 element, whether the corresponding element / row in the output should be taken
4529 from `x` (if true) or `y` (if false).
4531 If `condition` is a vector and `x` and `y` are higher rank matrices, then it
4532 chooses which row (outer dimension) to copy from `x` and `y`. If `condition`
4533 has the same shape as `x` and `y`, then it chooses which element to copy from
4534 `x` and `y`.
4536 Args:
4537 condition: A `Tensor` of type `bool`
4538 x: A Tensor which may have the same shape as `condition`. If `condition` is
4539 rank 1, `x` may have higher rank, but its first dimension must match the
4540 size of `condition`.
4541 y: A `tensor` with the same shape and type as `x`.
4542 name: A name of the operation (optional)
4544 Returns:
4545 A `Tensor` with the same type and shape as `x`, `y` if they are non-None.
4546 Otherwise, a `Tensor` with shape `(num_true, rank(condition))`.
4548 Raises:
4549 ValueError: When exactly one of `x` or `y` is non-None.
4551 @compatibility(TF2)
4553 This API is compatible with eager execution and `tf.function`. However, this
4554 is still a legacy API endpoint originally designed for TF1. To migrate to
4555 fully-native TF2, please replace its usage with `tf.where` instead, which is
4556 directly backwards compatible with `tf.compat.v1.where`.
4558 However,`tf.compat.v1.where` is more restrictive than `tf.where`, requiring
4559 `x` and `y` to have the same shape, and returning a `Tensor` with the same
4560 type and shape as `x`, `y` (if they are both non-None).
4562 `tf.where` will accept `x`, `y` that are not the same shape as long as they
4563 are broadcastable with one another and with `condition`, and will return a
4564 `Tensor` with shape broadcast from `condition`, `x`, and `y`.
4566 For example, the following works with `tf.where` but not `tf.compat.v1.where`:
4568 >>> tf.where([True, False, False, True], [1,2,3,4], [100])
4569 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 4],
4570 dtype=int32)>
4572 >>> tf.where(True, [1,2,3,4], 100)
4573 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
4574 dtype=int32)>
4576 @end_compatibility
4577 """
4578 if x is None and y is None:
4579 with ops.name_scope(name, "Where", [condition]) as name:
4580 condition = ops.convert_to_tensor(
4581 condition, preferred_dtype=dtypes.bool, name="condition")
4582 return gen_array_ops.where(condition=condition, name=name)
4583 elif x is not None and y is not None:
4584 return gen_math_ops.select(condition=condition, x=x, y=y, name=name)
4585 else:
4586 raise ValueError("x and y must both be non-None or both be None.")
4589@tf_export("where", v1=["where_v2"])
4590@dispatch.add_dispatch_support
4591def where_v2(condition, x=None, y=None, name=None):
4592 """Returns the indices of non-zero elements, or multiplexes `x` and `y`.
4594 This operation has two modes:
4596 1. **Return the indices of non-zero elements** - When only
4597 `condition` is provided the result is an `int64` tensor where each row is
4598 the index of a non-zero element of `condition`. The result's shape
4599 is `[tf.math.count_nonzero(condition), tf.rank(condition)]`.
4600 2. **Multiplex `x` and `y`** - When both `x` and `y` are provided the
4601 result has the shape of `x`, `y`, and `condition` broadcast together. The
4602 result is taken from `x` where `condition` is non-zero
4603 or `y` where `condition` is zero.
4605 #### 1. Return the indices of non-zero elements
4607 Note: In this mode `condition` can have a dtype of `bool` or any numeric
4608 dtype.
4610 If `x` and `y` are not provided (both are None):
4612 `tf.where` will return the indices of `condition` that are non-zero,
4613 in the form of a 2-D tensor with shape `[n, d]`, where `n` is the number of
4614 non-zero elements in `condition` (`tf.count_nonzero(condition)`), and `d` is
4615 the number of axes of `condition` (`tf.rank(condition)`).
4617 Indices are output in row-major order. The `condition` can have a `dtype` of
4618 `tf.bool`, or any numeric `dtype`.
4620 Here `condition` is a 1-axis `bool` tensor with 2 `True` values. The result
4621 has a shape of `[2,1]`
4623 >>> tf.where([True, False, False, True]).numpy()
4624 array([[0],
4625 [3]])
4627 Here `condition` is a 2-axis integer tensor, with 3 non-zero values. The
4628 result has a shape of `[3, 2]`.
4630 >>> tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
4631 array([[0, 0],
4632 [1, 0],
4633 [1, 2]])
4635 Here `condition` is a 3-axis float tensor, with 5 non-zero values. The output
4636 shape is `[5, 3]`.
4638 >>> float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
4639 ... [[0, 0], [0, 0], [99, 0]]]
4640 >>> tf.where(float_tensor).numpy()
4641 array([[0, 0, 0],
4642 [0, 1, 1],
4643 [0, 2, 0],
4644 [0, 2, 1],
4645 [1, 2, 0]])
4647 These indices are the same that `tf.sparse.SparseTensor` would use to
4648 represent the condition tensor:
4650 >>> sparse = tf.sparse.from_dense(float_tensor)
4651 >>> sparse.indices.numpy()
4652 array([[0, 0, 0],
4653 [0, 1, 1],
4654 [0, 2, 0],
4655 [0, 2, 1],
4656 [1, 2, 0]])
4658 A complex number is considered non-zero if either the real or imaginary
4659 component is non-zero:
4661 >>> tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
4662 array([[1],
4663 [2],
4664 [3]])
4666 #### 2. Multiplex `x` and `y`
4668 Note: In this mode `condition` must have a dtype of `bool`.
4670 If `x` and `y` are also provided (both have non-None values) the `condition`
4671 tensor acts as a mask that chooses whether the corresponding
4672 element / row in the output should be taken from `x` (if the element in
4673 `condition` is `True`) or `y` (if it is `False`).
4675 The shape of the result is formed by
4676 [broadcasting](https://docs.scipy.org/doc/numpy/reference/ufuncs.html)
4677 together the shapes of `condition`, `x`, and `y`.
4679 When all three inputs have the same size, each is handled element-wise.
4681 >>> tf.where([True, False, False, True],
4682 ... [1, 2, 3, 4],
4683 ... [100, 200, 300, 400]).numpy()
4684 array([ 1, 200, 300, 4], dtype=int32)
4686 There are two main rules for broadcasting:
4688 1. If a tensor has fewer axes than the others, length-1 axes are added to the
4689 left of the shape.
4690 2. Axes with length-1 are streched to match the coresponding axes of the other
4691 tensors.
4693 A length-1 vector is streched to match the other vectors:
4695 >>> tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
4696 array([ 1, 100, 100, 4], dtype=int32)
4698 A scalar is expanded to match the other arguments:
4700 >>> tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
4701 array([[ 1, 100], [100, 4]], dtype=int32)
4702 >>> tf.where([[True, False], [False, True]], 1, 100).numpy()
4703 array([[ 1, 100], [100, 1]], dtype=int32)
4705 A scalar `condition` returns the complete `x` or `y` tensor, with
4706 broadcasting applied.
4708 >>> tf.where(True, [1, 2, 3, 4], 100).numpy()
4709 array([1, 2, 3, 4], dtype=int32)
4710 >>> tf.where(False, [1, 2, 3, 4], 100).numpy()
4711 array([100, 100, 100, 100], dtype=int32)
4713 For a non-trivial example of broadcasting, here `condition` has a shape of
4714 `[3]`, `x` has a shape of `[3,3]`, and `y` has a shape of `[3,1]`.
4715 Broadcasting first expands the shape of `condition` to `[1,3]`. The final
4716 broadcast shape is `[3,3]`. `condition` will select columns from `x` and `y`.
4717 Since `y` only has one column, all columns from `y` will be identical.
4719 >>> tf.where([True, False, True],
4720 ... x=[[1, 2, 3],
4721 ... [4, 5, 6],
4722 ... [7, 8, 9]],
4723 ... y=[[100],
4724 ... [200],
4725 ... [300]]
4726 ... ).numpy()
4727 array([[ 1, 100, 3],
4728 [ 4, 200, 6],
4729 [ 7, 300, 9]], dtype=int32)
4731 Note that if the gradient of either branch of the `tf.where` generates
4732 a `NaN`, then the gradient of the entire `tf.where` will be `NaN`. This is
4733 because the gradient calculation for `tf.where` combines the two branches, for
4734 performance reasons.
4736 A workaround is to use an inner `tf.where` to ensure the function has
4737 no asymptote, and to avoid computing a value whose gradient is `NaN` by
4738 replacing dangerous inputs with safe inputs.
4740 Instead of this,
4742 >>> x = tf.constant(0., dtype=tf.float32)
4743 >>> with tf.GradientTape() as tape:
4744 ... tape.watch(x)
4745 ... y = tf.where(x < 1., 0., 1. / x)
4746 >>> print(tape.gradient(y, x))
4747 tf.Tensor(nan, shape=(), dtype=float32)
4749 Although, the `1. / x` values are never used, its gradient is a `NaN` when
4750 `x = 0`. Instead, we should guard that with another `tf.where`
4752 >>> x = tf.constant(0., dtype=tf.float32)
4753 >>> with tf.GradientTape() as tape:
4754 ... tape.watch(x)
4755 ... safe_x = tf.where(tf.equal(x, 0.), 1., x)
4756 ... y = tf.where(x < 1., 0., 1. / safe_x)
4757 >>> print(tape.gradient(y, x))
4758 tf.Tensor(0.0, shape=(), dtype=float32)
4760 See also:
4762 * `tf.sparse` - The indices returned by the first form of `tf.where` can be
4763 useful in `tf.sparse.SparseTensor` objects.
4764 * `tf.gather_nd`, `tf.scatter_nd`, and related ops - Given the
4765 list of indices returned from `tf.where` the `scatter` and `gather` family
4766 of ops can be used fetch values or insert values at those indices.
4767 * `tf.strings.length` - `tf.string` is not an allowed dtype for the
4768 `condition`. Use the string length instead.
4770 Args:
4771 condition: A `tf.Tensor` of dtype bool, or any numeric dtype. `condition`
4772 must have dtype `bool` when `x` and `y` are provided.
4773 x: If provided, a Tensor which is of the same type as `y`, and has a shape
4774 broadcastable with `condition` and `y`.
4775 y: If provided, a Tensor which is of the same type as `x`, and has a shape
4776 broadcastable with `condition` and `x`.
4777 name: A name of the operation (optional).
4779 Returns:
4780 If `x` and `y` are provided:
4781 A `Tensor` with the same type as `x` and `y`, and shape that
4782 is broadcast from `condition`, `x`, and `y`.
4783 Otherwise, a `Tensor` with shape `[tf.math.count_nonzero(condition),
4784 tf.rank(condition)]`.
4786 Raises:
4787 ValueError: When exactly one of `x` or `y` is non-None, or the shapes
4788 are not all broadcastable.
4789 """
4790 if x is None and y is None:
4791 with ops.name_scope(name, "Where", [condition]) as name:
4792 condition = ops.convert_to_tensor(
4793 condition, preferred_dtype=dtypes.bool, name="condition")
4794 return gen_array_ops.where(condition=condition, name=name)
4795 elif x is not None and y is not None:
4796 return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
4797 else:
4798 raise ValueError("x and y must both be non-None or both be None.")
4801# pylint: disable=redefined-builtin
4802@tf_export(v1=["reverse_sequence"])
4803@deprecation.deprecated_args(None,
4804 "seq_dim is deprecated, use seq_axis instead",
4805 "seq_dim")
4806@deprecation.deprecated_args(None,
4807 "batch_dim is deprecated, use batch_axis instead",
4808 "batch_dim")
4809def reverse_sequence(input,
4810 seq_lengths,
4811 seq_axis=None,
4812 batch_axis=None,
4813 name=None,
4814 seq_dim=None,
4815 batch_dim=None):
4816 """Reverses variable length slices.
4818 This op first slices `input` along the dimension `batch_axis`, and for
4819 each slice `i`, reverses the first `seq_lengths[i]` elements along the
4820 dimension `seq_axis`.
4822 The elements of `seq_lengths` must obey `seq_lengths[i] <=
4823 input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4824 `input.dims[batch_axis]`.
4826 The output slice `i` along dimension `batch_axis` is then given by
4827 input slice `i`, with the first `seq_lengths[i]` slices along
4828 dimension `seq_axis` reversed.
4830 Example usage:
4832 >>> seq_lengths = [7, 2, 3, 5]
4833 >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4834 ... [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4835 >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4836 >>> output
4837 <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4838 array([[0, 0, 5, 4, 3, 2, 1, 0],
4839 [2, 1, 0, 0, 0, 0, 0, 0],
4840 [3, 2, 1, 4, 0, 0, 0, 0],
4841 [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4843 Args:
4844 input: A `Tensor`. The input to reverse.
4845 seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4846 `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4847 input.dims(seq_axis)`
4848 seq_axis: An `int`. The dimension which is partially reversed.
4849 batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4850 reversal is performed.
4851 name: A name for the operation (optional).
4853 Returns:
4854 A Tensor. Has the same type as input.
4855 """
4856 seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
4857 "seq_dim", seq_dim)
4858 batch_axis = deprecation.deprecated_argument_lookup("batch_axis", batch_axis,
4859 "batch_dim", batch_dim)
4860 return gen_array_ops.reverse_sequence(
4861 input=input,
4862 seq_lengths=seq_lengths,
4863 seq_dim=seq_axis,
4864 batch_dim=batch_axis,
4865 name=name)
4868@tf_export("reverse_sequence", v1=[])
4869@dispatch.add_dispatch_support
4870def reverse_sequence_v2(input,
4871 seq_lengths,
4872 seq_axis=None,
4873 batch_axis=None,
4874 name=None):
4875 """Reverses variable length slices.
4877 This op first slices `input` along the dimension `batch_axis`, and for
4878 each slice `i`, reverses the first `seq_lengths[i]` elements along the
4879 dimension `seq_axis`.
4881 The elements of `seq_lengths` must obey `seq_lengths[i] <=
4882 input.dims[seq_axis]`, and `seq_lengths` must be a vector of length
4883 `input.dims[batch_axis]`.
4885 The output slice `i` along dimension `batch_axis` is then given by
4886 input slice `i`, with the first `seq_lengths[i]` slices along
4887 dimension `seq_axis` reversed.
4889 Example usage:
4891 >>> seq_lengths = [7, 2, 3, 5]
4892 >>> input = [[1, 2, 3, 4, 5, 0, 0, 0], [1, 2, 0, 0, 0, 0, 0, 0],
4893 ... [1, 2, 3, 4, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8]]
4894 >>> output = tf.reverse_sequence(input, seq_lengths, seq_axis=1, batch_axis=0)
4895 >>> output
4896 <tf.Tensor: shape=(4, 8), dtype=int32, numpy=
4897 array([[0, 0, 5, 4, 3, 2, 1, 0],
4898 [2, 1, 0, 0, 0, 0, 0, 0],
4899 [3, 2, 1, 4, 0, 0, 0, 0],
4900 [5, 4, 3, 2, 1, 6, 7, 8]], dtype=int32)>
4902 Args:
4903 input: A `Tensor`. The input to reverse.
4904 seq_lengths: A `Tensor`. Must be one of the following types: `int32`,
4905 `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <=
4906 input.dims(seq_axis)`
4907 seq_axis: An `int`. The dimension which is partially reversed.
4908 batch_axis: An optional `int`. Defaults to `0`. The dimension along which
4909 reversal is performed.
4910 name: A name for the operation (optional).
4912 Returns:
4913 A Tensor. Has the same type as input.
4914 """
4915 return gen_array_ops.reverse_sequence(
4916 input=input,
4917 seq_lengths=seq_lengths,
4918 seq_dim=seq_axis,
4919 batch_dim=batch_axis,
4920 name=name)
4922# pylint: enable=redefined-builtin
4925@tf_export(v1=["gather"])
4926@dispatch.add_dispatch_support
4927@deprecation.deprecated_args(None,
4928 ("The `validate_indices` argument has no effect. "
4929 "Indices are always validated on CPU and never "
4930 "validated on GPU."),
4931 ("validate_indices", None))
4932def gather(params,
4933 indices,
4934 validate_indices=None,
4935 name=None,
4936 axis=None,
4937 batch_dims=0): # pylint: disable=g-doc-args
4938 r"""Gather slices from params axis `axis` according to indices.
4940 Gather slices from `params` axis `axis` according to `indices`. `indices`
4941 must be an integer tensor of any dimension (often 1-D).
4943 `Tensor.__getitem__` works for scalars, `tf.newaxis`, and
4944 [python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)
4946 `tf.gather` extends indexing to handle tensors of indices.
4948 In the simplest case it's identical to scalar indexing:
4950 >>> params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
4951 >>> params[3].numpy()
4952 b'p3'
4953 >>> tf.gather(params, 3).numpy()
4954 b'p3'
4956 The most common case is to pass a single axis tensor of indices (this
4957 can't be expressed as a python slice because the indices are not sequential):
4959 >>> indices = [2, 0, 2, 5]
4960 >>> tf.gather(params, indices).numpy()
4961 array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
4963 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
4964 <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png"
4965 alt>
4966 </div>
4968 The indices can have any shape. When the `params` has 1 axis, the
4969 output shape is equal to the input shape:
4971 >>> tf.gather(params, [[2, 0], [2, 5]]).numpy()
4972 array([[b'p2', b'p0'],
4973 [b'p2', b'p5']], dtype=object)
4975 The `params` may also have any shape. `gather` can select slices
4976 across any axis depending on the `axis` argument (which defaults to 0).
4977 Below it is used to gather first rows, then columns from a matrix:
4979 >>> params = tf.constant([[0, 1.0, 2.0],
4980 ... [10.0, 11.0, 12.0],
4981 ... [20.0, 21.0, 22.0],
4982 ... [30.0, 31.0, 32.0]])
4983 >>> tf.gather(params, indices=[3,1]).numpy()
4984 array([[30., 31., 32.],
4985 [10., 11., 12.]], dtype=float32)
4986 >>> tf.gather(params, indices=[2,1], axis=1).numpy()
4987 array([[ 2., 1.],
4988 [12., 11.],
4989 [22., 21.],
4990 [32., 31.]], dtype=float32)
4992 More generally: The output shape has the same shape as the input, with the
4993 indexed-axis replaced by the shape of the indices.
4995 >>> def result_shape(p_shape, i_shape, axis=0):
4996 ... return p_shape[:axis] + i_shape + p_shape[axis+1:]
4997 >>>
4998 >>> result_shape([1, 2, 3], [], axis=1)
4999 [1, 3]
5000 >>> result_shape([1, 2, 3], [7], axis=1)
5001 [1, 7, 3]
5002 >>> result_shape([1, 2, 3], [7, 5], axis=1)
5003 [1, 7, 5, 3]
5005 Here are some examples:
5007 >>> params.shape.as_list()
5008 [4, 3]
5009 >>> indices = tf.constant([[0, 2]])
5010 >>> tf.gather(params, indices=indices, axis=0).shape.as_list()
5011 [1, 2, 3]
5012 >>> tf.gather(params, indices=indices, axis=1).shape.as_list()
5013 [4, 1, 2]
5015 >>> params = tf.random.normal(shape=(5, 6, 7, 8))
5016 >>> indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
5017 >>> result = tf.gather(params, indices, axis=2)
5018 >>> result.shape.as_list()
5019 [5, 6, 10, 11, 8]
5021 This is because each index takes a slice from `params`, and
5022 places it at the corresponding location in the output. For the above example
5024 >>> # For any location in indices
5025 >>> a, b = 0, 1
5026 >>> tf.reduce_all(
5027 ... # the corresponding slice of the result
5028 ... result[:, :, a, b, :] ==
5029 ... # is equal to the slice of `params` along `axis` at the index.
5030 ... params[:, :, indices[a, b], :]
5031 ... ).numpy()
5032 True
5034 ### Batching:
5036 The `batch_dims` argument lets you gather different items from each element
5037 of a batch.
5039 Using `batch_dims=1` is equivalent to having an outer loop over the first
5040 axis of `params` and `indices`:
5042 >>> params = tf.constant([
5043 ... [0, 0, 1, 0, 2],
5044 ... [3, 0, 0, 0, 4],
5045 ... [0, 5, 0, 6, 0]])
5046 >>> indices = tf.constant([
5047 ... [2, 4],
5048 ... [0, 4],
5049 ... [1, 3]])
5051 >>> tf.gather(params, indices, axis=1, batch_dims=1).numpy()
5052 array([[1, 2],
5053 [3, 4],
5054 [5, 6]], dtype=int32)
5056 This is equivalent to:
5058 >>> def manually_batched_gather(params, indices, axis):
5059 ... batch_dims=1
5060 ... result = []
5061 ... for p,i in zip(params, indices):
5062 ... r = tf.gather(p, i, axis=axis-batch_dims)
5063 ... result.append(r)
5064 ... return tf.stack(result)
5065 >>> manually_batched_gather(params, indices, axis=1).numpy()
5066 array([[1, 2],
5067 [3, 4],
5068 [5, 6]], dtype=int32)
5070 Higher values of `batch_dims` are equivalent to multiple nested loops over
5071 the outer axes of `params` and `indices`. So the overall shape function is
5073 >>> def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):
5074 ... return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]
5075 >>>
5076 >>> batched_result_shape(
5077 ... p_shape=params.shape.as_list(),
5078 ... i_shape=indices.shape.as_list(),
5079 ... axis=1,
5080 ... batch_dims=1)
5081 [3, 2]
5083 >>> tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()
5084 [3, 2]
5086 This comes up naturally if you need to use the indices of an operation like
5087 `tf.argsort`, or `tf.math.top_k` where the last dimension of the indices
5088 indexes into the last dimension of input, at the corresponding location.
5089 In this case you can use `tf.gather(values, indices, batch_dims=-1)`.
5091 See also:
5093 * `tf.Tensor.__getitem__`: The direct tensor index operation (`t[]`), handles
5094 scalars and python-slices `tensor[..., 7, 1:-1]`
5095 * `tf.scatter`: A collection of operations similar to `__setitem__`
5096 (`t[i] = x`)
5097 * `tf.gather_nd`: An operation similar to `tf.gather` but gathers across
5098 multiple axis at once (it can gather elements of a matrix instead of rows
5099 or columns)
5100 * `tf.boolean_mask`, `tf.where`: Binary indexing.
5101 * `tf.slice` and `tf.strided_slice`: For lower level access to the
5102 implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]`)
5104 Args:
5105 params: The `Tensor` from which to gather values. Must be at least rank
5106 `axis + 1`.
5107 indices: The index `Tensor`. Must be one of the following types: `int32`,
5108 `int64`. The values must be in range `[0, params.shape[axis])`.
5109 validate_indices: Deprecated, does nothing. Indices are always validated on
5110 CPU, never validated on GPU.
5112 Caution: On CPU, if an out of bound index is found, an error is raised.
5113 On GPU, if an out of bound index is found, a 0 is stored in the
5114 corresponding output value.
5115 axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5116 `axis` in `params` to gather `indices` from. Must be greater than or equal
5117 to `batch_dims`. Defaults to the first non-batch dimension. Supports
5118 negative indexes.
5119 batch_dims: An `integer`. The number of batch dimensions. Must be less
5120 than or equal to `rank(indices)`.
5121 name: A name for the operation (optional).
5123 Returns:
5124 A `Tensor`. Has the same type as `params`.
5125 """
5126 del validate_indices
5128 if axis is None:
5129 axis = batch_dims
5130 if tensor_util.constant_value(axis) != 0:
5131 return gen_array_ops.gather_v2(
5132 params, indices, axis, batch_dims=batch_dims, name=name)
5133 try:
5134 # TODO(apassos) find a less bad way of detecting resource variables
5135 # without introducing a circular dependency.
5136 return params.sparse_read(indices, name=name)
5137 except AttributeError:
5138 return gen_array_ops.gather_v2(params, indices, axis, name=name)
5141@tf_export("gather", v1=[])
5142@dispatch.add_dispatch_support
5143def gather_v2(params,
5144 indices,
5145 validate_indices=None,
5146 axis=None,
5147 batch_dims=0,
5148 name=None):
5149 return gather(
5150 params,
5151 indices,
5152 validate_indices=validate_indices,
5153 name=name,
5154 axis=axis,
5155 batch_dims=batch_dims)
5158gather_v2.__doc__ = gather.__doc__
5161@tf_export(v1=["batch_gather"])
5162@dispatch.add_dispatch_support
5163@deprecation.deprecated(
5164 "2017-10-25", "`tf.batch_gather` is deprecated, please use `tf.gather` "
5165 "with `batch_dims=tf.rank(indices) - 1` instead.") # pylint: disable=missing-docstring
5166def batch_gather(params, indices, name=None):
5167 """Gather slices from params according to indices with leading batch dims."""
5168 with ops.name_scope(name, "BatchGather", [params, indices]):
5169 indices = ops.convert_to_tensor(indices, name="indices")
5170 params = ops.convert_to_tensor(params, name="params")
5171 if indices.shape.ndims is None:
5172 raise ValueError(
5173 "batch_gather does not allow indices with unknown shape.")
5174 return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
5177def _batch_gather(params, indices, batch_dims, axis=None):
5178 r"""Gather slices from params according to indices with leading batch dims.
5180 This operation assumes that the leading `batch_dims` dimensions of `indices`
5181 and `params` are batch dimensions; and performs a `tf.gather` operation within
5182 each batch. (If `batch_dims` is not specified, then it defaults to
5183 `rank(indices)-1`.) In the case in which `batch_dims==0`, this operation
5184 is equivalent to `tf.gather`.
5186 Args:
5187 params: A Tensor. The tensor from which to gather values.
5188 indices: A Tensor. Must be one of the following types: int32, int64. Index
5189 tensor. Must be in range `[0, params.shape[batch_dims]]`.
5190 batch_dims: An integer or none. The number of batch dimensions. Must be
5191 less than `rank(indices)`. Defaults to `rank(indices) - 1` if None.
5192 axis: A `Tensor`. Must be one of the following types: `int32`, `int64`. The
5193 `axis` in `params` to gather `indices` from. Must be greater than or equal
5194 to `batch_dims`. Defaults to the first non-batch dimension. Supports
5195 negative indexes.
5197 Returns:
5198 A Tensor. Has the same type as `params`.
5200 Raises:
5201 ValueError: if `indices` has an unknown shape.
5202 """
5203 if batch_dims is not None and not isinstance(batch_dims, int):
5204 raise TypeError("Argument `batch_dims` must be an int. "
5205 f"Received `batch_dims` = {batch_dims} instead")
5206 indices = ops.convert_to_tensor(indices, name="indices")
5207 params = ops.convert_to_tensor(params, name="params")
5209 indices_ndims = indices.shape.ndims
5210 if indices_ndims is None:
5211 raise ValueError("tf.gather does not allow indices with unknown "
5212 "rank when batch_dims is specified.")
5213 if batch_dims is None:
5214 batch_dims = indices_ndims - 1
5215 if batch_dims < 0:
5216 batch_dims += indices_ndims
5217 if batch_dims < 0 or batch_dims >= indices_ndims:
5218 raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less than "
5219 f"rank(`indices`) = {indices_ndims}")
5220 if params.shape.ndims is not None and batch_dims >= params.shape.ndims:
5221 raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less than "
5222 f"rank(`params`) = {params.shape.ndims}")
5224 # Handle axis by transposing the axis dimension to be the first non-batch
5225 # dimension, recursively calling batch_gather with axis=0, and then
5226 # transposing the result to put the pre-axis dimensions before the indices
5227 # dimensions.
5228 if axis is not None and axis != batch_dims:
5229 # Adjust axis to be positive.
5230 if not isinstance(axis, int):
5231 axis = tf.where(axis < 0, axis + array_ops.rank(params), axis)
5232 elif axis < 0 and params.shape.ndims is None:
5233 axis = axis + array_ops.rank(params)
5234 else:
5235 if (axis < -params.shape.ndims) or (axis >= params.shape.ndims):
5236 raise ValueError(f"Argument `axis` = {axis} out of range "
5237 f"[{-params.shape.ndims}, {params.shape.ndims})")
5238 if axis < 0:
5239 axis += params.shape.ndims
5240 if axis < batch_dims:
5241 raise ValueError(f"Argument `batch_dims` = {batch_dims} must be less "
5242 f"than or equal to argument `axis` = {axis}")
5244 # Move params[axis] up to params[batch_dims].
5245 perm = [
5246 list(range(batch_dims)), [axis],
5247 gen_math_ops._range(batch_dims, axis, 1),
5248 gen_math_ops._range(axis + 1, rank(params), 1)
5249 ]
5250 params = transpose(params, concat(perm, axis=0))
5252 result = _batch_gather(params, indices, batch_dims=batch_dims)
5254 # Move the result dimensions corresponding to params[batch_dims:axis]
5255 # to just before the dimensions corresponding to indices[batch_dims:].
5256 params_start = indices_ndims + axis - batch_dims
5257 perm = [
5258 list(range(batch_dims)),
5259 gen_math_ops._range(indices_ndims, params_start, 1),
5260 list(range(batch_dims, indices_ndims)),
5261 gen_math_ops._range(params_start, rank(result), 1)
5262 ]
5263 return transpose(result, perm=concat(perm, axis=0))
5265 indices_shape = shape(indices)
5266 params_shape = shape(params)
5267 batch_indices = indices
5268 indices_dtype = indices.dtype.base_dtype
5269 accum_dim_value = ones((), dtype=indices_dtype)
5270 # Use correct type for offset index computation
5271 casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
5272 for dim in range(batch_dims, 0, -1):
5273 dim_value = casted_params_shape[dim - 1]
5274 accum_dim_value *= casted_params_shape[dim]
5275 start = zeros((), dtype=indices_dtype)
5276 step = ones((), dtype=indices_dtype)
5277 dim_indices = gen_math_ops._range(start, dim_value, step)
5278 dim_indices *= accum_dim_value
5279 dim_shape = array_ops_stack.stack(
5280 [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
5281 batch_indices += reshape(dim_indices, dim_shape)
5283 flat_indices = reshape(batch_indices, [-1])
5284 outer_shape = params_shape[batch_dims + 1:]
5285 flat_inner_shape = gen_math_ops.prod(params_shape[:batch_dims + 1], [0],
5286 False)
5288 flat_params = reshape(params, concat([[flat_inner_shape], outer_shape],
5289 axis=0))
5290 flat_result = gather(flat_params, flat_indices)
5291 result = reshape(flat_result, concat([indices_shape, outer_shape], axis=0))
5292 final_shape = indices.get_shape()[:batch_dims].merge_with(
5293 params.get_shape()[:batch_dims])
5294 final_shape = final_shape.concatenate(indices.get_shape().dims[batch_dims:])
5295 final_shape = final_shape.concatenate(params.get_shape()[batch_dims + 1:])
5296 result.set_shape(final_shape)
5297 return result
5300@tf_export(v1=["gather_nd", "manip.gather_nd"])
5301@dispatch.add_dispatch_support
5302@deprecated_endpoints("manip.gather_nd")
5303def gather_nd(params, indices, name=None, batch_dims=0):
5304 r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
5306 `indices` is a `Tensor` of indices into `params`. The index vectors are
5307 arranged along the last axis of `indices`.
5309 This is similar to `tf.gather`, in which `indices` defines slices into the
5310 first dimension of `params`. In `tf.gather_nd`, `indices` defines slices into
5311 the first `N` dimensions of `params`, where `N = indices.shape[-1]`.
5313 Caution: On CPU, if an out of bound index is found, an error is returned.
5314 On GPU, if an out of bound index is found, a 0 is stored in the
5315 corresponding output value.
5317 ## Gathering scalars
5319 In the simplest case the vectors in `indices` index the full rank of `params`:
5321 >>> tf.gather_nd(
5322 ... indices=[[0, 0],
5323 ... [1, 1]],
5324 ... params = [['a', 'b'],
5325 ... ['c', 'd']]).numpy()
5326 array([b'a', b'd'], dtype=object)
5328 In this case the result has 1-axis fewer than `indices`, and each index vector
5329 is replaced by the scalar indexed from `params`.
5331 In this case the shape relationship is:
5333 ```
5334 index_depth = indices.shape[-1]
5335 assert index_depth == params.shape.rank
5336 result_shape = indices.shape[:-1]
5337 ```
5339 If `indices` has a rank of `K`, it is helpful to think `indices` as a
5340 (K-1)-dimensional tensor of indices into `params`.
5342 ## Gathering slices
5344 If the index vectors do not index the full rank of `params` then each location
5345 in the result contains a slice of params. This example collects rows from a
5346 matrix:
5348 >>> tf.gather_nd(
5349 ... indices = [[1],
5350 ... [0]],
5351 ... params = [['a', 'b', 'c'],
5352 ... ['d', 'e', 'f']]).numpy()
5353 array([[b'd', b'e', b'f'],
5354 [b'a', b'b', b'c']], dtype=object)
5356 Here `indices` contains `[2]` index vectors, each with a length of `1`.
5357 The index vectors each refer to rows of the `params` matrix. Each
5358 row has a shape of `[3]` so the output shape is `[2, 3]`.
5360 In this case, the relationship between the shapes is:
5362 ```
5363 index_depth = indices.shape[-1]
5364 outer_shape = indices.shape[:-1]
5365 assert index_depth <= params.shape.rank
5366 inner_shape = params.shape[index_depth:]
5367 output_shape = outer_shape + inner_shape
5368 ```
5370 It is helpful to think of the results in this case as tensors-of-tensors.
5371 The shape of the outer tensor is set by the leading dimensions of `indices`.
5372 While the shape of the inner tensors is the shape of a single slice.
5374 ## Batches
5376 Additionally, both `params` and `indices` can have `M` leading batch
5377 dimensions that exactly match. In this case `batch_dims` must be set to `M`.
5379 For example, to collect one row from each of a batch of matrices you could
5380 set the leading elements of the index vectors to be their location in the
5381 batch:
5383 >>> tf.gather_nd(
5384 ... indices = [[0, 1],
5385 ... [1, 0],
5386 ... [2, 4],
5387 ... [3, 2],
5388 ... [4, 1]],
5389 ... params=tf.zeros([5, 7, 3])).shape.as_list()
5390 [5, 3]
5392 The `batch_dims` argument lets you omit those leading location dimensions
5393 from the index:
5395 >>> tf.gather_nd(
5396 ... batch_dims=1,
5397 ... indices = [[1],
5398 ... [0],
5399 ... [4],
5400 ... [2],
5401 ... [1]],
5402 ... params=tf.zeros([5, 7, 3])).shape.as_list()
5403 [5, 3]
5405 This is equivalent to caling a separate `gather_nd` for each location in the
5406 batch dimensions.
5409 >>> params=tf.zeros([5, 7, 3])
5410 >>> indices=tf.zeros([5, 1])
5411 >>> batch_dims = 1
5412 >>>
5413 >>> index_depth = indices.shape[-1]
5414 >>> batch_shape = indices.shape[:batch_dims]
5415 >>> assert params.shape[:batch_dims] == batch_shape
5416 >>> outer_shape = indices.shape[batch_dims:-1]
5417 >>> assert index_depth <= params.shape.rank
5418 >>> inner_shape = params.shape[batch_dims + index_depth:]
5419 >>> output_shape = batch_shape + outer_shape + inner_shape
5420 >>> output_shape.as_list()
5421 [5, 3]
5423 ### More examples
5425 Indexing into a 3-tensor:
5427 >>> tf.gather_nd(
5428 ... indices = [[1]],
5429 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5430 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5431 array([[[b'a1', b'b1'],
5432 [b'c1', b'd1']]], dtype=object)
5436 >>> tf.gather_nd(
5437 ... indices = [[0, 1], [1, 0]],
5438 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5439 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5440 array([[b'c0', b'd0'],
5441 [b'a1', b'b1']], dtype=object)
5444 >>> tf.gather_nd(
5445 ... indices = [[0, 0, 1], [1, 0, 1]],
5446 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5447 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5448 array([b'b0', b'b1'], dtype=object)
5450 The examples below are for the case when only indices have leading extra
5451 dimensions. If both 'params' and 'indices' have leading batch dimensions, use
5452 the 'batch_dims' parameter to run gather_nd in batch mode.
5454 Batched indexing into a matrix:
5456 >>> tf.gather_nd(
5457 ... indices = [[[0, 0]], [[0, 1]]],
5458 ... params = [['a', 'b'], ['c', 'd']]).numpy()
5459 array([[b'a'],
5460 [b'b']], dtype=object)
5464 Batched slice indexing into a matrix:
5466 >>> tf.gather_nd(
5467 ... indices = [[[1]], [[0]]],
5468 ... params = [['a', 'b'], ['c', 'd']]).numpy()
5469 array([[[b'c', b'd']],
5470 [[b'a', b'b']]], dtype=object)
5473 Batched indexing into a 3-tensor:
5475 >>> tf.gather_nd(
5476 ... indices = [[[1]], [[0]]],
5477 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5478 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5479 array([[[[b'a1', b'b1'],
5480 [b'c1', b'd1']]],
5481 [[[b'a0', b'b0'],
5482 [b'c0', b'd0']]]], dtype=object)
5485 >>> tf.gather_nd(
5486 ... indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],
5487 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5488 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5489 array([[[b'c0', b'd0'],
5490 [b'a1', b'b1']],
5491 [[b'a0', b'b0'],
5492 [b'c1', b'd1']]], dtype=object)
5494 >>> tf.gather_nd(
5495 ... indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],
5496 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5497 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5498 array([[b'b0', b'b1'],
5499 [b'd0', b'c1']], dtype=object)
5502 Examples with batched 'params' and 'indices':
5504 >>> tf.gather_nd(
5505 ... batch_dims = 1,
5506 ... indices = [[1],
5507 ... [0]],
5508 ... params = [[['a0', 'b0'],
5509 ... ['c0', 'd0']],
5510 ... [['a1', 'b1'],
5511 ... ['c1', 'd1']]]).numpy()
5512 array([[b'c0', b'd0'],
5513 [b'a1', b'b1']], dtype=object)
5516 >>> tf.gather_nd(
5517 ... batch_dims = 1,
5518 ... indices = [[[1]], [[0]]],
5519 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5520 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5521 array([[[b'c0', b'd0']],
5522 [[b'a1', b'b1']]], dtype=object)
5524 >>> tf.gather_nd(
5525 ... batch_dims = 1,
5526 ... indices = [[[1, 0]], [[0, 1]]],
5527 ... params = [[['a0', 'b0'], ['c0', 'd0']],
5528 ... [['a1', 'b1'], ['c1', 'd1']]]).numpy()
5529 array([[b'c0'],
5530 [b'b1']], dtype=object)
5533 See also `tf.gather`.
5535 Args:
5536 params: A `Tensor`. The tensor from which to gather values.
5537 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
5538 Index tensor.
5539 name: A name for the operation (optional).
5540 batch_dims: An integer or a scalar 'Tensor'. The number of batch dimensions.
5542 Returns:
5543 A `Tensor`. Has the same type as `params`.
5544 """
5545 batch_dims_ = tensor_util.constant_value(batch_dims)
5546 if batch_dims_ is not None:
5547 batch_dims = int(batch_dims_)
5548 if batch_dims == 0:
5549 try:
5550 # TODO(apassos) find a less bad way of detecting resource variables
5551 # without introducing a circular dependency.
5552 return params.gather_nd(indices, name=name)
5553 except AttributeError:
5554 return gen_array_ops.gather_nd(params, indices, name=name)
5555 else:
5556 return batch_gather_nd(params, indices, batch_dims=batch_dims, name=name)
5559@tf_export("gather_nd", v1=[])
5560@dispatch.add_dispatch_support
5561def gather_nd_v2(params, indices, batch_dims=0, name=None):
5562 return gather_nd(params, indices, name=name, batch_dims=batch_dims)
5565gather_nd_v2.__doc__ = gather_nd.__doc__
5568def batch_gather_nd(params, indices, batch_dims, name=None):
5569 """gather_nd implementation with batch support."""
5570 with ops.name_scope(name, "BatchGatherND", [params, indices]):
5571 indices = ops.convert_to_tensor(indices, name="indices")
5572 params = ops.convert_to_tensor(params, name="params")
5574 if not isinstance(batch_dims, int):
5575 raise TypeError(f"Argument `batch_dims` must be an int; got {batch_dims}")
5576 if batch_dims < 0:
5577 raise ValueError("tf.gather_nd does not allow negative batch_dims.")
5578 params_ndims = params.shape.ndims
5579 indices_ndims = indices.shape.ndims
5580 if indices_ndims is not None and batch_dims >= indices_ndims:
5581 raise ValueError(f"Argument `batch_dims` = {batch_dims} must be "
5582 f"less than rank(`indices`) = {indices_ndims}")
5583 if params_ndims is not None and batch_dims >= params_ndims:
5584 raise ValueError(f"Argument `batch_dims` = {batch_dims} must be "
5585 f"less than rank(`params`) = {params_ndims}")
5587 expand = batch_dims == 0
5588 if expand:
5589 # Normally gather_nd will be called when batch_dims == 0.
5590 # But if this function is called with batch_dims = 0, e.g. for testing
5591 # purposes, this adds a dummy batch dimension to make batch_dims = 1.
5592 params = expand_dims(params, axis=0)
5593 indices = expand_dims(indices, axis=0)
5594 batch_dims = 1
5596 params_shape = shape(params)
5597 indices_shape = shape(indices)
5598 batch_shape = params_shape[:batch_dims]
5599 batch_size = gen_math_ops.prod(batch_shape, [0])
5600 index_internal_ndims = rank(indices) - batch_dims - 1
5601 indices_internal_shape = indices_shape[batch_dims:-1]
5603 # Assuming a 'params' with shape [b1, ..., bM, g1, ..., gN] and an 'indices'
5604 # with shape [b1, ..., bM, i1, ..., iK, C], where C <= N, we need to modify
5605 # 'indices' s.t. it has shape [i1, ..., iK, D], where D <= M + N and slices
5606 # to the entire 'params' tensor.
5607 # Assuming we have a batch of shape [B1, B2], we use meshgrid to create a
5608 # grid of size B1 x B2.
5609 batch_dim_list = array_ops_stack.unstack(batch_shape, axis=0)
5610 dim_ranges = [
5611 gen_math_ops.cast(gen_math_ops._range(0, x, 1), indices.dtype)
5612 for x in batch_dim_list
5613 ]
5614 mesh_list = meshgrid(*dim_ranges, indexing="ij") if dim_ranges else []
5615 # Then we flatten and stack the tensors to form a (B1.B2) by 2 matrix.
5616 flat_list = [reshape(x, shape=(-1,)) for x in mesh_list]
5617 index_grid = transpose(array_ops_stack.stack(flat_list, axis=0))
5618 # We need to concatenate these batch coordinates with the internal indices.
5619 # concat -> index_grid [B1.B2, 2] with indices [i1, ..., iK, C]
5620 # So we reshape them both to [(B1.B2), i1, ..., iK, *]
5621 index_grid_shape = shape(index_grid)
5622 index_grid = reshape(
5623 index_grid,
5624 concat([
5625 index_grid_shape[:1],
5626 ones(index_internal_ndims, dtype=dtypes.int32), index_grid_shape[1:]
5627 ],
5628 axis=0))
5629 tile_shape = concat(((1,), indices_internal_shape, (1,)), axis=0)
5630 index_grid = tile(index_grid, multiples=tile_shape)
5631 # index_grid now has shape [(B1.B2), i1, ..., iK, 2]
5632 flat_shape = concat(([batch_size], indices_shape[batch_dims:]), axis=0)
5633 flat_indices = reshape(indices, shape=flat_shape)
5634 # flat_indices now has shape [(B1.B2), i1, ..., iK, C]
5635 indices = concat((index_grid, flat_indices), axis=-1)
5636 # indices has shape [(B1.B2), i1, ..., iK, 2+C]
5637 out = gen_array_ops.gather_nd(params, indices)
5638 # out has shape [(B1.B2), i1, ..., iK, N-C]. Now we reshape batch to
5639 # its original form.
5640 out_shape = shape(out)
5641 out = reshape(out, shape=concat((batch_shape, out_shape[1:]), axis=0))
5642 if expand:
5643 out = squeeze(out, axis=0)
5644 return out
5647@deprecation.deprecated_endpoints("tensor_scatter_update")
5648@tf_export(
5649 "tensor_scatter_nd_update",
5650 v1=["tensor_scatter_nd_update", "tensor_scatter_update"])
5651@dispatch.add_dispatch_support
5652def tensor_scatter_nd_update(tensor, indices, updates, name=None):
5653 """Scatter `updates` into an existing tensor according to `indices`.
5655 This operation creates a new tensor by applying sparse `updates` to the
5656 input `tensor`. This is similar to an index assignment.
5658 ```
5659 # Not implemented: tensors cannot be updated inplace.
5660 tensor[indices] = updates
5661 ```
5663 If an out of bound index is found on CPU, an error is returned.
5665 > **WARNING**: There are some GPU specific semantics for this operation.
5666 >
5667 > - If an out of bound index is found, the index is ignored.
5668 > - The order in which updates are applied is nondeterministic, so the output
5669 > will be nondeterministic if `indices` contains duplicates.
5671 This operation is very similar to `tf.scatter_nd`, except that the updates are
5672 scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
5673 for the existing tensor cannot be re-used, a copy is made and updated.
5675 In general:
5677 * `indices` is an integer tensor - the indices to update in `tensor`.
5678 * `indices` has **at least two** axes, the last axis is the depth of the
5679 index vectors.
5680 * For each index vector in `indices` there is a corresponding entry in
5681 `updates`.
5682 * If the length of the index vectors matches the rank of the `tensor`, then
5683 the index vectors each point to scalars in `tensor` and each update is a
5684 scalar.
5685 * If the length of the index vectors is less than the rank of `tensor`, then
5686 the index vectors each point to the slices of `tensor` and shape of the updates
5687 must match that slice.
5689 Overall this leads to the following shape constraints:
5691 ```
5692 assert tf.rank(indices) >= 2
5693 index_depth = indices.shape[-1]
5694 batch_shape = indices.shape[:-1]
5695 assert index_depth <= tf.rank(tensor)
5696 outer_shape = tensor.shape[:index_depth]
5697 inner_shape = tensor.shape[index_depth:]
5698 assert updates.shape == batch_shape + inner_shape
5699 ```
5701 Typical usage is often much simpler than this general form, and it
5702 can be better understood starting with simple examples:
5704 ### Scalar updates
5706 The simplest usage inserts scalar elements into a tensor by index.
5707 In this case, the `index_depth` must equal the rank of the
5708 input `tensor`, slice each column of `indices` is an index into an axis of the
5709 input `tensor`.
5711 In this simplest case the shape constraints are:
5713 ```
5714 num_updates, index_depth = indices.shape.as_list()
5715 assert updates.shape == [num_updates]
5716 assert index_depth == tf.rank(tensor)`
5717 ```
5719 For example, to insert 4 scattered elements in a rank-1 tensor with
5720 8 elements.
5722 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
5723 <img style="width:100%"
5724 src="https://www.tensorflow.org/images/ScatterNd1.png">
5725 </div>
5727 This scatter operation would look like this:
5729 >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
5730 >>> indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1
5731 >>> updates = [9, 10, 11, 12] # num_updates == 4
5732 >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5733 tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
5735 The length (first axis) of `updates` must equal the length of the `indices`:
5736 `num_updates`. This is the number of updates being inserted. Each scalar
5737 update is inserted into `tensor` at the indexed location.
5739 For a higher rank input `tensor` scalar updates can be inserted by using an
5740 `index_depth` that matches `tf.rank(tensor)`:
5742 >>> tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
5743 >>> indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
5744 >>> updates = [5, 10] # num_updates == 2
5745 >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
5746 tf.Tensor(
5747 [[ 1 5]
5748 [ 1 1]
5749 [10 1]], shape=(3, 2), dtype=int32)
5751 ### Slice updates
5753 When the input `tensor` has more than one axis scatter can be used to update
5754 entire slices.
5756 In this case it's helpful to think of the input `tensor` as being a two level
5757 array-of-arrays. The shape of this two level array is split into the
5758 `outer_shape` and the `inner_shape`.
5760 `indices` indexes into the outer level of the input tensor (`outer_shape`).
5761 and replaces the sub-array at that location with the corresponding item from
5762 the `updates` list. The shape of each update is `inner_shape`.
5764 When updating a list of slices the shape constraints are:
5766 ```
5767 num_updates, index_depth = indices.shape.as_list()
5768 outer_shape = tensor.shape[:index_depth]
5769 inner_shape = tensor.shape[index_depth:]
5770 assert updates.shape == [num_updates, inner_shape]
5771 ```
5773 For example, to update rows of a `(6, 3)` `tensor`:
5775 >>> tensor = tf.zeros([6, 3], dtype=tf.int32)
5777 Use an index depth of one.
5779 >>> indices = tf.constant([[2], [4]]) # num_updates == 2, index_depth == 1
5780 >>> num_updates, index_depth = indices.shape.as_list()
5782 The `outer_shape` is `6`, the inner shape is `3`:
5784 >>> outer_shape = tensor.shape[:index_depth]
5785 >>> inner_shape = tensor.shape[index_depth:]
5787 2 rows are being indexed so 2 `updates` must be supplied.
5788 Each update must be shaped to match the `inner_shape`.
5790 >>> # num_updates == 2, inner_shape==3
5791 >>> updates = tf.constant([[1, 2, 3],
5792 ... [4, 5, 6]])
5794 Altogether this gives:
5796 >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5797 array([[0, 0, 0],
5798 [0, 0, 0],
5799 [1, 2, 3],
5800 [0, 0, 0],
5801 [4, 5, 6],
5802 [0, 0, 0]], dtype=int32)
5804 #### More slice update examples
5806 A tensor representing a batch of uniformly sized video clips naturally has 5
5807 axes: `[batch_size, time, width, height, channels]`.
5809 For example:
5811 >>> batch_size, time, width, height, channels = 13,11,7,5,3
5812 >>> video_batch = tf.zeros([batch_size, time, width, height, channels])
5814 To replace a selection of video clips:
5815 * Use an `index_depth` of 1 (indexing the `outer_shape`: `[batch_size]`)
5816 * Provide updates each with a shape matching the `inner_shape`:
5817 `[time, width, height, channels]`.
5819 To replace the first two clips with ones:
5821 >>> indices = [[0],[1]]
5822 >>> new_clips = tf.ones([2, time, width, height, channels])
5823 >>> tf.tensor_scatter_nd_update(video_batch, indices, new_clips)
5825 To replace a selection of frames in the videos:
5827 * `indices` must have an `index_depth` of 2 for the `outer_shape`:
5828 `[batch_size, time]`.
5829 * `updates` must be shaped like a list of images. Each update must have a
5830 shape, matching the `inner_shape`: `[width, height, channels]`.
5832 To replace the first frame of the first three video clips:
5834 >>> indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2
5835 >>> new_images = tf.ones([
5836 ... # num_updates=3, inner_shape=(width, height, channels)
5837 ... 3, width, height, channels])
5838 >>> tf.tensor_scatter_nd_update(video_batch, indices, new_images)
5840 ### Folded indices
5842 In simple cases it's convenient to think of `indices` and `updates` as
5843 lists, but this is not a strict requirement. Instead of a flat `num_updates`,
5844 the `indices` and `updates` can be folded into a `batch_shape`. This
5845 `batch_shape` is all axes of the `indices`, except for the innermost
5846 `index_depth` axis.
5848 ```
5849 index_depth = indices.shape[-1]
5850 batch_shape = indices.shape[:-1]
5851 ```
5853 Note: The one exception is that the `batch_shape` cannot be `[]`. You can't
5854 update a single index by passing indices with shape `[index_depth]`.
5856 `updates` must have a matching `batch_shape` (the axes before `inner_shape`).
5858 ```
5859 assert updates.shape == batch_shape + inner_shape
5860 ```
5862 Note: The result is equivalent to flattening the `batch_shape` axes of
5863 `indices` and `updates`. This generalization just avoids the need
5864 for reshapes when it is more natural to construct "folded" indices and
5865 updates.
5867 With this generalization the full shape constraints are:
5869 ```
5870 assert tf.rank(indices) >= 2
5871 index_depth = indices.shape[-1]
5872 batch_shape = indices.shape[:-1]
5873 assert index_depth <= tf.rank(tensor)
5874 outer_shape = tensor.shape[:index_depth]
5875 inner_shape = tensor.shape[index_depth:]
5876 assert updates.shape == batch_shape + inner_shape
5877 ```
5879 For example, to draw an `X` on a `(5,5)` matrix start with these indices:
5881 >>> tensor = tf.zeros([5,5])
5882 >>> indices = tf.constant([
5883 ... [[0,0],
5884 ... [1,1],
5885 ... [2,2],
5886 ... [3,3],
5887 ... [4,4]],
5888 ... [[0,4],
5889 ... [1,3],
5890 ... [2,2],
5891 ... [3,1],
5892 ... [4,0]],
5893 ... ])
5894 >>> indices.shape.as_list() # batch_shape == [2, 5], index_depth == 2
5895 [2, 5, 2]
5897 Here the `indices` do not have a shape of `[num_updates, index_depth]`, but a
5898 shape of `batch_shape+[index_depth]`.
5900 Since the `index_depth` is equal to the rank of `tensor`:
5902 * `outer_shape` is `(5,5)`
5903 * `inner_shape` is `()` - each update is scalar
5904 * `updates.shape` is `batch_shape + inner_shape == (5,2) + ()`
5906 >>> updates = [
5907 ... [1,1,1,1,1],
5908 ... [1,1,1,1,1],
5909 ... ]
5911 Putting this together gives:
5913 >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()
5914 array([[1., 0., 0., 0., 1.],
5915 [0., 1., 0., 1., 0.],
5916 [0., 0., 1., 0., 0.],
5917 [0., 1., 0., 1., 0.],
5918 [1., 0., 0., 0., 1.]], dtype=float32)
5920 Args:
5921 tensor: Tensor to copy/update.
5922 indices: Indices to update.
5923 updates: Updates to apply at the indices.
5924 name: Optional name for the operation.
5926 Returns:
5927 A new tensor with the given shape and updates applied according to the
5928 indices.
5929 """
5930 return gen_array_ops.tensor_scatter_update(
5931 tensor=tensor, indices=indices, updates=updates, name=name)
5934# Define quantize_v2 here in order to make name the second-to-last attribute,
5935# because round_mode was added later.
5936# (And also now because of 'axis' processing).
5937@tf_export(v1=["quantize_v2"])
5938@dispatch.add_dispatch_support
5939@deprecation.deprecated(
5940 "2017-10-25",
5941 "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
5942 "instead.") # pylint: disable=missing-docstring
5943def quantize_v2(
5944 input, # pylint: disable=redefined-builtin
5945 min_range,
5946 max_range,
5947 T,
5948 mode="MIN_COMBINED",
5949 name=None,
5950 round_mode="HALF_AWAY_FROM_ZERO",
5951 narrow_range=False,
5952 axis=None,
5953 ensure_minimum_range=0.01):
5954 if axis is None:
5955 axis = -1
5956 elif axis < 0:
5957 if input.shape.ndims is None:
5958 raise ValueError("input should have known rank to use negative axis.")
5959 axis %= input.shape.ndims
5961 if ensure_minimum_range != 0.01:
5962 return gen_array_ops.quantize_v2(
5963 input,
5964 min_range,
5965 max_range,
5966 T=T,
5967 mode=mode,
5968 name=name,
5969 round_mode=round_mode,
5970 narrow_range=narrow_range,
5971 axis=axis,
5972 ensure_minimum_range=ensure_minimum_range)
5973 return gen_array_ops.quantize_v2(
5974 input,
5975 min_range,
5976 max_range,
5977 T=T,
5978 mode=mode,
5979 name=name,
5980 round_mode=round_mode,
5981 narrow_range=narrow_range,
5982 axis=axis)
5985quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead."""
5988# We want to expose tf.quantization.quantize instead of
5989# tf.quantization.quantize; we can deprecate tf.quantization.quantize in next
5990# version of TensorFlow.
5991@tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"])
5992@dispatch.add_dispatch_support
5993@deprecation.deprecated_endpoints("quantize")
5994def quantize(
5995 input, # pylint: disable=redefined-builtin
5996 min_range,
5997 max_range,
5998 T,
5999 mode="MIN_COMBINED",
6000 round_mode="HALF_AWAY_FROM_ZERO",
6001 name=None,
6002 narrow_range=False,
6003 axis=None,
6004 ensure_minimum_range=0.01):
6005 """Quantize the input tensor."""
6006 if ensure_minimum_range != 0.01:
6007 return quantize_v2(
6008 input,
6009 min_range,
6010 max_range,
6011 T,
6012 mode=mode,
6013 round_mode=round_mode,
6014 name=name,
6015 narrow_range=narrow_range,
6016 axis=axis,
6017 ensure_minimum_range=ensure_minimum_range)
6018 return quantize_v2(
6019 input,
6020 min_range,
6021 max_range,
6022 T,
6023 mode=mode,
6024 round_mode=round_mode,
6025 name=name,
6026 narrow_range=narrow_range,
6027 axis=axis)
6030@tf_export("quantization.dequantize", v1=["quantization.dequantize",
6031 "dequantize"])
6032@dispatch.add_dispatch_support
6033@deprecation.deprecated_endpoints("dequantize")
6034def dequantize( # pylint: disable=missing-docstring
6035 input, # pylint: disable=redefined-builtin
6036 min_range,
6037 max_range,
6038 mode="MIN_COMBINED",
6039 name=None,
6040 axis=None,
6041 narrow_range=False,
6042 dtype=dtypes.float32):
6043 if axis is None:
6044 axis = -1
6045 elif axis < 0:
6046 if input.shape.ndims is None:
6047 raise ValueError("input should have known rank to use negative axis.")
6048 axis %= input.shape.ndims
6050 if axis >= 0 or narrow_range:
6051 return gen_array_ops.dequantize(
6052 input,
6053 min_range,
6054 max_range,
6055 mode=mode,
6056 name=name,
6057 narrow_range=narrow_range,
6058 axis=axis,
6059 dtype=dtype)
6060 return gen_array_ops.dequantize(
6061 input, min_range, max_range, mode=mode, name=name, dtype=dtype)
6064dequantize.__doc__ = gen_array_ops.dequantize.__doc__
6067@tf_export("quantization.quantize_and_dequantize")
6068@dispatch.add_dispatch_support
6069@deprecation.deprecated(None,
6070 "This Op has been deprecated, use" +
6071 "`quantize_and_dequantize_v2` instead. To " +
6072 "To simulate the V1 the behavior of " +
6073 "tf.quantization.quantize_and_dequantize(...) use " +
6074 "tf.grad_pass_through(" +
6075 "tf.quantization.quantize_and_dequantize_v2)(...).")
6076def quantize_and_dequantize(
6077 input, # pylint: disable=redefined-builtin
6078 input_min,
6079 input_max,
6080 signed_input=True,
6081 num_bits=8,
6082 range_given=False,
6083 round_mode="HALF_TO_EVEN",
6084 name=None,
6085 narrow_range=False,
6086 axis=None):
6087 """Quantizes then dequantizes a tensor.
6089 Args:
6090 input: A `Tensor` to quantize and dequantize.
6091 input_min: If range_given=True, the minimum input value, that needs to be
6092 represented in the quantized representation. If axis is specified, this
6093 should be a vector of minimum values for each slice along axis.
6094 input_max: If range_given=True, the maximum input value that needs to be
6095 represented in the quantized representation. If axis is specified, this
6096 should be a vector of maximum values for each slice along axis.
6097 signed_input: True if the quantization is signed or unsigned.
6098 num_bits: The bitwidth of the quantization.
6099 range_given: If true use `input_min` and `input_max` for the range of the
6100 input, otherwise determine min and max from the input `Tensor`.
6101 round_mode: Rounding mode when rounding from float values to quantized ones.
6102 one of ['HALF_TO_EVEN', 'HALF_UP']
6103 name: Optional name for the operation.
6104 narrow_range: If true, then the absolute value of the quantized minimum
6105 value is the same as the quantized maximum value, instead of 1 greater.
6106 i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
6107 axis: Integer. If specified, refers to a dimension of the input tensor, such
6108 that quantization will be per slice along that dimension.
6110 Returns:
6111 A `Tensor`. Each element is the result of quantizing and dequantizing the
6112 corresponding element of `input`.
6113 """
6114 if axis is None:
6115 axis = -1
6116 elif axis < 0:
6117 if input.shape.ndims is None:
6118 raise ValueError("input should have known rank to use negative axis.")
6119 axis %= input.shape.ndims
6121 return gen_array_ops.quantize_and_dequantize_v2(
6122 input,
6123 input_min=input_min,
6124 input_max=input_max,
6125 signed_input=signed_input,
6126 num_bits=num_bits,
6127 range_given=range_given,
6128 round_mode=round_mode,
6129 narrow_range=narrow_range,
6130 axis=axis,
6131 name=name)
6134@tf_export("quantization.quantize_and_dequantize_v2")
6135@dispatch.add_dispatch_support
6136def quantize_and_dequantize_v2(
6137 input, # pylint: disable=redefined-builtin
6138 input_min,
6139 input_max,
6140 signed_input=True,
6141 num_bits=8,
6142 range_given=False,
6143 round_mode="HALF_TO_EVEN",
6144 name=None,
6145 narrow_range=False,
6146 axis=None):
6147 """Quantizes then dequantizes a tensor.
6149 Updates the gradient definition for quantization that is outside the range to
6150 be 0.To simulate the V1 the behavior of
6151 tf.quantization.quantize_and_dequantize(...) use
6152 tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
6154 Example usage:
6156 ```python
6157 def getQuantizeOp(input):
6158 input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6159 net = tf.quantization.quantize_and_dequantize(input,
6160 input_min=min_threshold,
6161 input_max=max_threshold,
6162 range_given=True)
6164 To simulate v1 behavior:
6166 def testDecomposeQuantizeDequantize(self):
6167 def f(input_tensor):
6168 return tf.quantization.quantize_and_dequantize_v2(input_tensor,
6169 input_min = 5.0,
6170 input_max= -10.0,
6171 range_given=True)
6172 input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
6173 net = tf.grad_pass_through(f)(input_tensor)
6174 ```
6176 Args:
6177 input: A `Tensor` to quantize and dequantize.
6178 input_min: If range_given=True, the minimum input value, that needs to be
6179 represented in the quantized representation. If axis is specified, this
6180 should be a vector of minimum values for each slice along axis.
6181 input_max: If range_given=True, the maximum input value that needs to be
6182 represented in the quantized representation. If axis is specified, this
6183 should be a vector of maximum values for each slice along axis.
6184 signed_input: True if the quantization is signed or unsigned.
6185 num_bits: The bitwidth of the quantization.
6186 range_given: If true use `input_min` and `input_max` for the range of the
6187 input, otherwise determine min and max from the input `Tensor`.
6188 round_mode: Rounding mode when rounding from float values to quantized ones.
6189 one of ['HALF_TO_EVEN', 'HALF_UP']
6190 name: Optional name for the operation.
6191 narrow_range: If true, then the absolute value of the quantized minimum
6192 value is the same as the quantized maximum value, instead of 1 greater.
6193 i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
6194 axis: Integer. If specified, refers to a dimension of the input tensor, such
6195 that quantization will be per slice along that dimension.
6197 Returns:
6198 A `Tensor`. Each element is the result of quantizing and dequantizing the
6199 corresponding element of `input`.
6200 """
6201 if axis is None:
6202 axis = -1
6203 elif axis < 0:
6204 if input.shape.ndims is None:
6205 raise ValueError("input should have known rank to use negative axis.")
6206 axis %= input.shape.ndims
6208 return gen_array_ops.quantize_and_dequantize_v4(
6209 input,
6210 input_min=input_min,
6211 input_max=input_max,
6212 signed_input=signed_input,
6213 num_bits=num_bits,
6214 range_given=range_given,
6215 round_mode=round_mode,
6216 narrow_range=narrow_range,
6217 axis=axis,
6218 name=name)
6221@tf_export("searchsorted")
6222@dispatch.add_dispatch_support
6223def searchsorted(sorted_sequence,
6224 values,
6225 side="left",
6226 out_type=dtypes.int32,
6227 name=None):
6228 """Searches for where a value would go in a sorted sequence.
6230 This is not a method for checking containment (like python `in`).
6232 The typical use case for this operation is "binning", "bucketing", or
6233 "discretizing". The `values` are assigned to bucket-indices based on the
6234 **edges** listed in `sorted_sequence`. This operation
6235 returns the bucket-index for each value.
6237 >>> edges = [-1, 3.3, 9.1, 10.0]
6238 >>> values = [0.0, 4.1, 12.0]
6239 >>> tf.searchsorted(edges, values).numpy()
6240 array([1, 2, 4], dtype=int32)
6242 The `side` argument controls which index is returned if a value lands exactly
6243 on an edge:
6245 >>> seq = [0, 3, 9, 10, 10]
6246 >>> values = [0, 4, 10]
6247 >>> tf.searchsorted(seq, values).numpy()
6248 array([0, 2, 3], dtype=int32)
6249 >>> tf.searchsorted(seq, values, side="right").numpy()
6250 array([1, 2, 5], dtype=int32)
6252 The `axis` is not settable for this operation. It always operates on the
6253 innermost dimension (`axis=-1`). The operation will accept any number of
6254 outer dimensions. Here it is applied to the rows of a matrix:
6256 >>> sorted_sequence = [[0., 3., 8., 9., 10.],
6257 ... [1., 2., 3., 4., 5.]]
6258 >>> values = [[9.8, 2.1, 4.3],
6259 ... [0.1, 6.6, 4.5, ]]
6260 >>> tf.searchsorted(sorted_sequence, values).numpy()
6261 array([[4, 1, 2],
6262 [0, 5, 4]], dtype=int32)
6264 Note: This operation assumes that `sorted_sequence` **is sorted** along the
6265 innermost axis, maybe using `tf.sort(..., axis=-1)`. **If the sequence is not
6266 sorted, no error is raised** and the content of the returned tensor is not well
6267 defined.
6269 Args:
6270 sorted_sequence: N-D `Tensor` containing a sorted sequence.
6271 values: N-D `Tensor` containing the search values.
6272 side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
6273 upper_bound.
6274 out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
6275 name: Optional name for the operation.
6277 Returns:
6278 An N-D `Tensor` the size of `values` containing the result of applying
6279 either lower_bound or upper_bound (depending on side) to each value. The
6280 result is not a global index to the entire `Tensor`, but the index in the
6281 last dimension.
6283 Raises:
6284 ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
6285 If the total size of `values` exceeds `2^31 - 1` elements.
6286 If the first `N-1` dimensions of the two tensors don't match.
6287 """
6288 sequence_size = shape_internal(sorted_sequence)[-1]
6289 values_size = shape_internal(values)[-1]
6290 sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
6291 values_2d = reshape(values, [-1, values_size])
6292 if side == "right":
6293 output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
6294 name)
6295 elif side == "left":
6296 output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
6297 name)
6298 else:
6299 raise ValueError("Argument `side` must be either 'right' or 'left'. "
6300 f"Received: `side` = '{side}'.")
6301 return reshape(output, shape_internal(values))
6304quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
6307@tf_export("image.extract_patches")
6308@dispatch.add_dispatch_support
6309def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None):
6310 r"""Extract `patches` from `images`.
6312 This op collects patches from the input image, as if applying a
6313 convolution. All extracted patches are stacked in the depth (last) dimension
6314 of the output.
6316 Specifically, the op extracts patches of shape `sizes` which are `strides`
6317 apart in the input image. The output is subsampled using the `rates` argument,
6318 in the same manner as "atrous" or "dilated" convolutions.
6320 The result is a 4D tensor which is indexed by batch, row, and column.
6321 `output[i, x, y]` contains a flattened patch of size `sizes[1], sizes[2]`
6322 which is taken from the input starting at
6323 `images[i, x*strides[1], y*strides[2]]`.
6325 Each output patch can be reshaped to `sizes[1], sizes[2], depth`, where
6326 `depth` is `images.shape[3]`.
6328 The output elements are taken from the input at intervals given by the `rate`
6329 argument, as in dilated convolutions.
6331 The `padding` argument has no effect on the size of each patch, it determines
6332 how many patches are extracted. If `VALID`, only patches which are fully
6333 contained in the input image are included. If `SAME`, all patches whose
6334 starting point is inside the input are included, and areas outside the input
6335 default to zero.
6337 Example:
6339 ```
6340 n = 10
6341 # images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100
6342 images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
6344 # We generate two outputs as follows:
6345 # 1. 3x3 patches with stride length 5
6346 # 2. Same as above, but the rate is increased to 2
6347 tf.image.extract_patches(images=images,
6348 sizes=[1, 3, 3, 1],
6349 strides=[1, 5, 5, 1],
6350 rates=[1, 1, 1, 1],
6351 padding='VALID')
6353 # Yields:
6354 [[[[ 1 2 3 11 12 13 21 22 23]
6355 [ 6 7 8 16 17 18 26 27 28]]
6356 [[51 52 53 61 62 63 71 72 73]
6357 [56 57 58 66 67 68 76 77 78]]]]
6358 ```
6360 If we mark the pixels in the input image which are taken for the output with
6361 `*`, we see the pattern:
6363 ```
6364 * * * 4 5 * * * 9 10
6365 * * * 14 15 * * * 19 20
6366 * * * 24 25 * * * 29 30
6367 31 32 33 34 35 36 37 38 39 40
6368 41 42 43 44 45 46 47 48 49 50
6369 * * * 54 55 * * * 59 60
6370 * * * 64 65 * * * 69 70
6371 * * * 74 75 * * * 79 80
6372 81 82 83 84 85 86 87 88 89 90
6373 91 92 93 94 95 96 97 98 99 100
6374 ```
6376 ```
6377 tf.image.extract_patches(images=images,
6378 sizes=[1, 3, 3, 1],
6379 strides=[1, 5, 5, 1],
6380 rates=[1, 2, 2, 1],
6381 padding='VALID')
6383 # Yields:
6384 [[[[ 1 3 5 21 23 25 41 43 45]
6385 [ 6 8 10 26 28 30 46 48 50]]
6387 [[ 51 53 55 71 73 75 91 93 95]
6388 [ 56 58 60 76 78 80 96 98 100]]]]
6389 ```
6391 We can again draw the effect, this time using the symbols `*`, `x`, `+` and
6392 `o` to distinguish the patches:
6394 ```
6395 * 2 * 4 * x 7 x 9 x
6396 11 12 13 14 15 16 17 18 19 20
6397 * 22 * 24 * x 27 x 29 x
6398 31 32 33 34 35 36 37 38 39 40
6399 * 42 * 44 * x 47 x 49 x
6400 + 52 + 54 + o 57 o 59 o
6401 61 62 63 64 65 66 67 68 69 70
6402 + 72 + 74 + o 77 o 79 o
6403 81 82 83 84 85 86 87 88 89 90
6404 + 92 + 94 + o 97 o 99 o
6405 ```
6407 Args:
6408 images: A 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
6409 sizes: The size of the extracted patches. Must be
6410 `[1, size_rows, size_cols, 1]`.
6411 strides: A 1-D Tensor of length 4. How far the centers of two consecutive
6412 patches are in the images. Must be: `[1, stride_rows, stride_cols, 1]`.
6413 rates: A 1-D Tensor of length 4. Must be: `[1, rate_rows, rate_cols, 1]`.
6414 This is the input stride, specifying how far two consecutive patch samples
6415 are in the input. Equivalent to extracting patches with `patch_sizes_eff =
6416 patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling
6417 them spatially by a factor of `rates`. This is equivalent to `rate` in
6418 dilated (a.k.a. Atrous) convolutions.
6419 padding: The type of padding algorithm to use.
6420 name: A name for the operation (optional).
6422 Returns:
6423 A 4-D Tensor of the same type as the input.
6424 """
6425 return gen_array_ops.extract_image_patches(images, sizes, strides, rates,
6426 padding, name)
6429@tf_export(v1=["image.extract_image_patches", "extract_image_patches"])
6430@dispatch.add_dispatch_support
6431@deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead",
6432 "ksizes")
6433def extract_image_patches( # pylint: disable=missing-docstring
6434 images,
6435 ksizes=None,
6436 strides=None,
6437 rates=None,
6438 padding=None,
6439 name=None,
6440 sizes=None):
6441 """Extract patches from images and put them in the "depth" output dimension.
6443 Args:
6444 `images`: A `Tensor`. Must be one of the following types: `float32`,
6445 `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`,
6446 `uint16`, `half`, `uint32`, `uint64`. 4-D Tensor with shape
6447 `[batch, in_rows, in_cols, depth]`. `ksizes`: A list of `ints` that has
6448 length `>= 4`. The size of the sliding window for each
6449 dimension of `images`. `strides`: A list of `ints` that has length `>= 4`.
6450 1-D of length 4. How far the centers of two consecutive
6451 patches are in the images. Must be:
6452 `[1, stride_rows, stride_cols, 1]`. `rates`: A list of `ints`
6453 that has length `>= 4`. 1-D of length 4. Must be: `[1, rate_rows, rate_cols,
6454 1]`. This is the input stride, specifying how far two consecutive patch
6455 samples are in the input. Equivalent to extracting patches with
6456 `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`,
6457 followed by subsampling them spatially by a factor of `rates`. This is
6458 equivalent to `rate` in dilated (a.k.a. Atrous) convolutions.
6459 `padding`: A `string` from: "SAME", "VALID". The type of padding algorithm
6460 to use.
6461 We specify the size-related attributes as: ``` ksizes = [1, ksize_rows,
6462 ksize_cols, 1] strides = [1, strides_rows, strides_cols, 1] rates = [1,
6463 rates_rows, rates_cols, 1]
6464 name: A name for the operation (optional). ```
6466 Returns:
6467 A Tensor. Has the same type as images.
6468 """
6469 ksizes = deprecation.deprecated_argument_lookup("sizes", sizes, "ksizes",
6470 ksizes)
6471 return gen_array_ops.extract_image_patches(images, ksizes, strides, rates,
6472 padding, name)
6475extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__
6478@tf_export("fingerprint")
6479@dispatch.add_dispatch_support
6480def fingerprint(data, method="farmhash64", name=None):
6481 r"""Generates fingerprint values.
6483 Generates fingerprint values of `data`.
6485 Fingerprint op considers the first dimension of `data` as the batch dimension,
6486 and `output[i]` contains the fingerprint value generated from contents in
6487 `data[i, ...]` for all `i`.
6489 Fingerprint op writes fingerprint values as byte arrays. For example, the
6490 default method `farmhash64` generates a 64-bit fingerprint value at a time.
6491 This 8-byte value is written out as an `tf.uint8` array of size 8, in
6492 little-endian order.
6494 For example, suppose that `data` has data type `tf.int32` and shape (2, 3, 4),
6495 and that the fingerprint method is `farmhash64`. In this case, the output
6496 shape is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the
6497 size of each fingerprint value in bytes. `output[0, :]` is generated from
6498 12 integers in `data[0, :, :]` and similarly `output[1, :]` is generated from
6499 other 12 integers in `data[1, :, :]`.
6501 Note that this op fingerprints the raw underlying buffer, and it does not
6502 fingerprint Tensor's metadata such as data type and/or shape. For example, the
6503 fingerprint values are invariant under reshapes and bitcasts as long as the
6504 batch dimension remain the same:
6506 ```python
6507 tf.fingerprint(data) == tf.fingerprint(tf.reshape(data, ...))
6508 tf.fingerprint(data) == tf.fingerprint(tf.bitcast(data, ...))
6509 ```
6511 For string data, one should expect `tf.fingerprint(data) !=
6512 tf.fingerprint(tf.string.reduce_join(data))` in general.
6514 Args:
6515 data: A `Tensor`. Must have rank 1 or higher.
6516 method: A `Tensor` of type `tf.string`. Fingerprint method used by this op.
6517 Currently, available method is `farmhash64`.
6518 name: A name for the operation (optional).
6520 Returns:
6521 A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to
6522 `data`'s first dimension, and the second dimension size depends on the
6523 fingerprint algorithm.
6524 """
6525 return gen_array_ops.fingerprint(data, method, name)
6528def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
6529 """Converts the given value to an integer Tensor."""
6530 tensor = ops.convert_to_tensor(
6531 tensor, name=name, preferred_dtype=dtype or dtypes.int32)
6532 if tensor.dtype.is_integer:
6533 if dtype is not None:
6534 tensor = gen_math_ops.cast(tensor, dtype)
6535 else:
6536 raise TypeError(f"Argument `tensor` (name: {name}) must be of type integer."
6537 f" Received `tensor` = {tensor} of dtype: {tensor.dtype}")
6538 return tensor
6541def get_positive_axis(axis, ndims, axis_name="axis", ndims_name="ndims"):
6542 """Validate an `axis` parameter, and normalize it to be positive.
6544 If `ndims` is known (i.e., not `None`), then check that `axis` is in the
6545 range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
6546 `axis + ndims` (otherwise).
6547 If `ndims` is not known, and `axis` is positive, then return it as-is.
6548 If `ndims` is not known, and `axis` is negative, then report an error.
6550 Args:
6551 axis: An integer constant
6552 ndims: An integer constant, or `None`
6553 axis_name: The name of `axis` (for error messages).
6554 ndims_name: The name of `ndims` (for error messages).
6556 Returns:
6557 The normalized `axis` value.
6559 Raises:
6560 ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
6561 `ndims is None`.
6562 """
6563 if not isinstance(axis, int):
6564 raise TypeError(f"{axis_name} must be an int; got {type(axis).__name__}")
6565 if ndims is not None:
6566 if 0 <= axis < ndims:
6567 return axis
6568 elif -ndims <= axis < 0:
6569 return axis + ndims
6570 else:
6571 raise ValueError(f"{axis_name}={axis} out of bounds: "
6572 f"expected {-ndims}<={axis_name}<{ndims}")
6573 elif axis < 0:
6574 raise ValueError(f"{axis_name}={axis} may only be negative "
6575 f"if {ndims_name} is statically known.")
6576 return axis
6579# This op is intended to exactly match the semantics of numpy.repeat, with
6580# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
6581# when axis is not specified. Rather than implement that special behavior, we
6582# simply make `axis` be a required argument.
6583#
6584# External (OSS) `tf.repeat` feature request:
6585# https://github.com/tensorflow/tensorflow/issues/8246
6586def repeat_with_axis(data, repeats, axis, name=None):
6587 """Repeats elements of `data`.
6589 Args:
6590 data: An `N`-dimensional tensor.
6591 repeats: A 1-D integer tensor specifying how many times each element in
6592 `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
6593 Supports broadcasting from a scalar value.
6594 axis: `int`. The axis along which to repeat values. Must be less than
6595 `max(N, 1)`.
6596 name: A name for the operation.
6598 Returns:
6599 A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
6600 except that dimension `axis` has size `sum(repeats)`.
6602 Example usage:
6604 >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6605 <tf.Tensor: shape=(5,), dtype=string,
6606 numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6607 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6608 <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6609 array([[1, 2],
6610 [1, 2],
6611 [3, 4],
6612 [3, 4],
6613 [3, 4]], dtype=int32)>
6614 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6615 <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6616 array([[1, 1, 2, 2, 2],
6617 [3, 3, 4, 4, 4]], dtype=int32)>
6619 """
6620 # Whether the execution uses the optimized non-XLA implementation below.
6621 # TODO(b/236387200): Separate the implementations at a lower level, so that
6622 # non-XLA path gets the performance benefits and the XLA path is not broken
6623 # after loading a saved model with the optimization.
6624 use_optimized_non_xla_implementation = False
6626 if not isinstance(axis, int):
6627 raise TypeError("Argument `axis` must be an int. "
6628 f"Received `axis` = {axis} of type {type(axis).__name__}")
6630 with ops.name_scope(name, "Repeat", [data, repeats]):
6631 data = ops.convert_to_tensor(data, name="data")
6632 # Note: We want to pass dtype=None to convert_to_int_tensor so that the
6633 # existing type is maintained instead of force-casting to int32. However,
6634 # this is not compatible with the implementation used on the XLA path.
6635 if not use_optimized_non_xla_implementation:
6636 repeats = convert_to_int_tensor(repeats, name="repeats")
6637 else:
6638 repeats = convert_to_int_tensor(repeats, name="repeats", dtype=None)
6640 repeats.shape.with_rank_at_most(1)
6642 # If `data` is a scalar, then upgrade it to a vector.
6643 data = _with_nonzero_rank(data)
6644 data_shape = shape(data, out_type=repeats.dtype)
6646 # If `axis` is negative, then convert it to a positive value.
6647 axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
6649 # If we know that `repeats` is a scalar, then we can just tile & reshape.
6650 if repeats.shape.num_elements() == 1:
6651 repeats = reshape(repeats, [])
6652 expanded = expand_dims(data, axis + 1)
6653 tiled = tile_one_dimension(expanded, axis + 1, repeats)
6654 result_shape = concat([
6655 data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
6656 ],
6657 axis=0)
6658 return reshape(tiled, result_shape)
6660 # Check data Tensor shapes.
6661 if repeats.shape.ndims == 1:
6662 data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
6664 repeats = broadcast_to(repeats, [data_shape[axis]])
6666 # The implementation on the else branch has better performance. However, it
6667 # does not work on the XLA path since it relies on the range op with a
6668 # shape that is not a compile-time constant.
6669 if not use_optimized_non_xla_implementation:
6670 repeats_original = repeats
6672 # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
6673 if repeats.shape.ndims != axis + 1:
6674 repeats_shape = shape(repeats)
6675 repeats_ndims = rank(repeats)
6676 broadcast_shape = concat(
6677 [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
6678 repeats = broadcast_to(repeats, broadcast_shape)
6679 repeats.set_shape([None] * (axis + 1))
6681 # Create a "sequence mask" based on `repeats`, where slices across `axis`
6682 # contain one `True` value for each repetition. E.g., if
6683 # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
6684 max_repeat = gen_math_ops._max(repeats, _all_dimensions(repeats))
6685 max_repeat = gen_math_ops.maximum(
6686 ops.convert_to_tensor(0, name="zero", dtype=max_repeat.dtype),
6687 max_repeat)
6689 mask = sequence_mask(repeats, max_repeat)
6691 # Add a new dimension around each value that needs to be repeated, and
6692 # then tile that new dimension to match the maximum number of repetitions.
6693 expanded = expand_dims(data, axis + 1)
6694 tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
6696 # Use `boolean_mask` to discard the extra repeated values. This also
6697 # flattens all dimensions up through `axis`.
6698 masked = boolean_mask(tiled, mask)
6700 # Reshape the output tensor to add the outer dimensions back.
6701 if axis == 0:
6702 result = masked
6703 else:
6704 repeated_dim_size = gen_math_ops._sum(
6705 repeats_original,
6706 axis=gen_math_ops._range(0, rank(repeats_original), 1))
6707 result_shape = concat(
6708 [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
6709 axis=0)
6710 result = reshape(masked, result_shape)
6712 # Preserve shape information.
6713 if data.shape.ndims is not None:
6714 new_axis_size = 0 if repeats.shape[0] == 0 else None
6715 result.set_shape(data.shape[:axis].concatenate(
6716 [new_axis_size]).concatenate(data.shape[axis + 1:]))
6718 return result
6720 else:
6721 # Non-XLA path implementation
6722 # E.g., repeats = [3, 4, 0, 2, 1].
6723 # E.g., repeats_scan = [3, 7, 7, 9, 10].
6724 repeats_scan = gen_math_ops.cumsum(repeats)
6725 # This concat just prepends 0 to handle the case when repeats are empty.
6726 # E.g., output_size = [0, 3, 7, 7, 9, 10][-1] = 10.
6727 output_size = concat([zeros(1, dtype=repeats_scan.dtype), repeats_scan],
6728 axis=0)[-1]
6729 # E.g., output_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9].
6730 output_indices = gen_math_ops.range(output_size, dtype=repeats.dtype)
6731 # E.g., gather_indices = [0, 0, 0, 1, 1, 1, 1, 3, 3, 4].
6732 gather_indices = searchsorted(
6733 repeats_scan, output_indices, side="right", out_type=repeats.dtype)
6734 return gather(data, gather_indices, axis=axis)
6737def tile_one_dimension(data, axis, multiple):
6738 """Tiles a single dimension of a tensor."""
6739 # Assumes axis is a nonnegative int.
6740 if data.shape.ndims is not None:
6741 multiples = [1] * data.shape.ndims
6742 multiples[axis] = multiple
6743 else:
6744 ones_value = ones(rank(data), dtypes.int32)
6745 multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]],
6746 axis=0)
6747 return tile(data, multiples)
6750def _with_nonzero_rank(data):
6751 """If `data` is scalar, then add a dimension; otherwise return as-is."""
6752 if data.shape.ndims is not None:
6753 if data.shape.ndims == 0:
6754 return array_ops_stack.stack([data])
6755 else:
6756 return data
6757 else:
6758 data_shape = shape(data)
6759 data_ndims = rank(data)
6760 return reshape(data, concat([[1], data_shape], axis=0)[-data_ndims:])
6763@tf_export("repeat")
6764@dispatch.add_dispatch_support
6765def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
6766 """Repeat elements of `input`.
6768 See also `tf.concat`, `tf.stack`, `tf.tile`.
6770 Args:
6771 input: An `N`-dimensional Tensor.
6772 repeats: An 1-D `int` Tensor. The number of repetitions for each element.
6773 repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
6774 must equal `input.shape[axis]` if axis is not None.
6775 axis: An int. The axis along which to repeat values. By default, (axis=None),
6776 use the flattened input array, and return a flat output array.
6777 name: A name for the operation.
6779 Returns:
6780 A Tensor which has the same shape as `input`, except along the given axis.
6781 If axis is None then the output array is flattened to match the flattened
6782 input array.
6784 Example usage:
6786 >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
6787 <tf.Tensor: shape=(5,), dtype=string,
6788 numpy=array([b'a', b'a', b'a', b'c', b'c'], dtype=object)>
6790 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
6791 <tf.Tensor: shape=(5, 2), dtype=int32, numpy=
6792 array([[1, 2],
6793 [1, 2],
6794 [3, 4],
6795 [3, 4],
6796 [3, 4]], dtype=int32)>
6798 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
6799 <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
6800 array([[1, 1, 2, 2, 2],
6801 [3, 3, 4, 4, 4]], dtype=int32)>
6803 >>> repeat(3, repeats=4)
6804 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([3, 3, 3, 3], dtype=int32)>
6806 >>> repeat([[1,2], [3,4]], repeats=2)
6807 <tf.Tensor: shape=(8,), dtype=int32,
6808 numpy=array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)>
6810 """
6811 if axis is None:
6812 input = reshape(input, [-1])
6813 axis = 0
6814 return repeat_with_axis(input, repeats, axis, name)
6817@tf_export("guarantee_const")
6818@deprecation.deprecated(None, "Not for public use.")
6819def guarantee_const(input, name=None): # pylint: disable=redefined-builtin
6820 """Promise to the TF runtime that the input tensor is a constant.
6822 The runtime is then free to make optimizations based on this.
6824 Returns the input tensor without modification.
6826 Args:
6827 input: A `Tensor`.
6828 name: A name for this operation.
6830 Returns:
6831 A `Tensor`. Has the same dtype as `input`.
6832 """
6833 return gen_array_ops.guarantee_const(input=input, name=name)
6836@tf_export("stop_gradient")
6837@dispatch.add_dispatch_support
6838def stop_gradient(input, name=None): # pylint: disable=redefined-builtin
6839 """Stops gradient computation.
6841 NOTE: This docstring is patched out below. See
6842 tensorflow/core/api_def/base_api/api_def_StopGradient.pbtxt for the full
6843 docstring. That file determines the public documentation page.
6845 Args:
6846 input: A `Tensor`.
6847 name: A name for this operation.
6849 Returns:
6850 A `Tensor`. Has the same dtype as `input`.
6851 """
6852 # Don't expand ResourceVariables, so stop_gradient(variable) will return a
6853 # Tensor.
6854 if (isinstance(input, composite_tensor.CompositeTensor) and
6855 not _pywrap_utils.IsResourceVariable(input)):
6856 return nest.map_structure(stop_gradient, input, expand_composites=True)
6857 # The StopGradient op has a gradient function registered which returns None
6858 # (meaning statically known to be zero). For correctness, that's all we
6859 # need. However, tf.GradientTape often makes decisions about what to keep in
6860 # memory based on which forward-pass tensors are currently being watched, and
6861 # returning None in a gradient is not sufficient to stop watching a tensor
6862 # since the backward function doesn't run in the forward pass. Pausing the
6863 # tape around this op instructs any tf.GradientTapes to ignore the
6864 # forward-pass output of StopGradient, which may be much more efficient.
6865 with record.stop_recording():
6866 return gen_array_ops.stop_gradient(input, name=name)
6869stop_gradient.__doc__ = gen_array_ops.stop_gradient.__doc__
6872# Register elementwise ops that don't have Python wrappers.
6873dispatch.register_unary_elementwise_api(gen_array_ops.check_numerics)