Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_array_ops.py: 20%

1045 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Common array methods.""" 

16# pylint: disable=g-direct-tensorflow-import 

17 

18import enum 

19import functools 

20import math 

21import numbers 

22import numpy as np 

23 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import dtypes 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import tensor_shape 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import array_ops_stack 

30from tensorflow.python.ops import clip_ops 

31from tensorflow.python.ops import control_flow_assert 

32from tensorflow.python.ops import linalg_ops 

33from tensorflow.python.ops import manip_ops 

34from tensorflow.python.ops import math_ops 

35from tensorflow.python.ops import sort_ops 

36from tensorflow.python.ops.numpy_ops import np_arrays 

37from tensorflow.python.ops.numpy_ops import np_dtypes 

38from tensorflow.python.ops.numpy_ops import np_export 

39from tensorflow.python.ops.numpy_ops import np_utils 

40from tensorflow.python.util import nest 

41 

42 

43newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis) 

44 

45 

46@np_utils.np_doc('empty') 

47def empty(shape, dtype=float): # pylint: disable=redefined-outer-name 

48 return zeros(shape, dtype) 

49 

50 

51@np_utils.np_doc('empty_like') 

52def empty_like(a, dtype=None): 

53 return zeros_like(a, dtype) 

54 

55 

56@np_utils.np_doc('zeros') 

57def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name 

58 dtype = ( 

59 np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()) 

60 return array_ops.zeros(shape, dtype=dtype) 

61 

62 

63@np_utils.np_doc('zeros_like') 

64def zeros_like(a, dtype=None): # pylint: disable=missing-docstring 

65 dtype = np_utils.result_type_unary(a, dtype) 

66 

67 dtype = dtypes.as_dtype(dtype) # Work around b/149877262 

68 return array_ops.zeros_like(a, dtype) 

69 

70 

71@np_utils.np_doc('ones') 

72def ones(shape, dtype=float): # pylint: disable=redefined-outer-name 

73 if dtype: 

74 dtype = np_utils.result_type(dtype) 

75 return array_ops.ones(shape, dtype=dtype) 

76 

77 

78@np_utils.np_doc('ones_like') 

79def ones_like(a, dtype=None): 

80 dtype = np_utils.result_type_unary(a, dtype) 

81 return array_ops.ones_like(a, dtype) 

82 

83 

84@np_utils.np_doc('eye') 

85def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring 

86 if dtype: 

87 dtype = np_utils.result_type(dtype) 

88 if not M: 

89 M = N 

90 # Making sure N, M and k are `int` 

91 N = int(N) 

92 M = int(M) 

93 k = int(k) 

94 if k >= M or -k >= N: 

95 # tf.linalg.diag will raise an error in this case 

96 return zeros([N, M], dtype=dtype) 

97 if k == 0: 

98 return linalg_ops.eye(N, M, dtype=dtype) 

99 # We need the precise length, otherwise tf.linalg.diag will raise an error 

100 diag_len = min(N, M) 

101 if k > 0: 

102 if N >= M: 

103 diag_len -= k 

104 elif N + k > M: 

105 diag_len = M - k 

106 elif k <= 0: 

107 if M >= N: 

108 diag_len += k 

109 elif M - k > N: 

110 diag_len = N + k 

111 diagonal_ = array_ops.ones([diag_len], dtype=dtype) 

112 return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) 

113 

114 

115@np_utils.np_doc('identity') 

116def identity(n, dtype=float): 

117 return eye(N=n, M=n, dtype=dtype) 

118 

119 

120@np_utils.np_doc('full') 

121def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name 

122 if not isinstance(shape, np_arrays.ndarray): 

123 shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) 

124 shape = atleast_1d(shape) 

125 fill_value = asarray(fill_value, dtype=dtype) 

126 return array_ops.broadcast_to(fill_value, shape) 

127 

128 

129# Using doc only here since np full_like signature doesn't seem to have the 

130# shape argument (even though it exists in the documentation online). 

131@np_utils.np_doc_only('full_like') 

132def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name 

133 """order, subok and shape arguments mustn't be changed.""" 

134 if order != 'K': 

135 raise ValueError('Non-standard orders are not supported.') 

136 if not subok: 

137 raise ValueError('subok being False is not supported.') 

138 if shape: 

139 raise ValueError('Overriding the shape is not supported.') 

140 

141 a = asarray(a) 

142 dtype = dtype or np_utils.result_type(a) 

143 fill_value = asarray(fill_value, dtype=dtype) 

144 return array_ops.broadcast_to(fill_value, array_ops.shape(a)) 

145 

146 

147def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 

148 """Main implementation of np.array().""" 

149 result_t = val 

150 

151 if not isinstance(result_t, ops.Tensor): 

152 dtype = np_utils.result_type_unary(result_t, dtype) 

153 # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because 

154 # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) 

155 # while np.array allows them. We need to convert-then-cast. 

156 

157 # EagerTensor conversion complains about "mixed types" when converting 

158 # tensors with no dtype information. This is because it infers types based 

159 # on one selected item in the list. So e.g. when converting [2., 2j] 

160 # to a tensor, it will select float32 as the inferred type and not be able 

161 # to convert the list to a float 32 tensor. 

162 # Since we have some information about the final dtype we care about, we 

163 # supply that information so that convert_to_tensor will do best-effort 

164 # conversion to that dtype first. 

165 result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype) 

166 result_t = math_ops.cast(result_t, dtype=dtype) 

167 elif dtype: 

168 result_t = math_ops.cast(result_t, dtype) 

169 

170 if copy: 

171 result_t = array_ops.identity(result_t) 

172 

173 max_ndmin = 32 

174 if ndmin > max_ndmin: 

175 raise ValueError('ndmin bigger than allowable number of dimensions: ' 

176 f'{max_ndmin}.') 

177 

178 if ndmin == 0: 

179 return result_t 

180 

181 ndims = array_ops.rank(result_t) 

182 

183 def true_fn(): 

184 old_shape = array_ops.shape(result_t) 

185 new_shape = array_ops.concat( 

186 [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0) 

187 return array_ops.reshape(result_t, new_shape) 

188 

189 result_t = np_utils.cond( 

190 np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) 

191 return result_t 

192 

193 

194# TODO(wangpeng): investigate whether we can make `copy` default to False. 

195# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 

196@np_utils.np_doc_only('array') 

197def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 

198 """Since Tensors are immutable, a copy is made only if val is placed on a 

199 

200 different device than the current one. Even if `copy` is False, a new Tensor 

201 may need to be built to satisfy `dtype` and `ndim`. This is used only if `val` 

202 is an ndarray or a Tensor. 

203 """ # pylint:disable=g-docstring-missing-newline 

204 if dtype: 

205 dtype = np_utils.result_type(dtype) 

206 return _array_internal(val, dtype, copy, ndmin) 

207 

208 

209# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 

210 

211 

212@np_utils.np_doc('asarray') 

213def asarray(a, dtype=None): 

214 if dtype: 

215 dtype = np_utils.result_type(dtype) 

216 if isinstance(a, np_arrays.ndarray) and ( 

217 not dtype or dtype == a.dtype.as_numpy_dtype): 

218 return a 

219 return array(a, dtype, copy=False) 

220 

221 

222@np_utils.np_doc('asanyarray') 

223def asanyarray(a, dtype=None): 

224 return asarray(a, dtype) 

225 

226 

227@np_utils.np_doc('ascontiguousarray') 

228def ascontiguousarray(a, dtype=None): 

229 return array(a, dtype, ndmin=1) 

230 

231 

232# Numerical ranges. 

233@np_utils.np_doc('arange') 

234def arange(start, stop=None, step=1, dtype=None): 

235 """Returns `step`-separated values in the range [start, stop). 

236 

237 Args: 

238 start: Start of the interval. Included in the range. 

239 stop: End of the interval. If not specified, `start` is treated as 0 and 

240 `start` value is used as `stop`. If specified, it is not included in the 

241 range if `step` is integer. When `step` is floating point, it may or may 

242 not be included. 

243 step: The difference between 2 consecutive values in the output range. It is 

244 recommended to use `linspace` instead of using non-integer values for 

245 `step`. 

246 dtype: Optional. Type of the resulting ndarray. Could be a python type, a 

247 NumPy type or a TensorFlow `DType`. If not provided, the largest type of 

248 `start`, `stop`, `step` is used. 

249 

250 Raises: 

251 ValueError: If step is zero. 

252 """ 

253 if not step: 

254 raise ValueError('step must be non-zero.') 

255 if dtype: 

256 dtype = np_utils.result_type(dtype) 

257 else: 

258 if stop is None: 

259 dtype = np_utils.result_type(start, step) 

260 else: 

261 dtype = np_utils.result_type(start, step, stop) 

262 if step > 0 and ((stop is not None and start > stop) or 

263 (stop is None and start < 0)): 

264 return array([], dtype=dtype) 

265 if step < 0 and ((stop is not None and start < stop) or 

266 (stop is None and start > 0)): 

267 return array([], dtype=dtype) 

268 # TODO(srbs): There are some bugs when start or stop is float type and dtype 

269 # is integer type. 

270 return math_ops.cast( 

271 math_ops.range(start, limit=stop, delta=step), dtype=dtype) 

272 

273 

274# Building matrices. 

275@np_utils.np_doc('diag') 

276def diag(v, k=0): # pylint: disable=missing-docstring 

277 """Raises an error if input is not 1- or 2-d.""" 

278 v = asarray(v) 

279 v_rank = array_ops.rank(v) 

280 

281 v.shape.with_rank_at_most(2) 

282 

283 # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during 

284 # tracing time if the shape is known. 

285 control_flow_assert.Assert( 

286 np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)), 

