Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numpy_ops/np_utils.py: 57%

307 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions for internal use.""" 

16# pylint: disable=g-direct-tensorflow-import 

17 

18import inspect 

19import numbers 

20import os 

21import re 

22import numpy as np 

23 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import indexed_slices 

26from tensorflow.python.framework import tensor_util 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import cond as tf_cond 

29from tensorflow.python.ops import math_ops 

30from tensorflow.python.ops.numpy_ops import np_arrays 

31from tensorflow.python.ops.numpy_ops import np_dtypes 

32from tensorflow.python.ops.numpy_ops import np_export 

33from tensorflow.python.types import core 

34from tensorflow.python.util import nest 

35 

36 

37def _canonicalize_axis(axis, rank): 

38 return _canonicalize_axes([axis], rank)[0] 

39 

40 

41def _canonicalize_axes(axes, rank): 

42 rank = _maybe_static(rank) 

43 

44 if isinstance(rank, core.Tensor): 

45 canonicalizer = ( 

46 lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) 

47 else: 

48 canonicalizer = lambda axis: axis + rank if axis < 0 else axis 

49 

50 return [canonicalizer(axis) for axis in axes] 

51 

52 

53def _supports_signature(): 

54 return hasattr(inspect, 'signature') 

55 

56 

57def _to_tf_type(dtype): 

58 """Converts a native python or numpy type to TF DType. 

59 

60 Args: 

61 dtype: Could be a python type, a numpy type or a TF DType. 

62 

63 Returns: 

64 A tensorflow `DType`. 

65 """ 

66 return dtypes.as_dtype(dtype) 

67 

68 

69def _to_numpy_type(dtype): 

70 """Converts a native python or TF DType to numpy type. 

71 

72 Args: 

73 dtype: Could be a python type, a numpy type or a TF DType. 

74 

75 Returns: 

76 A NumPy `dtype`. 

77 """ 

78 if isinstance(dtype, dtypes.DType): 

79 return dtype.as_numpy_dtype 

80 return np.dtype(dtype) 

81 

82 

83def isscalar(val): 

84 """Returns whether `val` is a scalar value or scalar Tensor.""" 

85 if isinstance(val, np_arrays.ndarray): 

86 val = val.data 

87 if isinstance(val, core.Tensor): 

88 ndims = val.shape.ndims 

89 if ndims is not None: 

90 return ndims == 0 

91 else: 

92 return math_ops.equal(array_ops.rank(val), 0) 

93 else: 

94 return np.isscalar(val) 

95 

96 

97def _has_docstring(f): 

98 return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and 

99 f.__doc__) 

100 

101 

102def _add_blank_line(s): 

103 if s.endswith('\n'): 

104 return s + '\n' 

105 else: 

106 return s + '\n\n' 

107 

108 

109def _np_signature(f): 

110 """An enhanced inspect.signature that can handle numpy.ufunc.""" 

111 # TODO(wangpeng): consider migrating away from inspect.signature. 

112 # inspect.signature is supported in Python 3.3. 

113 if not hasattr(inspect, 'signature'): 

114 return None 

115 if f is None: 

116 return None 

117 if not isinstance(f, np.ufunc): 

118 try: 

119 return inspect.signature(f) 

120 except ValueError: 

121 return None 

122 

123 def names_from_num(prefix, n): 

124 if n <= 0: 

125 return [] 

126 elif n == 1: 

127 return [prefix] 

128 else: 

129 return [prefix + str(i + 1) for i in range(n)] 

130 

131 input_names = names_from_num('x', f.nin) 

132 output_names = names_from_num('out', f.nout) 

133 keyword_only_params = [('where', True), ('casting', 'same_kind'), 

134 ('order', 'K'), ('dtype', None), ('subok', True), 

135 ('signature', None), ('extobj', None)] 

136 params = [] 

137 params += [ 

138 inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY) 

139 for name in input_names 

140 ] 

141 if f.nout > 1: 

142 params += [ 

143 inspect.Parameter( 

144 name, inspect.Parameter.POSITIONAL_ONLY, default=None) 

145 for name in output_names 

146 ] 

147 params += [ 

148 inspect.Parameter( 

149 'out', 

150 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

151 default=None if f.nout == 1 else (None,) * f.nout) 

152 ] 

153 params += [ 

154 inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default) 

155 for name, default in keyword_only_params 

156 ] 

157 return inspect.Signature(params) 

158 

159 

160# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't 

161# allow positional-only argument. So we conflate positional-only, keyword-only 

162# and positional-or-keyword arguments here. 

163def _is_compatible_param_kind(a, b): 

164 

165 def relax(k): 

166 if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY): 

167 return inspect.Parameter.POSITIONAL_OR_KEYWORD 

168 return k 

169 

170 return relax(a) == relax(b) 

171 

172 

173def _prepare_np_fun_name_and_fun(np_fun_name, np_fun): 

174 """Mutually propagates information between `np_fun_name` and `np_fun`. 

