Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_math_ops.py: 33%

830 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Mathematical operations.""" 

16# pylint: disable=g-direct-tensorflow-import 

17 

18import numbers 

19import sys 

20 

21import numpy as np 

22 

23from tensorflow.python.framework import constant_op 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import errors 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import array_ops_stack 

29from tensorflow.python.ops import bitwise_ops 

30from tensorflow.python.ops import clip_ops 

31from tensorflow.python.ops import control_flow_assert 

32from tensorflow.python.ops import gen_math_ops 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.ops import nn_ops 

35from tensorflow.python.ops import sort_ops 

36from tensorflow.python.ops import special_math_ops 

37from tensorflow.python.ops import while_loop 

38from tensorflow.python.ops.numpy_ops import np_array_ops 

39from tensorflow.python.ops.numpy_ops import np_arrays 

40from tensorflow.python.ops.numpy_ops import np_dtypes 

41from tensorflow.python.ops.numpy_ops import np_export 

42from tensorflow.python.ops.numpy_ops import np_utils 

43 

44 

45pi = np_export.np_export_constant(__name__, 'pi', np.pi) 

46e = np_export.np_export_constant(__name__, 'e', np.e) 

47inf = np_export.np_export_constant(__name__, 'inf', np.inf) 

48 

49 

50@np_utils.np_doc_only('dot') 

51def dot(a, b): # pylint: disable=missing-docstring 

52 

53 def f(a, b): # pylint: disable=missing-docstring 

54 return np_utils.cond( 

55 np_utils.logical_or( 

56 math_ops.equal(array_ops.rank(a), 0), 

57 math_ops.equal(array_ops.rank(b), 0)), 

58 lambda: a * b, 

59 lambda: np_utils.cond( # pylint: disable=g-long-lambda 

60 math_ops.equal(array_ops.rank(b), 1), 

61 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]), 

62 lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]]))) 

63 

64 return _bin_op(f, a, b) 

65 

66 

67# TODO(wangpeng): Make element-wise ops `ufunc`s 

68def _bin_op(tf_fun, a, b, promote=True): 

69 if promote: 

70 a, b = np_array_ops._promote_dtype_binary(a, b) # pylint: disable=protected-access 

71 else: 

72 a = np_array_ops.array(a) 

73 b = np_array_ops.array(b) 

74 return tf_fun(a, b) 

75 

76 

77@np_utils.np_doc('add') 

78def add(x1, x2): 

79 

80 def add_or_or(x1, x2): 

81 if x1.dtype == dtypes.bool: 

82 assert x2.dtype == dtypes.bool 

83 return math_ops.logical_or(x1, x2) 

84 return math_ops.add(x1, x2) 

85 

86 return _bin_op(add_or_or, x1, x2) 

87 

88 

89@np_utils.np_doc('subtract') 

90def subtract(x1, x2): 

91 return _bin_op(math_ops.subtract, x1, x2) 

92 

93 

94@np_utils.np_doc('multiply') 

95def multiply(x1, x2): 

96 

97 def mul_or_and(x1, x2): 

98 if x1.dtype == dtypes.bool: 

99 assert x2.dtype == dtypes.bool 

100 return math_ops.logical_and(x1, x2) 

101 return math_ops.multiply(x1, x2) 

102 

103 return _bin_op(mul_or_and, x1, x2) 

104 

105 

106@np_utils.np_doc('true_divide') 

107def true_divide(x1, x2): # pylint: disable=missing-function-docstring 

108 

109 def _avoid_float64(x1, x2): 

110 if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64): 

111 x1 = math_ops.cast(x1, dtype=dtypes.float32) 

112 x2 = math_ops.cast(x2, dtype=dtypes.float32) 

113 return x1, x2 

114 

115 def f(x1, x2): 

116 if x1.dtype == dtypes.bool: 

117 assert x2.dtype == dtypes.bool 

118 float_ = np_dtypes.default_float_type() 

119 x1 = math_ops.cast(x1, float_) 

120 x2 = math_ops.cast(x2, float_) 

121 if not np_dtypes.is_allow_float64(): 

122 # math_ops.truediv in Python3 produces float64 when both inputs are int32 

123 # or int64. We want to avoid that when is_allow_float64() is False. 

124 x1, x2 = _avoid_float64(x1, x2) 

125 return math_ops.truediv(x1, x2) 

126 

127 return _bin_op(f, x1, x2) 

128 

129 

130@np_utils.np_doc('divide') 

131def divide(x1, x2): # pylint: disable=missing-function-docstring 

132 return true_divide(x1, x2) 

133 

134 

135@np_utils.np_doc('floor_divide') 

136def floor_divide(x1, x2): # pylint: disable=missing-function-docstring 

137 

138 def f(x1, x2): 

139 if x1.dtype == dtypes.bool: 

140 assert x2.dtype == dtypes.bool 

141 x1 = math_ops.cast(x1, dtypes.int8) 

142 x2 = math_ops.cast(x2, dtypes.int8) 

143 return math_ops.floordiv(x1, x2) 

144 

145 return _bin_op(f, x1, x2) 

146 

147 

148@np_utils.np_doc('mod') 

149def mod(x1, x2): # pylint: disable=missing-function-docstring 

150 

151 def f(x1, x2): 

152 if x1.dtype == dtypes.bool: 

153 assert x2.dtype == dtypes.bool 

154 x1 = math_ops.cast(x1, dtypes.int8) 

155 x2 = math_ops.cast(x2, dtypes.int8) 

156 return math_ops.mod(x1, x2) 

157 

158 return _bin_op(f, x1, x2) 

159 

160 

161@np_utils.np_doc('remainder') 

162def remainder(x1, x2): # pylint: disable=missing-function-docstring 

163 return mod(x1, x2) 

164 

165 

166@np_utils.np_doc('divmod') 

167def divmod(x1, x2): # pylint: disable=redefined-builtin 

168 return floor_divide(x1, x2), mod(x1, x2) 

169 

170 

171@np_utils.np_doc('maximum') 

172def maximum(x1, x2): # pylint: disable=missing-function-docstring 

173 

174 # Fast path for when maximum is used as relu. 

175 if isinstance( 

176 x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance( 

177 x1, np_arrays.ndarray) and x1.dtype != dtypes.bool: 

178 return nn_ops.relu(np_array_ops.asarray(x1)) 

179 

180 def max_or_or(x1, x2): 

181 if x1.dtype == dtypes.bool: 

182 assert x2.dtype == dtypes.bool 

183 return math_ops.logical_or(x1, x2) 

184 return math_ops.maximum(x1, x2) 

185 

186 return _bin_op(max_or_or, x1, x2) 

187 

188 

189@np_utils.np_doc('minimum') 

190def minimum(x1, x2): 

191 

192 def min_or_and(x1, x2): 

193 if x1.dtype == dtypes.bool: 

194 assert x2.dtype == dtypes.bool 

195 return math_ops.logical_and(x1, x2) 

196 return math_ops.minimum(x1, x2) 

197 

198 return _bin_op(min_or_and, x1, x2) 

199 

200 

201@np_utils.np_doc('clip') 

202def clip(a, a_min, a_max): # pylint: disable=missing-docstring 

203 if a_min is None and a_max is None: 

204 raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.') 

205 if a_min is None: 

206 return minimum(a, a_max) 

207 elif a_max is None: 

208 return maximum(a, a_min) 

209 else: 

210 a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access 

211 return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max)) 

212 

213 

214@np_utils.np_doc('matmul') 

215def matmul(x1, x2): # pylint: disable=missing-docstring 

216 def f(x1, x2): 

217 try: 

218 if x1._rank() == 2 and x2._rank() == 2: # pylint: disable=protected-access 

219 # Fast path for known ranks. 

220 return gen_math_ops.mat_mul(x1, x2) 

221 return np_utils.cond( 

222 math_ops.equal(np_utils.tf_rank(x2), 1), 

223 lambda: math_ops.tensordot(x1, x2, axes=1), 

224 lambda: np_utils.cond( # pylint: disable=g-long-lambda 

225 math_ops.equal(np_utils.tf_rank(x1), 1), 

226 lambda: math_ops.tensordot( # pylint: disable=g-long-lambda 

227 x1, x2, axes=[[0], [-2]]), 

228 lambda: math_ops.matmul(x1, x2))) 

229 except errors.InvalidArgumentError as err: 

230 raise ValueError(str(err)).with_traceback(sys.exc_info()[2]) 

231 

232 return _bin_op(f, x1, x2) 

233 

234 

235# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles 

236# batched matmul as well, so simply including promotion in TF's current 

237# __matmul__ implementation was not sufficient. 

238setattr(np_arrays.ndarray, '_matmul', matmul) 

239 

240 

241@np_utils.np_doc('tensordot') 

242def tensordot(a, b, axes=2): 

243 return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b) 

244 

245 

246@np_utils.np_doc_only('inner') 

247def inner(a, b): # pylint: disable=missing-function-docstring 

248 

249 def f(a, b): 

250 return np_utils.cond( 

251 np_utils.logical_or( 

252 math_ops.equal(array_ops.rank(a), 0), 

253 math_ops.equal(array_ops.rank(b), 0)), lambda: a * b, 

254 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]])) 

255 

256 return _bin_op(f, a, b) 

257 

258 

259@np_utils.np_doc('cross') 

260def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring 

261 

262 def f(a, b): # pylint: disable=missing-docstring 

263 # We can't assign to captured variable `axisa`, so make a new variable 

264 if axis is None: 

265 axis_a = axisa 

266 axis_b = axisb 

267 axis_c = axisc 

268 else: 

269 axis_a = axis 

270 axis_b = axis 

271 axis_c = axis 

272 if axis_a < 0: 

273 axis_a = np_utils.add(axis_a, array_ops.rank(a)) 

274 if axis_b < 0: 

275 axis_b = np_utils.add(axis_b, array_ops.rank(b)) 

276 

277 def maybe_move_axis_to_last(a, axis): 

278 

279 def move_axis_to_last(a, axis): 

280 return array_ops.transpose( 

281 a, 

282 array_ops.concat([ 

283 math_ops.range(axis), 

284 math_ops.range(axis + 1, array_ops.rank(a)), [axis] 

285 ], 

286 axis=0)) 

287 

288 return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1), 

289 lambda: a, lambda: move_axis_to_last(a, axis)) 

290 

291 a = maybe_move_axis_to_last(a, axis_a) 

292 b = maybe_move_axis_to_last(b, axis_b) 

293 a_dim = np_utils.getitem(array_ops.shape(a), -1) 

294 b_dim = np_utils.getitem(array_ops.shape(b), -1) 

295 

296 def maybe_pad_0(a, size_of_last_dim): 

297 

298 def pad_0(a): 

299 return array_ops.pad( 

300 a, 

301 array_ops.concat([ 

302 array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32), 

303 constant_op.constant([[0, 1]], dtypes.int32) 

304 ], 

305 axis=0)) 

306 

307 return np_utils.cond( 

308 math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a) 

309 

310 a = maybe_pad_0(a, a_dim) 

311 b = maybe_pad_0(b, b_dim) 

312 c = math_ops.cross(*np_utils.tf_broadcast(a, b)) 

313 if axis_c < 0: 

314 axis_c = np_utils.add(axis_c, array_ops.rank(c)) 

315 

316 def move_last_to_axis(a, axis): 

317 r = array_ops.rank(a) 

318 return array_ops.transpose( 

319 a, 

320 array_ops.concat( 

321 [math_ops.range(axis), [r - 1], 

322 math_ops.range(axis, r - 1)], 

323 axis=0)) 

324 

325 c = np_utils.cond( 

326 (a_dim == 2) & (b_dim == 2), 

327 lambda: c[..., 2], 

328 lambda: np_utils.cond( # pylint: disable=g-long-lambda 

329 axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c, 

330 lambda: move_last_to_axis(c, axis_c))) 

331 return c 

332 

333 return _bin_op(f, a, b) 

334 

335 

336@np_utils.np_doc_only('vdot') 

337def vdot(a, b): # pylint: disable=missing-docstring 

338 a, b = np_array_ops._promote_dtype(a, b) # pylint: disable=protected-access 

339 a = np_array_ops.reshape(a, [-1]) 

340 b = np_array_ops.reshape(b, [-1]) 

341 if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64: 

342 a = conj(a) 

343 return dot(a, b) 

344 

345 

346@np_utils.np_doc('power') 

347def power(x1, x2): 

348 return _bin_op(math_ops.pow, x1, x2) 

349 

350 

351@np_utils.np_doc('float_power') 

352def float_power(x1, x2): 

353 return power(x1, x2) 

354 

355 

356@np_utils.np_doc('arctan2') 

357def arctan2(x1, x2): 

358 return _bin_op(math_ops.atan2, x1, x2) 

359 

360 

361@np_utils.np_doc('nextafter') 

362def nextafter(x1, x2): 

363 return _bin_op(math_ops.nextafter, x1, x2) 

364 

365 

366@np_utils.np_doc('heaviside') 

367def heaviside(x1, x2): # pylint: disable=missing-function-docstring 

368 

369 def f(x1, x2): 

370 return array_ops.where_v2( 

371 x1 < 0, constant_op.constant(0, dtype=x2.dtype), 

372 array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2)) 

373 

374 y = _bin_op(f, x1, x2) 

375 if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact): 

376 y = y.astype(np_dtypes.default_float_type()) 

377 return y 

378 

379 

380@np_utils.np_doc('hypot') 

381def hypot(x1, x2): 

382 return sqrt(square(x1) + square(x2)) 

383 

384 

385@np_utils.np_doc('kron') 

386def kron(a, b): # pylint: disable=missing-function-docstring 

387 # pylint: disable=protected-access,g-complex-comprehension 

388 a, b = np_array_ops._promote_dtype(a, b) 

389 t_a = np_utils.cond( 

390 a.shape.rank < b.shape.rank, 

391 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 

392 a, np_array_ops._pad_left_to(b.shape.rank, a.shape)), 

393 lambda: a) 

394 t_b = np_utils.cond( 

395 b.shape.rank < a.shape.rank, 

396 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 

397 b, np_array_ops._pad_left_to(a.shape.rank, b.shape)), 

398 lambda: b) 

399 

400 def _make_shape(shape, prepend): 

401 ones = array_ops.ones_like(shape) 

402 if prepend: 

403 shapes = [ones, shape] 

404 else: 

405 shapes = [shape, ones] 

406 return array_ops.reshape(array_ops_stack.stack(shapes, axis=1), [-1]) 

407 

408 a_shape = array_ops.shape(t_a) 

409 b_shape = array_ops.shape(t_b) 

410 a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False)) 

411 b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True)) 

412 out_shape = a_shape * b_shape 

413 return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape) 

414 

415 

416@np_utils.np_doc('outer') 

417def outer(a, b): 

418 

419 def f(a, b): 

420 return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1]) 

421 

422 return _bin_op(f, a, b) 

423 

424 

425# This can also be implemented via tf.reduce_logsumexp 

426@np_utils.np_doc('logaddexp') 

427def logaddexp(x1, x2): 

428 amax = maximum(x1, x2) 

429 delta = x1 - x2 

430 return np_array_ops.where( 

431 isnan(delta), 

432 x1 + x2, # NaNs or infinities of the same sign. 

433 amax + log1p(exp(-abs(delta)))) 

434 

435 

436@np_utils.np_doc('logaddexp2') 

437def logaddexp2(x1, x2): 

438 amax = maximum(x1, x2) 

439 delta = x1 - x2 

440 return np_array_ops.where( 

441 isnan(delta), 

442 x1 + x2, # NaNs or infinities of the same sign. 

443 amax + log1p(exp2(-abs(delta))) / np.log(2)) 

444 

445 

446@np_utils.np_doc('polyval') 

447def polyval(p, x): # pylint: disable=missing-function-docstring 

448 

449 def f(p, x): 

450 if p.shape.rank == 0: 

451 p = array_ops.reshape(p, [1]) 

452 p = array_ops_stack.unstack(p) 

453 # TODO(wangpeng): Make tf version take a tensor for p instead of a list. 

454 y = math_ops.polyval(p, x) 

455 # If the polynomial is 0-order, numpy requires the result to be broadcast to 

456 # `x`'s shape. 

457 if len(p) == 1: 

458 y = array_ops.broadcast_to(y, x.shape) 

459 return y 

460 

461 return _bin_op(f, p, x) 

462 

463 

464@np_utils.np_doc('isclose') 

465def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring 

466 

467 def f(a, b): # pylint: disable=missing-docstring 

468 dtype = a.dtype 

469 if np.issubdtype(dtype.as_numpy_dtype, np.inexact): 

470 rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype) 

471 atol_ = ops.convert_to_tensor(atol, dtype.real_dtype) 

472 result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b)) 

473 if equal_nan: 

474 result = result | (math_ops.is_nan(a) & math_ops.is_nan(b)) 

475 return result 

476 else: 

477 return a == b 

478 

479 return _bin_op(f, a, b) 

480 

481 

482@np_utils.np_doc('allclose') 

483def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): 

484 return np_array_ops.all( 

485 isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) 

486 

487 

488def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring 

489 

490 def _gcd_cond_fn(_, x2): 

491 return math_ops.reduce_any(x2 != 0) 

492 

493 def _gcd_body_fn(x1, x2): 

494 # math_ops.mod will raise an error when any element of x2 is 0. To avoid 

495 # that, we change those zeros to ones. Their values don't matter because 

496 # they won't be used. 

497 x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype)) 

498 x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1), 

499 array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe), 

500 constant_op.constant(0, x2.dtype))) 

501 return (array_ops.where_v2(x1 < x2, x2, 

502 x1), array_ops.where_v2(x1 < x2, x1, x2)) 

503 

504 if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or 

505 not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): 

506 raise ValueError('Arguments to gcd must be integers.') 

507 shape = array_ops.broadcast_dynamic_shape( 

508 array_ops.shape(x1), array_ops.shape(x2)) 

509 x1 = array_ops.broadcast_to(x1, shape) 

510 x2 = array_ops.broadcast_to(x2, shape) 

511 value, _ = while_loop.while_loop(_gcd_cond_fn, _gcd_body_fn, 

512 (math_ops.abs(x1), math_ops.abs(x2))) 

513 return value 

514 

515 

516# Note that np.gcd may not be present in some supported versions of numpy. 

517@np_utils.np_doc('gcd') 

518def gcd(x1, x2): 

519 return _bin_op(_tf_gcd, x1, x2) 

520 

521 

522# Note that np.lcm may not be present in some supported versions of numpy. 

523@np_utils.np_doc('lcm') 

524def lcm(x1, x2): # pylint: disable=missing-function-docstring 

525 

526 def f(x1, x2): 

527 d = _tf_gcd(x1, x2) 

528 # Same as the `x2_safe` trick above 

529 d_safe = array_ops.where_v2( 

530 math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d) 

531 x1 = math_ops.abs(x1) 

532 x2 = math_ops.abs(x2) 

533 return array_ops.where_v2( 

534 math_ops.equal(d, 0), constant_op.constant(0, d.dtype), 

535 x1 * (x2 // d_safe)) 

536 

537 return _bin_op(f, x1, x2) 

538 

539 

540def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring 

541 

542 def f(x1, x2): 

543 is_bool = (x1.dtype == dtypes.bool) 

544 if is_bool: 

545 assert x2.dtype == dtypes.bool 

546 x1 = math_ops.cast(x1, dtypes.int8) 

547 x2 = math_ops.cast(x2, dtypes.int8) 

548 r = tf_fn(x1, x2) 

549 if is_bool: 

550 r = math_ops.cast(r, dtypes.bool) 

551 return r 

552 

553 return _bin_op(f, x1, x2) 

554 

555 

556@np_utils.np_doc('bitwise_and') 

557def bitwise_and(x1, x2): 

558 return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2) 

559 

560 

561@np_utils.np_doc('bitwise_or') 

562def bitwise_or(x1, x2): 

563 return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2) 

564 

565 

566@np_utils.np_doc('bitwise_xor') 

567def bitwise_xor(x1, x2): 

568 return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2) 

569 

570 

571@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert')) 

572def bitwise_not(x): 

573 

574 def f(x): 

575 if x.dtype == dtypes.bool: 

576 return math_ops.logical_not(x) 

577 return bitwise_ops.invert(x) 

578 

579 return _scalar(f, x) 

580 

581 

582def _scalar(tf_fn, x, promote_to_float=False): 

583 """Computes the tf_fn(x) for each element in `x`. 