287 [v_rank]) 

288 

289 def _diag(v, k): 

290 return np_utils.cond( 

291 math_ops.equal(array_ops.size(v), 0), 

292 lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype), 

293 lambda: array_ops.matrix_diag(v, k=k)) 

294 

295 def _diag_part(v, k): 

296 v_shape = array_ops.shape(v) 

297 v, k = np_utils.cond( 

298 np_utils.logical_or( 

299 np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)), 

300 np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)), 

301 ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) 

302 result = array_ops.matrix_diag_part(v, k=k) 

303 return result 

304 

305 result = np_utils.cond( 

306 math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) 

307 return result 

308 

309 

310@np_utils.np_doc('diagonal') 

311def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring 

312 a = asarray(a) 

313 

314 maybe_rank = a.shape.rank 

315 if maybe_rank is not None and offset == 0 and ( 

316 axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or 

317 axis2 == -1): 

318 return array_ops.matrix_diag_part(a) 

319 

320 a = moveaxis(a, (axis1, axis2), (-2, -1)) 

321 

322 a_shape = array_ops.shape(a) 

323 

324 def _zeros(): # pylint: disable=missing-docstring 

325 return (array_ops.zeros( 

326 array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) 

327 

328 # All zeros since diag_part doesn't handle all possible k (aka offset). 

329 # Written this way since cond will run shape inference on both branches, 

330 # and diag_part shape inference will fail when offset is out of bounds. 

331 a, offset = np_utils.cond( 

332 np_utils.logical_or( 

333 np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)), 

334 np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)), 

335 ), _zeros, lambda: (a, offset)) 

336 

337 a = array_ops.matrix_diag_part(a, k=offset) 

338 return a 

339 

340 

341@np_utils.np_doc('diagflat') 

342def diagflat(v, k=0): 

343 v = asarray(v) 

344 return diag(array_ops.reshape(v, [-1]), k) 

345 

346 

347def _promote_dtype(*arrays): 

348 dtype = np_utils.result_type(*arrays) 

349 def _fast_asarray(a): 

350 if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype: 

351 return a 

352 return _array_internal(a, dtype=dtype, copy=False) 

353 return [_fast_asarray(a) for a in arrays] 

354 

355 

356def _promote_dtype_binary(t1, t2): 

357 dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access 

358 if not( 

359 isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype): 

360 t1 = _array_internal(t1, dtype=dtype, copy=False) 

361 if not( 

362 isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype): 

363 t2 = _array_internal(t2, dtype=dtype, copy=False) 

364 return t1, t2 

365 

366 

367@np_utils.np_doc('all') 

368def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 

369 a = asarray(a, dtype=bool) 

370 return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims) 

371 

372 

373@np_utils.np_doc('any') 

374def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 

375 a = asarray(a, dtype=bool) 

376 return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims) 

377 

378 

379@np_utils.np_doc('compress') 

380def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,missing-function-docstring 

381 condition = asarray(condition, dtype=bool) 

382 a = asarray(a) 

383 

384 if condition.ndim != 1: 

385 raise ValueError('condition must be a 1-d array.') 

386 # `np.compress` treats scalars as 1-d arrays. 

387 if a.ndim == 0: 

388 a = ravel(a) 

389 

390 if axis is None: 

391 a = ravel(a) 

392 axis = 0 

393 

394 if axis < 0: 

395 axis += a.ndim 

396 

397 assert axis >= 0 and axis < a.ndim 

398 

399 # `tf.boolean_mask` requires the first dimensions of array and condition to 

400 # match. `np.compress` pads condition with False when it is shorter. 

401 condition_t = condition 

402 a_t = a 

403 if condition.shape[0] < a.shape[axis]: 

404 padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False) 

405 condition_t = array_ops.concat([condition_t, padding], axis=0) 

406 return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) 

407 

408 

409@np_utils.np_doc('copy') 

410def copy(a): 

411 return array(a, copy=True) 

412 

413 

414def _maybe_promote_to_int(a): 

415 if dtypes.as_dtype(a.dtype).is_integer: 

416 # If a is an integer type and its precision is less than that of `int`, 

417 # the output type will be `int`. 

418 a_numpy_dtype = a.dtype.as_numpy_dtype 

419 output_type = np.promote_types(a_numpy_dtype, int) 

420 if output_type != a_numpy_dtype: 

421 a = asarray(a, dtype=output_type) 

422 

423 return a 

424 

425 

426@np_utils.np_doc('cumprod') 

427def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring 

428 a = asarray(a, dtype=dtype) 

429 

430 if dtype is None: 

431 a = _maybe_promote_to_int(a) 

432 

433 # If axis is None, the input is flattened. 

434 if axis is None: 

435 a = ravel(a) 

436 axis = 0 

437 elif axis < 0: 

438 axis += array_ops.rank(a) 

439 return math_ops.cumprod(a, axis) 

440 

441 

442@np_utils.np_doc('cumsum') 

443def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring 

444 a = asarray(a, dtype=dtype) 

445 

446 if dtype is None: 

447 a = _maybe_promote_to_int(a) 

448 

449 # If axis is None, the input is flattened. 

450 if axis is None: 

451 a = ravel(a) 

452 axis = 0 

453 elif axis < 0: 

454 axis += array_ops.rank(a) 

455 return math_ops.cumsum(a, axis) 

456 

457 

458@np_utils.np_doc('imag') 

459def imag(val): 

460 val = asarray(val) 

461 # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always 

462 # return an ndarray. 

463 return math_ops.imag(val) 

464 

465 

466_TO_INT_ = 0 

467_TO_FLOAT = 1 

468 

469 

470def _reduce(tf_fn, 

471 a, 

472 axis=None, 

473 dtype=None, 

474 keepdims=None, 

475 promote_int=_TO_INT_, 

476 tf_bool_fn=None, 

477 preserve_bool=False): 

478 """A general reduction function. 

479 

480 Args: 

481 tf_fn: the TF reduction function. 

482 a: the array to be reduced. 

483 axis: (optional) the axis along which to do the reduction. If None, all 

484 dimensions are reduced. 

485 dtype: (optional) the dtype of the result. 

486 keepdims: (optional) whether to keep the reduced dimension(s). 

487 promote_int: how to promote integer and bool inputs. There are three 

488 choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2) 

489 `_TO_FLOAT` always promotes them to a float type (determined by 

490 dtypes.default_float_type); (3) None: don't promote. 

491 tf_bool_fn: (optional) the TF reduction function for bool inputs. It will 

492 only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype 

493 is `np.bool_` and `preserve_bool` is True. 

494 preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype 

495 is `np.bool_` (some reductions such as np.sum convert bools to integers, 

496 while others such as np.max preserve bools. 

497 

498 Returns: 

499 An ndarray. 

500 """ 

501 if dtype: 

502 dtype = np_utils.result_type(dtype) 

503 if keepdims is None: 

504 keepdims = False 

505 a = asarray(a, dtype=dtype) 

506 if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and 

507 tf_bool_fn is not None): 

508 return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims) 

509 if dtype is None: 

510 dtype = a.dtype.as_numpy_dtype 

511 if np.issubdtype(dtype, np.integer) or dtype == np.bool_: 

512 if promote_int == _TO_INT_: 

513 # If a is an integer/bool type and whose bit width is less than np.int_, 

514 # numpy up-casts it to np.int_ based on the documentation at 

515 # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html 

