Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_array_ops.py: 20%
1045 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common array methods."""
16# pylint: disable=g-direct-tensorflow-import
18import enum
19import functools
20import math
21import numbers
22import numpy as np
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import array_ops_stack
30from tensorflow.python.ops import clip_ops
31from tensorflow.python.ops import control_flow_assert
32from tensorflow.python.ops import linalg_ops
33from tensorflow.python.ops import manip_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import sort_ops
36from tensorflow.python.ops.numpy_ops import np_arrays
37from tensorflow.python.ops.numpy_ops import np_dtypes
38from tensorflow.python.ops.numpy_ops import np_export
39from tensorflow.python.ops.numpy_ops import np_utils
40from tensorflow.python.util import nest
43newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis)
46@np_utils.np_doc('empty')
47def empty(shape, dtype=float): # pylint: disable=redefined-outer-name
48 return zeros(shape, dtype)
51@np_utils.np_doc('empty_like')
52def empty_like(a, dtype=None):
53 return zeros_like(a, dtype)
56@np_utils.np_doc('zeros')
57def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name
58 dtype = (
59 np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
60 return array_ops.zeros(shape, dtype=dtype)
63@np_utils.np_doc('zeros_like')
64def zeros_like(a, dtype=None): # pylint: disable=missing-docstring
65 dtype = np_utils.result_type_unary(a, dtype)
67 dtype = dtypes.as_dtype(dtype) # Work around b/149877262
68 return array_ops.zeros_like(a, dtype)
71@np_utils.np_doc('ones')
72def ones(shape, dtype=float): # pylint: disable=redefined-outer-name
73 if dtype:
74 dtype = np_utils.result_type(dtype)
75 return array_ops.ones(shape, dtype=dtype)
78@np_utils.np_doc('ones_like')
79def ones_like(a, dtype=None):
80 dtype = np_utils.result_type_unary(a, dtype)
81 return array_ops.ones_like(a, dtype)
84@np_utils.np_doc('eye')
85def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring
86 if dtype:
87 dtype = np_utils.result_type(dtype)
88 if not M:
89 M = N
90 # Making sure N, M and k are `int`
91 N = int(N)
92 M = int(M)
93 k = int(k)
94 if k >= M or -k >= N:
95 # tf.linalg.diag will raise an error in this case
96 return zeros([N, M], dtype=dtype)
97 if k == 0:
98 return linalg_ops.eye(N, M, dtype=dtype)
99 # We need the precise length, otherwise tf.linalg.diag will raise an error
100 diag_len = min(N, M)
101 if k > 0:
102 if N >= M:
103 diag_len -= k
104 elif N + k > M:
105 diag_len = M - k
106 elif k <= 0:
107 if M >= N:
108 diag_len += k
109 elif M - k > N:
110 diag_len = N + k
111 diagonal_ = array_ops.ones([diag_len], dtype=dtype)
112 return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
115@np_utils.np_doc('identity')
116def identity(n, dtype=float):
117 return eye(N=n, M=n, dtype=dtype)
120@np_utils.np_doc('full')
121def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
122 if not isinstance(shape, np_arrays.ndarray):
123 shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
124 shape = atleast_1d(shape)
125 fill_value = asarray(fill_value, dtype=dtype)
126 return array_ops.broadcast_to(fill_value, shape)
129# Using doc only here since np full_like signature doesn't seem to have the
130# shape argument (even though it exists in the documentation online).
131@np_utils.np_doc_only('full_like')
132def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name
133 """order, subok and shape arguments mustn't be changed."""
134 if order != 'K':
135 raise ValueError('Non-standard orders are not supported.')
136 if not subok:
137 raise ValueError('subok being False is not supported.')
138 if shape:
139 raise ValueError('Overriding the shape is not supported.')
141 a = asarray(a)
142 dtype = dtype or np_utils.result_type(a)
143 fill_value = asarray(fill_value, dtype=dtype)
144 return array_ops.broadcast_to(fill_value, array_ops.shape(a))
147def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name
148 """Main implementation of np.array()."""
149 result_t = val
151 if not isinstance(result_t, ops.Tensor):
152 dtype = np_utils.result_type_unary(result_t, dtype)
153 # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
154 # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
155 # while np.array allows them. We need to convert-then-cast.
157 # EagerTensor conversion complains about "mixed types" when converting
158 # tensors with no dtype information. This is because it infers types based
159 # on one selected item in the list. So e.g. when converting [2., 2j]
160 # to a tensor, it will select float32 as the inferred type and not be able
161 # to convert the list to a float 32 tensor.
162 # Since we have some information about the final dtype we care about, we
163 # supply that information so that convert_to_tensor will do best-effort
164 # conversion to that dtype first.
165 result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype)
166 result_t = math_ops.cast(result_t, dtype=dtype)
167 elif dtype:
168 result_t = math_ops.cast(result_t, dtype)
170 if copy:
171 result_t = array_ops.identity(result_t)
173 max_ndmin = 32
174 if ndmin > max_ndmin:
175 raise ValueError('ndmin bigger than allowable number of dimensions: '
176 f'{max_ndmin}.')
178 if ndmin == 0:
179 return result_t
181 ndims = array_ops.rank(result_t)
183 def true_fn():
184 old_shape = array_ops.shape(result_t)
185 new_shape = array_ops.concat(
186 [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0)
187 return array_ops.reshape(result_t, new_shape)
189 result_t = np_utils.cond(
190 np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
191 return result_t
194# TODO(wangpeng): investigate whether we can make `copy` default to False.
195# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
196@np_utils.np_doc_only('array')
197def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name
198 """Since Tensors are immutable, a copy is made only if val is placed on a
200 different device than the current one. Even if `copy` is False, a new Tensor
201 may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
202 is an ndarray or a Tensor.
203 """ # pylint:disable=g-docstring-missing-newline
204 if dtype:
205 dtype = np_utils.result_type(dtype)
206 return _array_internal(val, dtype, copy, ndmin)
209# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
212@np_utils.np_doc('asarray')
213def asarray(a, dtype=None):
214 if dtype:
215 dtype = np_utils.result_type(dtype)
216 if isinstance(a, np_arrays.ndarray) and (
217 not dtype or dtype == a.dtype.as_numpy_dtype):
218 return a
219 return array(a, dtype, copy=False)
222@np_utils.np_doc('asanyarray')
223def asanyarray(a, dtype=None):
224 return asarray(a, dtype)
227@np_utils.np_doc('ascontiguousarray')
228def ascontiguousarray(a, dtype=None):
229 return array(a, dtype, ndmin=1)
232# Numerical ranges.
233@np_utils.np_doc('arange')
234def arange(start, stop=None, step=1, dtype=None):
235 """Returns `step`-separated values in the range [start, stop).
237 Args:
238 start: Start of the interval. Included in the range.
239 stop: End of the interval. If not specified, `start` is treated as 0 and
240 `start` value is used as `stop`. If specified, it is not included in the
241 range if `step` is integer. When `step` is floating point, it may or may
242 not be included.
243 step: The difference between 2 consecutive values in the output range. It is
244 recommended to use `linspace` instead of using non-integer values for
245 `step`.
246 dtype: Optional. Type of the resulting ndarray. Could be a python type, a
247 NumPy type or a TensorFlow `DType`. If not provided, the largest type of
248 `start`, `stop`, `step` is used.
250 Raises:
251 ValueError: If step is zero.
252 """
253 if not step:
254 raise ValueError('step must be non-zero.')
255 if dtype:
256 dtype = np_utils.result_type(dtype)
257 else:
258 if stop is None:
259 dtype = np_utils.result_type(start, step)
260 else:
261 dtype = np_utils.result_type(start, step, stop)
262 if step > 0 and ((stop is not None and start > stop) or
263 (stop is None and start < 0)):
264 return array([], dtype=dtype)
265 if step < 0 and ((stop is not None and start < stop) or
266 (stop is None and start > 0)):
267 return array([], dtype=dtype)
268 # TODO(srbs): There are some bugs when start or stop is float type and dtype
269 # is integer type.
270 return math_ops.cast(
271 math_ops.range(start, limit=stop, delta=step), dtype=dtype)
274# Building matrices.
275@np_utils.np_doc('diag')
276def diag(v, k=0): # pylint: disable=missing-docstring
277 """Raises an error if input is not 1- or 2-d."""
278 v = asarray(v)
279 v_rank = array_ops.rank(v)
281 v.shape.with_rank_at_most(2)
283 # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
284 # tracing time if the shape is known.
285 control_flow_assert.Assert(
286 np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)),
287 [v_rank])
289 def _diag(v, k):
290 return np_utils.cond(
291 math_ops.equal(array_ops.size(v), 0),
292 lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
293 lambda: array_ops.matrix_diag(v, k=k))
295 def _diag_part(v, k):
296 v_shape = array_ops.shape(v)
297 v, k = np_utils.cond(
298 np_utils.logical_or(
299 np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
300 np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
301 ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
302 result = array_ops.matrix_diag_part(v, k=k)
303 return result
305 result = np_utils.cond(
306 math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
307 return result
310@np_utils.np_doc('diagonal')
311def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring
312 a = asarray(a)
314 maybe_rank = a.shape.rank
315 if maybe_rank is not None and offset == 0 and (
316 axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
317 axis2 == -1):
318 return array_ops.matrix_diag_part(a)
320 a = moveaxis(a, (axis1, axis2), (-2, -1))
322 a_shape = array_ops.shape(a)
324 def _zeros(): # pylint: disable=missing-docstring
325 return (array_ops.zeros(
326 array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)
328 # All zeros since diag_part doesn't handle all possible k (aka offset).
329 # Written this way since cond will run shape inference on both branches,
330 # and diag_part shape inference will fail when offset is out of bounds.
331 a, offset = np_utils.cond(
332 np_utils.logical_or(
333 np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
334 np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
335 ), _zeros, lambda: (a, offset))
337 a = array_ops.matrix_diag_part(a, k=offset)
338 return a
341@np_utils.np_doc('diagflat')
342def diagflat(v, k=0):
343 v = asarray(v)
344 return diag(array_ops.reshape(v, [-1]), k)
347def _promote_dtype(*arrays):
348 dtype = np_utils.result_type(*arrays)
349 def _fast_asarray(a):
350 if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
351 return a
352 return _array_internal(a, dtype=dtype, copy=False)
353 return [_fast_asarray(a) for a in arrays]
356def _promote_dtype_binary(t1, t2):
357 dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access
358 if not(
359 isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
360 t1 = _array_internal(t1, dtype=dtype, copy=False)
361 if not(
362 isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
363 t2 = _array_internal(t2, dtype=dtype, copy=False)
364 return t1, t2
367@np_utils.np_doc('all')
368def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin
369 a = asarray(a, dtype=bool)
370 return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
373@np_utils.np_doc('any')
374def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin
375 a = asarray(a, dtype=bool)
376 return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
379@np_utils.np_doc('compress')
380def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,missing-function-docstring
381 condition = asarray(condition, dtype=bool)
382 a = asarray(a)
384 if condition.ndim != 1:
385 raise ValueError('condition must be a 1-d array.')
386 # `np.compress` treats scalars as 1-d arrays.
387 if a.ndim == 0:
388 a = ravel(a)
390 if axis is None:
391 a = ravel(a)
392 axis = 0
394 if axis < 0:
395 axis += a.ndim
397 assert axis >= 0 and axis < a.ndim
399 # `tf.boolean_mask` requires the first dimensions of array and condition to
400 # match. `np.compress` pads condition with False when it is shorter.
401 condition_t = condition
402 a_t = a
403 if condition.shape[0] < a.shape[axis]:
404 padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
405 condition_t = array_ops.concat([condition_t, padding], axis=0)
406 return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
409@np_utils.np_doc('copy')
410def copy(a):
411 return array(a, copy=True)
414def _maybe_promote_to_int(a):
415 if dtypes.as_dtype(a.dtype).is_integer:
416 # If a is an integer type and its precision is less than that of `int`,
417 # the output type will be `int`.
418 a_numpy_dtype = a.dtype.as_numpy_dtype
419 output_type = np.promote_types(a_numpy_dtype, int)
420 if output_type != a_numpy_dtype:
421 a = asarray(a, dtype=output_type)
423 return a
426@np_utils.np_doc('cumprod')
427def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring
428 a = asarray(a, dtype=dtype)
430 if dtype is None:
431 a = _maybe_promote_to_int(a)
433 # If axis is None, the input is flattened.
434 if axis is None:
435 a = ravel(a)
436 axis = 0
437 elif axis < 0:
438 axis += array_ops.rank(a)
439 return math_ops.cumprod(a, axis)
442@np_utils.np_doc('cumsum')
443def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring
444 a = asarray(a, dtype=dtype)
446 if dtype is None:
447 a = _maybe_promote_to_int(a)
449 # If axis is None, the input is flattened.
450 if axis is None:
451 a = ravel(a)
452 axis = 0
453 elif axis < 0:
454 axis += array_ops.rank(a)
455 return math_ops.cumsum(a, axis)
458@np_utils.np_doc('imag')
459def imag(val):
460 val = asarray(val)
461 # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
462 # return an ndarray.
463 return math_ops.imag(val)
466_TO_INT_ = 0
467_TO_FLOAT = 1
470def _reduce(tf_fn,
471 a,
472 axis=None,
473 dtype=None,
474 keepdims=None,
475 promote_int=_TO_INT_,
476 tf_bool_fn=None,
477 preserve_bool=False):
478 """A general reduction function.
480 Args:
481 tf_fn: the TF reduction function.
482 a: the array to be reduced.
483 axis: (optional) the axis along which to do the reduction. If None, all
484 dimensions are reduced.
485 dtype: (optional) the dtype of the result.
486 keepdims: (optional) whether to keep the reduced dimension(s).
487 promote_int: how to promote integer and bool inputs. There are three
488 choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2)
489 `_TO_FLOAT` always promotes them to a float type (determined by
490 dtypes.default_float_type); (3) None: don't promote.
491 tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
492 only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
493 is `np.bool_` and `preserve_bool` is True.
494 preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
495 is `np.bool_` (some reductions such as np.sum convert bools to integers,
496 while others such as np.max preserve bools.
498 Returns:
499 An ndarray.
500 """
501 if dtype:
502 dtype = np_utils.result_type(dtype)
503 if keepdims is None:
504 keepdims = False
505 a = asarray(a, dtype=dtype)
506 if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
507 tf_bool_fn is not None):
508 return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
509 if dtype is None:
510 dtype = a.dtype.as_numpy_dtype
511 if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
512 if promote_int == _TO_INT_:
513 # If a is an integer/bool type and whose bit width is less than np.int_,
514 # numpy up-casts it to np.int_ based on the documentation at
515 # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html
516 if dtype == np.bool_:
517 is_signed = True
518 width = 8 # We can use any number here that is less than 64
519 else:
520 is_signed = np.issubdtype(dtype, np.signedinteger)
521 width = np.iinfo(dtype).bits
522 # Numpy int_ and uint are defined as 'long' and 'unsigned long', so
523 # should have the same bit width.
524 if width < np.iinfo(np.int_).bits:
525 if is_signed:
526 dtype = np.int_
527 else:
528 dtype = np.uint
529 a = math_ops.cast(a, dtype)
530 elif promote_int == _TO_FLOAT:
531 a = math_ops.cast(a, np_dtypes.default_float_type())
533 if isinstance(axis, ops.Tensor) and axis.dtype not in (
534 dtypes.int32, dtypes.int64):
535 axis = math_ops.cast(axis, dtypes.int64)
537 return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
540# TODO (DarrenZhang01): Add `axis` support to the `size` API.
541@np_utils.np_doc('size')
542def size(x, axis=None): # pylint: disable=missing-docstring
543 if axis is not None:
544 raise NotImplementedError('axis argument is not supported in the current '
545 '`np.size` implementation')
546 if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
547 return 1
548 x = asarray(x)
549 if x.shape.is_fully_defined():
550 return np.prod(x.shape.as_list(), dtype=int)
551 else:
552 return array_ops.size_v2(x)
555@np_utils.np_doc('sum')
556def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin
557 return _reduce(
558 math_ops.reduce_sum,
559 a,
560 axis=axis,
561 dtype=dtype,
562 keepdims=keepdims,
563 tf_bool_fn=math_ops.reduce_any)
566@np_utils.np_doc('prod')
567def prod(a, axis=None, dtype=None, keepdims=None):
568 return _reduce(
569 math_ops.reduce_prod,
570 a,
571 axis=axis,
572 dtype=dtype,
573 keepdims=keepdims,
574 tf_bool_fn=math_ops.reduce_all)
577@np_utils.np_doc('mean', unsupported_params=['out'])
578def mean(a, axis=None, dtype=None, out=None, keepdims=None):
579 if out is not None:
580 raise ValueError('Setting out is not supported.')
581 return _reduce(
582 math_ops.reduce_mean,
583 a,
584 axis=axis,
585 dtype=dtype,
586 keepdims=keepdims,
587 promote_int=_TO_FLOAT)
590@np_utils.np_doc('amax', unsupported_params=['out'])
591def amax(a, axis=None, out=None, keepdims=None):
592 if out is not None:
593 raise ValueError('Setting out is not supported.')
594 return _reduce(
595 math_ops.reduce_max,
596 a,
597 axis=axis,
598 dtype=None,
599 keepdims=keepdims,
600 promote_int=None,
601 tf_bool_fn=math_ops.reduce_any,
602 preserve_bool=True)
605@np_utils.np_doc('amin', unsupported_params=['out'])
606def amin(a, axis=None, out=None, keepdims=None):
607 if out is not None:
608 raise ValueError('Setting out is not supported.')
609 return _reduce(
610 math_ops.reduce_min,
611 a,
612 axis=axis,
613 dtype=None,
614 keepdims=keepdims,
615 promote_int=None,
616 tf_bool_fn=math_ops.reduce_all,
617 preserve_bool=True)
620@np_utils.np_doc('var')
621def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring
622 if dtype:
623 working_dtype = np_utils.result_type(a, dtype)
624 else:
625 working_dtype = None
626 if out is not None:
627 raise ValueError('Setting out is not supported.')
628 if ddof != 0:
629 # TF reduce_variance doesn't support ddof, so calculate it using raw ops.
630 def reduce_fn(input_tensor, axis, keepdims):
631 means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True)
632 centered = input_tensor - means
633 if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
634 centered = math_ops.cast(
635 math_ops.real(centered * math_ops.conj(centered)),
636 input_tensor.dtype)
637 else:
638 centered = math_ops.square(centered)
639 squared_deviations = math_ops.reduce_sum(
640 centered, axis=axis, keepdims=keepdims)
642 if axis is None:
643 n = array_ops.size(input_tensor)
644 else:
645 if axis < 0:
646 axis += array_ops.rank(input_tensor)
647 n = math_ops.reduce_prod(
648 array_ops.gather(array_ops.shape(input_tensor), axis))
649 n = math_ops.cast(n - ddof, input_tensor.dtype)
651 return math_ops.cast(math_ops.divide(squared_deviations, n), dtype)
652 else:
653 reduce_fn = math_ops.reduce_variance
655 result = _reduce(
656 reduce_fn,
657 a,
658 axis=axis,
659 dtype=working_dtype,
660 keepdims=keepdims,
661 promote_int=_TO_FLOAT)
662 if dtype:
663 result = math_ops.cast(result, dtype)
664 return result
667@np_utils.np_doc('std')
668def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
669 return _reduce(
670 math_ops.reduce_std,
671 a,
672 axis=axis,
673 dtype=None,
674 keepdims=keepdims,
675 promote_int=_TO_FLOAT)
678@np_utils.np_doc('ravel')
679def ravel(a): # pylint: disable=missing-docstring
680 a = asarray(a)
681 return array_ops.reshape(a, [-1])
684@np_utils.np_doc('real')
685def real(val):
686 val = asarray(val)
687 # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
688 # return an ndarray.
689 return math_ops.real(val)
692@np_utils.np_doc('repeat')
693def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
694 a = asarray(a)
695 original_shape = a._shape_as_list() # pylint: disable=protected-access
696 # Best effort recovery of the shape.
697 known_shape = original_shape is not None and None not in original_shape
698 if known_shape:
699 if not original_shape:
700 original_shape = (repeats,)
701 else:
702 repeats_np = np.ravel(np.array(repeats))
703 if repeats_np.size == 1:
704 repeats_np = repeats_np.item()
705 if axis is None:
706 original_shape = (repeats_np * np.prod(original_shape),)
707 else:
708 original_shape[axis] = repeats_np * original_shape[axis]
709 else:
710 if axis is None:
711 original_shape = (repeats_np.sum(),)
712 else:
713 original_shape[axis] = repeats_np.sum()
715 repeats = asarray(repeats)
716 result = array_ops.repeat(a, repeats, axis)
717 if known_shape:
718 result.set_shape(original_shape)
720 return result
723@np_utils.np_doc('around')
724def around(a, decimals=0): # pylint: disable=missing-docstring
725 a = asarray(a)
726 dtype = a.dtype.as_numpy_dtype
727 factor = math.pow(10, decimals)
728 if np.issubdtype(dtype, np.inexact):
729 factor = math_ops.cast(factor, dtype)
730 else:
731 # Use float as the working dtype when a.dtype is exact (e.g. integer),
732 # because `decimals` can be negative.
733 float_dtype = np_dtypes.default_float_type()
734 a = a.astype(float_dtype)
735 factor = math_ops.cast(factor, float_dtype)
736 a = math_ops.multiply(a, factor)
737 a = math_ops.round(a)
738 a = math_ops.divide(a, factor)
739 return a.astype(dtype)
742setattr(np_arrays.ndarray, '__round__', around)
745@np_utils.np_doc('reshape')
746def reshape(a, newshape, order='C'):
747 """order argument can only b 'C' or 'F'."""
748 if order not in {'C', 'F'}:
749 raise ValueError('Unsupported order argument {}'.format(order))
751 a = asarray(a)
752 if isinstance(newshape, int):
753 newshape = [newshape]
755 if order == 'F':
756 r = array_ops.transpose(
757 array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
758 else:
759 r = array_ops.reshape(a, newshape)
761 return r
764def _reshape_method_wrapper(a, *newshape, **kwargs):
765 order = kwargs.pop('order', 'C')
766 if kwargs:
767 raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
769 if len(newshape) == 1 and not isinstance(newshape[0], int):
770 newshape = newshape[0]
772 return reshape(a, newshape, order=order)
775@np_utils.np_doc('expand_dims')
776def expand_dims(a, axis):
777 a = asarray(a)
778 return array_ops.expand_dims(a, axis=axis)
781@np_utils.np_doc('squeeze')
782def squeeze(a, axis=None):
783 a = asarray(a)
784 return array_ops.squeeze(a, axis)
787@np_utils.np_doc('flatten', link=np_utils.NoLink())
788def flatten(a, order='C'):
789 a = asarray(a)
790 if order == 'C' or order == 'A' or order == 'K':
791 # Row major.
792 return array_ops.reshape(a, [-1])
793 elif order == 'F':
794 # Column major
795 return array_ops.reshape(array_ops.transpose(a), [-1])
796 else:
797 raise ValueError('order can only be C, A, K (all row major) or F '
798 '(column major).')
801@np_utils.np_doc('transpose')
802def transpose(a, axes=None):
803 a = asarray(a)
804 if axes is not None:
805 axes = asarray(axes)
806 return array_ops.transpose(a=a, perm=axes)
809@np_utils.np_doc('swapaxes')
810def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
811 a = asarray(a)
812 def adjust_axes(axes, rank):
813 def f(x):
814 if isinstance(x, int):
815 if x < 0:
816 x = x + rank
817 else:
818 x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
819 return x
820 return nest.map_structure(f, axes)
822 if (a.shape.rank is not None and
823 isinstance(axis1, int) and isinstance(axis2, int)):
824 # This branch makes sure `perm` is statically known, to avoid a
825 # not-compile-time-constant XLA error.
826 a_rank = a.shape.rank
827 axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
828 perm = list(range(a_rank))
829 perm[axis1] = axis2
830 perm[axis2] = axis1
831 else:
832 a_rank = array_ops.rank(a)
833 axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
834 perm = math_ops.range(a_rank)
835 perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
836 [axis2, axis1])
837 a = array_ops.transpose(a, perm)
838 return a
841@np_utils.np_doc('moveaxis')
842def moveaxis(a, source, destination): # pylint: disable=missing-docstring
843 """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
844 if not source and not destination:
845 return a
847 a = asarray(a)
849 if isinstance(source, int):
850 source = (source,)
851 if isinstance(destination, int):
852 destination = (destination,)
853 if len(source) != len(destination):
854 raise ValueError('The lengths of source and destination must equal')
856 a_rank = np_utils._maybe_static(array_ops.rank(a)) # pylint: disable=protected-access
858 def _correct_axis(axis, rank):
859 if axis < 0:
860 return axis + rank
861 return axis
863 source = tuple(_correct_axis(axis, a_rank) for axis in source)
864 destination = tuple(_correct_axis(axis, a_rank) for axis in destination)
866 if a.shape.rank is not None:
867 perm = [i for i in range(a_rank) if i not in source]
868 for dest, src in sorted(zip(destination, source)):
869 assert dest <= len(perm)
870 perm.insert(dest, src)
871 else:
872 r = math_ops.range(a_rank)
874 def _remove_indices(a, b):
875 """Remove indices (`b`) from `a`."""
876 items = array_ops_stack.unstack(
877 sort_ops.sort(array_ops_stack.stack(b)), num=len(b))
879 i = 0
880 result = []
882 for item in items:
883 result.append(a[i:item])
884 i = item + 1
886 result.append(a[i:])
888 return array_ops.concat(result, 0)
890 minus_sources = _remove_indices(r, source)
891 minus_dest = _remove_indices(r, destination)
893 perm = array_ops.scatter_nd(
894 array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank])
895 perm = array_ops.tensor_scatter_update(
896 perm, array_ops.expand_dims(destination, 1), source)
897 a = array_ops.transpose(a, perm)
899 return a
902@np_utils.np_doc('pad')
903def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name
904 """Only supports modes 'constant', 'reflect' and 'symmetric' currently."""
905 constant_values = kwargs.get('constant_values', 0)
906 if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
907 raise ValueError('Unsupported padding mode: ' + mode)
908 mode = mode.upper()
909 array = asarray(array)
910 pad_width = asarray(pad_width, dtype=dtypes.int32)
911 return array_ops.pad(
912 tensor=array,
913 paddings=pad_width,
914 mode=mode,
915 constant_values=constant_values)
918@np_utils.np_doc('take')
919def take(a, indices, axis=None, out=None, mode='clip'):
920 """out argument is not supported, and default mode is clip."""
921 if out is not None:
922 raise ValueError('out argument is not supported in take.')
924 if mode not in {'raise', 'clip', 'wrap'}:
925 raise ValueError("Invalid mode '{}' for take".format(mode))
927 a = asarray(a)
928 indices = asarray(indices)
930 if axis is None:
931 a = array_ops.reshape(a, [-1])
932 axis = 0
934 axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
935 if mode == 'clip':
936 indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
937 elif mode == 'wrap':
938 indices = math_ops.floormod(indices, axis_size)
939 else:
940 raise ValueError("The 'raise' mode to take is not supported.")
942 return array_ops.gather(a, indices, axis=axis)
945@np_utils.np_doc_only('where')
946def where(condition, x=None, y=None):
947 """Raises ValueError if exactly one of x or y is not None."""
948 condition = asarray(condition, dtype=np.bool_)
949 if x is None and y is None:
950 return nonzero(condition)
951 elif x is not None and y is not None:
952 x, y = _promote_dtype(x, y)
953 return array_ops.where_v2(condition, x, y)
954 raise ValueError('Both x and y must be ndarrays, or both must be None.')
957@np_utils.np_doc('select')
958def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring
959 if len(condlist) != len(choicelist):
960 msg = 'condlist must have length equal to choicelist ({} vs {})'
961 raise ValueError(msg.format(len(condlist), len(choicelist)))
962 if not condlist:
963 raise ValueError('condlist must be non-empty')
964 choices = _promote_dtype(default, *choicelist)
965 choicelist = choices[1:]
966 output = choices[0]
967 # The traversal is in reverse order so we can return the first value in
968 # choicelist where condlist is True.
969 for cond, choice in zip(condlist[::-1], choicelist[::-1]):
970 output = where(cond, choice, output)
971 return output
974@np_utils.np_doc('shape', link=np_utils.Link(
975 'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html'))
976def shape(a):
977 a = asarray(a)
978 return a.shape
981@np_utils.np_doc('ndim', link=np_utils.NoLink())
982def ndim(a):
983 a = asarray(a)
984 return a.ndim
987@np_utils.np_doc('isscalar')
988def isscalar(num):
989 return ndim(num) == 0
992def _boundaries_to_sizes(a, boundaries, axis):
993 """Converting boundaries of splits to sizes of splits.
995 Args:
996 a: the array to be split.
997 boundaries: the boundaries, as in np.split.
998 axis: the axis along which to split.
1000 Returns:
1001 A list of sizes of the splits, as in tf.split.
1002 """
1003 if axis >= len(a.shape):
1004 raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
1005 total_size = a.shape[axis]
1006 sizes = []
1007 sizes_sum = 0
1008 prev = 0
1009 for i, b in enumerate(boundaries):
1010 size = b - prev
1011 if size < 0:
1012 raise ValueError('The %s-th boundary %s is smaller than the previous '
1013 'boundary %s' % (i, b, prev))
1014 size = min(size, max(0, total_size - sizes_sum))
1015 sizes.append(size)
1016 sizes_sum += size
1017 prev = b
1018 sizes.append(max(0, total_size - sizes_sum))
1019 return sizes
1022@np_utils.np_doc('split')
1023def split(ary, indices_or_sections, axis=0):
1024 ary = asarray(ary)
1025 if not isinstance(indices_or_sections, int):
1026 indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
1027 return array_ops.split(ary, indices_or_sections, axis=axis)
1030def _split_on_axis(np_fun_name, axis): # pylint: disable=missing-function-docstring
1032 @np_utils.np_doc(np_fun_name)
1033 def f(ary, indices_or_sections):
1034 # for 1-D array, hsplit becomes vsplit
1035 new_axis = np_utils.cond(
1036 math_ops.equal(axis, 1),
1037 lambda: np_utils.cond( # pylint: disable=g-long-lambda
1038 math_ops.equal(array_ops.rank(ary), 1), lambda: 0, lambda: axis
1039 ),
1040 lambda: axis,
1041 )
1042 if isinstance(indices_or_sections, int):
1043 ary_shape = ary.shape[new_axis]
1044 if ary_shape is not None and ary_shape % indices_or_sections:
1045 raise ValueError(
1046 'array split does not result in an equal division')
1047 return split(ary, indices_or_sections, axis=new_axis)
1049 return f
1052vsplit = _split_on_axis('vsplit', axis=0)
1053hsplit = _split_on_axis('hsplit', axis=1)
1054dsplit = _split_on_axis('dsplit', axis=2)
1057@np_utils.np_doc('broadcast_to')
1058def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
1059 return full(shape, array)
1062@np_utils.np_doc('stack')
1063def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
1064 if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
1065 arrays = asarray(arrays)
1066 if axis == 0:
1067 return arrays
1068 else:
1069 return swapaxes(arrays, 0, axis)
1070 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
1071 unwrapped_arrays = [
1072 a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1073 ]
1074 return asarray(array_ops_stack.stack(unwrapped_arrays, axis))
1077@np_utils.np_doc('hstack')
1078def hstack(tup):
1079 arrays = [atleast_1d(a) for a in tup]
1080 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
1081 unwrapped_arrays = [
1082 a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1083 ]
1084 rank = array_ops.rank(unwrapped_arrays[0])
1085 return np_utils.cond(
1086 math_ops.equal(rank,
1087 1), lambda: array_ops.concat(unwrapped_arrays, axis=0),
1088 lambda: array_ops.concat(unwrapped_arrays, axis=1))
1091@np_utils.np_doc('vstack')
1092def vstack(tup):
1093 arrays = [atleast_2d(a) for a in tup]
1094 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
1095 unwrapped_arrays = [
1096 a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1097 ]
1098 return array_ops.concat(unwrapped_arrays, axis=0)
1101@np_utils.np_doc('dstack')
1102def dstack(tup):
1103 arrays = [atleast_3d(a) for a in tup]
1104 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
1105 unwrapped_arrays = [
1106 a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1107 ]
1108 return array_ops.concat(unwrapped_arrays, axis=2)
1111def _pad_left_to(n, old_shape):
1112 old_shape = asarray(old_shape, dtype=np.int32)
1113 new_shape = array_ops.pad(
1114 old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
1115 constant_values=1)
1116 return asarray(new_shape)
1119def _atleast_nd(n, new_shape, *arys):
1120 """Reshape arrays to be at least `n`-dimensional.
1122 Args:
1123 n: The minimal rank.
1124 new_shape: a function that takes `n` and the old shape and returns the
1125 desired new shape.
1126 *arys: ndarray(s) to be reshaped.
1128 Returns:
1129 The reshaped array(s).
1130 """
1132 def f(x):
1133 # pylint: disable=g-long-lambda
1134 x = asarray(x)
1135 return asarray(
1136 np_utils.cond(
1137 np_utils.greater(n, array_ops.rank(x)),
1138 lambda: reshape(x, new_shape(n, array_ops.shape(x))),
1139 lambda: x))
1141 arys = list(map(f, arys))
1142 if len(arys) == 1:
1143 return arys[0]
1144 else:
1145 return arys
1148@np_utils.np_doc('atleast_1d')
1149def atleast_1d(*arys):
1150 return _atleast_nd(1, _pad_left_to, *arys)
1153@np_utils.np_doc('atleast_2d')
1154def atleast_2d(*arys):
1155 return _atleast_nd(2, _pad_left_to, *arys)
1158@np_utils.np_doc('atleast_3d')
1159def atleast_3d(*arys): # pylint: disable=missing-docstring
1161 def new_shape(_, old_shape):
1162 # pylint: disable=g-long-lambda
1163 ndim_ = array_ops.size(old_shape)
1164 return np_utils.cond(
1165 math_ops.equal(ndim_, 0),
1166 lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32),
1167 lambda: np_utils.cond(
1168 math_ops.equal(ndim_, 1), lambda: array_ops.pad(
1169 old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad(
1170 old_shape, [[0, 1]], constant_values=1)))
1172 return _atleast_nd(3, new_shape, *arys)
1175@np_utils.np_doc('nonzero')
1176def nonzero(a):
1177 a = atleast_1d(a)
1178 if a.shape.rank is None:
1179 raise ValueError("The rank of `a` is unknown, so we can't decide how many "
1180 'arrays to return.')
1181 return array_ops_stack.unstack(
1182 array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
1183 a.shape.rank,
1184 axis=1)
1187@np_utils.np_doc('diag_indices')
1188def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name
1189 if n < 0:
1190 raise ValueError(
1191 'n argument to diag_indices must be nonnegative, got {}'.format(n))
1192 if ndim < 0:
1193 raise ValueError(
1194 'ndim argument to diag_indices must be nonnegative, got {}'.format(
1195 ndim))
1197 return (math_ops.range(n),) * ndim
1200@np_utils.np_doc('tri')
1201def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring
1202 M = M if M is not None else N
1203 if dtype is not None:
1204 dtype = np_utils.result_type(dtype)
1205 else:
1206 dtype = np_dtypes.default_float_type()
1208 if k < 0:
1209 lower = -k - 1
1210 if lower > N:
1211 r = array_ops.zeros([N, M], dtype)
1212 else:
1213 # Keep as tf bool, since we create an upper triangular matrix and invert
1214 # it.
1215 o = array_ops.ones([N, M], dtype=dtypes.bool)
1216 r = math_ops.cast(
1217 math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype)
1218 else:
1219 o = array_ops.ones([N, M], dtype)
1220 if k > M:
1221 r = o
1222 else:
1223 r = array_ops.matrix_band_part(o, -1, k)
1224 return r
1227@np_utils.np_doc('tril')
1228def tril(m, k=0): # pylint: disable=missing-docstring
1229 m = asarray(m)
1230 if m.shape.ndims is None:
1231 raise ValueError('Argument to tril should have known rank')
1232 m_shape = m.shape.as_list()
1234 if len(m_shape) < 2:
1235 raise ValueError('Argument to tril must have rank at least 2')
1237 if m_shape[-1] is None or m_shape[-2] is None:
1238 raise ValueError('Currently, the last two dimensions of the input array '
1239 'need to be known.')
1241 z = constant_op.constant(0, m.dtype)
1243 mask = tri(*m_shape[-2:], k=k, dtype=bool)
1244 return array_ops.where_v2(
1245 array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
1248@np_utils.np_doc('triu')
1249def triu(m, k=0): # pylint: disable=missing-docstring
1250 m = asarray(m)
1251 if m.shape.ndims is None:
1252 raise ValueError('Argument to triu should have known rank')
1253 m_shape = m.shape.as_list()
1255 if len(m_shape) < 2:
1256 raise ValueError('Argument to triu must have rank at least 2')
1258 if m_shape[-1] is None or m_shape[-2] is None:
1259 raise ValueError('Currently, the last two dimensions of the input array '
1260 'need to be known.')
1262 z = constant_op.constant(0, m.dtype)
1264 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
1265 return array_ops.where_v2(
1266 array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
1269@np_utils.np_doc('flip')
1270def flip(m, axis=None): # pylint: disable=missing-docstring
1271 m = asarray(m)
1273 if axis is None:
1274 return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
1276 axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access
1278 return array_ops.reverse(m, [axis])
1281@np_utils.np_doc('flipud')
1282def flipud(m): # pylint: disable=missing-docstring
1283 return flip(m, 0)
1286@np_utils.np_doc('fliplr')
1287def fliplr(m): # pylint: disable=missing-docstring
1288 return flip(m, 1)
1291@np_utils.np_doc('roll')
1292def roll(a, shift, axis=None): # pylint: disable=missing-docstring
1293 a = asarray(a)
1295 if axis is not None:
1296 return manip_ops.roll(a, shift, axis)
1298 # If axis is None, the roll happens as a 1-d tensor.
1299 original_shape = array_ops.shape(a)
1300 a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
1301 return array_ops.reshape(a, original_shape)
1304@np_utils.np_doc('rot90')
1305def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring
1306 m_rank = array_ops.rank(m)
1307 ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access
1309 k = k % 4
1310 if k == 0:
1311 return m
1312 elif k == 2:
1313 return flip(flip(m, ax1), ax2)
1314 else:
1315 perm = math_ops.range(m_rank)
1316 perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
1318 if k == 1:
1319 return transpose(flip(m, ax2), perm)
1320 else:
1321 return flip(transpose(m, perm), ax2)
1324@np_utils.np_doc('vander')
1325def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name
1326 x = asarray(x)
1328 x_shape = array_ops.shape(x)
1329 N = N or x_shape[0]
1331 N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name
1332 if N_temp is not None:
1333 N = N_temp
1334 if N < 0:
1335 raise ValueError('N must be nonnegative')
1336 else:
1337 control_flow_assert.Assert(N >= 0, [N])
1339 rank = array_ops.rank(x)
1340 rank_temp = np_utils.get_static_value(rank)
1341 if rank_temp is not None:
1342 rank = rank_temp
1343 if rank != 1:
1344 raise ValueError('x must be a one-dimensional array')
1345 else:
1346 control_flow_assert.Assert(math_ops.equal(rank, 1), [rank])
1348 if increasing:
1349 start = 0
1350 limit = N
1351 delta = 1
1352 else:
1353 start = N - 1
1354 limit = -1
1355 delta = -1
1357 x = array_ops.expand_dims(x, -1)
1358 return math_ops.pow(
1359 x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
1362@np_utils.np_doc('ix_')
1363def ix_(*args): # pylint: disable=missing-docstring
1364 n = len(args)
1365 output = []
1366 for i, a in enumerate(args):
1367 a = asarray(a)
1368 a_rank = array_ops.rank(a)
1369 a_rank_temp = np_utils.get_static_value(a_rank)
1370 if a_rank_temp is not None:
1371 a_rank = a_rank_temp
1372 if a_rank != 1:
1373 raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
1374 i, a_rank))
1375 else:
1376 control_flow_assert.Assert(math_ops.equal(a_rank, 1), [a_rank])
1378 new_shape = [1] * n
1379 new_shape[i] = -1
1380 dtype = a.dtype
1381 if dtype == dtypes.bool:
1382 output.append(array_ops.reshape(nonzero(a)[0], new_shape))
1383 elif dtype.is_integer:
1384 output.append(array_ops.reshape(a, new_shape))
1385 else:
1386 raise ValueError(
1387 'Only integer and bool dtypes are supported, got {}'.format(dtype))
1389 return output
1392@np_utils.np_doc('broadcast_arrays')
1393def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring
1394 subok = kwargs.pop('subok', False)
1395 if subok:
1396 raise ValueError('subok=True is not supported.')
1397 if kwargs:
1398 raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
1400 args = [asarray(arg) for arg in args]
1401 return np_utils.tf_broadcast(*args)
1404@np_utils.np_doc_only('sign')
1405def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name
1406 if out:
1407 raise ValueError('tf.numpy doesnt support setting out.')
1408 if where:
1409 raise ValueError('tf.numpy doesnt support setting where.')
1410 if kwargs:
1411 raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
1413 x = asarray(x)
1414 dtype = x.dtype.as_numpy_dtype
1415 if np.issubdtype(dtype, np.complexfloating):
1416 result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
1417 else:
1418 result = math_ops.sign(x)
1420 return result
1423# Note that np.take_along_axis may not be present in some supported versions of
1424# numpy.
1425@np_utils.np_doc('take_along_axis')
1426def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
1427 arr = asarray(arr)
1428 indices = asarray(indices)
1430 if axis is None:
1431 return take_along_axis(arr.ravel(), indices, 0)
1433 rank = array_ops.rank(arr)
1434 axis = axis + rank if axis < 0 else axis
1436 # Broadcast shapes to match, ensure that the axis of interest is not
1437 # broadcast.
1438 arr_shape_original = array_ops.shape(arr)
1439 indices_shape_original = array_ops.shape(indices)
1440 arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
1441 indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
1442 [[axis]], [1])
1443 broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
1444 indices_shape)
1445 arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
1446 [arr_shape_original[axis]])
1447 indices_shape = array_ops.tensor_scatter_update(
1448 broadcasted_shape, [[axis]], [indices_shape_original[axis]])
1449 arr = array_ops.broadcast_to(arr, arr_shape)
1450 indices = array_ops.broadcast_to(indices, indices_shape)
1452 # Save indices shape so we can restore it later.
1453 possible_result_shape = indices.shape
1455 # Correct indices since gather doesn't correctly handle negative indices.
1456 indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
1458 swapaxes_ = lambda t: swapaxes(t, axis, -1)
1460 dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
1461 arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
1462 lambda: swapaxes_(arr))
1463 indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
1464 lambda: swapaxes_(indices))
1466 arr_shape = array_ops.shape(arr)
1467 arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
1469 indices_shape = array_ops.shape(indices)
1470 indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
1472 result = array_ops.gather(arr, indices, batch_dims=1)
1473 result = array_ops.reshape(result, indices_shape)
1474 result = np_utils.cond(dont_move_axis_to_end, lambda: result,
1475 lambda: swapaxes_(result))
1476 result.set_shape(possible_result_shape)
1478 return result
1481_SLICE_ERORR = (
1482 'only integers, slices (`:`), ellipsis (`...`), '
1483 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
1486def _as_index(idx, need_scalar=True):
1487 """Helper function to parse idx as an index.
1489 Args:
1490 idx: index
1491 need_scalar: If idx needs to be a scalar value.
1493 Returns:
1494 A pair, (indx, bool). First one is the parsed index and can be a tensor,
1495 or scalar integer / Dimension. Second one is True if rank is known to be 0.
1497 Raises:
1498 IndexError: For incorrect indices.
1499 """
1500 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
1501 return idx, True
1502 data = asarray(idx)
1503 if data.dtype == dtypes.bool:
1504 if data.shape.ndims != 1:
1505 # TODO(agarwal): handle higher rank boolean masks.
1506 raise NotImplementedError('Need rank 1 for bool index %s' % idx)
1507 data = array_ops.where_v2(data)
1508 data = array_ops.reshape(data, [-1])
1509 if need_scalar and data.shape.rank not in (None, 0):
1510 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1511 np_dtype = data.dtype.as_numpy_dtype
1512 if not np.issubdtype(np_dtype, np.integer):
1513 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1514 if data.dtype not in (dtypes.int64, dtypes.int32):
1515 # TF slicing can only handle int32/int64. So we need to cast.
1516 promoted_dtype = np.promote_types(np.int32, np_dtype)
1517 if promoted_dtype == np.int32:
1518 data = math_ops.cast(data, dtypes.int32)
1519 elif promoted_dtype == np.int64:
1520 data = math_ops.cast(data, dtypes.int64)
1521 else:
1522 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1523 return data, data.shape.rank == 0
1526class _UpdateMethod(enum.Enum):
1527 UPDATE = 0
1528 ADD = 1
1529 MIN = 2
1530 MAX = 3
1533def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
1534 """Helper function for __getitem__ and _with_index_update_helper.
1536 This function collects the indices in `slice_spec` into two buckets, which we
1537 can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2
1538 `gather`. They also correspond to "basic indices" and "advanced indices" in
1539 numpy. This function supports both reading and writing at the indices. The
1540 reading path can be summarized as `gather(stride_slice(tensor, idx1),
1541 idx2)`. The writing path can be summarized as `strided_slice_update(tensor,
1542 idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`. (`gather` here
1543 means `tf.gather` or `tf.gather_nd`; `scatter` here means
1544 `tf.tensor_scatter_update`.) The writing path is inefficient because it needs
1545 to first read out a portion (probably much larger than `updates`) of `tensor`
1546 using `strided_slice`, update it, and then write the portion back. An
1547 alternative approach is to only use `scatter`, which amounts to using the
1548 indexing mechanism of gather/scatter to implement
1549 strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter
1550 because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but
1551 not TF gather/scatter because they don't support spans (except those that
1552 cover entire dimensions, i.e. `:`). If we materialize spans into individual
1553 indices, the size of the index tensor would explode. (Note that XLA
1554 Gather/Scatter have a similar problem for stride > 1 because they don't
1555 support strides. Indices such as `1:2:8` will need to be materialized into
1556 individual indices such as [1, 3, 5, 7].)
1558 Args:
1559 tensor: the tensor to be read from or write into.
1560 slice_spec: the indices.
1561 update_method: (optional) a member of `_UpdateMethod`, indicating how to
1562 update the values (replacement, add, etc.). `None` indicates just reading.
1563 updates: (optional) the new values to write into `tensor`. It must have the
1564 same dtype as `tensor`.
1566 Returns:
1567 The result of reading (if `update_method` is `None`) or the updated `tensor`
1568 after writing.
1569 """
1570 begin, end, strides = [], [], []
1571 new_axis_mask, shrink_axis_mask = 0, 0
1572 begin_mask, end_mask = 0, 0
1573 ellipsis_mask = 0
1574 advanced_indices = []
1575 shrink_indices = []
1576 for index, s in enumerate(slice_spec):
1577 if isinstance(s, slice):
1578 if s.start is not None:
1579 begin.append(_as_index(s.start)[0])
1580 else:
1581 begin.append(0)
1582 begin_mask |= (1 << index)
1583 if s.stop is not None:
1584 end.append(_as_index(s.stop)[0])
1585 else:
1586 end.append(0)
1587 end_mask |= (1 << index)
1588 if s.step is not None:
1589 strides.append(_as_index(s.step)[0])
1590 else:
1591 strides.append(1)
1592 elif s is Ellipsis:
1593 begin.append(0)
1594 end.append(0)
1595 strides.append(1)
1596 ellipsis_mask |= (1 << index)
1597 elif s is array_ops.newaxis:
1598 begin.append(0)
1599 end.append(0)
1600 strides.append(1)
1601 new_axis_mask |= (1 << index)
1602 else:
1603 s, is_scalar = _as_index(s, False)
1604 if is_scalar:
1605 begin.append(s)
1606 end.append(s + 1)
1607 strides.append(1)
1608 shrink_axis_mask |= (1 << index)
1609 shrink_indices.append(index)
1610 else:
1611 begin.append(0)
1612 end.append(0)
1613 strides.append(1)
1614 begin_mask |= (1 << index)
1615 end_mask |= (1 << index)
1616 advanced_indices.append((index, s, ellipsis_mask != 0))
1618 # stack possibly involves no tensors, so we must use op_scope correct graph.
1619 with ops.name_scope(
1620 None,
1621 'strided_slice', [tensor] + begin + end + strides,
1622 skip_on_eager=False) as name:
1623 if begin:
1624 packed_begin, packed_end, packed_strides = (
1625 array_ops_stack.stack(begin),
1626 array_ops_stack.stack(end),
1627 array_ops_stack.stack(strides))
1628 if (packed_begin.dtype == dtypes.int64 or
1629 packed_end.dtype == dtypes.int64 or
1630 packed_strides.dtype == dtypes.int64):
1631 if packed_begin.dtype != dtypes.int64:
1632 packed_begin = math_ops.cast(packed_begin, dtypes.int64)
1633 if packed_end.dtype != dtypes.int64:
1634 packed_end = math_ops.cast(packed_end, dtypes.int64)
1635 if packed_strides.dtype != dtypes.int64:
1636 packed_strides = math_ops.cast(packed_strides, dtypes.int64)
1637 else:
1638 var_empty = constant_op.constant([], dtype=dtypes.int32)
1639 packed_begin = packed_end = packed_strides = var_empty
1640 if update_method == _UpdateMethod.UPDATE and not advanced_indices:
1641 return array_ops.tensor_strided_slice_update(
1642 tensor,
1643 packed_begin,
1644 packed_end,
1645 packed_strides,
1646 updates,
1647 begin_mask=begin_mask,
1648 end_mask=end_mask,
1649 shrink_axis_mask=shrink_axis_mask,
1650 new_axis_mask=new_axis_mask,
1651 ellipsis_mask=ellipsis_mask,
1652 name=name)
1653 else:
1654 # TODO(b/164251540): Find a better way to support update that does not
1655 # involve one read + two writes.
1656 if updates is not None:
1657 original_tensor = tensor
1658 # TODO(agarwal): set_shape on tensor to set rank.
1659 tensor = array_ops.strided_slice(
1660 tensor,
1661 packed_begin,
1662 packed_end,
1663 packed_strides,
1664 begin_mask=begin_mask,
1665 end_mask=end_mask,
1666 shrink_axis_mask=shrink_axis_mask,
1667 new_axis_mask=new_axis_mask,
1668 ellipsis_mask=ellipsis_mask,
1669 name=name)
1670 if not advanced_indices:
1671 if update_method is None:
1672 return tensor
1673 assert update_method != _UpdateMethod.UPDATE
1674 # TF lacks TensorStridedSliceAdd and alike, so we need to do
1675 # read+add+update.
1676 if update_method == _UpdateMethod.ADD:
1677 update_op = math_ops.add
1678 elif update_method == _UpdateMethod.MIN:
1679 update_op = math_ops.minimum
1680 elif update_method == _UpdateMethod.MAX:
1681 update_op = math_ops.maximum
1682 return array_ops.tensor_strided_slice_update(
1683 original_tensor,
1684 packed_begin,
1685 packed_end,
1686 packed_strides,
1687 update_op(tensor, updates),
1688 begin_mask=begin_mask,
1689 end_mask=end_mask,
1690 shrink_axis_mask=shrink_axis_mask,
1691 new_axis_mask=new_axis_mask,
1692 ellipsis_mask=ellipsis_mask,
1693 name=name + '_2')
1694 advanced_indices_map = {}
1695 for index, data, had_ellipsis in advanced_indices:
1696 if had_ellipsis:
1697 num_shrink = len([x for x in shrink_indices if x > index])
1698 dim = index - len(slice_spec) + num_shrink
1699 else:
1700 num_shrink = len([x for x in shrink_indices if x < index])
1701 dim = index - num_shrink
1702 advanced_indices_map[dim] = data
1703 dims = sorted(advanced_indices_map.keys())
1704 dims_contiguous = True
1705 if len(dims) > 1:
1706 if dims[0] < 0 and dims[-1] >= 0: # not all same sign
1707 dims_contiguous = False
1708 else:
1709 for i in range(len(dims) - 1):
1710 if dims[i] + 1 != dims[i + 1]:
1711 dims_contiguous = False
1712 break
1713 indices = [advanced_indices_map[x] for x in dims]
1714 indices = _promote_dtype(*indices)
1715 indices = np_utils.tf_broadcast(*indices)
1716 stacked_indices = array_ops_stack.stack(indices, axis=-1)
1717 # Skip the contiguous-dims optimization for update because there is no
1718 # tf.*scatter* op that supports the `axis` argument.
1719 if not dims_contiguous or updates is not None:
1720 if range(len(dims)) != dims:
1721 tensor = moveaxis(tensor, dims, range(len(dims)))
1722 tensor_shape_prefix = array_ops.shape(
1723 tensor, out_type=stacked_indices.dtype)[:len(dims)]
1724 stacked_indices = array_ops.where_v2(
1725 stacked_indices < 0, stacked_indices + tensor_shape_prefix,
1726 stacked_indices)
1727 if updates is None:
1728 return array_ops.gather_nd(tensor, stacked_indices)
1729 else:
1730 # We only need to move-axis `updates` in the contiguous case becausce
1731 # only in this case the result dimensions of advanced indexing are in
1732 # the middle of `updates`. In the non-contiguous case, those dimensions
1733 # are always at the front.
1734 if dims_contiguous:
1735 # TODO(wangpeng): Support unknown rank (e.g. by partially flattening
1736 # `updates`)
1737 if stacked_indices.shape.rank is None:
1738 raise NotImplementedError(
1739 'Rank of the advanced indices must currently be known')
1740 batch_size = stacked_indices.shape.rank - 1
1741 batch_start = dims[0]
1742 if batch_start < 0:
1743 batch_start += len(dims) - batch_size
1744 def range_(start, length):
1745 return range(start, start + length)
1746 updates = moveaxis(updates, range_(batch_start, batch_size),
1747 range(batch_size))
1748 if update_method == _UpdateMethod.UPDATE:
1749 update_op = array_ops.tensor_scatter_update
1750 elif update_method == _UpdateMethod.ADD:
1751 update_op = array_ops.tensor_scatter_add
1752 elif update_method == _UpdateMethod.MIN:
1753 update_op = array_ops.tensor_scatter_min
1754 elif update_method == _UpdateMethod.MAX:
1755 update_op = array_ops.tensor_scatter_max
1756 tensor = update_op(
1757 tensor, stacked_indices, updates)
1758 if range(len(dims)) != dims:
1759 tensor = moveaxis(tensor, range(len(dims)), dims)
1760 return array_ops.tensor_strided_slice_update(
1761 original_tensor,
1762 packed_begin,
1763 packed_end,
1764 packed_strides,
1765 tensor,
1766 begin_mask=begin_mask,
1767 end_mask=end_mask,
1768 shrink_axis_mask=shrink_axis_mask,
1769 new_axis_mask=new_axis_mask,
1770 ellipsis_mask=ellipsis_mask,
1771 name=name + '_2')
1772 # Note that gather_nd does not support gathering from inside the array.
1773 # To avoid shuffling data back and forth, we transform the indices and
1774 # do a gather instead.
1775 rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access
1776 dims = [(x + rank if x < 0 else x) for x in dims]
1777 shape_tensor = array_ops.shape(tensor)
1778 dim_sizes = array_ops.gather(shape_tensor, dims)
1779 if len(dims) == 1:
1780 stacked_indices = indices[0]
1781 stacked_indices = math_ops.cast(stacked_indices, dtypes.int32)
1782 stacked_indices = array_ops.where_v2(stacked_indices < 0,
1783 stacked_indices + dim_sizes,
1784 stacked_indices)
1785 axis = dims[0]
1786 if len(dims) > 1:
1787 index_scaling = math_ops.cumprod(
1788 dim_sizes, reverse=True, exclusive=True)
1789 def _tensordot(a, b):
1790 # TODO(b/168657656): This function should be replaced by
1791 # tensordot(axis=1) once MatMul has int32 XLA kernel.
1792 b = array_ops.broadcast_to(b, array_ops.shape(a))
1793 return math_ops.reduce_sum(a * b, axis=-1)
1794 stacked_indices = _tensordot(stacked_indices, index_scaling)
1795 flat_shape = array_ops.concat(
1796 [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
1797 axis=0)
1798 tensor = array_ops.reshape(tensor, flat_shape)
1800 return array_ops.gather(tensor, stacked_indices, axis=axis)
1803def _as_spec_tuple(slice_spec):
1804 """Convert slice_spec to tuple."""
1805 if isinstance(slice_spec,
1806 (list, tuple)) and not isinstance(slice_spec, np.ndarray):
1807 is_index = True
1808 for s in slice_spec:
1809 if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)):
1810 is_index = False
1811 break
1812 elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
1813 is_index = False
1814 break
1815 if not is_index:
1816 return tuple(slice_spec)
1817 return (slice_spec,)
1820def _getitem(self, slice_spec):
1821 """Implementation of ndarray.__getitem__."""
1822 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1823 slice_spec.dtype == dtypes.bool) or
1824 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1825 slice_spec.dtype == np.bool_)):
1826 return array_ops.boolean_mask(tensor=self, mask=slice_spec)
1828 if not isinstance(slice_spec, tuple):
1829 slice_spec = _as_spec_tuple(slice_spec)
1831 result_t = _slice_helper(self, slice_spec)
1832 return result_t
1835def _with_index_update_helper(update_method, a, slice_spec, updates):
1836 """Implementation of ndarray._with_index_*."""
1837 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1838 slice_spec.dtype == dtypes.bool) or
1839 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1840 slice_spec.dtype == np.bool_)):
1841 slice_spec = nonzero(slice_spec)
1843 if not isinstance(slice_spec, tuple):
1844 slice_spec = _as_spec_tuple(slice_spec)
1846 a_dtype = a.dtype
1847 a, updates = _promote_dtype_binary(a, updates)
1848 result_t = _slice_helper(a, slice_spec, update_method, updates)
1849 return result_t.astype(a_dtype)
1852setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
1853setattr(np_arrays.ndarray, '_with_index_update',
1854 functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
1855setattr(np_arrays.ndarray, '_with_index_add',
1856 functools.partial(_with_index_update_helper, _UpdateMethod.ADD))
1857setattr(np_arrays.ndarray, '_with_index_min',
1858 functools.partial(_with_index_update_helper, _UpdateMethod.MIN))
1859setattr(np_arrays.ndarray, '_with_index_max',
1860 functools.partial(_with_index_update_helper, _UpdateMethod.MAX))