175 

176 If one is None and the other is not, we'll try to make the former not None in 

177 a best effort. 

178 

179 Args: 

180 np_fun_name: name for the np_fun symbol. At least one of np_fun or 

181 np_fun_name shoud be set. 

182 np_fun: the numpy function whose docstring will be used. 

183 

184 Returns: 

185 Processed `np_fun_name` and `np_fun`. 

186 """ 

187 if np_fun_name is not None: 

188 assert isinstance(np_fun_name, str) 

189 if np_fun is not None: 

190 assert not isinstance(np_fun, str) 

191 if np_fun is None: 

192 assert np_fun_name is not None 

193 try: 

194 np_fun = getattr(np, str(np_fun_name)) 

195 except AttributeError: 

196 np_fun = None 

197 if np_fun_name is None: 

198 assert np_fun is not None 

199 np_fun_name = np_fun.__name__ 

200 return np_fun_name, np_fun 

201 

202 

203def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None, 

204 link=None): 

205 """Helper to get docs.""" 

206 assert np_f or np_fun_name 

207 if not np_fun_name: 

208 np_fun_name = np_f.__name__ 

209 doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name 

210 if unsupported_params: 

211 doc += 'Unsupported arguments: ' + ', '.join( 

212 '`' + name + '`' for name in unsupported_params) + '.\n\n' 

213 if _has_docstring(f): 

214 doc += f.__doc__ 

215 doc = _add_blank_line(doc) 

216 # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy 

217 # doc according to some global switch. 

218 doc = _add_np_doc(doc, np_fun_name, np_f, link=link) 

219 return doc 

220 

221 

222_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16') 

223 

224 

225def get_np_doc_form(): 

226 """Gets the form of the original numpy docstrings. 

227 

228 Returns: 

229 See `set_np_doc_form` for the list of valid values. 

230 """ 

231 return _np_doc_form 

232 

233 

234def set_np_doc_form(value): 

235 r"""Selects the form of the original numpy docstrings. 

236 

237 This function sets a global variable that controls how a tf-numpy symbol's 

238 docstring should refer to the original numpy docstring. If `value` is 