516 if dtype == np.bool_: 

517 is_signed = True 

518 width = 8 # We can use any number here that is less than 64 

519 else: 

520 is_signed = np.issubdtype(dtype, np.signedinteger) 

521 width = np.iinfo(dtype).bits 

522 # Numpy int_ and uint are defined as 'long' and 'unsigned long', so 

523 # should have the same bit width. 

524 if width < np.iinfo(np.int_).bits: 

525 if is_signed: 

526 dtype = np.int_ 

527 else: 

528 dtype = np.uint 

529 a = math_ops.cast(a, dtype) 

530 elif promote_int == _TO_FLOAT: 

531 a = math_ops.cast(a, np_dtypes.default_float_type()) 

532 

533 if isinstance(axis, ops.Tensor) and axis.dtype not in ( 

534 dtypes.int32, dtypes.int64): 

535 axis = math_ops.cast(axis, dtypes.int64) 

536 

537 return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims) 

538 

539 

540# TODO (DarrenZhang01): Add `axis` support to the `size` API. 

541@np_utils.np_doc('size') 

542def size(x, axis=None): # pylint: disable=missing-docstring 

543 if axis is not None: 

544 raise NotImplementedError('axis argument is not supported in the current ' 

545 '`np.size` implementation') 

546 if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)): 

547 return 1 

548 x = asarray(x) 

549 if x.shape.is_fully_defined(): 

550 return np.prod(x.shape.as_list(), dtype=int) 

551 else: 

552 return array_ops.size_v2(x) 

553 

554 

555@np_utils.np_doc('sum') 

556def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin 

557 return _reduce( 

558 math_ops.reduce_sum, 

559 a, 

560 axis=axis, 

561 dtype=dtype, 

562 keepdims=keepdims, 

563 tf_bool_fn=math_ops.reduce_any) 

564 

565 

566@np_utils.np_doc('prod') 

567def prod(a, axis=None, dtype=None, keepdims=None): 

568 return _reduce( 

569 math_ops.reduce_prod, 

570 a, 

571 axis=axis, 

572 dtype=dtype, 

573 keepdims=keepdims, 

574 tf_bool_fn=math_ops.reduce_all) 

575 

576 

577@np_utils.np_doc('mean', unsupported_params=['out']) 

578def mean(a, axis=None, dtype=None, out=None, keepdims=None): 

579 if out is not None: 

580 raise ValueError('Setting out is not supported.') 

581 return _reduce( 

582 math_ops.reduce_mean, 

583 a, 

584 axis=axis, 

585 dtype=dtype, 

586 keepdims=keepdims, 

587 promote_int=_TO_FLOAT) 

588 

589 

590@np_utils.np_doc('amax', unsupported_params=['out']) 

591def amax(a, axis=None, out=None, keepdims=None): 

592 if out is not None: 

593 raise ValueError('Setting out is not supported.') 

594 return _reduce( 

595 math_ops.reduce_max, 

596 a, 

597 axis=axis, 

598 dtype=None, 

599 keepdims=keepdims, 

600 promote_int=None, 

601 tf_bool_fn=math_ops.reduce_any, 

602 preserve_bool=True) 

603 

604 

605@np_utils.np_doc('amin', unsupported_params=['out']) 

606def amin(a, axis=None, out=None, keepdims=None): 

607 if out is not None: 

608 raise ValueError('Setting out is not supported.') 

609 return _reduce( 

610 math_ops.reduce_min, 

611 a, 

612 axis=axis, 

613 dtype=None, 

614 keepdims=keepdims, 

615 promote_int=None, 

616 tf_bool_fn=math_ops.reduce_all, 

617 preserve_bool=True) 

618 

619 

620@np_utils.np_doc('var') 

621def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring 

622 if dtype: 

623 working_dtype = np_utils.result_type(a, dtype) 

624 else: 

625 working_dtype = None 

626 if out is not None: 

627 raise ValueError('Setting out is not supported.') 

628 if ddof != 0: 

629 # TF reduce_variance doesn't support ddof, so calculate it using raw ops. 

630 def reduce_fn(input_tensor, axis, keepdims): 

631 means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True) 

632 centered = input_tensor - means 

633 if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): 

634 centered = math_ops.cast( 

635 math_ops.real(centered * math_ops.conj(centered)), 

636 input_tensor.dtype) 

637 else: 

638 centered = math_ops.square(centered) 

639 squared_deviations = math_ops.reduce_sum( 

640 centered, axis=axis, keepdims=keepdims) 

641 

642 if axis is None: 

643 n = array_ops.size(input_tensor) 

644 else: 

645 if axis < 0: 

646 axis += array_ops.rank(input_tensor) 

647 n = math_ops.reduce_prod( 

648 array_ops.gather(array_ops.shape(input_tensor), axis)) 

649 n = math_ops.cast(n - ddof, input_tensor.dtype) 

650 

651 return math_ops.cast(math_ops.divide(squared_deviations, n), dtype) 

652 else: 

653 reduce_fn = math_ops.reduce_variance 

654 

655 result = _reduce( 

656 reduce_fn, 

657 a, 

658 axis=axis, 

659 dtype=working_dtype, 

660 keepdims=keepdims, 

661 promote_int=_TO_FLOAT) 

662 if dtype: 

663 result = math_ops.cast(result, dtype) 

664 return result 

665 

666 

667@np_utils.np_doc('std') 

668def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring 

669 return _reduce( 

670 math_ops.reduce_std, 

671 a, 

672 axis=axis, 

673 dtype=None, 

674 keepdims=keepdims, 

675 promote_int=_TO_FLOAT) 

676 

677 

678@np_utils.np_doc('ravel') 

679def ravel(a): # pylint: disable=missing-docstring 

680 a = asarray(a) 

681 return array_ops.reshape(a, [-1]) 

682 

683 

684@np_utils.np_doc('real') 

685def real(val): 

686 val = asarray(val) 

687 # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always 

688 # return an ndarray. 

689 return math_ops.real(val) 

690 

691 

692@np_utils.np_doc('repeat') 

693def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring 

694 a = asarray(a) 

695 original_shape = a._shape_as_list() # pylint: disable=protected-access 

696 # Best effort recovery of the shape. 

697 known_shape = original_shape is not None and None not in original_shape 

698 if known_shape: 

699 if not original_shape: 

700 original_shape = (repeats,) 

701 else: 

702 repeats_np = np.ravel(np.array(repeats)) 

703 if repeats_np.size == 1: 

704 repeats_np = repeats_np.item() 

705 if axis is None: 

706 original_shape = (repeats_np * np.prod(original_shape),) 

707 else: 

708 original_shape[axis] = repeats_np * original_shape[axis] 

709 else: 

710 if axis is None: 

711 original_shape = (repeats_np.sum(),) 

712 else: 

713 original_shape[axis] = repeats_np.sum() 

714 

715 repeats = asarray(repeats) 

716 result = array_ops.repeat(a, repeats, axis) 

717 if known_shape: 

718 result.set_shape(original_shape) 

719 

720 return result 

721 

722 

723@np_utils.np_doc('around') 

724def around(a, decimals=0): # pylint: disable=missing-docstring 

725 a = asarray(a) 

726 dtype = a.dtype.as_numpy_dtype 

727 factor = math.pow(10, decimals) 

728 if np.issubdtype(dtype, np.inexact): 

729 factor = math_ops.cast(factor, dtype) 

730 else: 

731 # Use float as the working dtype when a.dtype is exact (e.g. integer), 

732 # because `decimals` can be negative. 

733 float_dtype = np_dtypes.default_float_type() 

734 a = a.astype(float_dtype) 

735 factor = math_ops.cast(factor, float_dtype) 

736 a = math_ops.multiply(a, factor) 

737 a = math_ops.round(a) 

738 a = math_ops.divide(a, factor) 

739 return a.astype(dtype) 

740 

741 

742setattr(np_arrays.ndarray, '__round__', around) 

743 

744 

745@np_utils.np_doc('reshape') 

746def reshape(a, newshape, order='C'): 

747 """order argument can only b 'C' or 'F'.""" 

748 if order not in {'C', 'F'}: 

749 raise ValueError('Unsupported order argument {}'.format(order)) 

750 

751 a = asarray(a) 

752 if isinstance(newshape, int): 

753 newshape = [newshape] 

754 

755 if order == 'F': 

756 r = array_ops.transpose( 

757 array_ops.reshape(array_ops.transpose(a), newshape[::-1])) 

758 else: 

759 r = array_ops.reshape(a, newshape) 

760 

761 return r 

762 

763 

