Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_math_ops.py: 33%
830 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mathematical operations."""
16# pylint: disable=g-direct-tensorflow-import
18import numbers
19import sys
21import numpy as np
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import array_ops_stack
29from tensorflow.python.ops import bitwise_ops
30from tensorflow.python.ops import clip_ops
31from tensorflow.python.ops import control_flow_assert
32from tensorflow.python.ops import gen_math_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.ops import sort_ops
36from tensorflow.python.ops import special_math_ops
37from tensorflow.python.ops import while_loop
38from tensorflow.python.ops.numpy_ops import np_array_ops
39from tensorflow.python.ops.numpy_ops import np_arrays
40from tensorflow.python.ops.numpy_ops import np_dtypes
41from tensorflow.python.ops.numpy_ops import np_export
42from tensorflow.python.ops.numpy_ops import np_utils
45pi = np_export.np_export_constant(__name__, 'pi', np.pi)
46e = np_export.np_export_constant(__name__, 'e', np.e)
47inf = np_export.np_export_constant(__name__, 'inf', np.inf)
50@np_utils.np_doc_only('dot')
51def dot(a, b): # pylint: disable=missing-docstring
53 def f(a, b): # pylint: disable=missing-docstring
54 return np_utils.cond(
55 np_utils.logical_or(
56 math_ops.equal(array_ops.rank(a), 0),
57 math_ops.equal(array_ops.rank(b), 0)),
58 lambda: a * b,
59 lambda: np_utils.cond( # pylint: disable=g-long-lambda
60 math_ops.equal(array_ops.rank(b), 1),
61 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
62 lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
64 return _bin_op(f, a, b)
67# TODO(wangpeng): Make element-wise ops `ufunc`s
68def _bin_op(tf_fun, a, b, promote=True):
69 if promote:
70 a, b = np_array_ops._promote_dtype_binary(a, b) # pylint: disable=protected-access
71 else:
72 a = np_array_ops.array(a)
73 b = np_array_ops.array(b)
74 return tf_fun(a, b)
77@np_utils.np_doc('add')
78def add(x1, x2):
80 def add_or_or(x1, x2):
81 if x1.dtype == dtypes.bool:
82 assert x2.dtype == dtypes.bool
83 return math_ops.logical_or(x1, x2)
84 return math_ops.add(x1, x2)
86 return _bin_op(add_or_or, x1, x2)
89@np_utils.np_doc('subtract')
90def subtract(x1, x2):
91 return _bin_op(math_ops.subtract, x1, x2)
94@np_utils.np_doc('multiply')
95def multiply(x1, x2):
97 def mul_or_and(x1, x2):
98 if x1.dtype == dtypes.bool:
99 assert x2.dtype == dtypes.bool
100 return math_ops.logical_and(x1, x2)
101 return math_ops.multiply(x1, x2)
103 return _bin_op(mul_or_and, x1, x2)
106@np_utils.np_doc('true_divide')
107def true_divide(x1, x2): # pylint: disable=missing-function-docstring
109 def _avoid_float64(x1, x2):
110 if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
111 x1 = math_ops.cast(x1, dtype=dtypes.float32)
112 x2 = math_ops.cast(x2, dtype=dtypes.float32)
113 return x1, x2
115 def f(x1, x2):
116 if x1.dtype == dtypes.bool:
117 assert x2.dtype == dtypes.bool
118 float_ = np_dtypes.default_float_type()
119 x1 = math_ops.cast(x1, float_)
120 x2 = math_ops.cast(x2, float_)
121 if not np_dtypes.is_allow_float64():
122 # math_ops.truediv in Python3 produces float64 when both inputs are int32
123 # or int64. We want to avoid that when is_allow_float64() is False.
124 x1, x2 = _avoid_float64(x1, x2)
125 return math_ops.truediv(x1, x2)
127 return _bin_op(f, x1, x2)
130@np_utils.np_doc('divide')
131def divide(x1, x2): # pylint: disable=missing-function-docstring
132 return true_divide(x1, x2)
135@np_utils.np_doc('floor_divide')
136def floor_divide(x1, x2): # pylint: disable=missing-function-docstring
138 def f(x1, x2):
139 if x1.dtype == dtypes.bool:
140 assert x2.dtype == dtypes.bool
141 x1 = math_ops.cast(x1, dtypes.int8)
142 x2 = math_ops.cast(x2, dtypes.int8)
143 return math_ops.floordiv(x1, x2)
145 return _bin_op(f, x1, x2)
148@np_utils.np_doc('mod')
149def mod(x1, x2): # pylint: disable=missing-function-docstring
151 def f(x1, x2):
152 if x1.dtype == dtypes.bool:
153 assert x2.dtype == dtypes.bool
154 x1 = math_ops.cast(x1, dtypes.int8)
155 x2 = math_ops.cast(x2, dtypes.int8)
156 return math_ops.mod(x1, x2)
158 return _bin_op(f, x1, x2)
161@np_utils.np_doc('remainder')
162def remainder(x1, x2): # pylint: disable=missing-function-docstring
163 return mod(x1, x2)
166@np_utils.np_doc('divmod')
167def divmod(x1, x2): # pylint: disable=redefined-builtin
168 return floor_divide(x1, x2), mod(x1, x2)
171@np_utils.np_doc('maximum')
172def maximum(x1, x2): # pylint: disable=missing-function-docstring
174 # Fast path for when maximum is used as relu.
175 if isinstance(
176 x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance(
177 x1, np_arrays.ndarray) and x1.dtype != dtypes.bool:
178 return nn_ops.relu(np_array_ops.asarray(x1))
180 def max_or_or(x1, x2):
181 if x1.dtype == dtypes.bool:
182 assert x2.dtype == dtypes.bool
183 return math_ops.logical_or(x1, x2)
184 return math_ops.maximum(x1, x2)
186 return _bin_op(max_or_or, x1, x2)
189@np_utils.np_doc('minimum')
190def minimum(x1, x2):
192 def min_or_and(x1, x2):
193 if x1.dtype == dtypes.bool:
194 assert x2.dtype == dtypes.bool
195 return math_ops.logical_and(x1, x2)
196 return math_ops.minimum(x1, x2)
198 return _bin_op(min_or_and, x1, x2)
201@np_utils.np_doc('clip')
202def clip(a, a_min, a_max): # pylint: disable=missing-docstring
203 if a_min is None and a_max is None:
204 raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.')
205 if a_min is None:
206 return minimum(a, a_max)
207 elif a_max is None:
208 return maximum(a, a_min)
209 else:
210 a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access
211 return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
214@np_utils.np_doc('matmul')
215def matmul(x1, x2): # pylint: disable=missing-docstring
216 def f(x1, x2):
217 try:
218 if x1._rank() == 2 and x2._rank() == 2: # pylint: disable=protected-access
219 # Fast path for known ranks.
220 return gen_math_ops.mat_mul(x1, x2)
221 return np_utils.cond(
222 math_ops.equal(np_utils.tf_rank(x2), 1),
223 lambda: math_ops.tensordot(x1, x2, axes=1),
224 lambda: np_utils.cond( # pylint: disable=g-long-lambda
225 math_ops.equal(np_utils.tf_rank(x1), 1),
226 lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
227 x1, x2, axes=[[0], [-2]]),
228 lambda: math_ops.matmul(x1, x2)))
229 except errors.InvalidArgumentError as err:
230 raise ValueError(str(err)).with_traceback(sys.exc_info()[2])
232 return _bin_op(f, x1, x2)
235# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles
236# batched matmul as well, so simply including promotion in TF's current
237# __matmul__ implementation was not sufficient.
238setattr(np_arrays.ndarray, '_matmul', matmul)
241@np_utils.np_doc('tensordot')
242def tensordot(a, b, axes=2):
243 return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
246@np_utils.np_doc_only('inner')
247def inner(a, b): # pylint: disable=missing-function-docstring
249 def f(a, b):
250 return np_utils.cond(
251 np_utils.logical_or(
252 math_ops.equal(array_ops.rank(a), 0),
253 math_ops.equal(array_ops.rank(b), 0)), lambda: a * b,
254 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]))
256 return _bin_op(f, a, b)
259@np_utils.np_doc('cross')
260def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring
262 def f(a, b): # pylint: disable=missing-docstring
263 # We can't assign to captured variable `axisa`, so make a new variable
264 if axis is None:
265 axis_a = axisa
266 axis_b = axisb
267 axis_c = axisc
268 else:
269 axis_a = axis
270 axis_b = axis
271 axis_c = axis
272 if axis_a < 0:
273 axis_a = np_utils.add(axis_a, array_ops.rank(a))
274 if axis_b < 0:
275 axis_b = np_utils.add(axis_b, array_ops.rank(b))
277 def maybe_move_axis_to_last(a, axis):
279 def move_axis_to_last(a, axis):
280 return array_ops.transpose(
281 a,
282 array_ops.concat([
283 math_ops.range(axis),
284 math_ops.range(axis + 1, array_ops.rank(a)), [axis]
285 ],
286 axis=0))
288 return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1),
289 lambda: a, lambda: move_axis_to_last(a, axis))
291 a = maybe_move_axis_to_last(a, axis_a)
292 b = maybe_move_axis_to_last(b, axis_b)
293 a_dim = np_utils.getitem(array_ops.shape(a), -1)
294 b_dim = np_utils.getitem(array_ops.shape(b), -1)
296 def maybe_pad_0(a, size_of_last_dim):
298 def pad_0(a):
299 return array_ops.pad(
300 a,
301 array_ops.concat([
302 array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32),
303 constant_op.constant([[0, 1]], dtypes.int32)
304 ],
305 axis=0))
307 return np_utils.cond(
308 math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a)
310 a = maybe_pad_0(a, a_dim)
311 b = maybe_pad_0(b, b_dim)
312 c = math_ops.cross(*np_utils.tf_broadcast(a, b))
313 if axis_c < 0:
314 axis_c = np_utils.add(axis_c, array_ops.rank(c))
316 def move_last_to_axis(a, axis):
317 r = array_ops.rank(a)
318 return array_ops.transpose(
319 a,
320 array_ops.concat(
321 [math_ops.range(axis), [r - 1],
322 math_ops.range(axis, r - 1)],
323 axis=0))
325 c = np_utils.cond(
326 (a_dim == 2) & (b_dim == 2),
327 lambda: c[..., 2],
328 lambda: np_utils.cond( # pylint: disable=g-long-lambda
329 axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c,
330 lambda: move_last_to_axis(c, axis_c)))
331 return c
333 return _bin_op(f, a, b)
336@np_utils.np_doc_only('vdot')
337def vdot(a, b): # pylint: disable=missing-docstring
338 a, b = np_array_ops._promote_dtype(a, b) # pylint: disable=protected-access
339 a = np_array_ops.reshape(a, [-1])
340 b = np_array_ops.reshape(b, [-1])
341 if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
342 a = conj(a)
343 return dot(a, b)
346@np_utils.np_doc('power')
347def power(x1, x2):
348 return _bin_op(math_ops.pow, x1, x2)
351@np_utils.np_doc('float_power')
352def float_power(x1, x2):
353 return power(x1, x2)
356@np_utils.np_doc('arctan2')
357def arctan2(x1, x2):
358 return _bin_op(math_ops.atan2, x1, x2)
361@np_utils.np_doc('nextafter')
362def nextafter(x1, x2):
363 return _bin_op(math_ops.nextafter, x1, x2)
366@np_utils.np_doc('heaviside')
367def heaviside(x1, x2): # pylint: disable=missing-function-docstring
369 def f(x1, x2):
370 return array_ops.where_v2(
371 x1 < 0, constant_op.constant(0, dtype=x2.dtype),
372 array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2))
374 y = _bin_op(f, x1, x2)
375 if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact):
376 y = y.astype(np_dtypes.default_float_type())
377 return y
380@np_utils.np_doc('hypot')
381def hypot(x1, x2):
382 return sqrt(square(x1) + square(x2))
385@np_utils.np_doc('kron')
386def kron(a, b): # pylint: disable=missing-function-docstring
387 # pylint: disable=protected-access,g-complex-comprehension
388 a, b = np_array_ops._promote_dtype(a, b)
389 t_a = np_utils.cond(
390 a.shape.rank < b.shape.rank,
391 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
392 a, np_array_ops._pad_left_to(b.shape.rank, a.shape)),
393 lambda: a)
394 t_b = np_utils.cond(
395 b.shape.rank < a.shape.rank,
396 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
397 b, np_array_ops._pad_left_to(a.shape.rank, b.shape)),
398 lambda: b)
400 def _make_shape(shape, prepend):
401 ones = array_ops.ones_like(shape)
402 if prepend:
403 shapes = [ones, shape]
404 else:
405 shapes = [shape, ones]
406 return array_ops.reshape(array_ops_stack.stack(shapes, axis=1), [-1])
408 a_shape = array_ops.shape(t_a)
409 b_shape = array_ops.shape(t_b)
410 a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
411 b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
412 out_shape = a_shape * b_shape
413 return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
416@np_utils.np_doc('outer')
417def outer(a, b):
419 def f(a, b):
420 return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1])
422 return _bin_op(f, a, b)
425# This can also be implemented via tf.reduce_logsumexp
426@np_utils.np_doc('logaddexp')
427def logaddexp(x1, x2):
428 amax = maximum(x1, x2)
429 delta = x1 - x2
430 return np_array_ops.where(
431 isnan(delta),
432 x1 + x2, # NaNs or infinities of the same sign.
433 amax + log1p(exp(-abs(delta))))
436@np_utils.np_doc('logaddexp2')
437def logaddexp2(x1, x2):
438 amax = maximum(x1, x2)
439 delta = x1 - x2
440 return np_array_ops.where(
441 isnan(delta),
442 x1 + x2, # NaNs or infinities of the same sign.
443 amax + log1p(exp2(-abs(delta))) / np.log(2))
446@np_utils.np_doc('polyval')
447def polyval(p, x): # pylint: disable=missing-function-docstring
449 def f(p, x):
450 if p.shape.rank == 0:
451 p = array_ops.reshape(p, [1])
452 p = array_ops_stack.unstack(p)
453 # TODO(wangpeng): Make tf version take a tensor for p instead of a list.
454 y = math_ops.polyval(p, x)
455 # If the polynomial is 0-order, numpy requires the result to be broadcast to
456 # `x`'s shape.
457 if len(p) == 1:
458 y = array_ops.broadcast_to(y, x.shape)
459 return y
461 return _bin_op(f, p, x)
464@np_utils.np_doc('isclose')
465def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring
467 def f(a, b): # pylint: disable=missing-docstring
468 dtype = a.dtype
469 if np.issubdtype(dtype.as_numpy_dtype, np.inexact):
470 rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype)
471 atol_ = ops.convert_to_tensor(atol, dtype.real_dtype)
472 result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b))
473 if equal_nan:
474 result = result | (math_ops.is_nan(a) & math_ops.is_nan(b))
475 return result
476 else:
477 return a == b
479 return _bin_op(f, a, b)
482@np_utils.np_doc('allclose')
483def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
484 return np_array_ops.all(
485 isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
488def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
490 def _gcd_cond_fn(_, x2):
491 return math_ops.reduce_any(x2 != 0)
493 def _gcd_body_fn(x1, x2):
494 # math_ops.mod will raise an error when any element of x2 is 0. To avoid
495 # that, we change those zeros to ones. Their values don't matter because
496 # they won't be used.
497 x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype))
498 x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1),
499 array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe),
500 constant_op.constant(0, x2.dtype)))
501 return (array_ops.where_v2(x1 < x2, x2,
502 x1), array_ops.where_v2(x1 < x2, x1, x2))
504 if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
505 not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
506 raise ValueError('Arguments to gcd must be integers.')
507 shape = array_ops.broadcast_dynamic_shape(
508 array_ops.shape(x1), array_ops.shape(x2))
509 x1 = array_ops.broadcast_to(x1, shape)
510 x2 = array_ops.broadcast_to(x2, shape)
511 value, _ = while_loop.while_loop(_gcd_cond_fn, _gcd_body_fn,
512 (math_ops.abs(x1), math_ops.abs(x2)))
513 return value
516# Note that np.gcd may not be present in some supported versions of numpy.
517@np_utils.np_doc('gcd')
518def gcd(x1, x2):
519 return _bin_op(_tf_gcd, x1, x2)
522# Note that np.lcm may not be present in some supported versions of numpy.
523@np_utils.np_doc('lcm')
524def lcm(x1, x2): # pylint: disable=missing-function-docstring
526 def f(x1, x2):
527 d = _tf_gcd(x1, x2)
528 # Same as the `x2_safe` trick above
529 d_safe = array_ops.where_v2(
530 math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d)
531 x1 = math_ops.abs(x1)
532 x2 = math_ops.abs(x2)
533 return array_ops.where_v2(
534 math_ops.equal(d, 0), constant_op.constant(0, d.dtype),
535 x1 * (x2 // d_safe))
537 return _bin_op(f, x1, x2)
540def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring
542 def f(x1, x2):
543 is_bool = (x1.dtype == dtypes.bool)
544 if is_bool:
545 assert x2.dtype == dtypes.bool
546 x1 = math_ops.cast(x1, dtypes.int8)
547 x2 = math_ops.cast(x2, dtypes.int8)
548 r = tf_fn(x1, x2)
549 if is_bool:
550 r = math_ops.cast(r, dtypes.bool)
551 return r
553 return _bin_op(f, x1, x2)
556@np_utils.np_doc('bitwise_and')
557def bitwise_and(x1, x2):
558 return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2)
561@np_utils.np_doc('bitwise_or')
562def bitwise_or(x1, x2):
563 return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2)
566@np_utils.np_doc('bitwise_xor')
567def bitwise_xor(x1, x2):
568 return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2)
571@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert'))
572def bitwise_not(x):
574 def f(x):
575 if x.dtype == dtypes.bool:
576 return math_ops.logical_not(x)
577 return bitwise_ops.invert(x)
579 return _scalar(f, x)
582def _scalar(tf_fn, x, promote_to_float=False):
583 """Computes the tf_fn(x) for each element in `x`.
585 Args:
586 tf_fn: function that takes a single Tensor argument.
587 x: array_like. Could be an ndarray, a Tensor or any object that can be
588 converted to a Tensor using `ops.convert_to_tensor`.
589 promote_to_float: whether to cast the argument to a float dtype
590 (`np_dtypes.default_float_type`) if it is not already.
592 Returns:
593 An ndarray with the same shape as `x`. The default output dtype is
594 determined by `np_dtypes.default_float_type`, unless x is an ndarray with a
595 floating point type, in which case the output type is same as x.dtype.
596 """
597 x = np_array_ops.asarray(x)
598 if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
599 x = x.astype(np_dtypes.default_float_type())
600 return tf_fn(x)
603@np_utils.np_doc('log')
604def log(x):
605 return _scalar(math_ops.log, x, True)
608@np_utils.np_doc('exp')
609def exp(x):
610 return _scalar(math_ops.exp, x, True)
613@np_utils.np_doc('sqrt')
614def sqrt(x):
615 return _scalar(math_ops.sqrt, x, True)
618@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute'))
619def abs(x): # pylint: disable=redefined-builtin
620 return _scalar(math_ops.abs, x)
623@np_utils.np_doc('absolute')
624def absolute(x):
625 return abs(x)
628@np_utils.np_doc('fabs')
629def fabs(x):
630 return abs(x)
633@np_utils.np_doc('ceil')
634def ceil(x):
635 return _scalar(math_ops.ceil, x, True)
638@np_utils.np_doc('floor')
639def floor(x):
640 return _scalar(math_ops.floor, x, True)
643@np_utils.np_doc('conj')
644def conj(x):
645 return _scalar(math_ops.conj, x)
648@np_utils.np_doc('negative')
649def negative(x):
650 return _scalar(math_ops.negative, x)
653@np_utils.np_doc('reciprocal')
654def reciprocal(x):
655 return _scalar(math_ops.reciprocal, x)
658@np_utils.np_doc('signbit')
659def signbit(x):
661 def f(x):
662 if x.dtype == dtypes.bool:
663 return array_ops.fill(array_ops.shape(x), False)
664 return x < 0
666 return _scalar(f, x)
669@np_utils.np_doc('sin')
670def sin(x):
671 return _scalar(math_ops.sin, x, True)
674@np_utils.np_doc('cos')
675def cos(x):
676 return _scalar(math_ops.cos, x, True)
679@np_utils.np_doc('tan')
680def tan(x):
681 return _scalar(math_ops.tan, x, True)
684@np_utils.np_doc('sinh')
685def sinh(x):
686 return _scalar(math_ops.sinh, x, True)
689@np_utils.np_doc('cosh')
690def cosh(x):
691 return _scalar(math_ops.cosh, x, True)
694@np_utils.np_doc('tanh')
695def tanh(x):
696 return _scalar(math_ops.tanh, x, True)
699@np_utils.np_doc('arcsin')
700def arcsin(x):
701 return _scalar(math_ops.asin, x, True)
704@np_utils.np_doc('arccos')
705def arccos(x):
706 return _scalar(math_ops.acos, x, True)
709@np_utils.np_doc('arctan')
710def arctan(x):
711 return _scalar(math_ops.atan, x, True)
714@np_utils.np_doc('arcsinh')
715def arcsinh(x):
716 return _scalar(math_ops.asinh, x, True)
719@np_utils.np_doc('arccosh')
720def arccosh(x):
721 return _scalar(math_ops.acosh, x, True)
724@np_utils.np_doc('arctanh')
725def arctanh(x):
726 return _scalar(math_ops.atanh, x, True)
729@np_utils.np_doc('deg2rad')
730def deg2rad(x):
732 def f(x):
733 return x * (np.pi / 180.0)
735 return _scalar(f, x, True)
738@np_utils.np_doc('rad2deg')
739def rad2deg(x):
740 return x * (180.0 / np.pi)
743_tf_float_types = [
744 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
745]
748@np_utils.np_doc('angle')
749def angle(z, deg=False): # pylint: disable=missing-function-docstring
751 def f(x):
752 if x.dtype in _tf_float_types:
753 # Workaround for b/147515503
754 return array_ops.where_v2(x < 0, np.pi, 0)
755 else:
756 return math_ops.angle(x)
758 y = _scalar(f, z, True)
759 if deg:
760 y = rad2deg(y)
761 return y
764@np_utils.np_doc('cbrt')
765def cbrt(x):
767 def f(x):
768 # __pow__ can't handle negative base, so we use `abs` here.
769 rt = math_ops.abs(x)**(1.0 / 3)
770 return array_ops.where_v2(x < 0, -rt, rt)
772 return _scalar(f, x, True)
775@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj'))
776def conjugate(x):
777 return _scalar(math_ops.conj, x)
780@np_utils.np_doc('exp2')
781def exp2(x):
783 def f(x):
784 return 2**x
786 return _scalar(f, x, True)
789@np_utils.np_doc('expm1')
790def expm1(x):
791 return _scalar(math_ops.expm1, x, True)
794@np_utils.np_doc('fix')
795def fix(x):
797 def f(x):
798 return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x))
800 return _scalar(f, x, True)
803@np_utils.np_doc('iscomplex')
804def iscomplex(x):
805 return np_array_ops.imag(x) != 0
808@np_utils.np_doc('isreal')
809def isreal(x):
810 return np_array_ops.imag(x) == 0
813@np_utils.np_doc('iscomplexobj')
814def iscomplexobj(x):
815 x = np_array_ops.array(x)
816 return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating)
819@np_utils.np_doc('isrealobj')
820def isrealobj(x):
821 return not iscomplexobj(x)
824@np_utils.np_doc('isnan')
825def isnan(x):
826 return _scalar(math_ops.is_nan, x, True)
829def _make_nan_reduction(np_fun_name, reduction, init_val):
830 """Helper to generate nan* functions."""
832 @np_utils.np_doc(np_fun_name)
833 def nan_reduction(a, axis=None, dtype=None, keepdims=False):
834 a = np_array_ops.array(a)
835 v = np_array_ops.array(init_val, dtype=a.dtype)
836 return reduction(
837 np_array_ops.where(isnan(a), v, a),
838 axis=axis,
839 dtype=dtype,
840 keepdims=keepdims)
842 return nan_reduction
845nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0)
846nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1)
849@np_utils.np_doc('nanmean')
850def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring
851 a = np_array_ops.array(a)
852 if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype(
853 a.dtype.as_numpy_dtype, np.integer):
854 return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
855 nan_mask = logical_not(isnan(a))
856 if dtype is None:
857 dtype = a.dtype.as_numpy_dtype
858 normalizer = np_array_ops.sum(
859 nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
860 return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
863@np_utils.np_doc('isfinite')
864def isfinite(x):
865 return _scalar(math_ops.is_finite, x, True)
868@np_utils.np_doc('isinf')
869def isinf(x):
870 return _scalar(math_ops.is_inf, x, True)
873@np_utils.np_doc('isneginf')
874def isneginf(x):
875 return x == np_array_ops.full_like(x, -np.inf)
878@np_utils.np_doc('isposinf')
879def isposinf(x):
880 return x == np_array_ops.full_like(x, np.inf)
883@np_utils.np_doc('log2')
884def log2(x):
885 return log(x) / np.log(2)
888@np_utils.np_doc('log10')
889def log10(x):
890 return log(x) / np.log(10)
893@np_utils.np_doc('log1p')
894def log1p(x):
895 return _scalar(math_ops.log1p, x, True)
898@np_utils.np_doc('positive')
899def positive(x):
900 return _scalar(lambda x: x, x)
903@np_utils.np_doc('sinc')
904def sinc(x):
906 def f(x):
907 pi_x = x * np.pi
908 return array_ops.where_v2(x == 0, array_ops.ones_like(x),
909 math_ops.sin(pi_x) / pi_x)
911 return _scalar(f, x, True)
914@np_utils.np_doc('square')
915def square(x):
916 return _scalar(math_ops.square, x)
919@np_utils.np_doc('diff')
920def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
922 def f(a):
923 # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
924 # TODO(agarwal): avoid depending on static rank.
925 nd = a.shape.rank
926 if nd is None:
927 raise ValueError(
928 'Function `diff` currently requires a known rank for input `a`. '
929 f'Received: a={a} (unknown rank)')
930 if (axis + nd if axis < 0 else axis) >= nd:
931 raise ValueError(
932 f'Argument `axis` (received axis={axis}) is out of bounds '
933 f'for input {a} of rank {nd}.')
934 if n < 0:
935 raise ValueError('Argument `order` must be a non-negative integer. '
936 f'Received: axis={n}')
937 slice1 = [slice(None)] * nd
938 slice2 = [slice(None)] * nd
939 slice1[axis] = slice(1, None)
940 slice2[axis] = slice(None, -1)
941 slice1 = tuple(slice1)
942 slice2 = tuple(slice2)
943 op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract
944 for _ in range(n):
945 a = op(a[slice1], a[slice2])
946 return a
948 return _scalar(f, a)
951def _wrap(f, reverse=False):
952 """Wraps binary ops so they can be added as operator overloads on ndarray."""
954 def _f(a, b):
955 if reverse:
956 a, b = b, a
958 if getattr(b, '__array_priority__',
959 0) > np_arrays.ndarray.__array_priority__:
960 return NotImplemented
962 return f(a, b)
964 return _f
967def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
968 """Helper function for comparision."""
969 dtype = np_utils.result_type(x1, x2)
970 # Cast x1 and x2 to the result_type if needed.
971 x1 = np_array_ops.array(x1, dtype=dtype)
972 x2 = np_array_ops.array(x2, dtype=dtype)
973 if cast_bool_to_int and x1.dtype == dtypes.bool:
974 x1 = math_ops.cast(x1, dtypes.int32)
975 x2 = math_ops.cast(x2, dtypes.int32)
976 return tf_fun(x1, x2)
979@np_utils.np_doc('equal')
980def equal(x1, x2):
981 return _comparison(math_ops.equal, x1, x2)
984@np_utils.np_doc('not_equal')
985def not_equal(x1, x2):
986 return _comparison(math_ops.not_equal, x1, x2)
989@np_utils.np_doc('greater')
990def greater(x1, x2):
991 return _comparison(math_ops.greater, x1, x2, True)
994@np_utils.np_doc('greater_equal')
995def greater_equal(x1, x2):
996 return _comparison(math_ops.greater_equal, x1, x2, True)
999@np_utils.np_doc('less')
1000def less(x1, x2):
1001 return _comparison(math_ops.less, x1, x2, True)
1004@np_utils.np_doc('less_equal')
1005def less_equal(x1, x2):
1006 return _comparison(math_ops.less_equal, x1, x2, True)
1009@np_utils.np_doc('array_equal')
1010def array_equal(a1, a2): # pylint: disable=missing-function-docstring
1012 def f(x1, x2):
1013 return np_utils.cond(
1014 math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
1015 lambda: np_utils.cond( # pylint: disable=g-long-lambda
1016 np_utils.reduce_all(
1017 math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
1018 ),
1019 lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
1020 lambda: constant_op.constant(False)),
1021 lambda: constant_op.constant(False))
1023 return _comparison(f, a1, a2)
1026def _logical_binary_op(tf_fun, x1, x2):
1027 x1 = np_array_ops.array(x1, dtype=np.bool_)
1028 x2 = np_array_ops.array(x2, dtype=np.bool_)
1029 return tf_fun(x1, x2)
1032@np_utils.np_doc('logical_and')
1033def logical_and(x1, x2):
1034 return _logical_binary_op(math_ops.logical_and, x1, x2)
1037@np_utils.np_doc('logical_or')
1038def logical_or(x1, x2):
1039 return _logical_binary_op(math_ops.logical_or, x1, x2)
1042@np_utils.np_doc('logical_xor')
1043def logical_xor(x1, x2):
1044 return _logical_binary_op(math_ops.logical_xor, x1, x2)
1047@np_utils.np_doc('logical_not')
1048def logical_not(x):
1049 x = np_array_ops.array(x, dtype=np.bool_)
1050 return math_ops.logical_not(x)
1053@np_utils.np_doc('linspace')
1054def linspace( # pylint: disable=missing-docstring
1055 start,
1056 stop,
1057 num=50,
1058 endpoint=True,
1059 retstep=False,
1060 dtype=float,
1061 axis=0):
1062 if dtype:
1063 dtype = np_utils.result_type(dtype)
1064 start = np_array_ops.array(start, dtype=dtype)
1065 stop = np_array_ops.array(stop, dtype=dtype)
1066 if num < 0:
1067 raise ValueError(
1068 'Argument `num` (number of samples) must be a non-negative integer. '
1069 f'Received: num={num}')
1070 step = ops.convert_to_tensor(np.nan)
1071 if endpoint:
1072 result = math_ops.linspace(start, stop, num, axis=axis)
1073 if num > 1:
1074 step = (stop - start) / (num - 1)
1075 else:
1076 # math_ops.linspace does not support endpoint=False so we manually handle it
1077 # here.
1078 if num > 0:
1079 step = ((stop - start) / num)
1080 if num > 1:
1081 new_stop = math_ops.cast(stop, step.dtype) - step
1082 start = math_ops.cast(start, new_stop.dtype)
1083 result = math_ops.linspace(start, new_stop, num, axis=axis)
1084 else:
1085 result = math_ops.linspace(start, stop, num, axis=axis)
1086 if dtype:
1087 if dtype.is_integer:
1088 # Since numpy 1.20, linspace's rounding is towards -inf instead of 0
1089 result = math_ops.floor(result)
1090 result = math_ops.cast(result, dtype)
1091 if retstep:
1092 return (result, step)
1093 else:
1094 return result
1097@np_utils.np_doc('logspace')
1098def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
1099 dtype = np_utils.result_type(start, stop, dtype)
1100 result = linspace(
1101 start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis)
1102 result = math_ops.pow(math_ops.cast(base, result.dtype), result)
1103 if dtype:
1104 result = math_ops.cast(result, dtype)
1105 return result
1108@np_utils.np_doc('geomspace')
1109def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring
1110 dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type(
1111 start, stop, float(num), np_array_ops.zeros((), dtype))
1112 computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32)
1113 start = np_array_ops.asarray(start, dtype=computation_dtype)
1114 stop = np_array_ops.asarray(stop, dtype=computation_dtype)
1115 # follow the numpy geomspace convention for negative and complex endpoints
1116 start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
1117 stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
1118 signflip = 1 - start_sign * stop_sign // 2
1119 res = signflip * logspace(
1120 log10(signflip * start),
1121 log10(signflip * stop),
1122 num,
1123 endpoint=endpoint,
1124 base=10.0,
1125 dtype=computation_dtype,
1126 axis=0)
1127 if axis != 0:
1128 res = np_array_ops.moveaxis(res, 0, axis)
1129 return math_ops.cast(res, dtype)
1132@np_utils.np_doc('ptp')
1133def ptp(a, axis=None, keepdims=None):
1134 return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) -
1135 np_array_ops.amin(a, axis=axis, keepdims=keepdims))
1138@np_utils.np_doc_only('concatenate')
1139def concatenate(arys, axis=0):
1140 if not isinstance(arys, (list, tuple)):
1141 arys = [arys]
1142 if not arys:
1143 raise ValueError('Need at least one array to concatenate. Received empty '
1144 f'input: arys={arys}')
1145 dtype = np_utils.result_type(*arys)
1146 arys = [np_array_ops.array(array, dtype=dtype) for array in arys]
1147 return array_ops.concat(arys, axis)
1150@np_utils.np_doc_only('tile')
1151def tile(a, reps): # pylint: disable=missing-function-docstring
1152 a = np_array_ops.array(a)
1153 reps = array_ops.reshape(np_array_ops.array(reps, dtype=dtypes.int32), [-1])
1155 a_rank = array_ops.rank(a)
1156 reps_size = array_ops.size(reps)
1157 reps = array_ops.pad(
1158 reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1)
1159 a_shape = array_ops.pad(
1160 array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]],
1161 constant_values=1)
1162 a = array_ops.reshape(a, a_shape)
1164 return array_ops.tile(a, reps)
1167@np_utils.np_doc('count_nonzero')
1168def count_nonzero(a, axis=None):
1169 return math_ops.count_nonzero(np_array_ops.array(a), axis)
1172@np_utils.np_doc('argsort')
1173def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring
1174 # TODO(nareshmodi): make string tensors also work.
1175 if kind not in ('quicksort', 'stable'):
1176 raise ValueError(
1177 'Invalid value for argument `kind`. '
1178 'Only kind="quicksort" and kind="stable" are supported. '
1179 f'Received: kind={kind}')
1180 if order is not None:
1181 raise ValueError('The `order` argument is not supported. Pass order=None')
1182 stable = (kind == 'stable')
1184 a = np_array_ops.array(a)
1186 def _argsort(a, axis, stable):
1187 if axis is None:
1188 a = array_ops.reshape(a, [-1])
1189 axis = 0
1191 return sort_ops.argsort(a, axis, stable=stable)
1193 tf_ans = np_utils.cond(
1194 math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]),
1195 lambda: _argsort(a, axis, stable))
1197 return np_array_ops.array(tf_ans, dtype=np.intp)
1200@np_utils.np_doc('sort')
1201def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring
1202 if kind != 'quicksort':
1203 raise ValueError(
1204 'Invalid value for argument `kind`. '
1205 'Only kind="quicksort" is supported. '
1206 f'Received: kind={kind}')
1207 if order is not None:
1208 raise ValueError('The `order` argument is not supported. Pass order=None')
1210 a = np_array_ops.array(a)
1212 if axis is None:
1213 return sort_ops.sort(array_ops.reshape(a, [-1]), 0)
1214 else:
1215 return sort_ops.sort(a, axis)
1218def _argminmax(fn, a, axis=None):
1219 a = np_array_ops.array(a)
1220 if axis is None:
1221 # When axis is None numpy flattens the array.
1222 a_t = array_ops.reshape(a, [-1])
1223 else:
1224 a_t = np_array_ops.atleast_1d(a)
1225 return fn(input=a_t, axis=axis)
1228@np_utils.np_doc('argmax')
1229def argmax(a, axis=None):
1230 return _argminmax(math_ops.argmax, a, axis)
1233@np_utils.np_doc('argmin')
1234def argmin(a, axis=None):
1235 return _argminmax(math_ops.argmin, a, axis)
1238@np_utils.np_doc('append')
1239def append(arr, values, axis=None):
1240 if axis is None:
1241 return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0)
1242 else:
1243 return concatenate([arr, values], axis=axis)
1246@np_utils.np_doc('average')
1247def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring
1248 if axis is not None and not isinstance(axis, int):
1249 # TODO(wangpeng): Support tuple of ints as `axis`
1250 raise ValueError('Argument `axis` must be an integer. '
1251 f'Received axis={axis} (of type {type(axis)})')
1252 a = np_array_ops.array(a)
1253 if weights is None: # Treat all weights as 1
1254 if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1255 a = a.astype(
1256 np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
1257 avg = math_ops.reduce_mean(a, axis=axis)
1258 if returned:
1259 if axis is None:
1260 weights_sum = array_ops.size(a)
1261 else:
1262 weights_sum = array_ops.shape(a)[axis]
1263 weights_sum = math_ops.cast(weights_sum, a.dtype)
1264 else:
1265 if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1266 out_dtype = np_utils.result_type(a.dtype, weights)
1267 else:
1268 out_dtype = np_utils.result_type(a.dtype, weights,
1269 np_dtypes.default_float_type())
1270 a = np_array_ops.array(a, out_dtype)
1271 weights = np_array_ops.array(weights, out_dtype)
1273 def rank_equal_case():
1274 control_flow_assert.Assert(
1275 math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)),
1276 [array_ops.shape(a), array_ops.shape(weights)])
1277 weights_sum = math_ops.reduce_sum(weights, axis=axis)
1278 avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum
1279 return avg, weights_sum
1281 if axis is None:
1282 avg, weights_sum = rank_equal_case()
1283 else:
1285 def rank_not_equal_case():
1286 control_flow_assert.Assert(
1287 array_ops.rank(weights) == 1, [array_ops.rank(weights)])
1288 weights_sum = math_ops.reduce_sum(weights)
1289 axes = ops.convert_to_tensor([[axis], [0]])
1290 avg = math_ops.tensordot(a, weights, axes) / weights_sum
1291 return avg, weights_sum
1293 # We condition on rank rather than shape equality, because if we do the
1294 # latter, when the shapes are partially unknown but the ranks are known
1295 # and different, np_utils.cond will run shape checking on the true branch,
1296 # which will raise a shape-checking error.
1297 avg, weights_sum = np_utils.cond(
1298 math_ops.equal(array_ops.rank(a), array_ops.rank(weights)),
1299 rank_equal_case, rank_not_equal_case)
1301 avg = np_array_ops.array(avg)
1302 if returned:
1303 weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg))
1304 return avg, weights_sum
1305 return avg
1308@np_utils.np_doc('trace')
1309def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring
1310 if dtype:
1311 dtype = np_utils.result_type(dtype)
1312 a = np_array_ops.asarray(a, dtype)
1314 if offset == 0:
1315 a_shape = a.shape
1316 if a_shape.rank is not None:
1317 rank = len(a_shape)
1318 if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
1319 axis2 == rank - 1):
1320 return math_ops.trace(a)
1322 a = np_array_ops.diagonal(a, offset, axis1, axis2)
1323 return np_array_ops.sum(a, -1, dtype)
1326@np_utils.np_doc('meshgrid')
1327def meshgrid(*xi, **kwargs):
1328 """This currently requires copy=True and sparse=False."""
1329 sparse = kwargs.get('sparse', False)
1330 if sparse:
1331 raise ValueError(
1332 'Function `meshgrid` does not support returning sparse arrays yet. '
1333 f'Received: sparse={sparse}')
1335 copy = kwargs.get('copy', True)
1336 if not copy:
1337 raise ValueError('Function `meshgrid` only supports copy=True. '
1338 f'Received: copy={copy}')
1340 indexing = kwargs.get('indexing', 'xy')
1342 xi = [np_array_ops.asarray(arg) for arg in xi]
1343 kwargs = {'indexing': indexing}
1345 outputs = array_ops.meshgrid(*xi, **kwargs)
1347 return outputs
1350# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument
1351# `subscripts`, even though the doc says it has.
1352@np_utils.np_doc_only('einsum')
1353def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring
1354 casting = kwargs.get('casting', 'safe')
1355 optimize = kwargs.get('optimize', False)
1356 if casting == 'safe':
1357 operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access
1358 elif casting == 'no':
1359 operands = [np_array_ops.asarray(x) for x in operands]
1360 else:
1361 raise ValueError(
1362 'Invalid value for argument `casting`. '
1363 f'Expected casting="safe" or casting="no". Received: casting={casting}')
1364 if not optimize:
1365 # TF doesn't have a "no optimization" option.
1366 # TODO(wangpeng): Print a warning that np and tf use different
1367 # optimizations.
1368 tf_optimize = 'greedy'
1369 elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
1370 tf_optimize = 'greedy'
1371 elif optimize == 'greedy':
1372 tf_optimize = 'greedy'
1373 elif optimize == 'optimal':
1374 tf_optimize = 'optimal'
1375 else:
1376 raise ValueError(
1377 'Invalid value for argument `optimize`. '
1378 'Expected one of {True, "greedy", "optimal"}. '
1379 f'Received: optimize={optimize}')
1381 res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
1382 return res
1385def _tensor_t(self):
1386 """Returns a Tensor which is the transpose of this Tensor."""
1387 return self.transpose()
1390def _tensor_ndim(self):
1391 """Returns the rank of the Tensor."""
1392 return self.shape.ndims
1395def _tensor_pos(self):
1396 """Returns self, for unary operator `+`."""
1397 return self
1400def _tensor_size(self):
1401 """Returns the number of elements in this Tensor, if fully known."""
1402 if not self.shape.is_fully_defined():
1403 return None
1404 return np.prod(self.shape.as_list())
1407def _tensor_tolist(self):
1408 if isinstance(self, ops.EagerTensor):
1409 return self._numpy().tolist() # pylint: disable=protected-access
1411 raise ValueError('Symbolic Tensors do not support the tolist API.')
1414def enable_numpy_methods_on_tensor():
1415 """Adds additional NumPy methods on tf.Tensor class."""
1416 t = property(_tensor_t)
1417 setattr(ops.Tensor, 'T', t)
1419 ndim = property(_tensor_ndim)
1420 setattr(ops.Tensor, 'ndim', ndim)
1422 size = property(_tensor_size)
1423 setattr(ops.Tensor, 'size', size)
1425 setattr(ops.Tensor, '__pos__', _tensor_pos)
1426 setattr(ops.Tensor, 'tolist', _tensor_tolist)
1428 # TODO(b/178540516): Make a custom `setattr` that changes the method's
1429 # docstring to the TF one.
1430 setattr(ops.Tensor, 'transpose', np_array_ops.transpose)
1431 setattr(ops.Tensor, 'flatten', np_array_ops.flatten)
1432 setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access
1433 setattr(ops.Tensor, 'ravel', np_array_ops.ravel)
1434 setattr(ops.Tensor, 'clip', clip)
1435 setattr(ops.Tensor, 'astype', math_ops.cast)
1436 setattr(ops.Tensor, '__round__', np_array_ops.around)
1437 setattr(ops.Tensor, 'max', np_array_ops.amax)
1438 setattr(ops.Tensor, 'mean', np_array_ops.mean)
1439 setattr(ops.Tensor, 'min', np_array_ops.amin)
1441 # TODO(wangpeng): Remove `data` when all uses of it are removed
1442 data = property(lambda self: self)
1443 setattr(ops.Tensor, 'data', data)