239 `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy 

240 docstring. Otherwise, a link to the original numpy docstring will be 

241 added. Which numpy version the link points to depends on `value`: 

242 * `'stable'`: the current stable version; 

243 * `'dev'`: the current development version; 

244 * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number, 

245 e.g. '1.16'. 

246 

247 Args: 

248 value: the value to set the global variable to. 

249 """ 

250 global _np_doc_form 

251 _np_doc_form = value 

252 

253 

254class Link: 

255 

256 def __init__(self, v): 

257 self.value = v 

258 

259 

260class AliasOf: 

261 

262 def __init__(self, v): 

263 self.value = v 

264 

265 

266class NoLink: 

267 pass 

268 

269 

270def generate_link(flag, np_fun_name): 

271 """Generates link from numpy function name. 

272 

273 Args: 

274 flag: the flag to control link form. See `set_np_doc_form`. 

275 np_fun_name: the numpy function name. 

276 

277 Returns: 

278 A string. 

279 """ 

280 # Only adds link in this case 

281 if flag == 'dev': 

282 template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html' 

283 elif flag == 'stable': 

284 template = ( 

285 'https://numpy.org/doc/stable/reference/generated/numpy.%s.html') 

286 elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag): 

287 # `flag` is the version number 

288 template = ('https://numpy.org/doc/' + flag + 

289 '/reference/generated/numpy.%s.html') 

290 else: 

291 return None 

292 return template % np_fun_name 

293 

294 

295_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in 

296 ('True', 'true', '1')) 

297 

298 

299def is_check_link(): 

300 return _is_check_link 

301 

302 

303def set_check_link(value): 

304 global _is_check_link 

305 _is_check_link = value 

306 

307 

308def _add_np_doc(doc, np_fun_name, np_f, link): 

309 """Appends the numpy docstring to `doc`, according to `set_np_doc_form`. 

310 

311 See `set_np_doc_form` for how it controls the form of the numpy docstring. 

312 

313 Args: 

314 doc: the docstring to be appended to. 

315 np_fun_name: the name of the numpy function. 

316 np_f: (optional) the numpy function. 

317 link: (optional) which link to use. See `np_doc` for details. 

318 

319 Returns: 

320 `doc` with numpy docstring appended. 

321 """ 

322 flag = get_np_doc_form() 

323 if flag == 'inlined': 

324 if _has_docstring(np_f): 

325 doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name 

326 # TODO(wangpeng): It looks like code snippets in numpy doc don't work 

327 # correctly with doctest. Fix that and remove the reformatting of the np_f 

328 # comment. 

329 doc += np_f.__doc__.replace('>>>', '>') 

330 elif isinstance(flag, str): 

331 if link is None: 

332 url = generate_link(flag, np_fun_name) 

333 elif isinstance(link, AliasOf): 

334 url = generate_link(flag, link.value) 

335 elif isinstance(link, Link): 

336 url = link.value 

337 else: 

338 url = None 

339 if url is not None: 

340 if is_check_link(): 

341 # Imports locally because some builds may not have `requests` 

342 import requests # pylint: disable=g-import-not-at-top 

343 r = requests.head(url) 

344 if r.status_code != 200: 

345 raise ValueError( 

346 f'Check link failed at [{url}] with status code {r.status_code}. ' 

347 f'Argument `np_fun_name` is {np_fun_name}.') 

348 doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % ( 

349 np_fun_name, url) 

350 return doc 

351 

352 

353_is_sig_mismatch_an_error = ( 

354 os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1')) 

355 

356 

357def is_sig_mismatch_an_error(): 

358 return _is_sig_mismatch_an_error 

359 

360 

361def set_is_sig_mismatch_an_error(value): 

362 global _is_sig_mismatch_an_error 

363 _is_sig_mismatch_an_error = value 

364 

365 

366def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None, 

367 link=None): 

368 """Attachs numpy docstring to a function. 

369 

370 Args: 

371 np_fun_name: name for the np_fun symbol. At least one of np_fun or 

372 np_fun_name shoud be set. 

373 np_fun: (optional) the numpy function whose docstring will be used. 

374 export: whether to export this symbol under module 

375 `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be 

376 a function directly under the `numpy` module, not under any submodule of 

377 `numpy` (e.g. `numpy.random`). 

378 unsupported_params: (optional) the list of parameters not supported 

379 by tf.numpy. 

380 link: (optional) which link to use. If `None`, a default link generated from 

381 `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will 

382 be used in place of `np_fun_name` for the link generation. If an instance 

383 of `Link`, `link.value` will be used as the whole link. If an instance of 

384 `NoLink`, no link will be added. 

385 

386 Returns: 

387 A function decorator that attaches the docstring from `np_fun` to the 

388 decorated function. 

389 """ 

390 np_fun_name_orig, np_fun_orig = np_fun_name, np_fun 

391 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 

392 np_sig = _np_signature(np_fun) 

393 if unsupported_params is None: 

394 unsupported_params = [] 

395 

396 def decorator(f): 

397 """The decorator.""" 

398 if hasattr(inspect, 'signature') and np_sig is not None: 

399 try: 

400 sig = inspect.signature(f) 

401 except ValueError: 

402 sig = None 

403 if sig is not None: 

404 for name, param in sig.parameters.items(): 

405 np_param = np_sig.parameters.get(name) 

406 if np_param is None: 

407 if is_sig_mismatch_an_error(): 

408 raise TypeError( 

409 f'Cannot find parameter {name} in the numpy function\'s ' 

410 f'signature (which has these parameters: ' 

411 f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` ' 

412 f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 

413 else: 

414 continue 

415 if (is_sig_mismatch_an_error() and 

416 not _is_compatible_param_kind(param.kind, np_param.kind)): 