764def _reshape_method_wrapper(a, *newshape, **kwargs): 

765 order = kwargs.pop('order', 'C') 

766 if kwargs: 

767 raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) 

768 

769 if len(newshape) == 1 and not isinstance(newshape[0], int): 

770 newshape = newshape[0] 

771 

772 return reshape(a, newshape, order=order) 

773 

774 

775@np_utils.np_doc('expand_dims') 

776def expand_dims(a, axis): 

777 a = asarray(a) 

778 return array_ops.expand_dims(a, axis=axis) 

779 

780 

781@np_utils.np_doc('squeeze') 

782def squeeze(a, axis=None): 

783 a = asarray(a) 

784 return array_ops.squeeze(a, axis) 

785 

786 

787@np_utils.np_doc('flatten', link=np_utils.NoLink()) 

788def flatten(a, order='C'): 

789 a = asarray(a) 

790 if order == 'C' or order == 'A' or order == 'K': 

791 # Row major. 

792 return array_ops.reshape(a, [-1]) 

793 elif order == 'F': 

794 # Column major 

795 return array_ops.reshape(array_ops.transpose(a), [-1]) 

796 else: 

797 raise ValueError('order can only be C, A, K (all row major) or F ' 

798 '(column major).') 

799 

800 

801@np_utils.np_doc('transpose') 

802def transpose(a, axes=None): 

803 a = asarray(a) 

804 if axes is not None: 

805 axes = asarray(axes) 

806 return array_ops.transpose(a=a, perm=axes) 

807 

808 

809@np_utils.np_doc('swapaxes') 

810def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring 

811 a = asarray(a) 

812 def adjust_axes(axes, rank): 

813 def f(x): 

814 if isinstance(x, int): 

815 if x < 0: 

816 x = x + rank 

817 else: 

818 x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) 

819 return x 

820 return nest.map_structure(f, axes) 

821 

822 if (a.shape.rank is not None and 

823 isinstance(axis1, int) and isinstance(axis2, int)): 

824 # This branch makes sure `perm` is statically known, to avoid a 

825 # not-compile-time-constant XLA error. 

826 a_rank = a.shape.rank 

827 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 

828 perm = list(range(a_rank)) 

829 perm[axis1] = axis2 

830 perm[axis2] = axis1 

831 else: 

832 a_rank = array_ops.rank(a) 

833 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 

834 perm = math_ops.range(a_rank) 

835 perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], 

836 [axis2, axis1]) 

837 a = array_ops.transpose(a, perm) 

838 return a 

839 

840 

841@np_utils.np_doc('moveaxis') 

842def moveaxis(a, source, destination): # pylint: disable=missing-docstring 

843 """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" 

844 if not source and not destination: 

845 return a 

846 

847 a = asarray(a) 

848 

849 if isinstance(source, int): 

850 source = (source,) 

851 if isinstance(destination, int): 

852 destination = (destination,) 

853 if len(source) != len(destination): 

854 raise ValueError('The lengths of source and destination must equal') 

855 

856 a_rank = np_utils._maybe_static(array_ops.rank(a)) # pylint: disable=protected-access 

857 

858 def _correct_axis(axis, rank): 

859 if axis < 0: 

860 return axis + rank 

861 return axis 

862 

863 source = tuple(_correct_axis(axis, a_rank) for axis in source) 

864 destination = tuple(_correct_axis(axis, a_rank) for axis in destination) 

865 

866 if a.shape.rank is not None: 

867 perm = [i for i in range(a_rank) if i not in source] 

868 for dest, src in sorted(zip(destination, source)): 

869 assert dest <= len(perm) 

870 perm.insert(dest, src) 

871 else: 

872 r = math_ops.range(a_rank) 

873 

874 def _remove_indices(a, b): 

875 """Remove indices (`b`) from `a`.""" 

876 items = array_ops_stack.unstack( 

877 sort_ops.sort(array_ops_stack.stack(b)), num=len(b)) 

878 

879 i = 0 

880 result = [] 

881 

882 for item in items: 

883 result.append(a[i:item]) 

884 i = item + 1 

885 

886 result.append(a[i:]) 

887 

888 return array_ops.concat(result, 0) 

889 

890 minus_sources = _remove_indices(r, source) 

891 minus_dest = _remove_indices(r, destination) 

892 

893 perm = array_ops.scatter_nd( 

894 array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank]) 

895 perm = array_ops.tensor_scatter_update( 

896 perm, array_ops.expand_dims(destination, 1), source) 

897 a = array_ops.transpose(a, perm) 

898 

899 return a 

900 

901 

902@np_utils.np_doc('pad') 

903def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name 

904 """Only supports modes 'constant', 'reflect' and 'symmetric' currently.""" 

905 constant_values = kwargs.get('constant_values', 0) 

906 if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): 

907 raise ValueError('Unsupported padding mode: ' + mode) 

908 mode = mode.upper() 

909 array = asarray(array) 

910 pad_width = asarray(pad_width, dtype=dtypes.int32) 

911 return array_ops.pad( 

912 tensor=array, 

913 paddings=pad_width, 

914 mode=mode, 

915 constant_values=constant_values) 

916 

917 

918@np_utils.np_doc('take') 

919def take(a, indices, axis=None, out=None, mode='clip'): 

920 """out argument is not supported, and default mode is clip.""" 

921 if out is not None: 

922 raise ValueError('out argument is not supported in take.') 

923 

924 if mode not in {'raise', 'clip', 'wrap'}: 

925 raise ValueError("Invalid mode '{}' for take".format(mode)) 

926 

927 a = asarray(a) 

928 indices = asarray(indices) 

929 

930 if axis is None: 

931 a = array_ops.reshape(a, [-1]) 

932 axis = 0 

933 

934 axis_size = array_ops.shape(a, out_type=indices.dtype)[axis] 

935 if mode == 'clip': 

936 indices = clip_ops.clip_by_value(indices, 0, axis_size - 1) 

937 elif mode == 'wrap': 

938 indices = math_ops.floormod(indices, axis_size) 

939 else: 

940 raise ValueError("The 'raise' mode to take is not supported.") 

941 

942 return array_ops.gather(a, indices, axis=axis) 

943 

944 

945@np_utils.np_doc_only('where') 

946def where(condition, x=None, y=None): 

947 """Raises ValueError if exactly one of x or y is not None.""" 

948 condition = asarray(condition, dtype=np.bool_) 

949 if x is None and y is None: 

950 return nonzero(condition) 

951 elif x is not None and y is not None: 

952 x, y = _promote_dtype(x, y) 

953 return array_ops.where_v2(condition, x, y) 

954 raise ValueError('Both x and y must be ndarrays, or both must be None.') 

955 

956 

957@np_utils.np_doc('select') 

958def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring 

959 if len(condlist) != len(choicelist): 

960 msg = 'condlist must have length equal to choicelist ({} vs {})' 

961 raise ValueError(msg.format(len(condlist), len(choicelist))) 

962 if not condlist: 

963 raise ValueError('condlist must be non-empty') 

964 choices = _promote_dtype(default, *choicelist) 

965 choicelist = choices[1:] 

966 output = choices[0] 

967 # The traversal is in reverse order so we can return the first value in 

968 # choicelist where condlist is True. 

969 for cond, choice in zip(condlist[::-1], choicelist[::-1]): 

970 output = where(cond, choice, output) 

971 return output 

972 

973 

974@np_utils.np_doc('shape', link=np_utils.Link( 

975 'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html')) 

976def shape(a): 

977 a = asarray(a) 

978 return a.shape 

979 

980 

981@np_utils.np_doc('ndim', link=np_utils.NoLink()) 

982def ndim(a): 

983 a = asarray(a) 

984 return a.ndim 

985 

986 

987@np_utils.np_doc('isscalar') 

988def isscalar(num): 

989 return ndim(num) == 0 

990 

991 

992def _boundaries_to_sizes(a, boundaries, axis): 

993 """Converting boundaries of splits to sizes of splits. 

994 

995 Args: 

996 a: the array to be split. 

997 boundaries: the boundaries, as in np.split. 

998 axis: the axis along which to split. 

999 

1000 Returns: 

1001 A list of sizes of the splits, as in tf.split. 

