Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_utils.py: 57%
307 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for internal use."""
16# pylint: disable=g-direct-tensorflow-import
18import inspect
19import numbers
20import os
21import re
22import numpy as np
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import cond as tf_cond
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.numpy_ops import np_arrays
31from tensorflow.python.ops.numpy_ops import np_dtypes
32from tensorflow.python.ops.numpy_ops import np_export
33from tensorflow.python.types import core
34from tensorflow.python.util import nest
37def _canonicalize_axis(axis, rank):
38 return _canonicalize_axes([axis], rank)[0]
41def _canonicalize_axes(axes, rank):
42 rank = _maybe_static(rank)
44 if isinstance(rank, core.Tensor):
45 canonicalizer = (
46 lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
47 else:
48 canonicalizer = lambda axis: axis + rank if axis < 0 else axis
50 return [canonicalizer(axis) for axis in axes]
53def _supports_signature():
54 return hasattr(inspect, 'signature')
57def _to_tf_type(dtype):
58 """Converts a native python or numpy type to TF DType.
60 Args:
61 dtype: Could be a python type, a numpy type or a TF DType.
63 Returns:
64 A tensorflow `DType`.
65 """
66 return dtypes.as_dtype(dtype)
69def _to_numpy_type(dtype):
70 """Converts a native python or TF DType to numpy type.
72 Args:
73 dtype: Could be a python type, a numpy type or a TF DType.
75 Returns:
76 A NumPy `dtype`.
77 """
78 if isinstance(dtype, dtypes.DType):
79 return dtype.as_numpy_dtype
80 return np.dtype(dtype)
83def isscalar(val):
84 """Returns whether `val` is a scalar value or scalar Tensor."""
85 if isinstance(val, np_arrays.ndarray):
86 val = val.data
87 if isinstance(val, core.Tensor):
88 ndims = val.shape.ndims
89 if ndims is not None:
90 return ndims == 0
91 else:
92 return math_ops.equal(array_ops.rank(val), 0)
93 else:
94 return np.isscalar(val)
97def _has_docstring(f):
98 return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
99 f.__doc__)
102def _add_blank_line(s):
103 if s.endswith('\n'):
104 return s + '\n'
105 else:
106 return s + '\n\n'
109def _np_signature(f):
110 """An enhanced inspect.signature that can handle numpy.ufunc."""
111 # TODO(wangpeng): consider migrating away from inspect.signature.
112 # inspect.signature is supported in Python 3.3.
113 if not hasattr(inspect, 'signature'):
114 return None
115 if f is None:
116 return None
117 if not isinstance(f, np.ufunc):
118 try:
119 return inspect.signature(f)
120 except ValueError:
121 return None
123 def names_from_num(prefix, n):
124 if n <= 0:
125 return []
126 elif n == 1:
127 return [prefix]
128 else:
129 return [prefix + str(i + 1) for i in range(n)]
131 input_names = names_from_num('x', f.nin)
132 output_names = names_from_num('out', f.nout)
133 keyword_only_params = [('where', True), ('casting', 'same_kind'),
134 ('order', 'K'), ('dtype', None), ('subok', True),
135 ('signature', None), ('extobj', None)]
136 params = []
137 params += [
138 inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
139 for name in input_names
140 ]
141 if f.nout > 1:
142 params += [
143 inspect.Parameter(
144 name, inspect.Parameter.POSITIONAL_ONLY, default=None)
145 for name in output_names
146 ]
147 params += [
148 inspect.Parameter(
149 'out',
150 inspect.Parameter.POSITIONAL_OR_KEYWORD,
151 default=None if f.nout == 1 else (None,) * f.nout)
152 ]
153 params += [
154 inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
155 for name, default in keyword_only_params
156 ]
157 return inspect.Signature(params)
160# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
161# allow positional-only argument. So we conflate positional-only, keyword-only
162# and positional-or-keyword arguments here.
163def _is_compatible_param_kind(a, b):
165 def relax(k):
166 if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
167 return inspect.Parameter.POSITIONAL_OR_KEYWORD
168 return k
170 return relax(a) == relax(b)
173def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
174 """Mutually propagates information between `np_fun_name` and `np_fun`.
176 If one is None and the other is not, we'll try to make the former not None in
177 a best effort.
179 Args:
180 np_fun_name: name for the np_fun symbol. At least one of np_fun or
181 np_fun_name shoud be set.
182 np_fun: the numpy function whose docstring will be used.
184 Returns:
185 Processed `np_fun_name` and `np_fun`.
186 """
187 if np_fun_name is not None:
188 assert isinstance(np_fun_name, str)
189 if np_fun is not None:
190 assert not isinstance(np_fun, str)
191 if np_fun is None:
192 assert np_fun_name is not None
193 try:
194 np_fun = getattr(np, str(np_fun_name))
195 except AttributeError:
196 np_fun = None
197 if np_fun_name is None:
198 assert np_fun is not None
199 np_fun_name = np_fun.__name__
200 return np_fun_name, np_fun
203def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
204 link=None):
205 """Helper to get docs."""
206 assert np_f or np_fun_name
207 if not np_fun_name:
208 np_fun_name = np_f.__name__
209 doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
210 if unsupported_params:
211 doc += 'Unsupported arguments: ' + ', '.join(
212 '`' + name + '`' for name in unsupported_params) + '.\n\n'
213 if _has_docstring(f):
214 doc += f.__doc__
215 doc = _add_blank_line(doc)
216 # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
217 # doc according to some global switch.
218 doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
219 return doc
222_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
225def get_np_doc_form():
226 """Gets the form of the original numpy docstrings.
228 Returns:
229 See `set_np_doc_form` for the list of valid values.
230 """
231 return _np_doc_form
234def set_np_doc_form(value):
235 r"""Selects the form of the original numpy docstrings.
237 This function sets a global variable that controls how a tf-numpy symbol's
238 docstring should refer to the original numpy docstring. If `value` is
239 `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
240 docstring. Otherwise, a link to the original numpy docstring will be
241 added. Which numpy version the link points to depends on `value`:
242 * `'stable'`: the current stable version;
243 * `'dev'`: the current development version;
244 * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
245 e.g. '1.16'.
247 Args:
248 value: the value to set the global variable to.
249 """
250 global _np_doc_form
251 _np_doc_form = value
254class Link:
256 def __init__(self, v):
257 self.value = v
260class AliasOf:
262 def __init__(self, v):
263 self.value = v
266class NoLink:
267 pass
270def generate_link(flag, np_fun_name):
271 """Generates link from numpy function name.
273 Args:
274 flag: the flag to control link form. See `set_np_doc_form`.
275 np_fun_name: the numpy function name.
277 Returns:
278 A string.
279 """
280 # Only adds link in this case
281 if flag == 'dev':
282 template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
283 elif flag == 'stable':
284 template = (
285 'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
286 elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
287 # `flag` is the version number
288 template = ('https://numpy.org/doc/' + flag +
289 '/reference/generated/numpy.%s.html')
290 else:
291 return None
292 return template % np_fun_name
295_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
296 ('True', 'true', '1'))
299def is_check_link():
300 return _is_check_link
303def set_check_link(value):
304 global _is_check_link
305 _is_check_link = value
308def _add_np_doc(doc, np_fun_name, np_f, link):
309 """Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
311 See `set_np_doc_form` for how it controls the form of the numpy docstring.
313 Args:
314 doc: the docstring to be appended to.
315 np_fun_name: the name of the numpy function.
316 np_f: (optional) the numpy function.
317 link: (optional) which link to use. See `np_doc` for details.
319 Returns:
320 `doc` with numpy docstring appended.
321 """
322 flag = get_np_doc_form()
323 if flag == 'inlined':
324 if _has_docstring(np_f):
325 doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
326 # TODO(wangpeng): It looks like code snippets in numpy doc don't work
327 # correctly with doctest. Fix that and remove the reformatting of the np_f
328 # comment.
329 doc += np_f.__doc__.replace('>>>', '>')
330 elif isinstance(flag, str):
331 if link is None:
332 url = generate_link(flag, np_fun_name)
333 elif isinstance(link, AliasOf):
334 url = generate_link(flag, link.value)
335 elif isinstance(link, Link):
336 url = link.value
337 else:
338 url = None
339 if url is not None:
340 if is_check_link():
341 # Imports locally because some builds may not have `requests`
342 import requests # pylint: disable=g-import-not-at-top
343 r = requests.head(url)
344 if r.status_code != 200:
345 raise ValueError(
346 f'Check link failed at [{url}] with status code {r.status_code}. '
347 f'Argument `np_fun_name` is {np_fun_name}.')
348 doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
349 np_fun_name, url)
350 return doc
353_is_sig_mismatch_an_error = (
354 os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))
357def is_sig_mismatch_an_error():
358 return _is_sig_mismatch_an_error
361def set_is_sig_mismatch_an_error(value):
362 global _is_sig_mismatch_an_error
363 _is_sig_mismatch_an_error = value
366def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None,
367 link=None):
368 """Attachs numpy docstring to a function.
370 Args:
371 np_fun_name: name for the np_fun symbol. At least one of np_fun or
372 np_fun_name shoud be set.
373 np_fun: (optional) the numpy function whose docstring will be used.
374 export: whether to export this symbol under module
375 `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
376 a function directly under the `numpy` module, not under any submodule of
377 `numpy` (e.g. `numpy.random`).
378 unsupported_params: (optional) the list of parameters not supported
379 by tf.numpy.
380 link: (optional) which link to use. If `None`, a default link generated from
381 `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
382 be used in place of `np_fun_name` for the link generation. If an instance
383 of `Link`, `link.value` will be used as the whole link. If an instance of
384 `NoLink`, no link will be added.
386 Returns:
387 A function decorator that attaches the docstring from `np_fun` to the
388 decorated function.
389 """
390 np_fun_name_orig, np_fun_orig = np_fun_name, np_fun
391 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
392 np_sig = _np_signature(np_fun)
393 if unsupported_params is None:
394 unsupported_params = []
396 def decorator(f):
397 """The decorator."""
398 if hasattr(inspect, 'signature') and np_sig is not None:
399 try:
400 sig = inspect.signature(f)
401 except ValueError:
402 sig = None
403 if sig is not None:
404 for name, param in sig.parameters.items():
405 np_param = np_sig.parameters.get(name)
406 if np_param is None:
407 if is_sig_mismatch_an_error():
408 raise TypeError(
409 f'Cannot find parameter {name} in the numpy function\'s '
410 f'signature (which has these parameters: '
411 f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` '
412 f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
413 else:
414 continue
415 if (is_sig_mismatch_an_error() and
416 not _is_compatible_param_kind(param.kind, np_param.kind)):
417 raise TypeError(
418 f'Parameter {name} is of kind {param.kind} while in numpy it '
419 f'is of kind {np_param.kind}. Argument `np_fun_name` is '
420 f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
421 has_default = (param.default != inspect.Parameter.empty)
422 np_has_default = (np_param.default != inspect.Parameter.empty)
423 if is_sig_mismatch_an_error() and has_default != np_has_default:
424 raise TypeError(
425 'Parameter {} should{} have a default value. Argument '
426 '`np_fun_name` is {}. Argument `np_fun` is {}.'.format(
427 name, '' if np_has_default else ' not', np_fun_name_orig,
428 np_fun_orig))
429 for name in np_sig.parameters:
430 if name not in sig.parameters:
431 unsupported_params.append(name)
432 f.__doc__ = _np_doc_helper(
433 f, np_fun, np_fun_name=np_fun_name,
434 unsupported_params=unsupported_params, link=link)
435 if export:
436 return np_export.np_export(np_fun_name)(f)
437 else:
438 return f
440 return decorator
443def np_doc_only(np_fun_name, np_fun=None, export=True):
444 """Attachs numpy docstring to a function.
446 This differs from np_doc in that it doesn't check for a match in signature.
448 Args:
449 np_fun_name: name for the np_fun symbol. At least one of np_fun or
450 np_fun_name shoud be set.
451 np_fun: (optional) the numpy function whose docstring will be used.
452 export: whether to export this symbol under module
453 `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
454 function directly under the `numpy` module, not under any submodule of
455 `numpy` (e.g. `numpy.random`).
457 Returns:
458 A function decorator that attaches the docstring from `np_fun` to the
459 decorated function.
460 """
461 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
463 def decorator(f):
464 f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
465 if export:
466 return np_export.np_export(np_fun_name)(f)
467 else:
468 return f
470 return decorator
473# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
474@np_doc('finfo')
475def finfo(dtype):
476 """Note that currently it just forwards to the numpy namesake, while
477 tensorflow and numpy dtypes may have different properties."""
478 return np.finfo(_to_numpy_type(dtype))
479# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
482def _maybe_get_dtype(x):
483 """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
484 # Don't put np.ndarray in this list, because np.result_type looks at the
485 # value (not just dtype) of np.ndarray to decide the result type.
486 if isinstance(x, numbers.Real):
487 return x
488 if isinstance(x, indexed_slices.IndexedSlices) or tensor_util.is_tf_type(x):
489 return _to_numpy_type(x.dtype)
490 if isinstance(x, dtypes.DType):
491 return x.as_numpy_dtype
492 if isinstance(x, (list, tuple)):
493 raise ValueError(
494 f'Cannot find dtype for type inference from argument `x` of a sequence '
495 f'type {type(x)}. For sequences, please call this function on each '
496 f'element individually.')
497 return x
500# Can't use np_doc because np.result_type is a builtin function.
501@np_doc_only('result_type')
502def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring
503 arrays_and_dtypes = [
504 _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
505 ]
506 if not arrays_and_dtypes:
507 # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
508 arrays_and_dtypes = [np.asarray([])]
509 return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
512def result_type_unary(a, dtype): # pylint: disable=missing-function-docstring
513 """Find the result type from a single input and a dtype."""
514 if dtype:
515 # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
516 return result_type(dtype)
518 # np_utils.result_type treats string inputs as dtype strings, not as strings.
519 # but for unary we want to treat it as a string input.
520 if isinstance(a, str):
521 return np.unicode_
522 elif isinstance(a, bytes):
523 return np.bytes_
525 # TF and numpy has different interpretations of Python types such as
526 # `float`, so we let `np_utils.result_type` decide.
527 return result_type(a)
530def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring
531 """A specialization of result_type for 2 arguments for performance reasons."""
532 try:
533 return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access
534 _maybe_get_dtype(t2)) # pylint: disable=protected-access
535 except ValueError:
536 return result_type(t1, t2)
539@np_doc('promote_types')
540def promote_types(type1, type2): # pylint: disable=missing-function-docstring
541 type1 = _to_numpy_type(type1)
542 type2 = _to_numpy_type(type2)
543 return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
546def tf_broadcast(*args):
547 """Broadcast tensors.
549 Args:
550 *args: a list of tensors whose shapes are broadcastable against each other.
552 Returns:
553 Tensors broadcasted to the common shape.
554 """
555 if len(args) <= 1:
556 return args
557 sh = array_ops.shape(args[0])
558 for arg in args[1:]:
559 sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
560 return [array_ops.broadcast_to(arg, sh) for arg in args]
563# TODO(wangpeng): Move the following functions to a separate file and check for
564# float dtypes in each of them.
567def get_static_value(x):
568 """A version of tf.get_static_value that returns None on float dtypes.
570 It returns None on float dtypes in order to avoid breaking gradients.
572 Args:
573 x: a tensor.
575 Returns:
576 Same as `tf.get_static_value`, except that it returns None when `x` has a
577 float dtype.
578 """
579 if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
580 return None
581 return tensor_util.constant_value(x)
584def _maybe_static(x):
585 value = get_static_value(x)
586 if value is None:
587 return x
588 else:
589 return value
592# All the following functions exist becaues get_static_value can't handle
593# their TF counterparts.
596def cond(pred, true_fn, false_fn):
597 """A version of tf.cond that tries to evaluate the condition."""
598 v = get_static_value(pred)
599 if v is None:
600 return tf_cond.cond(pred, true_fn, false_fn)
601 if v:
602 return true_fn()
603 else:
604 return false_fn()
607def add(a, b):
608 """A version of tf.add that eagerly evaluates if possible."""
609 return _maybe_static(a) + _maybe_static(b)
612def subtract(a, b):
613 """A version of tf.subtract that eagerly evaluates if possible."""
614 return _maybe_static(a) - _maybe_static(b)
617def greater(a, b):
618 """A version of tf.greater that eagerly evaluates if possible."""
619 return _maybe_static(a) > _maybe_static(b)
622def greater_equal(a, b):
623 """A version of tf.greater_equal that eagerly evaluates if possible."""
624 return _maybe_static(a) >= _maybe_static(b)
627def less_equal(a, b):
628 """A version of tf.less_equal that eagerly evaluates if possible."""
629 return _maybe_static(a) <= _maybe_static(b)
632def logical_and(a, b):
633 """A version of tf.logical_and that eagerly evaluates if possible."""
634 a_value = get_static_value(a)
635 if a_value is not None:
636 if np.isscalar(a_value):
637 if a_value:
638 return _maybe_static(b)
639 else:
640 return a_value
641 else:
642 return a_value & _maybe_static(b)
643 else:
644 return a & _maybe_static(b)
647def logical_or(a, b):
648 """A version of tf.logical_or that eagerly evaluates if possible."""
649 a_value = get_static_value(a)
650 if a_value is not None:
651 if np.isscalar(a_value):
652 if a_value:
653 return a_value
654 else:
655 return _maybe_static(b)
656 else:
657 return a_value | _maybe_static(b)
658 else:
659 return a | _maybe_static(b)
662def getitem(a, slice_spec):
663 """A version of __getitem__ that eagerly evaluates if possible."""
664 return _maybe_static(a)[slice_spec]
667def reduce_all(input_tensor, axis=None, keepdims=False):
668 """A version of tf.reduce_all that eagerly evaluates if possible."""
669 v = get_static_value(input_tensor)
670 if v is None:
671 return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
672 else:
673 return v.all(axis=axis, keepdims=keepdims)
676def reduce_any(input_tensor, axis=None, keepdims=False):
677 """A version of tf.reduce_any that eagerly evaluates if possible."""
678 v = get_static_value(input_tensor)
679 if v is None:
680 return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
681 else:
682 return v.any(axis=axis, keepdims=keepdims)
685def tf_rank(t):
686 r = t.shape.rank
687 if r is not None:
688 return r
689 return array_ops.rank(t)