417 raise TypeError( 

418 f'Parameter {name} is of kind {param.kind} while in numpy it ' 

419 f'is of kind {np_param.kind}. Argument `np_fun_name` is ' 

420 f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.') 

421 has_default = (param.default != inspect.Parameter.empty) 

422 np_has_default = (np_param.default != inspect.Parameter.empty) 

423 if is_sig_mismatch_an_error() and has_default != np_has_default: 

424 raise TypeError( 

425 'Parameter {} should{} have a default value. Argument ' 

426 '`np_fun_name` is {}. Argument `np_fun` is {}.'.format( 

427 name, '' if np_has_default else ' not', np_fun_name_orig, 

428 np_fun_orig)) 

429 for name in np_sig.parameters: 

430 if name not in sig.parameters: 

431 unsupported_params.append(name) 

432 f.__doc__ = _np_doc_helper( 

433 f, np_fun, np_fun_name=np_fun_name, 

434 unsupported_params=unsupported_params, link=link) 

435 if export: 

436 return np_export.np_export(np_fun_name)(f) 

437 else: 

438 return f 

439 

440 return decorator 

441 

442 

443def np_doc_only(np_fun_name, np_fun=None, export=True): 

444 """Attachs numpy docstring to a function. 

445 

446 This differs from np_doc in that it doesn't check for a match in signature. 

447 

448 Args: 

449 np_fun_name: name for the np_fun symbol. At least one of np_fun or 

450 np_fun_name shoud be set. 

451 np_fun: (optional) the numpy function whose docstring will be used. 

452 export: whether to export this symbol under module 

453 `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a 

454 function directly under the `numpy` module, not under any submodule of 

455 `numpy` (e.g. `numpy.random`). 

456 

457 Returns: 

458 A function decorator that attaches the docstring from `np_fun` to the 

459 decorated function. 

460 """ 

461 np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun) 

462 

463 def decorator(f): 

464 f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name) 

465 if export: 

466 return np_export.np_export(np_fun_name)(f) 

467 else: 

468 return f 

469 

470 return decorator 

471 

472 

473# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 

474@np_doc('finfo') 

475def finfo(dtype): 

476 """Note that currently it just forwards to the numpy namesake, while 

477 tensorflow and numpy dtypes may have different properties.""" 

478 return np.finfo(_to_numpy_type(dtype)) 

479# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args 

480 

481 

482def _maybe_get_dtype(x): 

483 """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" 

484 # Don't put np.ndarray in this list, because np.result_type looks at the 

485 # value (not just dtype) of np.ndarray to decide the result type. 

486 if isinstance(x, numbers.Real): 

487 return x 

488 if isinstance(x, indexed_slices.IndexedSlices) or tensor_util.is_tf_type(x): 

489 return _to_numpy_type(x.dtype) 

490 if isinstance(x, dtypes.DType): 

491 return x.as_numpy_dtype 

492 if isinstance(x, (list, tuple)): 

493 raise ValueError( 

494 f'Cannot find dtype for type inference from argument `x` of a sequence ' 

495 f'type {type(x)}. For sequences, please call this function on each ' 

496 f'element individually.') 

497 return x 

498 

499 

500# Can't use np_doc because np.result_type is a builtin function. 

501@np_doc_only('result_type') 

502def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring 

503 arrays_and_dtypes = [ 

504 _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes) 

505 ] 

506 if not arrays_and_dtypes: 

507 # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is. 

508 arrays_and_dtypes = [np.asarray([])] 

509 return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access 

510 

511 

512def result_type_unary(a, dtype): # pylint: disable=missing-function-docstring 

513 """Find the result type from a single input and a dtype.""" 

514 if dtype: 

515 # We need to let np_utils.result_type decide the dtype, not tf.zeros_like 

516 return result_type(dtype) 

517 

518 # np_utils.result_type treats string inputs as dtype strings, not as strings. 

519 # but for unary we want to treat it as a string input. 

520 if isinstance(a, str): 

521 return np.unicode_ 

522 elif isinstance(a, bytes): 

523 return np.bytes_ 

524 

525 # TF and numpy has different interpretations of Python types such as 

526 # `float`, so we let `np_utils.result_type` decide. 

527 return result_type(a) 

528 

529 

530def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring 

531 """A specialization of result_type for 2 arguments for performance reasons.""" 

532 try: 

533 return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access 

534 _maybe_get_dtype(t2)) # pylint: disable=protected-access 

535 except ValueError: 

536 return result_type(t1, t2) 

537 

538 

539@np_doc('promote_types') 

540def promote_types(type1, type2): # pylint: disable=missing-function-docstring 

541 type1 = _to_numpy_type(type1) 

542 type2 = _to_numpy_type(type2) 

543 return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2)) 

544 

545 

546def tf_broadcast(*args): 

547 """Broadcast tensors. 