1002 """ 

1003 if axis >= len(a.shape): 

1004 raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) 

1005 total_size = a.shape[axis] 

1006 sizes = [] 

1007 sizes_sum = 0 

1008 prev = 0 

1009 for i, b in enumerate(boundaries): 

1010 size = b - prev 

1011 if size < 0: 

1012 raise ValueError('The %s-th boundary %s is smaller than the previous ' 

1013 'boundary %s' % (i, b, prev)) 

1014 size = min(size, max(0, total_size - sizes_sum)) 

1015 sizes.append(size) 

1016 sizes_sum += size 

1017 prev = b 

1018 sizes.append(max(0, total_size - sizes_sum)) 

1019 return sizes 

1020 

1021 

1022@np_utils.np_doc('split') 

1023def split(ary, indices_or_sections, axis=0): 

1024 ary = asarray(ary) 

1025 if not isinstance(indices_or_sections, int): 

1026 indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) 

1027 return array_ops.split(ary, indices_or_sections, axis=axis) 

1028 

1029 

1030def _split_on_axis(np_fun_name, axis): # pylint: disable=missing-function-docstring 

1031 

1032 @np_utils.np_doc(np_fun_name) 

1033 def f(ary, indices_or_sections): 

1034 # for 1-D array, hsplit becomes vsplit 

1035 new_axis = np_utils.cond( 

1036 math_ops.equal(axis, 1), 

1037 lambda: np_utils.cond( # pylint: disable=g-long-lambda 

1038 math_ops.equal(array_ops.rank(ary), 1), lambda: 0, lambda: axis 

1039 ), 

1040 lambda: axis, 

1041 ) 

1042 if isinstance(indices_or_sections, int): 

1043 ary_shape = ary.shape[new_axis] 

1044 if ary_shape is not None and ary_shape % indices_or_sections: 

1045 raise ValueError( 

1046 'array split does not result in an equal division') 

1047 return split(ary, indices_or_sections, axis=new_axis) 

1048 

1049 return f 

1050 

1051 

1052vsplit = _split_on_axis('vsplit', axis=0) 

1053hsplit = _split_on_axis('hsplit', axis=1) 

1054dsplit = _split_on_axis('dsplit', axis=2) 

1055 

1056 

1057@np_utils.np_doc('broadcast_to') 

1058def broadcast_to(array, shape): # pylint: disable=redefined-outer-name 

1059 return full(shape, array) 

1060 

1061 

1062@np_utils.np_doc('stack') 

1063def stack(arrays, axis=0): # pylint: disable=missing-function-docstring 

1064 if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): 

1065 arrays = asarray(arrays) 

1066 if axis == 0: 

1067 return arrays 

1068 else: 

1069 return swapaxes(arrays, 0, axis) 

1070 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 

1071 unwrapped_arrays = [ 

1072 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 

1073 ] 

1074 return asarray(array_ops_stack.stack(unwrapped_arrays, axis)) 

1075 

1076 

1077@np_utils.np_doc('hstack') 

1078def hstack(tup): 

1079 arrays = [atleast_1d(a) for a in tup] 

1080 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 

1081 unwrapped_arrays = [ 

1082 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 

1083 ] 

1084 rank = array_ops.rank(unwrapped_arrays[0]) 

1085 return np_utils.cond( 

1086 math_ops.equal(rank, 

1087 1), lambda: array_ops.concat(unwrapped_arrays, axis=0), 

1088 lambda: array_ops.concat(unwrapped_arrays, axis=1)) 

1089 

1090 

1091@np_utils.np_doc('vstack') 

1092def vstack(tup): 

1093 arrays = [atleast_2d(a) for a in tup] 

1094 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 

1095 unwrapped_arrays = [ 

1096 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 

1097 ] 

1098 return array_ops.concat(unwrapped_arrays, axis=0) 

1099 

1100 

1101@np_utils.np_doc('dstack') 

1102def dstack(tup): 

1103 arrays = [atleast_3d(a) for a in tup] 

1104 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 

1105 unwrapped_arrays = [ 

1106 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 

1107 ] 

1108 return array_ops.concat(unwrapped_arrays, axis=2) 

1109 

1110 

1111def _pad_left_to(n, old_shape): 

1112 old_shape = asarray(old_shape, dtype=np.int32) 

1113 new_shape = array_ops.pad( 

1114 old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]], 

1115 constant_values=1) 

1116 return asarray(new_shape) 

1117 

1118 

1119def _atleast_nd(n, new_shape, *arys): 

1120 """Reshape arrays to be at least `n`-dimensional. 

1121 

1122 Args: 

1123 n: The minimal rank. 

1124 new_shape: a function that takes `n` and the old shape and returns the 

1125 desired new shape. 

1126 *arys: ndarray(s) to be reshaped. 

1127 

1128 Returns: 

1129 The reshaped array(s). 