584 

585 Args: 

586 tf_fn: function that takes a single Tensor argument. 

587 x: array_like. Could be an ndarray, a Tensor or any object that can be 

588 converted to a Tensor using `ops.convert_to_tensor`. 

589 promote_to_float: whether to cast the argument to a float dtype 

590 (`np_dtypes.default_float_type`) if it is not already. 

591 

592 Returns: 

593 An ndarray with the same shape as `x`. The default output dtype is 

594 determined by `np_dtypes.default_float_type`, unless x is an ndarray with a 

595 floating point type, in which case the output type is same as x.dtype. 

596 """ 

597 x = np_array_ops.asarray(x) 

598 if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact): 

599 x = x.astype(np_dtypes.default_float_type()) 

600 return tf_fn(x) 

601 

602 

603@np_utils.np_doc('log') 

604def log(x): 

605 return _scalar(math_ops.log, x, True) 

606 

607 

608@np_utils.np_doc('exp') 

609def exp(x): 

610 return _scalar(math_ops.exp, x, True) 

611 

612 

613@np_utils.np_doc('sqrt') 

614def sqrt(x): 

615 return _scalar(math_ops.sqrt, x, True) 

616 

617 

618@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute')) 

619def abs(x): # pylint: disable=redefined-builtin 

620 return _scalar(math_ops.abs, x) 

621 

622 

623@np_utils.np_doc('absolute') 

624def absolute(x): 

625 return abs(x) 

626 

627 

628@np_utils.np_doc('fabs') 

629def fabs(x): 

630 return abs(x) 

631 

632 

633@np_utils.np_doc('ceil') 

634def ceil(x): 

635 return _scalar(math_ops.ceil, x, True) 

636 

637 

638@np_utils.np_doc('floor') 

639def floor(x): 

640 return _scalar(math_ops.floor, x, True) 

641 

642 

643@np_utils.np_doc('conj') 

644def conj(x): 

645 return _scalar(math_ops.conj, x) 

646 

647 

648@np_utils.np_doc('negative') 

649def negative(x): 

650 return _scalar(math_ops.negative, x) 

651 

652 

653@np_utils.np_doc('reciprocal') 

654def reciprocal(x): 

655 return _scalar(math_ops.reciprocal, x) 

656 

657 

658@np_utils.np_doc('signbit') 

659def signbit(x): 

660 

661 def f(x): 

662 if x.dtype == dtypes.bool: 

663 return array_ops.fill(array_ops.shape(x), False) 

664 return x < 0 

665 

666 return _scalar(f, x) 

667 

668 

669@np_utils.np_doc('sin') 

670def sin(x): 

671 return _scalar(math_ops.sin, x, True) 

672 

673 

674@np_utils.np_doc('cos') 

675def cos(x): 

676 return _scalar(math_ops.cos, x, True) 

677 

678 

679@np_utils.np_doc('tan') 

680def tan(x): 

681 return _scalar(math_ops.tan, x, True) 

682 

683 

684@np_utils.np_doc('sinh') 

685def sinh(x): 

686 return _scalar(math_ops.sinh, x, True) 

687 

688 

689@np_utils.np_doc('cosh') 

690def cosh(x): 

691 return _scalar(math_ops.cosh, x, True) 

692 

693 

694@np_utils.np_doc('tanh') 

695def tanh(x): 

696 return _scalar(math_ops.tanh, x, True) 

697 

698 

699@np_utils.np_doc('arcsin') 

700def arcsin(x): 

701 return _scalar(math_ops.asin, x, True) 

702 

703 

704@np_utils.np_doc('arccos') 

705def arccos(x): 

706 return _scalar(math_ops.acos, x, True) 

707 

708 

709@np_utils.np_doc('arctan') 

710def arctan(x): 

711 return _scalar(math_ops.atan, x, True) 

712 

713 

714@np_utils.np_doc('arcsinh') 

715def arcsinh(x): 

716 return _scalar(math_ops.asinh, x, True) 

717 

718 

719@np_utils.np_doc('arccosh') 

720def arccosh(x): 

721 return _scalar(math_ops.acosh, x, True) 

722 

723 

724@np_utils.np_doc('arctanh') 

725def arctanh(x): 

726 return _scalar(math_ops.atanh, x, True) 

727 

728 

729@np_utils.np_doc('deg2rad') 

730def deg2rad(x): 

731 

732 def f(x): 

733 return x * (np.pi / 180.0) 

734 

735 return _scalar(f, x, True) 

736 

737 

738@np_utils.np_doc('rad2deg') 

739def rad2deg(x): 

740 return x * (180.0 / np.pi) 

741 

742 

743_tf_float_types = [ 

744 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 

745] 

746 

747 

748@np_utils.np_doc('angle') 

749def angle(z, deg=False): # pylint: disable=missing-function-docstring 

750 

751 def f(x): 

752 if x.dtype in _tf_float_types: 

753 # Workaround for b/147515503 

754 return array_ops.where_v2(x < 0, np.pi, 0) 

755 else: 

756 return math_ops.angle(x) 

757 

758 y = _scalar(f, z, True) 

759 if deg: 

760 y = rad2deg(y) 

761 return y 

762 

763 

764@np_utils.np_doc('cbrt') 

765def cbrt(x): 

766 

767 def f(x): 

768 # __pow__ can't handle negative base, so we use `abs` here. 

769 rt = math_ops.abs(x)**(1.0 / 3) 

770 return array_ops.where_v2(x < 0, -rt, rt) 

771 

772 return _scalar(f, x, True) 

773 

774 

775@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj')) 

776def conjugate(x): 

777 return _scalar(math_ops.conj, x) 

778 

779 

780@np_utils.np_doc('exp2') 

781def exp2(x): 

782 

783 def f(x): 

784 return 2**x 

785 

786 return _scalar(f, x, True) 

787 

788 

789@np_utils.np_doc('expm1') 

790def expm1(x): 

791 return _scalar(math_ops.expm1, x, True) 

792 

793 

794@np_utils.np_doc('fix') 

795def fix(x): 

796 

797 def f(x): 

798 return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x)) 

799 

800 return _scalar(f, x, True) 

801 

802 

803@np_utils.np_doc('iscomplex') 

804def iscomplex(x): 

805 return np_array_ops.imag(x) != 0 

806 

807 

808@np_utils.np_doc('isreal') 

809def isreal(x): 

810 return np_array_ops.imag(x) == 0 

811 

812 

813@np_utils.np_doc('iscomplexobj') 

814def iscomplexobj(x): 

815 x = np_array_ops.array(x) 

816 return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating) 

817 

818 

819@np_utils.np_doc('isrealobj') 

820def isrealobj(x): 

821 return not iscomplexobj(x) 

822 

823 

824@np_utils.np_doc('isnan') 

825def isnan(x): 

826 return _scalar(math_ops.is_nan, x, True) 

827 

828 

829def _make_nan_reduction(np_fun_name, reduction, init_val): 

830 """Helper to generate nan* functions.""" 

831 

832 @np_utils.np_doc(np_fun_name) 

833 def nan_reduction(a, axis=None, dtype=None, keepdims=False): 

834 a = np_array_ops.array(a) 

835 v = np_array_ops.array(init_val, dtype=a.dtype) 

836 return reduction( 

837 np_array_ops.where(isnan(a), v, a), 

838 axis=axis, 

839 dtype=dtype, 

840 keepdims=keepdims) 

841 

842 return nan_reduction 

843 

844 

845nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0) 

846nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) 

847 

848 

849@np_utils.np_doc('nanmean') 

850def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring 

851 a = np_array_ops.array(a) 

852 if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype( 

853 a.dtype.as_numpy_dtype, np.integer): 

854 return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) 

855 nan_mask = logical_not(isnan(a)) 

856 if dtype is None: 

857 dtype = a.dtype.as_numpy_dtype 

858 normalizer = np_array_ops.sum( 

859 nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) 

860 return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer 

861 

862 

863@np_utils.np_doc('isfinite') 

864def isfinite(x): 

865 return _scalar(math_ops.is_finite, x, True) 

866 

867 

868@np_utils.np_doc('isinf') 

869def isinf(x): 

870 return _scalar(math_ops.is_inf, x, True) 

871 

872 

873@np_utils.np_doc('isneginf') 

874def isneginf(x): 

875 return x == np_array_ops.full_like(x, -np.inf) 

876 

877 

878@np_utils.np_doc('isposinf') 

879def isposinf(x): 

880 return x == np_array_ops.full_like(x, np.inf) 

881 

882 

883@np_utils.np_doc('log2') 

884def log2(x): 

885 return log(x) / np.log(2) 

886 

887 

888@np_utils.np_doc('log10') 

889def log10(x): 

890 return log(x) / np.log(10) 

891 

892 

893@np_utils.np_doc('log1p') 

894def log1p(x): 

895 return _scalar(math_ops.log1p, x, True) 

896 

897 

898@np_utils.np_doc('positive') 

899def positive(x): 

900 return _scalar(lambda x: x, x) 

901 

902 

903@np_utils.np_doc('sinc') 

904def sinc(x): 

905 

906 def f(x): 

907 pi_x = x * np.pi 

908 return array_ops.where_v2(x == 0, array_ops.ones_like(x), 

909 math_ops.sin(pi_x) / pi_x) 

910 

911 return _scalar(f, x, True) 

912 

913 

914@np_utils.np_doc('square') 

915def square(x): 

916 return _scalar(math_ops.square, x) 

917 

918 

919@np_utils.np_doc('diff') 

920def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring 

921 

922 def f(a): 

923 # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution 

924 # TODO(agarwal): avoid depending on static rank. 

925 nd = a.shape.rank 

926 if nd is None: 

927 raise ValueError( 

928 'Function `diff` currently requires a known rank for input `a`. ' 

929 f'Received: a={a} (unknown rank)') 

930 if (axis + nd if axis < 0 else axis) >= nd: 

931 raise ValueError( 

932 f'Argument `axis` (received axis={axis}) is out of bounds ' 

933 f'for input {a} of rank {nd}.') 

934 if n < 0: 

935 raise ValueError('Argument `order` must be a non-negative integer. ' 

936 f'Received: axis={n}') 

937 slice1 = [slice(None)] * nd 

938 slice2 = [slice(None)] * nd 

939 slice1[axis] = slice(1, None) 

940 slice2[axis] = slice(None, -1) 

941 slice1 = tuple(slice1) 

942 slice2 = tuple(slice2) 

943 op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract 

944 for _ in range(n): 

945 a = op(a[slice1], a[slice2]) 

946 return a 

947 

948 return _scalar(f, a) 

949 

950 

951def _wrap(f, reverse=False): 

952 """Wraps binary ops so they can be added as operator overloads on ndarray.""" 

953 

954 def _f(a, b): 

955 if reverse: 

956 a, b = b, a 

957 

958 if getattr(b, '__array_priority__', 

959 0) > np_arrays.ndarray.__array_priority__: 

960 return NotImplemented 

961 

962 return f(a, b) 

963 

964 return _f 

965 

966 

967def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): 

968 """Helper function for comparision.""" 

969 dtype = np_utils.result_type(x1, x2) 

970 # Cast x1 and x2 to the result_type if needed. 

971 x1 = np_array_ops.array(x1, dtype=dtype) 

972 x2 = np_array_ops.array(x2, dtype=dtype) 

973 if cast_bool_to_int and x1.dtype == dtypes.bool: 

974 x1 = math_ops.cast(x1, dtypes.int32) 

975 x2 = math_ops.cast(x2, dtypes.int32) 

976 return tf_fun(x1, x2) 

977 

978 

979@np_utils.np_doc('equal') 

980def equal(x1, x2): 

981 return _comparison(math_ops.equal, x1, x2) 

982 

983 

984@np_utils.np_doc('not_equal') 

985def not_equal(x1, x2): 

986 return _comparison(math_ops.not_equal, x1, x2) 

987 

988 

989@np_utils.np_doc('greater') 

990def greater(x1, x2): 

991 return _comparison(math_ops.greater, x1, x2, True) 

992 

993 

994@np_utils.np_doc('greater_equal') 

995def greater_equal(x1, x2): 

996 return _comparison(math_ops.greater_equal, x1, x2, True) 

997 

998 

999@np_utils.np_doc('less') 

1000def less(x1, x2): 

1001 return _comparison(math_ops.less, x1, x2, True) 

1002 

1003 

1004@np_utils.np_doc('less_equal') 

1005def less_equal(x1, x2): 

1006 return _comparison(math_ops.less_equal, x1, x2, True) 

1007 

1008 

1009@np_utils.np_doc('array_equal') 

1010def array_equal(a1, a2): # pylint: disable=missing-function-docstring 

1011 

1012 def f(x1, x2): 

1013 return np_utils.cond( 

1014 math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)), 

1015 lambda: np_utils.cond( # pylint: disable=g-long-lambda 

1016 np_utils.reduce_all( 

1017 math_ops.equal(array_ops.shape(x1), array_ops.shape(x2)) 

1018 ), 

1019 lambda: math_ops.reduce_all(math_ops.equal(x1, x2)), 

1020 lambda: constant_op.constant(False)), 

1021 lambda: constant_op.constant(False)) 

1022 

1023 return _comparison(f, a1, a2) 

1024 

1025 

1026def _logical_binary_op(tf_fun, x1, x2): 

1027 x1 = np_array_ops.array(x1, dtype=np.bool_) 

1028 x2 = np_array_ops.array(x2, dtype=np.bool_) 

1029 return tf_fun(x1, x2) 

1030 

1031 

1032@np_utils.np_doc('logical_and') 

1033def logical_and(x1, x2): 

1034 return _logical_binary_op(math_ops.logical_and, x1, x2) 

1035 

1036 

1037@np_utils.np_doc('logical_or') 

1038def logical_or(x1, x2): 

1039 return _logical_binary_op(math_ops.logical_or, x1, x2) 

1040 

1041 

1042@np_utils.np_doc('logical_xor') 

1043def logical_xor(x1, x2): 

1044 return _logical_binary_op(math_ops.logical_xor, x1, x2) 

1045 

1046 

1047@np_utils.np_doc('logical_not') 

1048def logical_not(x): 

1049 x = np_array_ops.array(x, dtype=np.bool_) 

1050 return math_ops.logical_not(x) 

1051 

1052 

1053@np_utils.np_doc('linspace') 

1054def linspace( # pylint: disable=missing-docstring 

1055 start, 

1056 stop, 

1057 num=50, 

1058 endpoint=True, 

1059 retstep=False, 

1060 dtype=float, 

1061 axis=0): 

1062 if dtype: 

1063 dtype = np_utils.result_type(dtype) 

1064 start = np_array_ops.array(start, dtype=dtype) 

1065 stop = np_array_ops.array(stop, dtype=dtype) 

1066 if num < 0: 

1067 raise ValueError( 

1068 'Argument `num` (number of samples) must be a non-negative integer. ' 

1069 f'Received: num={num}') 

1070 step = ops.convert_to_tensor(np.nan) 

1071 if endpoint: 

1072 result = math_ops.linspace(start, stop, num, axis=axis) 

1073 if num > 1: 

1074 step = (stop - start) / (num - 1) 

1075 else: 

1076 # math_ops.linspace does not support endpoint=False so we manually handle it 

1077 # here. 

1078 if num > 0: 

1079 step = ((stop - start) / num) 

1080 if num > 1: 

1081 new_stop = math_ops.cast(stop, step.dtype) - step 

1082 start = math_ops.cast(start, new_stop.dtype) 

1083 result = math_ops.linspace(start, new_stop, num, axis=axis) 

1084 else: 

1085 result = math_ops.linspace(start, stop, num, axis=axis) 

1086 if dtype: 

1087 if dtype.is_integer: 

1088 # Since numpy 1.20, linspace's rounding is towards -inf instead of 0 

1089 result = math_ops.floor(result) 

1090 result = math_ops.cast(result, dtype) 

1091 if retstep: 

1092 return (result, step) 

1093 else: 

1094 return result 

1095 

1096 

1097@np_utils.np_doc('logspace') 

1098def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): 

1099 dtype = np_utils.result_type(start, stop, dtype) 

1100 result = linspace( 

1101 start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis) 

1102 result = math_ops.pow(math_ops.cast(base, result.dtype), result) 

1103 if dtype: 

1104 result = math_ops.cast(result, dtype) 

1105 return result 

1106 

1107 

1108@np_utils.np_doc('geomspace') 

1109def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring 

1110 dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type( 

1111 start, stop, float(num), np_array_ops.zeros((), dtype)) 

1112 computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32) 

1113 start = np_array_ops.asarray(start, dtype=computation_dtype) 

1114 stop = np_array_ops.asarray(stop, dtype=computation_dtype) 

1115 # follow the numpy geomspace convention for negative and complex endpoints 

1116 start_sign = 1 - np_array_ops.sign(np_array_ops.real(start)) 

1117 stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop)) 

1118 signflip = 1 - start_sign * stop_sign // 2 

1119 res = signflip * logspace( 

1120 log10(signflip * start), 

1121 log10(signflip * stop), 

1122 num, 

1123 endpoint=endpoint, 

1124 base=10.0, 

1125 dtype=computation_dtype, 

1126 axis=0) 

1127 if axis != 0: 

1128 res = np_array_ops.moveaxis(res, 0, axis) 

1129 return math_ops.cast(res, dtype) 

1130 

1131 

1132@np_utils.np_doc('ptp') 

1133def ptp(a, axis=None, keepdims=None): 

1134 return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) - 

1135 np_array_ops.amin(a, axis=axis, keepdims=keepdims)) 

1136 

1137 

1138@np_utils.np_doc_only('concatenate') 

1139def concatenate(arys, axis=0): 

1140 if not isinstance(arys, (list, tuple)): 

1141 arys = [arys] 

1142 if not arys: 

1143 raise ValueError('Need at least one array to concatenate. Received empty ' 

1144 f'input: arys={arys}') 

1145 dtype = np_utils.result_type(*arys) 

1146 arys = [np_array_ops.array(array, dtype=dtype) for array in arys] 

1147 return array_ops.concat(arys, axis) 

1148 

1149 

1150@np_utils.np_doc_only('tile') 

1151def tile(a, reps): # pylint: disable=missing-function-docstring 

1152 a = np_array_ops.array(a) 

1153 reps = array_ops.reshape(np_array_ops.array(reps, dtype=dtypes.int32), [-1]) 

1154 

1155 a_rank = array_ops.rank(a) 

1156 reps_size = array_ops.size(reps) 

1157 reps = array_ops.pad( 

1158 reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1) 

1159 a_shape = array_ops.pad( 

1160 array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]], 

1161 constant_values=1) 

1162 a = array_ops.reshape(a, a_shape) 

1163 

1164 return array_ops.tile(a, reps) 

1165 

1166 

1167@np_utils.np_doc('count_nonzero') 

1168def count_nonzero(a, axis=None): 

1169 return math_ops.count_nonzero(np_array_ops.array(a), axis) 

1170 

1171 

1172@np_utils.np_doc('argsort') 

1173def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 

1174 # TODO(nareshmodi): make string tensors also work. 

1175 if kind not in ('quicksort', 'stable'): 

1176 raise ValueError( 

1177 'Invalid value for argument `kind`. ' 

1178 'Only kind="quicksort" and kind="stable" are supported. ' 

1179 f'Received: kind={kind}') 

1180 if order is not None: 

1181 raise ValueError('The `order` argument is not supported. Pass order=None') 

1182 stable = (kind == 'stable') 

1183 

1184 a = np_array_ops.array(a) 

1185 

1186 def _argsort(a, axis, stable): 

1187 if axis is None: 

1188 a = array_ops.reshape(a, [-1]) 

1189 axis = 0 

1190 

1191 return sort_ops.argsort(a, axis, stable=stable) 

1192 

1193 tf_ans = np_utils.cond( 

1194 math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]), 

1195 lambda: _argsort(a, axis, stable)) 

1196 

1197 return np_array_ops.array(tf_ans, dtype=np.intp) 

1198 

1199 

1200@np_utils.np_doc('sort') 

1201def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 

1202 if kind != 'quicksort': 

1203 raise ValueError( 

1204 'Invalid value for argument `kind`. ' 

1205 'Only kind="quicksort" is supported. ' 

1206 f'Received: kind={kind}') 

1207 if order is not None: 

1208 raise ValueError('The `order` argument is not supported. Pass order=None') 

1209 

1210 a = np_array_ops.array(a) 

1211 

1212 if axis is None: 

1213 return sort_ops.sort(array_ops.reshape(a, [-1]), 0) 

1214 else: 

1215 return sort_ops.sort(a, axis) 

1216 

1217 

1218def _argminmax(fn, a, axis=None): 

1219 a = np_array_ops.array(a) 

1220 if axis is None: 

1221 # When axis is None numpy flattens the array. 

1222 a_t = array_ops.reshape(a, [-1]) 

1223 else: 

1224 a_t = np_array_ops.atleast_1d(a) 

1225 return fn(input=a_t, axis=axis) 

1226 

1227 

1228@np_utils.np_doc('argmax') 

1229def argmax(a, axis=None): 

1230 return _argminmax(math_ops.argmax, a, axis) 

1231 

1232 

1233@np_utils.np_doc('argmin') 

1234def argmin(a, axis=None): 

1235 return _argminmax(math_ops.argmin, a, axis) 

1236 

1237 

1238@np_utils.np_doc('append') 

1239def append(arr, values, axis=None): 

1240 if axis is None: 

1241 return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0) 

1242 else: 

1243 return concatenate([arr, values], axis=axis) 

1244 

1245 

1246@np_utils.np_doc('average') 

1247def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring 

1248 if axis is not None and not isinstance(axis, int): 

1249 # TODO(wangpeng): Support tuple of ints as `axis` 

1250 raise ValueError('Argument `axis` must be an integer. ' 

1251 f'Received axis={axis} (of type {type(axis)})') 

1252 a = np_array_ops.array(a) 

1253 if weights is None: # Treat all weights as 1 

1254 if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 

1255 a = a.astype( 

1256 np_utils.result_type(a.dtype, np_dtypes.default_float_type())) 

1257 avg = math_ops.reduce_mean(a, axis=axis) 

1258 if returned: 

1259 if axis is None: 

1260 weights_sum = array_ops.size(a) 

1261 else: 

1262 weights_sum = array_ops.shape(a)[axis] 

1263 weights_sum = math_ops.cast(weights_sum, a.dtype) 

1264 else: 

1265 if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 

1266 out_dtype = np_utils.result_type(a.dtype, weights) 

1267 else: 

1268 out_dtype = np_utils.result_type(a.dtype, weights, 

1269 np_dtypes.default_float_type()) 

1270 a = np_array_ops.array(a, out_dtype) 

1271 weights = np_array_ops.array(weights, out_dtype) 

1272 

1273 def rank_equal_case(): 

1274 control_flow_assert.Assert( 

1275 math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)), 

1276 [array_ops.shape(a), array_ops.shape(weights)]) 

1277 weights_sum = math_ops.reduce_sum(weights, axis=axis) 

1278 avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum 

1279 return avg, weights_sum 

1280 

1281 if axis is None: 

1282 avg, weights_sum = rank_equal_case() 

1283 else: 

1284 

1285 def rank_not_equal_case(): 

1286 control_flow_assert.Assert( 

1287 array_ops.rank(weights) == 1, [array_ops.rank(weights)]) 

1288 weights_sum = math_ops.reduce_sum(weights) 

1289 axes = ops.convert_to_tensor([[axis], [0]]) 

1290 avg = math_ops.tensordot(a, weights, axes) / weights_sum 

1291 return avg, weights_sum 

1292 

1293 # We condition on rank rather than shape equality, because if we do the 

1294 # latter, when the shapes are partially unknown but the ranks are known 

1295 # and different, np_utils.cond will run shape checking on the true branch, 

1296 # which will raise a shape-checking error. 

1297 avg, weights_sum = np_utils.cond( 

1298 math_ops.equal(array_ops.rank(a), array_ops.rank(weights)), 

1299 rank_equal_case, rank_not_equal_case) 

1300 

1301 avg = np_array_ops.array(avg) 

1302 if returned: 

1303 weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg)) 

1304 return avg, weights_sum 

1305 return avg 

1306 

1307 

1308@np_utils.np_doc('trace') 

1309def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring 

1310 if dtype: 

1311 dtype = np_utils.result_type(dtype) 

1312 a = np_array_ops.asarray(a, dtype) 

1313 

1314 if offset == 0: 

1315 a_shape = a.shape 

1316 if a_shape.rank is not None: 

1317 rank = len(a_shape) 

1318 if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or 

1319 axis2 == rank - 1): 

1320 return math_ops.trace(a) 

1321 

1322 a = np_array_ops.diagonal(a, offset, axis1, axis2) 

1323 return np_array_ops.sum(a, -1, dtype) 

1324 

1325 

1326@np_utils.np_doc('meshgrid') 

1327def meshgrid(*xi, **kwargs): 

1328 """This currently requires copy=True and sparse=False.""" 

1329 sparse = kwargs.get('sparse', False) 

1330 if sparse: 

1331 raise ValueError( 

1332 'Function `meshgrid` does not support returning sparse arrays yet. ' 

1333 f'Received: sparse={sparse}') 

1334 

1335 copy = kwargs.get('copy', True) 

1336 if not copy: 

1337 raise ValueError('Function `meshgrid` only supports copy=True. ' 

1338 f'Received: copy={copy}') 

1339 

1340 indexing = kwargs.get('indexing', 'xy') 

1341 

1342 xi = [np_array_ops.asarray(arg) for arg in xi] 

1343 kwargs = {'indexing': indexing} 

1344 

1345 outputs = array_ops.meshgrid(*xi, **kwargs) 

1346 

1347 return outputs 

1348 

1349 

1350# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument 

1351# `subscripts`, even though the doc says it has. 

1352@np_utils.np_doc_only('einsum') 

1353def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring 

1354 casting = kwargs.get('casting', 'safe') 

1355 optimize = kwargs.get('optimize', False) 

1356 if casting == 'safe': 

1357 operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access 

1358 elif casting == 'no': 

1359 operands = [np_array_ops.asarray(x) for x in operands] 

1360 else: 

1361 raise ValueError( 

1362 'Invalid value for argument `casting`. ' 

1363 f'Expected casting="safe" or casting="no". Received: casting={casting}') 

1364 if not optimize: 

1365 # TF doesn't have a "no optimization" option. 

1366 # TODO(wangpeng): Print a warning that np and tf use different 

1367 # optimizations. 

1368 tf_optimize = 'greedy' 

1369 elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison 

1370 tf_optimize = 'greedy' 

1371 elif optimize == 'greedy': 

1372 tf_optimize = 'greedy' 

1373 elif optimize == 'optimal': 

1374 tf_optimize = 'optimal' 

1375 else: 

1376 raise ValueError( 

1377 'Invalid value for argument `optimize`. ' 

1378 'Expected one of {True, "greedy", "optimal"}. ' 

1379 f'Received: optimize={optimize}') 

1380 

1381 res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize) 

1382 return res 

1383 

1384 

1385def _tensor_t(self): 

1386 """Returns a Tensor which is the transpose of this Tensor.""" 

1387 return self.transpose() 

1388 

1389 

1390def _tensor_ndim(self): 

1391 """Returns the rank of the Tensor.""" 

1392 return self.shape.ndims 

1393 

1394 

1395def _tensor_pos(self): 

1396 """Returns self, for unary operator `+`.""" 

1397 return self 

1398 

1399 

1400def _tensor_size(self): 

1401 """Returns the number of elements in this Tensor, if fully known.""" 

1402 if not self.shape.is_fully_defined(): 

1403 return None 

1404 return np.prod(self.shape.as_list()) 

1405 

1406 

1407def _tensor_tolist(self): 

1408 if isinstance(self, ops.EagerTensor): 

1409 return self._numpy().tolist() # pylint: disable=protected-access 

1410 

1411 raise ValueError('Symbolic Tensors do not support the tolist API.') 

1412 

1413 

1414def enable_numpy_methods_on_tensor(): 

1415 """Adds additional NumPy methods on tf.Tensor class.""" 

1416 t = property(_tensor_t) 

1417 setattr(ops.Tensor, 'T', t) 

1418 

1419 ndim = property(_tensor_ndim) 

1420 setattr(ops.Tensor, 'ndim', ndim) 

1421 

1422 size = property(_tensor_size) 

1423 setattr(ops.Tensor, 'size', size) 

1424 

1425 setattr(ops.Tensor, '__pos__', _tensor_pos) 

1426 setattr(ops.Tensor, 'tolist', _tensor_tolist) 

1427 

1428 # TODO(b/178540516): Make a custom `setattr` that changes the method's 

1429 # docstring to the TF one. 

1430 setattr(ops.Tensor, 'transpose', np_array_ops.transpose) 

1431 setattr(ops.Tensor, 'flatten', np_array_ops.flatten) 

1432 setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access 

1433 setattr(ops.Tensor, 'ravel', np_array_ops.ravel) 

1434 setattr(ops.Tensor, 'clip', clip) 

1435 setattr(ops.Tensor, 'astype', math_ops.cast) 

1436 setattr(ops.Tensor, '__round__', np_array_ops.around) 

1437 setattr(ops.Tensor, 'max', np_array_ops.amax) 

1438 setattr(ops.Tensor, 'mean', np_array_ops.mean) 

1439 setattr(ops.Tensor, 'min', np_array_ops.amin) 

1440 

1441 # TODO(wangpeng): Remove `data` when all uses of it are removed 

1442 data = property(lambda self: self) 

1443 setattr(ops.Tensor, 'data', data)