548 

549 Args: 

550 *args: a list of tensors whose shapes are broadcastable against each other. 

551 

552 Returns: 

553 Tensors broadcasted to the common shape. 

554 """ 

555 if len(args) <= 1: 

556 return args 

557 sh = array_ops.shape(args[0]) 

558 for arg in args[1:]: 

559 sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg)) 

560 return [array_ops.broadcast_to(arg, sh) for arg in args] 

561 

562 

563# TODO(wangpeng): Move the following functions to a separate file and check for 

564# float dtypes in each of them. 

565 

566 

567def get_static_value(x): 

568 """A version of tf.get_static_value that returns None on float dtypes. 

569 

570 It returns None on float dtypes in order to avoid breaking gradients. 

571 

572 Args: 

573 x: a tensor. 

574 

575 Returns: 

576 Same as `tf.get_static_value`, except that it returns None when `x` has a 

577 float dtype. 

578 """ 

579 if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex): 

580 return None 

581 return tensor_util.constant_value(x) 

582 

583 

584def _maybe_static(x): 

585 value = get_static_value(x) 

586 if value is None: 

587 return x 

588 else: 

589 return value 

590 

591 

592# All the following functions exist becaues get_static_value can't handle 

593# their TF counterparts. 

594 

595 

596def cond(pred, true_fn, false_fn): 

597 """A version of tf.cond that tries to evaluate the condition.""" 

598 v = get_static_value(pred) 

599 if v is None: 

600 return tf_cond.cond(pred, true_fn, false_fn) 

601 if v: 

602 return true_fn() 

603 else: 

604 return false_fn() 

605 

606 

607def add(a, b): 

608 """A version of tf.add that eagerly evaluates if possible.""" 

609 return _maybe_static(a) + _maybe_static(b) 

610 

611 

612def subtract(a, b): 

613 """A version of tf.subtract that eagerly evaluates if possible.""" 

614 return _maybe_static(a) - _maybe_static(b) 

615 

616 

617def greater(a, b): 

618 """A version of tf.greater that eagerly evaluates if possible.""" 

619 return _maybe_static(a) > _maybe_static(b) 

620 

621 

622def greater_equal(a, b): 

623 """A version of tf.greater_equal that eagerly evaluates if possible.""" 

624 return _maybe_static(a) >= _maybe_static(b) 

625 

626 

627def less_equal(a, b): 

628 """A version of tf.less_equal that eagerly evaluates if possible.""" 

629 return _maybe_static(a) <= _maybe_static(b) 

630 

631 

632def logical_and(a, b): 

633 """A version of tf.logical_and that eagerly evaluates if possible.""" 

634 a_value = get_static_value(a) 

635 if a_value is not None: 

636 if np.isscalar(a_value): 

637 if a_value: 

638 return _maybe_static(b) 

639 else: 

640 return a_value 

641 else: 

642 return a_value & _maybe_static(b) 

643 else: 

644 return a & _maybe_static(b) 

645 

646 

647def logical_or(a, b): 

648 """A version of tf.logical_or that eagerly evaluates if possible.""" 

649 a_value = get_static_value(a) 

650 if a_value is not None: 

651 if np.isscalar(a_value): 

652 if a_value: 

653 return a_value 

654 else: 

655 return _maybe_static(b) 

656 else: 

657 return a_value | _maybe_static(b) 

658 else: 

659 return a | _maybe_static(b) 

660 

661 

662def getitem(a, slice_spec): 

663 """A version of __getitem__ that eagerly evaluates if possible.""" 

664 return _maybe_static(a)[slice_spec] 

665 

666 

667def reduce_all(input_tensor, axis=None, keepdims=False): 

668 """A version of tf.reduce_all that eagerly evaluates if possible.""" 

669 v = get_static_value(input_tensor) 

670 if v is None: 

671 return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims) 

672 else: 

673 return v.all(axis=axis, keepdims=keepdims) 

674 

675 

676def reduce_any(input_tensor, axis=None, keepdims=False): 

677 """A version of tf.reduce_any that eagerly evaluates if possible.""" 

678 v = get_static_value(input_tensor) 

679 if v is None: 

680 return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims) 

681 else: 

682 return v.any(axis=axis, keepdims=keepdims) 

683 

684 

685def tf_rank(t): 

686 r = t.shape.rank 

687 if r is not None: 

688 return r 

689 return array_ops.rank(t)