1130 """ 

1131 

1132 def f(x): 

1133 # pylint: disable=g-long-lambda 

1134 x = asarray(x) 

1135 return asarray( 

1136 np_utils.cond( 

1137 np_utils.greater(n, array_ops.rank(x)), 

1138 lambda: reshape(x, new_shape(n, array_ops.shape(x))), 

1139 lambda: x)) 

1140 

1141 arys = list(map(f, arys)) 

1142 if len(arys) == 1: 

1143 return arys[0] 

1144 else: 

1145 return arys 

1146 

1147 

1148@np_utils.np_doc('atleast_1d') 

1149def atleast_1d(*arys): 

1150 return _atleast_nd(1, _pad_left_to, *arys) 

1151 

1152 

1153@np_utils.np_doc('atleast_2d') 

1154def atleast_2d(*arys): 

1155 return _atleast_nd(2, _pad_left_to, *arys) 

1156 

1157 

1158@np_utils.np_doc('atleast_3d') 

1159def atleast_3d(*arys): # pylint: disable=missing-docstring 

1160 

1161 def new_shape(_, old_shape): 

1162 # pylint: disable=g-long-lambda 

1163 ndim_ = array_ops.size(old_shape) 

1164 return np_utils.cond( 

1165 math_ops.equal(ndim_, 0), 

1166 lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32), 

1167 lambda: np_utils.cond( 

1168 math_ops.equal(ndim_, 1), lambda: array_ops.pad( 

1169 old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad( 

1170 old_shape, [[0, 1]], constant_values=1))) 

1171 

1172 return _atleast_nd(3, new_shape, *arys) 

1173 

1174 

1175@np_utils.np_doc('nonzero') 

1176def nonzero(a): 

1177 a = atleast_1d(a) 

1178 if a.shape.rank is None: 

1179 raise ValueError("The rank of `a` is unknown, so we can't decide how many " 

1180 'arrays to return.') 

1181 return array_ops_stack.unstack( 

1182 array_ops.where_v2(math_ops.cast(a, dtypes.bool)), 

1183 a.shape.rank, 

1184 axis=1) 

1185 

1186 

1187@np_utils.np_doc('diag_indices') 

1188def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name 

1189 if n < 0: 

1190 raise ValueError( 

1191 'n argument to diag_indices must be nonnegative, got {}'.format(n)) 

1192 if ndim < 0: 

1193 raise ValueError( 

1194 'ndim argument to diag_indices must be nonnegative, got {}'.format( 

1195 ndim)) 

1196 

1197 return (math_ops.range(n),) * ndim 

1198 

1199 

1200@np_utils.np_doc('tri') 

1201def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring 

1202 M = M if M is not None else N 

1203 if dtype is not None: 

1204 dtype = np_utils.result_type(dtype) 

1205 else: 

1206 dtype = np_dtypes.default_float_type() 

1207 

1208 if k < 0: 

1209 lower = -k - 1 

1210 if lower > N: 

1211 r = array_ops.zeros([N, M], dtype) 

1212 else: 

1213 # Keep as tf bool, since we create an upper triangular matrix and invert 

1214 # it. 

1215 o = array_ops.ones([N, M], dtype=dtypes.bool) 

1216 r = math_ops.cast( 

1217 math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype) 

1218 else: 

1219 o = array_ops.ones([N, M], dtype) 

1220 if k > M: 

1221 r = o 

1222 else: 

1223 r = array_ops.matrix_band_part(o, -1, k) 

1224 return r 

1225 

1226 

1227@np_utils.np_doc('tril') 

1228def tril(m, k=0): # pylint: disable=missing-docstring 

1229 m = asarray(m) 

1230 if m.shape.ndims is None: 

1231 raise ValueError('Argument to tril should have known rank') 

1232 m_shape = m.shape.as_list() 

1233 

1234 if len(m_shape) < 2: 

1235 raise ValueError('Argument to tril must have rank at least 2') 

1236 

1237 if m_shape[-1] is None or m_shape[-2] is None: 

1238 raise ValueError('Currently, the last two dimensions of the input array ' 

1239 'need to be known.') 

1240 

1241 z = constant_op.constant(0, m.dtype) 

1242 

1243 mask = tri(*m_shape[-2:], k=k, dtype=bool) 

1244 return array_ops.where_v2( 

1245 array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) 

1246 

1247 

1248@np_utils.np_doc('triu') 

1249def triu(m, k=0): # pylint: disable=missing-docstring 

1250 m = asarray(m) 

1251 if m.shape.ndims is None: 

1252 raise ValueError('Argument to triu should have known rank') 

1253 m_shape = m.shape.as_list() 

1254 

1255 if len(m_shape) < 2: 

1256 raise ValueError('Argument to triu must have rank at least 2') 

1257 

1258 if m_shape[-1] is None or m_shape[-2] is None: 

1259 raise ValueError('Currently, the last two dimensions of the input array ' 

1260 'need to be known.') 

1261 

1262 z = constant_op.constant(0, m.dtype) 

1263 

1264 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) 

1265 return array_ops.where_v2( 

1266 array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) 

1267 

1268 

1269@np_utils.np_doc('flip') 

1270def flip(m, axis=None): # pylint: disable=missing-docstring 

1271 m = asarray(m) 

1272 

1273 if axis is None: 

1274 return array_ops.reverse(m, math_ops.range(array_ops.rank(m))) 

1275 

1276 axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access 

1277 

1278 return array_ops.reverse(m, [axis]) 

1279 

1280 

1281@np_utils.np_doc('flipud') 

1282def flipud(m): # pylint: disable=missing-docstring 

1283 return flip(m, 0) 

1284 

1285 

1286@np_utils.np_doc('fliplr') 

1287def fliplr(m): # pylint: disable=missing-docstring 

1288 return flip(m, 1) 

1289 

1290 

1291@np_utils.np_doc('roll') 

1292def roll(a, shift, axis=None): # pylint: disable=missing-docstring 

1293 a = asarray(a) 

1294 

1295 if axis is not None: 

1296 return manip_ops.roll(a, shift, axis) 

1297 

1298 # If axis is None, the roll happens as a 1-d tensor. 

1299 original_shape = array_ops.shape(a) 

1300 a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0) 

1301 return array_ops.reshape(a, original_shape) 

1302 

1303 

1304@np_utils.np_doc('rot90') 

1305def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring 

1306 m_rank = array_ops.rank(m) 

1307 ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access 

1308 

1309 k = k % 4 

1310 if k == 0: 

1311 return m 

1312 elif k == 2: 

1313 return flip(flip(m, ax1), ax2) 

1314 else: 

1315 perm = math_ops.range(m_rank) 

1316 perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1]) 

1317 

1318 if k == 1: 

1319 return transpose(flip(m, ax2), perm) 

1320 else: 

1321 return flip(transpose(m, perm), ax2) 

1322 

1323 

1324@np_utils.np_doc('vander') 

1325def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name 

1326 x = asarray(x) 

1327 

1328 x_shape = array_ops.shape(x) 

1329 N = N or x_shape[0] 

1330 

1331 N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name 

1332 if N_temp is not None: 

1333 N = N_temp 

1334 if N < 0: 

1335 raise ValueError('N must be nonnegative') 

1336 else: 

1337 control_flow_assert.Assert(N >= 0, [N]) 

1338 

1339 rank = array_ops.rank(x) 

1340 rank_temp = np_utils.get_static_value(rank) 

1341 if rank_temp is not None: 

1342 rank = rank_temp 

1343 if rank != 1: 

1344 raise ValueError('x must be a one-dimensional array') 

1345 else: 

1346 control_flow_assert.Assert(math_ops.equal(rank, 1), [rank]) 

1347 

1348 if increasing: 

1349 start = 0 

1350 limit = N 

1351 delta = 1 

1352 else: 

1353 start = N - 1 

1354 limit = -1 

1355 delta = -1 

1356 

1357 x = array_ops.expand_dims(x, -1) 

1358 return math_ops.pow( 

1359 x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)) 

1360 

1361 

1362@np_utils.np_doc('ix_') 

1363def ix_(*args): # pylint: disable=missing-docstring 

1364 n = len(args) 

1365 output = [] 

1366 for i, a in enumerate(args): 

1367 a = asarray(a) 

1368 a_rank = array_ops.rank(a) 

1369 a_rank_temp = np_utils.get_static_value(a_rank) 

1370 if a_rank_temp is not None: 

1371 a_rank = a_rank_temp 

1372 if a_rank != 1: 

1373 raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format( 

1374 i, a_rank)) 

1375 else: 

1376 control_flow_assert.Assert(math_ops.equal(a_rank, 1), [a_rank]) 

1377 

1378 new_shape = [1] * n 

1379 new_shape[i] = -1 

1380 dtype = a.dtype 

1381 if dtype == dtypes.bool: 

1382 output.append(array_ops.reshape(nonzero(a)[0], new_shape)) 

1383 elif dtype.is_integer: 

1384 output.append(array_ops.reshape(a, new_shape)) 

1385 else: 

1386 raise ValueError( 

1387 'Only integer and bool dtypes are supported, got {}'.format(dtype)) 

1388 

1389 return output 

1390 

1391 

1392@np_utils.np_doc('broadcast_arrays') 

1393def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring 

1394 subok = kwargs.pop('subok', False) 

1395 if subok: 

1396 raise ValueError('subok=True is not supported.') 

1397 if kwargs: 

1398 raise ValueError('Received unsupported arguments {}'.format(kwargs.keys())) 

1399 

1400 args = [asarray(arg) for arg in args] 

1401 return np_utils.tf_broadcast(*args) 

1402 

1403 

1404@np_utils.np_doc_only('sign') 

1405def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name 

1406 if out: 

1407 raise ValueError('tf.numpy doesnt support setting out.') 

1408 if where: 

1409 raise ValueError('tf.numpy doesnt support setting where.') 

1410 if kwargs: 

1411 raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys())) 

1412 

1413 x = asarray(x) 

1414 dtype = x.dtype.as_numpy_dtype 

1415 if np.issubdtype(dtype, np.complexfloating): 

1416 result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype) 

1417 else: 

1418 result = math_ops.sign(x) 

1419 

1420 return result 

1421 

1422 

1423# Note that np.take_along_axis may not be present in some supported versions of 

1424# numpy. 

1425@np_utils.np_doc('take_along_axis') 

1426def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring 

1427 arr = asarray(arr) 

1428 indices = asarray(indices) 

1429 

1430 if axis is None: 

1431 return take_along_axis(arr.ravel(), indices, 0) 

1432 

1433 rank = array_ops.rank(arr) 

1434 axis = axis + rank if axis < 0 else axis 

1435 

1436 # Broadcast shapes to match, ensure that the axis of interest is not 

1437 # broadcast. 

1438 arr_shape_original = array_ops.shape(arr) 

1439 indices_shape_original = array_ops.shape(indices) 

1440 arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1]) 

1441 indices_shape = array_ops.tensor_scatter_update(indices_shape_original, 

1442 [[axis]], [1]) 

1443 broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape, 

1444 indices_shape) 

1445 arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]], 

1446 [arr_shape_original[axis]]) 

1447 indices_shape = array_ops.tensor_scatter_update( 

1448 broadcasted_shape, [[axis]], [indices_shape_original[axis]]) 

1449 arr = array_ops.broadcast_to(arr, arr_shape) 

1450 indices = array_ops.broadcast_to(indices, indices_shape) 

1451 

1452 # Save indices shape so we can restore it later. 

1453 possible_result_shape = indices.shape 

1454 

1455 # Correct indices since gather doesn't correctly handle negative indices. 

1456 indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices) 

1457 

1458 swapaxes_ = lambda t: swapaxes(t, axis, -1) 

1459 

1460 dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) 

1461 arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, 

1462 lambda: swapaxes_(arr)) 

1463 indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, 

1464 lambda: swapaxes_(indices)) 

1465 

1466 arr_shape = array_ops.shape(arr) 

1467 arr = array_ops.reshape(arr, [-1, arr_shape[-1]]) 

1468 

1469 indices_shape = array_ops.shape(indices) 

1470 indices = array_ops.reshape(indices, [-1, indices_shape[-1]]) 

1471 

1472 result = array_ops.gather(arr, indices, batch_dims=1) 

1473 result = array_ops.reshape(result, indices_shape) 

1474 result = np_utils.cond(dont_move_axis_to_end, lambda: result, 

1475 lambda: swapaxes_(result)) 

1476 result.set_shape(possible_result_shape) 

1477 

1478 return result 

1479 

1480 

1481_SLICE_ERORR = ( 

1482 'only integers, slices (`:`), ellipsis (`...`), ' 

1483 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices') 

1484 

1485 

1486def _as_index(idx, need_scalar=True): 

1487 """Helper function to parse idx as an index. 

1488 

1489 Args: 

1490 idx: index 

1491 need_scalar: If idx needs to be a scalar value. 

1492 

1493 Returns: 

1494 A pair, (indx, bool). First one is the parsed index and can be a tensor, 

1495 or scalar integer / Dimension. Second one is True if rank is known to be 0. 

1496 

1497 Raises: 

1498 IndexError: For incorrect indices. 

1499 """ 

1500 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): 

1501 return idx, True 

1502 data = asarray(idx) 

1503 if data.dtype == dtypes.bool: 

1504 if data.shape.ndims != 1: 

1505 # TODO(agarwal): handle higher rank boolean masks. 

1506 raise NotImplementedError('Need rank 1 for bool index %s' % idx) 

1507 data = array_ops.where_v2(data) 

1508 data = array_ops.reshape(data, [-1]) 

1509 if need_scalar and data.shape.rank not in (None, 0): 

1510 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 

1511 np_dtype = data.dtype.as_numpy_dtype 

1512 if not np.issubdtype(np_dtype, np.integer): 

1513 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 

1514 if data.dtype not in (dtypes.int64, dtypes.int32): 

1515 # TF slicing can only handle int32/int64. So we need to cast. 

1516 promoted_dtype = np.promote_types(np.int32, np_dtype) 

1517 if promoted_dtype == np.int32: 

1518 data = math_ops.cast(data, dtypes.int32) 

1519 elif promoted_dtype == np.int64: 

1520 data = math_ops.cast(data, dtypes.int64) 

1521 else: 

1522 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 

1523 return data, data.shape.rank == 0 

1524 

1525 

1526class _UpdateMethod(enum.Enum): 

1527 UPDATE = 0 

1528 ADD = 1 

1529 MIN = 2 

1530 MAX = 3 

1531 

1532 

1533def _slice_helper(tensor, slice_spec, update_method=None, updates=None): 

1534 """Helper function for __getitem__ and _with_index_update_helper. 

1535 

1536 This function collects the indices in `slice_spec` into two buckets, which we 

1537 can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2 

1538 `gather`. They also correspond to "basic indices" and "advanced indices" in 

1539 numpy. This function supports both reading and writing at the indices. The 

1540 reading path can be summarized as `gather(stride_slice(tensor, idx1), 

1541 idx2)`. The writing path can be summarized as `strided_slice_update(tensor, 

1542 idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`. (`gather` here 

1543 means `tf.gather` or `tf.gather_nd`; `scatter` here means 

1544 `tf.tensor_scatter_update`.) The writing path is inefficient because it needs 

1545 to first read out a portion (probably much larger than `updates`) of `tensor` 

1546 using `strided_slice`, update it, and then write the portion back. An 

1547 alternative approach is to only use `scatter`, which amounts to using the 

1548 indexing mechanism of gather/scatter to implement 

1549 strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter 

1550 because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but 

1551 not TF gather/scatter because they don't support spans (except those that 

1552 cover entire dimensions, i.e. `:`). If we materialize spans into individual 

1553 indices, the size of the index tensor would explode. (Note that XLA 

1554 Gather/Scatter have a similar problem for stride > 1 because they don't 

1555 support strides. Indices such as `1:2:8` will need to be materialized into 

1556 individual indices such as [1, 3, 5, 7].) 

1557 

1558 Args: 

1559 tensor: the tensor to be read from or write into. 

1560 slice_spec: the indices. 

1561 update_method: (optional) a member of `_UpdateMethod`, indicating how to 

1562 update the values (replacement, add, etc.). `None` indicates just reading. 

1563 updates: (optional) the new values to write into `tensor`. It must have the 

1564 same dtype as `tensor`. 

1565 

1566 Returns: 

1567 The result of reading (if `update_method` is `None`) or the updated `tensor` 

1568 after writing. 

1569 """ 

1570 begin, end, strides = [], [], [] 

1571 new_axis_mask, shrink_axis_mask = 0, 0 

1572 begin_mask, end_mask = 0, 0 

1573 ellipsis_mask = 0 

1574 advanced_indices = [] 

1575 shrink_indices = [] 

1576 for index, s in enumerate(slice_spec): 

1577 if isinstance(s, slice): 

1578 if s.start is not None: 

1579 begin.append(_as_index(s.start)[0]) 

1580 else: 

1581 begin.append(0) 

1582 begin_mask |= (1 << index) 

1583 if s.stop is not None: 

1584 end.append(_as_index(s.stop)[0]) 

1585 else: 

1586 end.append(0) 

1587 end_mask |= (1 << index) 

1588 if s.step is not None: 

1589 strides.append(_as_index(s.step)[0]) 

1590 else: 

1591 strides.append(1) 

1592 elif s is Ellipsis: 

1593 begin.append(0) 

1594 end.append(0) 

1595 strides.append(1) 

1596 ellipsis_mask |= (1 << index) 

1597 elif s is array_ops.newaxis: 

1598 begin.append(0) 

1599 end.append(0) 

1600 strides.append(1) 

1601 new_axis_mask |= (1 << index) 

1602 else: 

1603 s, is_scalar = _as_index(s, False) 

1604 if is_scalar: 

1605 begin.append(s) 

1606 end.append(s + 1) 

1607 strides.append(1) 

1608 shrink_axis_mask |= (1 << index) 

1609 shrink_indices.append(index) 

1610 else: 

1611 begin.append(0) 

1612 end.append(0) 

1613 strides.append(1) 

1614 begin_mask |= (1 << index) 

1615 end_mask |= (1 << index) 

1616 advanced_indices.append((index, s, ellipsis_mask != 0)) 

1617 

1618 # stack possibly involves no tensors, so we must use op_scope correct graph. 

1619 with ops.name_scope( 

1620 None, 

1621 'strided_slice', [tensor] + begin + end + strides, 

1622 skip_on_eager=False) as name: 

1623 if begin: 

1624 packed_begin, packed_end, packed_strides = ( 

1625 array_ops_stack.stack(begin), 

1626 array_ops_stack.stack(end), 

1627 array_ops_stack.stack(strides)) 

1628 if (packed_begin.dtype == dtypes.int64 or 

1629 packed_end.dtype == dtypes.int64 or 

1630 packed_strides.dtype == dtypes.int64): 

1631 if packed_begin.dtype != dtypes.int64: 

1632 packed_begin = math_ops.cast(packed_begin, dtypes.int64) 

1633 if packed_end.dtype != dtypes.int64: 

1634 packed_end = math_ops.cast(packed_end, dtypes.int64) 

1635 if packed_strides.dtype != dtypes.int64: 

1636 packed_strides = math_ops.cast(packed_strides, dtypes.int64) 

1637 else: 

1638 var_empty = constant_op.constant([], dtype=dtypes.int32) 

1639 packed_begin = packed_end = packed_strides = var_empty 

1640 if update_method == _UpdateMethod.UPDATE and not advanced_indices: 

1641 return array_ops.tensor_strided_slice_update( 

1642 tensor, 

1643 packed_begin, 

1644 packed_end, 

1645 packed_strides, 

1646 updates, 

1647 begin_mask=begin_mask, 

1648 end_mask=end_mask, 

1649 shrink_axis_mask=shrink_axis_mask, 

1650 new_axis_mask=new_axis_mask, 

1651 ellipsis_mask=ellipsis_mask, 

1652 name=name) 

1653 else: 

1654 # TODO(b/164251540): Find a better way to support update that does not 

1655 # involve one read + two writes. 

1656 if updates is not None: 

1657 original_tensor = tensor 

1658 # TODO(agarwal): set_shape on tensor to set rank. 

1659 tensor = array_ops.strided_slice( 

1660 tensor, 

1661 packed_begin, 

1662 packed_end, 

1663 packed_strides, 

1664 begin_mask=begin_mask, 

1665 end_mask=end_mask, 

1666 shrink_axis_mask=shrink_axis_mask, 

1667 new_axis_mask=new_axis_mask, 

1668 ellipsis_mask=ellipsis_mask, 

1669 name=name) 

1670 if not advanced_indices: 

1671 if update_method is None: 

1672 return tensor 

1673 assert update_method != _UpdateMethod.UPDATE 

1674 # TF lacks TensorStridedSliceAdd and alike, so we need to do 

1675 # read+add+update. 

1676 if update_method == _UpdateMethod.ADD: 

1677 update_op = math_ops.add 

1678 elif update_method == _UpdateMethod.MIN: 

1679 update_op = math_ops.minimum 

1680 elif update_method == _UpdateMethod.MAX: 

1681 update_op = math_ops.maximum 

1682 return array_ops.tensor_strided_slice_update( 

1683 original_tensor, 

1684 packed_begin, 

1685 packed_end, 

1686 packed_strides, 

1687 update_op(tensor, updates), 

1688 begin_mask=begin_mask, 

1689 end_mask=end_mask, 

1690 shrink_axis_mask=shrink_axis_mask, 

1691 new_axis_mask=new_axis_mask, 

1692 ellipsis_mask=ellipsis_mask, 

1693 name=name + '_2') 

1694 advanced_indices_map = {} 

1695 for index, data, had_ellipsis in advanced_indices: 

1696 if had_ellipsis: 

1697 num_shrink = len([x for x in shrink_indices if x > index]) 

1698 dim = index - len(slice_spec) + num_shrink 

1699 else: 

1700 num_shrink = len([x for x in shrink_indices if x < index]) 

1701 dim = index - num_shrink 

1702 advanced_indices_map[dim] = data 

1703 dims = sorted(advanced_indices_map.keys()) 

1704 dims_contiguous = True 

1705 if len(dims) > 1: 

1706 if dims[0] < 0 and dims[-1] >= 0: # not all same sign 

1707 dims_contiguous = False 

1708 else: 

1709 for i in range(len(dims) - 1): 

1710 if dims[i] + 1 != dims[i + 1]: 

1711 dims_contiguous = False 

1712 break 

1713 indices = [advanced_indices_map[x] for x in dims] 

1714 indices = _promote_dtype(*indices) 

1715 indices = np_utils.tf_broadcast(*indices) 

1716 stacked_indices = array_ops_stack.stack(indices, axis=-1) 

1717 # Skip the contiguous-dims optimization for update because there is no 

1718 # tf.*scatter* op that supports the `axis` argument. 

1719 if not dims_contiguous or updates is not None: 

1720 if range(len(dims)) != dims: 

1721 tensor = moveaxis(tensor, dims, range(len(dims))) 

1722 tensor_shape_prefix = array_ops.shape( 

1723 tensor, out_type=stacked_indices.dtype)[:len(dims)] 

1724 stacked_indices = array_ops.where_v2( 

1725 stacked_indices < 0, stacked_indices + tensor_shape_prefix, 

1726 stacked_indices) 

1727 if updates is None: 

1728 return array_ops.gather_nd(tensor, stacked_indices) 

1729 else: 

1730 # We only need to move-axis `updates` in the contiguous case becausce 

1731 # only in this case the result dimensions of advanced indexing are in 

1732 # the middle of `updates`. In the non-contiguous case, those dimensions 

1733 # are always at the front. 

1734 if dims_contiguous: 

1735 # TODO(wangpeng): Support unknown rank (e.g. by partially flattening 

1736 # `updates`) 

1737 if stacked_indices.shape.rank is None: 

1738 raise NotImplementedError( 

1739 'Rank of the advanced indices must currently be known') 

1740 batch_size = stacked_indices.shape.rank - 1 

1741 batch_start = dims[0] 

1742 if batch_start < 0: 

1743 batch_start += len(dims) - batch_size 

1744 def range_(start, length): 

1745 return range(start, start + length) 

1746 updates = moveaxis(updates, range_(batch_start, batch_size), 

1747 range(batch_size)) 

1748 if update_method == _UpdateMethod.UPDATE: 

1749 update_op = array_ops.tensor_scatter_update 

1750 elif update_method == _UpdateMethod.ADD: 

1751 update_op = array_ops.tensor_scatter_add 

1752 elif update_method == _UpdateMethod.MIN: 

1753 update_op = array_ops.tensor_scatter_min 

1754 elif update_method == _UpdateMethod.MAX: 

1755 update_op = array_ops.tensor_scatter_max 

1756 tensor = update_op( 

1757 tensor, stacked_indices, updates) 

1758 if range(len(dims)) != dims: 

1759 tensor = moveaxis(tensor, range(len(dims)), dims) 

1760 return array_ops.tensor_strided_slice_update( 

1761 original_tensor, 

1762 packed_begin, 

1763 packed_end, 

1764 packed_strides, 

1765 tensor, 

1766 begin_mask=begin_mask, 

1767 end_mask=end_mask, 

1768 shrink_axis_mask=shrink_axis_mask, 

1769 new_axis_mask=new_axis_mask, 

1770 ellipsis_mask=ellipsis_mask, 

1771 name=name + '_2') 

1772 # Note that gather_nd does not support gathering from inside the array. 

1773 # To avoid shuffling data back and forth, we transform the indices and 

1774 # do a gather instead. 

1775 rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access 

1776 dims = [(x + rank if x < 0 else x) for x in dims] 

1777 shape_tensor = array_ops.shape(tensor) 

1778 dim_sizes = array_ops.gather(shape_tensor, dims) 

1779 if len(dims) == 1: 

1780 stacked_indices = indices[0] 

1781 stacked_indices = math_ops.cast(stacked_indices, dtypes.int32) 

1782 stacked_indices = array_ops.where_v2(stacked_indices < 0, 

1783 stacked_indices + dim_sizes, 

1784 stacked_indices) 

1785 axis = dims[0] 

1786 if len(dims) > 1: 

1787 index_scaling = math_ops.cumprod( 

1788 dim_sizes, reverse=True, exclusive=True) 

1789 def _tensordot(a, b): 

1790 # TODO(b/168657656): This function should be replaced by 

1791 # tensordot(axis=1) once MatMul has int32 XLA kernel. 

1792 b = array_ops.broadcast_to(b, array_ops.shape(a)) 

1793 return math_ops.reduce_sum(a * b, axis=-1) 

1794 stacked_indices = _tensordot(stacked_indices, index_scaling) 

1795 flat_shape = array_ops.concat( 

1796 [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], 

1797 axis=0) 

1798 tensor = array_ops.reshape(tensor, flat_shape) 

1799 

1800 return array_ops.gather(tensor, stacked_indices, axis=axis) 

1801 

1802 

1803def _as_spec_tuple(slice_spec): 

1804 """Convert slice_spec to tuple.""" 

1805 if isinstance(slice_spec, 

1806 (list, tuple)) and not isinstance(slice_spec, np.ndarray): 

1807 is_index = True 

1808 for s in slice_spec: 

1809 if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): 

1810 is_index = False 

1811 break 

1812 elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0: 

1813 is_index = False 

1814 break 

1815 if not is_index: 

1816 return tuple(slice_spec) 

1817 return (slice_spec,) 

1818 

1819 

1820def _getitem(self, slice_spec): 

1821 """Implementation of ndarray.__getitem__.""" 

1822 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 

1823 slice_spec.dtype == dtypes.bool) or 

1824 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 

1825 slice_spec.dtype == np.bool_)): 

1826 return array_ops.boolean_mask(tensor=self, mask=slice_spec) 

1827 

1828 if not isinstance(slice_spec, tuple): 

1829 slice_spec = _as_spec_tuple(slice_spec) 

1830 

1831 result_t = _slice_helper(self, slice_spec) 

1832 return result_t 

1833 

1834 

1835def _with_index_update_helper(update_method, a, slice_spec, updates): 

1836 """Implementation of ndarray._with_index_*.""" 

1837 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 

1838 slice_spec.dtype == dtypes.bool) or 

1839 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 

1840 slice_spec.dtype == np.bool_)): 

1841 slice_spec = nonzero(slice_spec) 

1842 

1843 if not isinstance(slice_spec, tuple): 

1844 slice_spec = _as_spec_tuple(slice_spec) 

1845 

1846 a_dtype = a.dtype 

1847 a, updates = _promote_dtype_binary(a, updates) 

1848 result_t = _slice_helper(a, slice_spec, update_method, updates) 

1849 return result_t.astype(a_dtype) 

1850 

1851 

1852setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem) 

1853setattr(np_arrays.ndarray, '_with_index_update', 

1854 functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE)) 

1855setattr(np_arrays.ndarray, '_with_index_add', 

1856 functools.partial(_with_index_update_helper, _UpdateMethod.ADD)) 

1857setattr(np_arrays.ndarray, '_with_index_min', 

1858 functools.partial(_with_index_update_helper, _UpdateMethod.MIN)) 

1859setattr(np_arrays.ndarray, '_with_index_max', 

1860 functools.partial(_with_index_update_helper, _UpdateMethod.MAX))