Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/check_ops.py: 35%

666 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=g-short-docstring-punctuation 

16"""Asserts and Boolean Checks.""" 

17 

18import collections 

19 

20import numpy as np 

21 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import errors 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import sparse_tensor 

27from tensorflow.python.framework import tensor_shape 

28from tensorflow.python.framework import tensor_util 

29from tensorflow.python.ops import array_ops 

30from tensorflow.python.ops import cond 

31from tensorflow.python.ops import control_flow_assert 

32from tensorflow.python.ops import control_flow_ops 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.util import compat 

35from tensorflow.python.util import deprecation 

36from tensorflow.python.util import dispatch 

37from tensorflow.python.util.tf_export import tf_export 

38 

39NUMERIC_TYPES = frozenset([ 

40 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, 

41 dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, 

42 dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8, 

43 dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16 

44]) 

45 

46__all__ = [ 

47 'assert_negative', 

48 'assert_positive', 

49 'assert_proper_iterable', 

50 'assert_non_negative', 

51 'assert_non_positive', 

52 'assert_equal', 

53 'assert_none_equal', 

54 'assert_near', 

55 'assert_integer', 

56 'assert_less', 

57 'assert_less_equal', 

58 'assert_greater', 

59 'assert_greater_equal', 

60 'assert_rank', 

61 'assert_rank_at_least', 

62 'assert_rank_in', 

63 'assert_same_float_dtype', 

64 'assert_scalar', 

65 'assert_type', 

66 'assert_shapes', 

67 'is_non_decreasing', 

68 'is_numeric_tensor', 

69 'is_strictly_increasing', 

70] 

71 

72 

73def _maybe_constant_value_string(t): 

74 if not isinstance(t, ops.Tensor): 

75 return str(t) 

76 const_t = tensor_util.constant_value(t) 

77 if const_t is not None: 

78 return str(const_t) 

79 return t 

80 

81 

82def _assert_static(condition, data): 

83 """Raises a InvalidArgumentError with as much information as possible.""" 

84 if not condition: 

85 data_static = [_maybe_constant_value_string(x) for x in data] 

86 raise errors.InvalidArgumentError(node_def=None, op=None, 

87 message='\n'.join(data_static)) 

88 

89 

90def _shape_and_dtype_str(tensor): 

91 """Returns a string containing tensor's shape and dtype.""" 

92 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 

93 

94 

95def _unary_assert_doc(sym, sym_name): 

96 """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor. 

97 

98 Args: 

99 sym: Mathematical symbol for the check performed on each element, i.e. "> 0" 

100 sym_name: English-language name for the op described by sym 

101 

102 Returns: 

103 Decorator that adds the appropriate docstring to the function for symbol 

104 `sym`. 

105 """ 

106 

107 def _decorator(func): 

108 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 

109 

110 Args: 

111 func: Function for a TensorFlow op 

112 

113 Returns: 

114 Version of `func` with documentation attached. 

115 """ 

116 opname = func.__name__ 

117 cap_sym_name = sym_name.capitalize() 

118 

119 func.__doc__ = """ 

120 Assert the condition `x {sym}` holds element-wise. 

121 

122 When running in graph mode, you should add a dependency on this operation 

123 to ensure that it runs. Example of adding a dependency to an operation: 

124 

125 ```python 

126 with tf.control_dependencies([tf.debugging.{opname}(x, y)]): 

127 output = tf.reduce_sum(x) 

128 ``` 

129 

130 {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`. 

131 If `x` is empty this is trivially satisfied. 

132 

133 Args: 

134 x: Numeric `Tensor`. 

135 data: The tensors to print out if the condition is False. Defaults to 

136 error message and first few entries of `x`. 

137 summarize: Print this many entries of each tensor. 

138 message: A string to prefix to the default message. 

139 name: A name for this operation (optional). Defaults to "{opname}". 

140 

141 Returns: 

142 Op that raises `InvalidArgumentError` if `x {sym}` is False. 

143 @compatibility(eager) 

144 returns None 

145 @end_compatibility 

146 

147 Raises: 

148 InvalidArgumentError: if the check can be performed immediately and 

149 `x {sym}` is False. The check can be performed immediately during 

150 eager execution or if `x` is statically known. 

151 """.format( 

152 sym=sym, sym_name=cap_sym_name, opname=opname) 

153 return func 

154 

155 return _decorator 

156 

157 

158def _binary_assert_doc(sym, test_var): 

159 """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise. 

160 

161 Args: 

162 sym: Binary operation symbol, i.e. "==" 

163 test_var: a string that represents the variable in the right-hand side of 

164 binary operator of the test case 

165 

166 Returns: 

167 Decorator that adds the appropriate docstring to the function for 

168 symbol `sym`. 

169 """ 

170 

171 def _decorator(func): 

172 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 

173 

174 Args: 

175 func: Function for a TensorFlow op 

176 

177 Returns: 

178 A version of `func` with documentation attached. 

179 """ 

180 opname = func.__name__ 

181 

182 func.__doc__ = """ 

183 Assert the condition `x {sym} y` holds element-wise. 

184 

185 This condition holds if for every pair of (possibly broadcast) elements 

186 `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`. 

187 If both `x` and `y` are empty, this is trivially satisfied. 

188 

189 When running in graph mode, you should add a dependency on this operation 

190 to ensure that it runs. Example of adding a dependency to an operation: 

191 

192 ```python 

193 with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]): 

194 output = tf.reduce_sum(x) 

195 ``` 

196 

197 Args: 

198 x: Numeric `Tensor`. 

199 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 

200 data: The tensors to print out if the condition is False. Defaults to 

201 error message and first few entries of `x`, `y`. 

202 summarize: Print this many entries of each tensor. 

203 message: A string to prefix to the default message. 

204 name: A name for this operation (optional). Defaults to "{opname}". 

205 

206 Returns: 

207 Op that raises `InvalidArgumentError` if `x {sym} y` is False. 

208 

209 Raises: 

210 InvalidArgumentError: if the check can be performed immediately and 

211 `x {sym} y` is False. The check can be performed immediately during 

212 eager execution or if `x` and `y` are statically known. 

213 

214 @compatibility(TF2) 

215 `tf.compat.v1.{opname}` is compatible with eager execution and 

216 `tf.function`. 

217 Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart 

218 from `data`, all arguments are supported with the same argument name. 

219 

220 If you want to ensure the assert statements run before the 

221 potentially-invalid computation, please use `tf.control_dependencies`, 

222 as tf.function auto-control dependencies are insufficient for assert 

223 statements. 

224 

225 #### Structural Mapping to Native TF2 

226 

227 Before: 

228 

229 ```python 

230 tf.compat.v1.{opname}( 

231 x=x, y=y, data=data, summarize=summarize, 

232 message=message, name=name) 

233 ``` 

234 

235 After: 

236 

237 ```python 

238 tf.debugging.{opname}( 

239 x=x, y=y, message=message, 

240 summarize=summarize, name=name) 

241 ``` 

242 

243 #### TF1 & TF2 Usage Example 

244 

245 TF1: 

246 

247 >>> g = tf.Graph() 

248 >>> with g.as_default(): 

249 ... a = tf.compat.v1.placeholder(tf.float32, [2]) 

250 ... b = tf.compat.v1.placeholder(tf.float32, [2]) 

251 ... result = tf.compat.v1.{opname}(a, b, 

252 ... message='"a {sym} b" does not hold for the given inputs') 

253 ... with tf.compat.v1.control_dependencies([result]): 

254 ... sum_node = a + b 

255 >>> sess = tf.compat.v1.Session(graph=g) 

256 >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}}) 

257 

258 

259 TF2: 

260 

261 >>> a = tf.Variable([1, 2], dtype=tf.float32) 

262 >>> b = tf.Variable({test_var}, dtype=tf.float32) 

263 >>> assert_op = tf.debugging.{opname}(a, b, message= 

264 ... '"a {sym} b" does not hold for the given inputs') 

265 >>> # When working with tf.control_dependencies 

266 >>> with tf.control_dependencies([assert_op]): 

267 ... val = a + b 

268 

269 @end_compatibility 

270 """.format( 

271 sym=sym, opname=opname, test_var=test_var) 

272 return func 

273 

274 return _decorator 

275 

276 

277def _binary_assert_doc_v2(sym, opname, test_var): 

278 """Common docstring for v2 assert_* ops that compare two tensors element-wise. 

279 

280 Args: 

281 sym: Binary operation symbol, i.e. "==" 

282 opname: Name for the symbol, i.e. "assert_equal" 

283 test_var: A number used in the docstring example 

284 

285 Returns: 

286 Decorator that adds the appropriate docstring to the function for 

287 symbol `sym`. 

288 """ 

289 

290 def _decorator(func): 

291 """Decorator that adds docstring to the function for symbol `sym`. 

292 

293 Args: 

294 func: Function for a TensorFlow op 

295 

296 Returns: 

297 A version of `func` with documentation attached. 

298 """ 

299 

300 func.__doc__ = """ 

301 Assert the condition `x {sym} y` holds element-wise. 

302 

303 This Op checks that `x[i] {sym} y[i]` holds for every pair of (possibly 

304 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 

305 trivially satisfied. 

306 

307 If `x` {sym} `y` does not hold, `message`, as well as the first `summarize` 

308 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. 

309 

310 When using inside `tf.function`, this API takes effects during execution. 

311 It's recommended to use this API with `tf.control_dependencies` to 

312 ensure the correct execution order. 

313 

314 In the following example, without `tf.control_dependencies`, errors may 

315 not be raised at all. 

316 Check `tf.control_dependencies` for more details. 

317 

318 >>> def check_size(x): 

319 ... with tf.control_dependencies([ 

320 ... tf.debugging.{opname}(tf.size(x), {test_var}, 

321 ... message='Bad tensor size')]): 

322 ... return x 

323 

324 >>> check_size(tf.ones([2, 3], tf.float32)) 

325 Traceback (most recent call last): 

326 ... 

327 InvalidArgumentError: ... 

328 

329 Args: 

330 x: Numeric `Tensor`. 

331 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 

332 message: A string to prefix to the default message. (optional) 

333 summarize: Print this many entries of each tensor. (optional) 

334 name: A name for this operation (optional). Defaults to "{opname}". 

335 

336 Returns: 

337 Op that raises `InvalidArgumentError` if `x {sym} y` is False. This can 

338 be used with `tf.control_dependencies` inside of `tf.function`s to 

339 block followup computation until the check has executed. 

340 @compatibility(eager) 

341 returns None 

342 @end_compatibility 

343 

344 Raises: 

345 InvalidArgumentError: if the check can be performed immediately and 

346 `x == y` is False. The check can be performed immediately during eager 

347 execution or if `x` and `y` are statically known. 

348 """.format( 

349 sym=sym, opname=opname, test_var=test_var) 

350 return func 

351 

352 return _decorator 

353 

354 

355def _make_assert_msg_data(sym, x, y, summarize, test_op): 

356 """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode. 

357 

358 Args: 

359 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 

360 i.e. "==" 

361 x: First input to the assertion after applying `convert_to_tensor()` 

362 y: Second input to the assertion 

363 summarize: Value of the "summarize" parameter to the original assert_* call; 

364 tells how many elements of each tensor to print. 

365 test_op: TensorFlow op that returns a Boolean tensor with True in each 

366 position where the assertion is satisfied. 

367 

368 Returns: 

369 List of tensors and scalars that, when stringified and concatenated, 

370 will produce the error message string. 

371 """ 

372 # Prepare a message with first elements of x and y. 

373 data = [] 

374 

375 data.append('Condition x %s y did not hold.' % sym) 

376 

377 if summarize > 0: 

378 if x.shape == y.shape and x.shape.as_list(): 

379 # If the shapes of x and y are the same (and not scalars), 

380 # Get the values that actually differed and their indices. 

381 # If shapes are different this information is more confusing 

382 # than useful. 

383 mask = math_ops.logical_not(test_op) 

384 indices = array_ops.where(mask) 

385 indices_np = indices.numpy() 

386 x_vals = array_ops.boolean_mask(x, mask) 

387 y_vals = array_ops.boolean_mask(y, mask) 

388 num_vals = min(summarize, indices_np.shape[0]) 

389 data.append('Indices of first %d different values:' % num_vals) 

390 data.append(indices_np[:num_vals]) 

391 data.append('Corresponding x values:') 

392 data.append(x_vals.numpy().reshape((-1,))[:num_vals]) 

393 data.append('Corresponding y values:') 

394 data.append(y_vals.numpy().reshape((-1,))[:num_vals]) 

395 

396 # reshape((-1,)) is the fastest way to get a flat array view. 

397 x_np = x.numpy().reshape((-1,)) 

398 y_np = y.numpy().reshape((-1,)) 

399 x_sum = min(x_np.size, summarize) 

400 y_sum = min(y_np.size, summarize) 

401 data.append('First %d elements of x:' % x_sum) 

402 data.append(x_np[:x_sum]) 

403 data.append('First %d elements of y:' % y_sum) 

404 data.append(y_np[:y_sum]) 

405 

406 return data 

407 

408 

409def _pretty_print(data_item, summarize): 

410 """Format a data item for use in an error message in eager mode. 

411 

412 Args: 

413 data_item: One of the items in the "data" argument to an assert_* function. 

414 Can be a Tensor or a scalar value. 

415 summarize: How many elements to retain of each tensor-valued entry in data. 

416 

417 Returns: 

418 An appropriate string representation of data_item 

419 """ 

420 if isinstance(data_item, ops.Tensor): 

421 arr = data_item.numpy() 

422 if np.isscalar(arr): 

423 # Tensor.numpy() returns a scalar for zero-dimensional tensors 

424 return str(arr) 

425 else: 

426 flat = arr.reshape((-1,)) 

427 lst = [str(x) for x in flat[:summarize]] 

428 if len(lst) < flat.size: 

429 lst.append('...') 

430 return str(lst) 

431 else: 

432 return str(data_item) 

433 

434 

435def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, 

436 message, name): 

437 """Generic binary elementwise assertion. 

438 

439 Implements the behavior described in _binary_assert_doc() above. 

440 Args: 

441 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 

442 i.e. "==" 

443 opname: Name of the assert op in the public API, i.e. "assert_equal" 

444 op_func: Function that, if passed the two Tensor inputs to the assertion (x 

445 and y), will return the test to be passed to reduce_all() i.e. 

446 static_func: Function that, if passed numpy ndarray versions of the two 

447 inputs to the assertion, will return a Boolean ndarray with containing 

448 True in all positions where the assertion PASSES. 

449 i.e. np.equal for assert_equal() 

450 x: Numeric `Tensor`. 

451 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 

452 data: The tensors to print out if the condition is False. Defaults to 

453 error message and first few entries of `x`, `y`. 

454 summarize: Print this many entries of each tensor. 

455 message: A string to prefix to the default message. 

456 name: A name for this operation (optional). Defaults to the value of 

457 `opname`. 

458 

459 Returns: 

460 See docstring template in _binary_assert_doc(). 

461 """ 

462 with ops.name_scope(name, opname, [x, y, data]): 

463 x = ops.convert_to_tensor(x, name='x') 

464 y = ops.convert_to_tensor(y, name='y') 

465 

466 if context.executing_eagerly(): 

467 test_op = op_func(x, y) 

468 condition = math_ops.reduce_all(test_op) 

469 if condition: 

470 return 

471 

472 # If we get here, the assertion has failed. 

473 # Default to printing 3 elements like control_flow_ops.Assert (used 

474 # by graph mode) does. Also treat negative values as "print 

475 # everything" for consistency with Tensor::SummarizeValue(). 

476 if summarize is None: 

477 summarize = 3 

478 elif summarize < 0: 

479 summarize = 1e9 # Code below will find exact size of x and y. 

480 

481 if data is None: 

482 data = _make_assert_msg_data(sym, x, y, summarize, test_op) 

483 

484 if message is not None: 

485 data = [message] + list(data) 

486 

487 raise errors.InvalidArgumentError( 

488 node_def=None, 

489 op=None, 

490 message=('\n'.join(_pretty_print(d, summarize) for d in data))) 

491 

492 else: # not context.executing_eagerly() 

493 if data is None: 

494 data = [ 

495 'Condition x %s y did not hold element-wise:' % sym, 

496 'x (%s) = ' % x.name, x, 

497 'y (%s) = ' % y.name, y 

498 ] 

499 if message is not None: 

500 data = [message] + list(data) 

501 condition = math_ops.reduce_all(op_func(x, y)) 

502 x_static = tensor_util.constant_value(x) 

503 y_static = tensor_util.constant_value(y) 

504 if x_static is not None and y_static is not None: 

505 condition_static = np.all(static_func(x_static, y_static)) 

506 _assert_static(condition_static, data) 

507 return control_flow_assert.Assert(condition, data, summarize=summarize) 

508 

509 

510@tf_export( 

511 'debugging.assert_proper_iterable', 

512 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) 

513@dispatch.add_dispatch_support 

514@deprecation.deprecated_endpoints('assert_proper_iterable') 

515def assert_proper_iterable(values): 

516 """Static assert that values is a "proper" iterable. 

517 

518 `Ops` that expect iterables of `Tensor` can call this to validate input. 

519 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 

520 

521 Args: 

522 values: Object to be checked. 

523 

524 Raises: 

525 TypeError: If `values` is not iterable or is one of 

526 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 

527 """ 

528 unintentional_iterables = ( 

529 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 

530 + compat.bytes_or_text_types 

531 ) 

532 if isinstance(values, unintentional_iterables): 

533 raise TypeError( 

534 'Expected argument "values" to be a "proper" iterable. Found: %s' % 

535 type(values)) 

536 

537 if not hasattr(values, '__iter__'): 

538 raise TypeError( 

539 'Expected argument "values" to be iterable. Found: %s' % type(values)) 

540 

541 

542@tf_export('debugging.assert_negative', v1=[]) 

543@dispatch.add_dispatch_support 

544def assert_negative_v2(x, message=None, summarize=None, name=None): 

545 """Assert the condition `x < 0` holds element-wise. 

546 

547 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is 

548 empty, this is trivially satisfied. 

549 

550 If `x` is not negative everywhere, `message`, as well as the first `summarize` 

551 entries of `x` are printed, and `InvalidArgumentError` is raised. 

552 

553 Args: 

554 x: Numeric `Tensor`. 

555 message: A string to prefix to the default message. 

556 summarize: Print this many entries of each tensor. 

557 name: A name for this operation (optional). Defaults to "assert_negative". 

558 

559 Returns: 

560 Op raising `InvalidArgumentError` unless `x` is all negative. This can be 

561 used with `tf.control_dependencies` inside of `tf.function`s to block 

562 followup computation until the check has executed. 

563 @compatibility(eager) 

564 returns None 

565 @end_compatibility 

566 

567 Raises: 

568 InvalidArgumentError: if the check can be performed immediately and 

569 `x[i] < 0` is False. The check can be performed immediately during eager 

570 execution or if `x` is statically known. 

571 """ 

572 return assert_negative(x=x, message=message, summarize=summarize, name=name) 

573 

574 

575@tf_export(v1=['debugging.assert_negative', 'assert_negative']) 

576@dispatch.add_dispatch_support 

577@deprecation.deprecated_endpoints('assert_negative') 

578@_unary_assert_doc('< 0', 'negative') 

579def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

580 message = _message_prefix(message) 

581 with ops.name_scope(name, 'assert_negative', [x, data]): 

582 x = ops.convert_to_tensor(x, name='x') 

583 if data is None: 

584 if context.executing_eagerly(): 

585 name = _shape_and_dtype_str(x) 

586 else: 

587 name = x.name 

588 data = [ 

589 message, 

590 'Condition x < 0 did not hold element-wise:', 

591 'x (%s) = ' % name, x] 

592 zero = ops.convert_to_tensor(0, dtype=x.dtype) 

593 return assert_less(x, zero, data=data, summarize=summarize) 

594 

595 

596@tf_export('debugging.assert_positive', v1=[]) 

597@dispatch.add_dispatch_support 

598def assert_positive_v2(x, message=None, summarize=None, name=None): 

599 """Assert the condition `x > 0` holds element-wise. 

600 

601 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is 

602 empty, this is trivially satisfied. 

603 

604 If `x` is not positive everywhere, `message`, as well as the first `summarize` 

605 entries of `x` are printed, and `InvalidArgumentError` is raised. 

606 

607 Args: 

608 x: Numeric `Tensor`. 

609 message: A string to prefix to the default message. 

610 summarize: Print this many entries of each tensor. 

611 name: A name for this operation (optional). Defaults to "assert_positive". 

612 

613 Returns: 

614 Op raising `InvalidArgumentError` unless `x` is all positive. This can be 

615 used with `tf.control_dependencies` inside of `tf.function`s to block 

616 followup computation until the check has executed. 

617 @compatibility(eager) 

618 returns None 

619 @end_compatibility 

620 

621 Raises: 

622 InvalidArgumentError: if the check can be performed immediately and 

623 `x[i] > 0` is False. The check can be performed immediately during eager 

624 execution or if `x` is statically known. 

625 """ 

626 return assert_positive(x=x, summarize=summarize, message=message, name=name) 

627 

628 

629@tf_export(v1=['debugging.assert_positive', 'assert_positive']) 

630@dispatch.add_dispatch_support 

631@deprecation.deprecated_endpoints('assert_positive') 

632@_unary_assert_doc('> 0', 'positive') 

633def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

634 message = _message_prefix(message) 

635 with ops.name_scope(name, 'assert_positive', [x, data]): 

636 x = ops.convert_to_tensor(x, name='x') 

637 if data is None: 

638 if context.executing_eagerly(): 

639 name = _shape_and_dtype_str(x) 

640 else: 

641 name = x.name 

642 data = [ 

643 message, 'Condition x > 0 did not hold element-wise:', 

644 'x (%s) = ' % name, x] 

645 zero = ops.convert_to_tensor(0, dtype=x.dtype) 

646 return assert_less(zero, x, data=data, summarize=summarize) 

647 

648 

649@tf_export('debugging.assert_non_negative', v1=[]) 

650@dispatch.add_dispatch_support 

651def assert_non_negative_v2(x, message=None, summarize=None, name=None): 

652 """Assert the condition `x >= 0` holds element-wise. 

653 

654 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is 

655 empty, this is trivially satisfied. 

656 

657 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize` 

658 entries of `x` are printed, and `InvalidArgumentError` is raised. 

659 

660 Args: 

661 x: Numeric `Tensor`. 

662 message: A string to prefix to the default message. 

663 summarize: Print this many entries of each tensor. 

664 name: A name for this operation (optional). Defaults to 

665 "assert_non_negative". 

666 

667 Returns: 

668 Op raising `InvalidArgumentError` unless `x` is all non-negative. This can 

669 be used with `tf.control_dependencies` inside of `tf.function`s to block 

670 followup computation until the check has executed. 

671 @compatibility(eager) 

672 returns None 

673 @end_compatibility 

674 

675 Raises: 

676 InvalidArgumentError: if the check can be performed immediately and 

677 `x[i] >= 0` is False. The check can be performed immediately during eager 

678 execution or if `x` is statically known. 

679 """ 

680 return assert_non_negative(x=x, summarize=summarize, message=message, 

681 name=name) 

682 

683 

684@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) 

685@dispatch.add_dispatch_support 

686@deprecation.deprecated_endpoints('assert_non_negative') 

687@_unary_assert_doc('>= 0', 'non-negative') 

688def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

689 message = _message_prefix(message) 

690 with ops.name_scope(name, 'assert_non_negative', [x, data]): 

691 x = ops.convert_to_tensor(x, name='x') 

692 if data is None: 

693 if context.executing_eagerly(): 

694 name = _shape_and_dtype_str(x) 

695 else: 

696 name = x.name 

697 data = [ 

698 message, 

699 'Condition x >= 0 did not hold element-wise:', 

700 'x (%s) = ' % name, x] 

701 zero = ops.convert_to_tensor(0, dtype=x.dtype) 

702 return assert_less_equal(zero, x, data=data, summarize=summarize) 

703 

704 

705@tf_export('debugging.assert_non_positive', v1=[]) 

706@dispatch.add_dispatch_support 

707def assert_non_positive_v2(x, message=None, summarize=None, name=None): 

708 """Assert the condition `x <= 0` holds element-wise. 

709 

710 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is 

711 empty, this is trivially satisfied. 

712 

713 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize` 

714 entries of `x` are printed, and `InvalidArgumentError` is raised. 

715 

716 Args: 

717 x: Numeric `Tensor`. 

718 message: A string to prefix to the default message. 

719 summarize: Print this many entries of each tensor. 

720 name: A name for this operation (optional). Defaults to 

721 "assert_non_positive". 

722 

723 Returns: 

724 Op raising `InvalidArgumentError` unless `x` is all non-positive. This can 

725 be used with `tf.control_dependencies` inside of `tf.function`s to block 

726 followup computation until the check has executed. 

727 @compatibility(eager) 

728 returns None 

729 @end_compatibility 

730 

731 Raises: 

732 InvalidArgumentError: if the check can be performed immediately and 

733 `x[i] <= 0` is False. The check can be performed immediately during eager 

734 execution or if `x` is statically known. 

735 """ 

736 return assert_non_positive(x=x, summarize=summarize, message=message, 

737 name=name) 

738 

739 

740@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) 

741@dispatch.add_dispatch_support 

742@deprecation.deprecated_endpoints('assert_non_positive') 

743@_unary_assert_doc('<= 0', 'non-positive') 

744def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

745 message = _message_prefix(message) 

746 with ops.name_scope(name, 'assert_non_positive', [x, data]): 

747 x = ops.convert_to_tensor(x, name='x') 

748 if data is None: 

749 if context.executing_eagerly(): 

750 name = _shape_and_dtype_str(x) 

751 else: 

752 name = x.name 

753 data = [ 

754 message, 

755 'Condition x <= 0 did not hold element-wise:' 

756 'x (%s) = ' % name, x] 

757 zero = ops.convert_to_tensor(0, dtype=x.dtype) 

758 return assert_less_equal(x, zero, data=data, summarize=summarize) 

759 

760 

761@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) 

762@dispatch.register_binary_elementwise_assert_api 

763@dispatch.add_dispatch_support 

764@_binary_assert_doc_v2('==', 'assert_equal', 3) 

765def assert_equal_v2(x, y, message=None, summarize=None, name=None): 

766 return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name) 

767 

768 

769@tf_export(v1=['debugging.assert_equal', 'assert_equal']) 

770@dispatch.register_binary_elementwise_assert_api 

771@dispatch.add_dispatch_support 

772@_binary_assert_doc('==', '[1, 2]') 

773def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

774 with ops.name_scope(name, 'assert_equal', [x, y, data]): 

775 # Short-circuit if x and y are the same tensor. 

776 if x is y: 

777 return None if context.executing_eagerly() else control_flow_ops.no_op() 

778 return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, 

779 data, summarize, message, name) 

780 

781 

782@tf_export('debugging.assert_none_equal', v1=[]) 

783@dispatch.register_binary_elementwise_assert_api 

784@dispatch.add_dispatch_support 

785@_binary_assert_doc_v2('!=', 'assert_none_equal', 6) 

786def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): 

787 return assert_none_equal(x=x, y=y, summarize=summarize, message=message, 

788 name=name) 

789 

790 

791@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) 

792@dispatch.register_binary_elementwise_assert_api 

793@dispatch.add_dispatch_support 

794@deprecation.deprecated_endpoints('assert_none_equal') 

795@_binary_assert_doc('!=', '[2, 1]') 

796def assert_none_equal( 

797 x, y, data=None, summarize=None, message=None, name=None): 

798 return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, 

799 np.not_equal, x, y, data, summarize, message, name) 

800 

801 

802@tf_export('debugging.assert_near', v1=[]) 

803@dispatch.register_binary_elementwise_assert_api 

804@dispatch.add_dispatch_support 

805def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, 

806 name=None): 

807 """Assert the condition `x` and `y` are close element-wise. 

808 

809 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every 

810 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are 

811 empty, this is trivially satisfied. 

812 

813 If any elements of `x` and `y` are not close, `message`, as well as the first 

814 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 

815 is raised. 

816 

817 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 

818 representable positive number such that `1 + eps != 1`. This is about 

819 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 

820 See `numpy.finfo`. 

821 

822 Args: 

823 x: Float or complex `Tensor`. 

824 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. 

825 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 

826 The relative tolerance. Default is `10 * eps`. 

827 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 

828 The absolute tolerance. Default is `10 * eps`. 

829 message: A string to prefix to the default message. 

830 summarize: Print this many entries of each tensor. 

831 name: A name for this operation (optional). Defaults to "assert_near". 

832 

833 Returns: 

834 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 

835 This can be used with `tf.control_dependencies` inside of `tf.function`s 

836 to block followup computation until the check has executed. 

837 @compatibility(eager) 

838 returns None 

839 @end_compatibility 

840 

841 Raises: 

842 InvalidArgumentError: if the check can be performed immediately and 

843 `x != y` is False for any pair of elements in `x` and `y`. The check can 

844 be performed immediately during eager execution or if `x` and `y` are 

845 statically known. 

846 

847 @compatibility(numpy) 

848 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 

849 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 

850 `64bit`, and even `16bit` data. 

851 @end_compatibility 

852 """ 

853 return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, 

854 message=message, name=name) 

855 

856 

857@tf_export(v1=['debugging.assert_near', 'assert_near']) 

858@dispatch.register_binary_elementwise_assert_api 

859@dispatch.add_dispatch_support 

860@deprecation.deprecated_endpoints('assert_near') 

861def assert_near( 

862 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 

863 name=None): 

864 """Assert the condition `x` and `y` are close element-wise. 

865 

866 Example of adding a dependency to an operation: 

867 

868 ```python 

869 with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): 

870 output = tf.reduce_sum(x) 

871 ``` 

872 

873 This condition holds if for every pair of (possibly broadcast) elements 

874 `x[i]`, `y[i]`, we have 

875 

876 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 

877 

878 If both `x` and `y` are empty, this is trivially satisfied. 

879 

880 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 

881 representable positive number such that `1 + eps != 1`. This is about 

882 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 

883 See `numpy.finfo`. 

884 

885 Args: 

886 x: Float or complex `Tensor`. 

887 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 

888 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 

889 The relative tolerance. Default is `10 * eps`. 

890 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 

891 The absolute tolerance. Default is `10 * eps`. 

892 data: The tensors to print out if the condition is False. Defaults to 

893 error message and first few entries of `x`, `y`. 

894 summarize: Print this many entries of each tensor. 

895 message: A string to prefix to the default message. 

896 name: A name for this operation (optional). Defaults to "assert_near". 

897 

898 Returns: 

899 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 

900 

901 @compatibility(numpy) 

902 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 

903 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 

904 `64bit`, and even `16bit` data. 

905 @end_compatibility 

906 """ 

907 message = _message_prefix(message) 

908 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 

909 x = ops.convert_to_tensor(x, name='x') 

910 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 

911 

912 dtype = x.dtype 

913 if dtype.is_complex: 

914 dtype = dtype.real_dtype 

915 eps = np.finfo(dtype.as_numpy_dtype).eps 

916 rtol = 10 * eps if rtol is None else rtol 

917 atol = 10 * eps if atol is None else atol 

918 

919 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) 

920 atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) 

921 

922 if context.executing_eagerly(): 

923 x_name = _shape_and_dtype_str(x) 

924 y_name = _shape_and_dtype_str(y) 

925 else: 

926 x_name = x.name 

927 y_name = y.name 

928 

929 if data is None: 

930 data = [ 

931 message, 

932 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 

933 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 

934 ] 

935 tol = atol + rtol * math_ops.abs(y) 

936 diff = math_ops.abs(x - y) 

937 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 

938 return control_flow_assert.Assert(condition, data, summarize=summarize) 

939 

940 

941@tf_export('debugging.assert_less', 'assert_less', v1=[]) 

942@dispatch.register_binary_elementwise_assert_api 

943@dispatch.add_dispatch_support 

944@_binary_assert_doc_v2('<', 'assert_less', 3) 

945def assert_less_v2(x, y, message=None, summarize=None, name=None): 

946 return assert_less(x=x, y=y, summarize=summarize, message=message, name=name) 

947 

948 

949@tf_export(v1=['debugging.assert_less', 'assert_less']) 

950@dispatch.register_binary_elementwise_assert_api 

951@dispatch.add_dispatch_support 

952@_binary_assert_doc('<', '[2, 3]') 

953def assert_less(x, y, data=None, summarize=None, message=None, name=None): 

954 return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, 

955 summarize, message, name) 

956 

957 

958@tf_export('debugging.assert_less_equal', v1=[]) 

959@dispatch.register_binary_elementwise_assert_api 

960@dispatch.add_dispatch_support 

961@_binary_assert_doc_v2('<=', 'assert_less_equal', 3) 

962def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): 

963 return assert_less_equal(x=x, y=y, 

964 summarize=summarize, message=message, name=name) 

965 

966 

967@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) 

968@dispatch.register_binary_elementwise_assert_api 

969@dispatch.add_dispatch_support 

970@deprecation.deprecated_endpoints('assert_less_equal') 

971@_binary_assert_doc('<=', '[1, 3]') 

972def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 

973 return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, 

974 np.less_equal, x, y, data, summarize, message, name) 

975 

976 

977@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) 

978@dispatch.register_binary_elementwise_assert_api 

979@dispatch.add_dispatch_support 

980@_binary_assert_doc_v2('>', 'assert_greater', 9) 

981def assert_greater_v2(x, y, message=None, summarize=None, name=None): 

982 return assert_greater(x=x, y=y, summarize=summarize, message=message, 

983 name=name) 

984 

985 

986@tf_export(v1=['debugging.assert_greater', 'assert_greater']) 

987@dispatch.register_binary_elementwise_assert_api 

988@dispatch.add_dispatch_support 

989@_binary_assert_doc('>', '[0, 1]') 

990def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 

991 return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, 

992 y, data, summarize, message, name) 

993 

994 

995@tf_export('debugging.assert_greater_equal', v1=[]) 

996@dispatch.register_binary_elementwise_assert_api 

997@dispatch.add_dispatch_support 

998@_binary_assert_doc_v2('>=', 'assert_greater_equal', 9) 

999def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): 

1000 return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, 

1001 name=name) 

1002 

1003 

1004@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) 

1005@dispatch.register_binary_elementwise_assert_api 

1006@dispatch.add_dispatch_support 

1007@deprecation.deprecated_endpoints('assert_greater_equal') 

1008@_binary_assert_doc('>=', '[1, 0]') 

1009def assert_greater_equal(x, y, data=None, summarize=None, message=None, 

1010 name=None): 

1011 return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, 

1012 np.greater_equal, x, y, data, summarize, message, name) 

1013 

1014 

1015def _assert_rank_condition( 

1016 x, rank, static_condition, dynamic_condition, data, summarize): 

1017 """Assert `x` has a rank that satisfies a given condition. 

1018 

1019 Args: 

1020 x: Numeric `Tensor`. 

1021 rank: Scalar `Tensor`. 

1022 static_condition: A python function that takes `[actual_rank, given_rank]` 

1023 and returns `True` if the condition is satisfied, `False` otherwise. 

1024 dynamic_condition: An `op` that takes [actual_rank, given_rank] and return 

1025 `True` if the condition is satisfied, `False` otherwise. 

1026 data: The tensors to print out if the condition is false. Defaults to 

1027 error message and first few entries of `x`. 

1028 summarize: Print this many entries of each tensor. 

1029 

1030 Returns: 

1031 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 

1032 

1033 Raises: 

1034 ValueError: If static checks determine `x` fails static_condition. 

1035 """ 

1036 assert_type(rank, dtypes.int32) 

1037 

1038 # Attempt to statically defined rank. 

1039 rank_static = tensor_util.constant_value(rank) 

1040 if rank_static is not None: 

1041 if rank_static.ndim != 0: 

1042 raise ValueError('Rank must be a scalar.') 

1043 

1044 x_rank_static = x.get_shape().ndims 

1045 if x_rank_static is not None: 

1046 if not static_condition(x_rank_static, rank_static): 

1047 raise ValueError( 

1048 'Static rank condition failed', x_rank_static, rank_static) 

1049 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 

1050 

1051 condition = dynamic_condition(array_ops.rank(x), rank) 

1052 

1053 # Add the condition that `rank` must have rank zero. Prevents the bug where 

1054 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 

1055 if rank_static is None: 

1056 this_data = ['Rank must be a scalar. Received rank: ', rank] 

1057 rank_check = assert_rank(rank, 0, data=this_data) 

1058 condition = control_flow_ops.with_dependencies([rank_check], condition) 

1059 

1060 return control_flow_assert.Assert(condition, data, summarize=summarize) 

1061 

1062 

1063@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) 

1064@dispatch.add_dispatch_support 

1065def assert_rank_v2(x, rank, message=None, name=None): 

1066 """Assert that `x` has rank equal to `rank`. 

1067 

1068 This Op checks that the rank of `x` is equal to `rank`. 

1069 

1070 If `x` has a different rank, `message`, as well as the shape of `x` are 

1071 printed, and `InvalidArgumentError` is raised. 

1072 

1073 Args: 

1074 x: `Tensor`. 

1075 rank: Scalar integer `Tensor`. 

1076 message: A string to prefix to the default message. 

1077 name: A name for this operation (optional). Defaults to 

1078 "assert_rank". 

1079 

1080 Returns: 

1081 Op raising `InvalidArgumentError` unless `x` has specified rank. 

1082 If static checks determine `x` has correct rank, a `no_op` is returned. 

1083 This can be used with `tf.control_dependencies` inside of `tf.function`s 

1084 to block followup computation until the check has executed. 

1085 @compatibility(eager) 

1086 returns None 

1087 @end_compatibility 

1088 

1089 Raises: 

1090 InvalidArgumentError: if the check can be performed immediately and 

1091 `x` does not have rank `rank`. The check can be performed immediately 

1092 during eager execution or if the shape of `x` is statically known. 

1093 """ 

1094 return assert_rank(x=x, rank=rank, message=message, name=name) 

1095 

1096 

1097@tf_export(v1=['debugging.assert_rank', 'assert_rank']) 

1098@dispatch.add_dispatch_support 

1099def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 

1100 """Assert `x` has rank equal to `rank`. 

1101 

1102 Example of adding a dependency to an operation: 

1103 

1104 ```python 

1105 with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): 

1106 output = tf.reduce_sum(x) 

1107 ``` 

1108 

1109 Args: 

1110 x: Numeric `Tensor`. 

1111 rank: Scalar integer `Tensor`. 

1112 data: The tensors to print out if the condition is False. Defaults to 

1113 error message and the shape of `x`. 

1114 summarize: Print this many entries of each tensor. 

1115 message: A string to prefix to the default message. 

1116 name: A name for this operation (optional). Defaults to "assert_rank". 

1117 

1118 Returns: 

1119 Op raising `InvalidArgumentError` unless `x` has specified rank. 

1120 If static checks determine `x` has correct rank, a `no_op` is returned. 

1121 

1122 Raises: 

1123 ValueError: If static checks determine `x` has wrong rank. 

1124 """ 

1125 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 

1126 if not isinstance(x, sparse_tensor.SparseTensor): 

1127 x = ops.convert_to_tensor(x, name='x') 

1128 rank = ops.convert_to_tensor(rank, name='rank') 

1129 message = _message_prefix(message) 

1130 

1131 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 

1132 dynamic_condition = math_ops.equal 

1133 

1134 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 

1135 name = '' 

1136 else: 

1137 name = x.name 

1138 

1139 if data is None: 

1140 data = [ 

1141 message, 

1142 'Tensor %s must have rank' % name, rank, 'Received shape: ', 

1143 array_ops.shape(x) 

1144 ] 

1145 

1146 try: 

1147 assert_op = _assert_rank_condition(x, rank, static_condition, 

1148 dynamic_condition, data, summarize) 

1149 

1150 except ValueError as e: 

1151 if e.args[0] == 'Static rank condition failed': 

1152 raise ValueError( 

1153 '%sTensor %s must have rank %d. Received rank %d, shape %s' % 

1154 (message, name, e.args[2], e.args[1], x.get_shape())) 

1155 else: 

1156 raise ValueError(e.args[0]) 

1157 

1158 return assert_op 

1159 

1160 

1161@tf_export('debugging.assert_rank_at_least', v1=[]) 

1162@dispatch.add_dispatch_support 

1163def assert_rank_at_least_v2(x, rank, message=None, name=None): 

1164 """Assert that `x` has rank of at least `rank`. 

1165 

1166 This Op checks that the rank of `x` is greater or equal to `rank`. 

1167 

1168 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` 

1169 are printed, and `InvalidArgumentError` is raised. 

1170 

1171 Args: 

1172 x: `Tensor`. 

1173 rank: Scalar integer `Tensor`. 

1174 message: A string to prefix to the default message. 

1175 name: A name for this operation (optional). Defaults to 

1176 "assert_rank_at_least". 

1177 

1178 Returns: 

1179 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 

1180 If static checks determine `x` has correct rank, a `no_op` is returned. 

1181 This can be used with `tf.control_dependencies` inside of `tf.function`s 

1182 to block followup computation until the check has executed. 

1183 @compatibility(eager) 

1184 returns None 

1185 @end_compatibility 

1186 

1187 Raises: 

1188 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank 

1189 cannot be statically determined. 

1190 ValueError: If static checks determine `x` has mismatched rank. 

1191 """ 

1192 return assert_rank_at_least(x=x, rank=rank, message=message, name=name) 

1193 

1194 

1195@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) 

1196@dispatch.add_dispatch_support 

1197@deprecation.deprecated_endpoints('assert_rank_at_least') 

1198def assert_rank_at_least( 

1199 x, rank, data=None, summarize=None, message=None, name=None): 

1200 """Assert `x` has rank equal to `rank` or higher. 

1201 

1202 Example of adding a dependency to an operation: 

1203 

1204 ```python 

1205 with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): 

1206 output = tf.reduce_sum(x) 

1207 ``` 

1208 

1209 Args: 

1210 x: Numeric `Tensor`. 

1211 rank: Scalar `Tensor`. 

1212 data: The tensors to print out if the condition is False. Defaults to 

1213 error message and first few entries of `x`. 

1214 summarize: Print this many entries of each tensor. 

1215 message: A string to prefix to the default message. 

1216 name: A name for this operation (optional). 

1217 Defaults to "assert_rank_at_least". 

1218 

1219 Returns: 

1220 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 

1221 If static checks determine `x` has correct rank, a `no_op` is returned. 

1222 

1223 Raises: 

1224 ValueError: If static checks determine `x` has wrong rank. 

1225 """ 

1226 with ops.name_scope( 

1227 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 

1228 x = ops.convert_to_tensor(x, name='x') 

1229 rank = ops.convert_to_tensor(rank, name='rank') 

1230 message = _message_prefix(message) 

1231 

1232 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 

1233 dynamic_condition = math_ops.greater_equal 

1234 

1235 if context.executing_eagerly(): 

1236 name = '' 

1237 else: 

1238 name = x.name 

1239 

1240 if data is None: 

1241 data = [ 

1242 message, 

1243 'Tensor %s must have rank at least' % name, rank, 

1244 'Received shape: ', array_ops.shape(x) 

1245 ] 

1246 

1247 try: 

1248 assert_op = _assert_rank_condition(x, rank, static_condition, 

1249 dynamic_condition, data, summarize) 

1250 

1251 except ValueError as e: 

1252 if e.args[0] == 'Static rank condition failed': 

1253 raise ValueError( 

1254 '%sTensor %s must have rank at least %d. Received rank %d, ' 

1255 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 

1256 else: 

1257 raise 

1258 

1259 return assert_op 

1260 

1261 

1262def _static_rank_in(actual_rank, given_ranks): 

1263 return actual_rank in given_ranks 

1264 

1265 

1266def _dynamic_rank_in(actual_rank, given_ranks): 

1267 if len(given_ranks) < 1: 

1268 return ops.convert_to_tensor(False) 

1269 result = math_ops.equal(given_ranks[0], actual_rank) 

1270 for given_rank in given_ranks[1:]: 

1271 result = math_ops.logical_or( 

1272 result, math_ops.equal(given_rank, actual_rank)) 

1273 return result 

1274 

1275 

1276def _assert_ranks_condition( 

1277 x, ranks, static_condition, dynamic_condition, data, summarize): 

1278 """Assert `x` has a rank that satisfies a given condition. 

1279 

1280 Args: 

1281 x: Numeric `Tensor`. 

1282 ranks: Scalar `Tensor`. 

1283 static_condition: A python function that takes 

1284 `[actual_rank, given_ranks]` and returns `True` if the condition is 

1285 satisfied, `False` otherwise. 

1286 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 

1287 and return `True` if the condition is satisfied, `False` otherwise. 

1288 data: The tensors to print out if the condition is false. Defaults to 

1289 error message and first few entries of `x`. 

1290 summarize: Print this many entries of each tensor. 

1291 

1292 Returns: 

1293 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 

1294 

1295 Raises: 

1296 ValueError: If static checks determine `x` fails static_condition. 

1297 """ 

1298 for rank in ranks: 

1299 assert_type(rank, dtypes.int32) 

1300 

1301 # Attempt to statically defined rank. 

1302 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 

1303 if not any(r is None for r in ranks_static): 

1304 for rank_static in ranks_static: 

1305 if rank_static.ndim != 0: 

1306 raise ValueError('Rank must be a scalar.') 

1307 

1308 x_rank_static = x.get_shape().ndims 

1309 if x_rank_static is not None: 

1310 if not static_condition(x_rank_static, ranks_static): 

1311 raise ValueError( 

1312 'Static rank condition failed', x_rank_static, ranks_static) 

1313 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 

1314 

1315 condition = dynamic_condition(array_ops.rank(x), ranks) 

1316 

1317 # Add the condition that `rank` must have rank zero. Prevents the bug where 

1318 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 

1319 for rank, rank_static in zip(ranks, ranks_static): 

1320 if rank_static is None: 

1321 this_data = ['Rank must be a scalar. Received rank: ', rank] 

1322 rank_check = assert_rank(rank, 0, data=this_data) 

1323 condition = control_flow_ops.with_dependencies([rank_check], condition) 

1324 

1325 return control_flow_assert.Assert(condition, data, summarize=summarize) 

1326 

1327 

1328@tf_export('debugging.assert_rank_in', v1=[]) 

1329@dispatch.add_dispatch_support 

1330def assert_rank_in_v2(x, ranks, message=None, name=None): 

1331 """Assert that `x` has a rank in `ranks`. 

1332 

1333 This Op checks that the rank of `x` is in `ranks`. 

1334 

1335 If `x` has a different rank, `message`, as well as the shape of `x` are 

1336 printed, and `InvalidArgumentError` is raised. 

1337 

1338 Args: 

1339 x: `Tensor`. 

1340 ranks: `Iterable` of scalar `Tensor` objects. 

1341 message: A string to prefix to the default message. 

1342 name: A name for this operation (optional). Defaults to "assert_rank_in". 

1343 

1344 Returns: 

1345 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 

1346 If static checks determine `x` has matching rank, a `no_op` is returned. 

1347 This can be used with `tf.control_dependencies` inside of `tf.function`s 

1348 to block followup computation until the check has executed. 

1349 @compatibility(eager) 

1350 returns None 

1351 @end_compatibility 

1352 

1353 Raises: 

1354 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot 

1355 be statically determined. 

1356 ValueError: If static checks determine `x` has mismatched rank. 

1357 """ 

1358 return assert_rank_in(x=x, ranks=ranks, message=message, name=name) 

1359 

1360 

1361@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) 

1362@dispatch.add_dispatch_support 

1363@deprecation.deprecated_endpoints('assert_rank_in') 

1364def assert_rank_in( 

1365 x, ranks, data=None, summarize=None, message=None, name=None): 

1366 """Assert `x` has rank in `ranks`. 

1367 

1368 Example of adding a dependency to an operation: 

1369 

1370 ```python 

1371 with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): 

1372 output = tf.reduce_sum(x) 

1373 ``` 

1374 

1375 Args: 

1376 x: Numeric `Tensor`. 

1377 ranks: Iterable of scalar `Tensor` objects. 

1378 data: The tensors to print out if the condition is False. Defaults to 

1379 error message and first few entries of `x`. 

1380 summarize: Print this many entries of each tensor. 

1381 message: A string to prefix to the default message. 

1382 name: A name for this operation (optional). 

1383 Defaults to "assert_rank_in". 

1384 

1385 Returns: 

1386 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 

1387 If static checks determine `x` has matching rank, a `no_op` is returned. 

1388 

1389 Raises: 

1390 ValueError: If static checks determine `x` has mismatched rank. 

1391 """ 

1392 with ops.name_scope( 

1393 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 

1394 if not isinstance(x, sparse_tensor.SparseTensor): 

1395 x = ops.convert_to_tensor(x, name='x') 

1396 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 

1397 message = _message_prefix(message) 

1398 

1399 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 

1400 name = '' 

1401 else: 

1402 name = x.name 

1403 

1404 if data is None: 

1405 data = [ 

1406 message, 'Tensor %s must have rank in' % name 

1407 ] + list(ranks) + [ 

1408 'Received shape: ', array_ops.shape(x) 

1409 ] 

1410 

1411 try: 

1412 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 

1413 _dynamic_rank_in, data, summarize) 

1414 

1415 except ValueError as e: 

1416 if e.args[0] == 'Static rank condition failed': 

1417 raise ValueError( 

1418 '%sTensor %s must have rank in %s. Received rank %d, ' 

1419 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 

1420 else: 

1421 raise 

1422 

1423 return assert_op 

1424 

1425 

1426@tf_export('debugging.assert_integer', v1=[]) 

1427@dispatch.add_dispatch_support 

1428def assert_integer_v2(x, message=None, name=None): 

1429 """Assert that `x` is of integer dtype. 

1430 

1431 If `x` has a non-integer type, `message`, as well as the dtype of `x` are 

1432 printed, and `InvalidArgumentError` is raised. 

1433 

1434 This can always be checked statically, so this method returns nothing. 

1435 

1436 Args: 

1437 x: A `Tensor`. 

1438 message: A string to prefix to the default message. 

1439 name: A name for this operation (optional). Defaults to "assert_integer". 

1440 

1441 Raises: 

1442 TypeError: If `x.dtype` is not a non-quantized integer type. 

1443 """ 

1444 assert_integer(x=x, message=message, name=name) 

1445 

1446 

1447@tf_export(v1=['debugging.assert_integer', 'assert_integer']) 

1448@dispatch.add_dispatch_support 

1449@deprecation.deprecated_endpoints('assert_integer') 

1450def assert_integer(x, message=None, name=None): 

1451 """Assert that `x` is of integer dtype. 

1452 

1453 Example of adding a dependency to an operation: 

1454 

1455 ```python 

1456 with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): 

1457 output = tf.reduce_sum(x) 

1458 ``` 

1459 

1460 Args: 

1461 x: `Tensor` whose basetype is integer and is not quantized. 

1462 message: A string to prefix to the default message. 

1463 name: A name for this operation (optional). Defaults to "assert_integer". 

1464 

1465 Raises: 

1466 TypeError: If `x.dtype` is anything other than non-quantized integer. 

1467 

1468 Returns: 

1469 A `no_op` that does nothing. Type can be determined statically. 

1470 """ 

1471 with ops.name_scope(name, 'assert_integer', [x]): 

1472 x = ops.convert_to_tensor(x, name='x') 

1473 if not x.dtype.is_integer: 

1474 if context.executing_eagerly(): 

1475 name = 'tensor' 

1476 else: 

1477 name = x.name 

1478 err_msg = ( 

1479 '%sExpected "x" to be integer type. Found: %s of dtype %s' 

1480 % (_message_prefix(message), name, x.dtype)) 

1481 raise TypeError(err_msg) 

1482 

1483 return control_flow_ops.no_op('statically_determined_was_integer') 

1484 

1485 

1486@tf_export('debugging.assert_type', v1=[]) 

1487@dispatch.add_dispatch_support 

1488def assert_type_v2(tensor, tf_type, message=None, name=None): 

1489 """Asserts that the given `Tensor` is of the specified type. 

1490 

1491 This can always be checked statically, so this method returns nothing. 

1492 

1493 Example: 

1494 

1495 >>> a = tf.Variable(1.0) 

1496 >>> tf.debugging.assert_type(a, tf_type= tf.float32) 

1497 

1498 >>> b = tf.constant(21) 

1499 >>> tf.debugging.assert_type(b, tf_type=tf.bool) 

1500 Traceback (most recent call last): 

1501 ... 

1502 TypeError: ... 

1503 

1504 >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], 

1505 ... dense_shape=[3, 4]) 

1506 >>> tf.debugging.assert_type(c, tf_type= tf.int32) 

1507 

1508 Args: 

1509 tensor: A `Tensor`, `SparseTensor` or `tf.Variable` . 

1510 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 

1511 etc). 

1512 message: A string to prefix to the default message. 

1513 name: A name for this operation. Defaults to "assert_type" 

1514 

1515 Raises: 

1516 TypeError: If the tensor's data type doesn't match `tf_type`. 

1517 """ 

1518 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) 

1519 

1520 

1521@tf_export(v1=['debugging.assert_type', 'assert_type']) 

1522@dispatch.add_dispatch_support 

1523@deprecation.deprecated_endpoints('assert_type') 

1524def assert_type(tensor, tf_type, message=None, name=None): 

1525 """Statically asserts that the given `Tensor` is of the specified type. 

1526 

1527 Args: 

1528 tensor: A `Tensor` or `SparseTensor`. 

1529 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 

1530 etc). 

1531 message: A string to prefix to the default message. 

1532 name: A name to give this `Op`. Defaults to "assert_type" 

1533 

1534 Raises: 

1535 TypeError: If the tensors data type doesn't match `tf_type`. 

1536 

1537 Returns: 

1538 A `no_op` that does nothing. Type can be determined statically. 

1539 """ 

1540 tf_type = dtypes.as_dtype(tf_type) 

1541 with ops.name_scope(name, 'assert_type', [tensor]): 

1542 if not isinstance(tensor, sparse_tensor.SparseTensor): 

1543 tensor = ops.convert_to_tensor(tensor, name='tensor') 

1544 if tensor.dtype != tf_type: 

1545 raise TypeError( 

1546 f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}' 

1547 f' must be of type {tf_type!r}; got {tensor.dtype!r}') 

1548 

1549 return control_flow_ops.no_op('statically_determined_correct_type') 

1550 

1551 

1552def _dimension_sizes(x): 

1553 """Gets the dimension sizes of a tensor `x`. 

1554 

1555 If a size can be determined statically it is returned as an integer, 

1556 otherwise as a tensor. 

1557 

1558 If `x` is a scalar it is treated as rank 1 size 1. 

1559 

1560 Args: 

1561 x: A `Tensor`. 

1562 

1563 Returns: 

1564 Dimension sizes. 

1565 """ 

1566 dynamic_shape = array_ops.shape(x) 

1567 rank = x.get_shape().rank 

1568 rank_is_known = rank is not None 

1569 if rank_is_known and rank == 0: 

1570 return (1,) 

1571 if rank_is_known and rank > 0: 

1572 static_shape = x.get_shape().as_list() 

1573 sizes = [ 

1574 int(size) if size is not None else dynamic_shape[i] 

1575 for i, size in enumerate(static_shape) 

1576 ] 

1577 return sizes 

1578 has_rank_zero = math_ops.equal(array_ops.rank(x), 0) 

1579 return cond.cond( 

1580 has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) 

1581 

1582 

1583def _symbolic_dimension_sizes(symbolic_shape): 

1584 # If len(symbolic_shape) == 0 construct a tuple 

1585 if not symbolic_shape: 

1586 return tuple([1]) 

1587 

1588 return symbolic_shape 

1589 

1590 

1591def _has_known_value(dimension_size): 

1592 not_none = dimension_size is not None 

1593 try: 

1594 int(dimension_size) 

1595 can_be_parsed_as_int = True 

1596 except (ValueError, TypeError): 

1597 can_be_parsed_as_int = False 

1598 return not_none and can_be_parsed_as_int 

1599 

1600 

1601def _is_symbol_for_any_size(symbol): 

1602 return symbol in [None, '.'] 

1603 

1604 

1605_TensorDimSizes = collections.namedtuple( 

1606 '_TensorDimSizes', 

1607 ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) 

1608 

1609 

1610@tf_export('debugging.assert_shapes', v1=[]) 

1611@dispatch.add_dispatch_support 

1612def assert_shapes_v2(shapes, data=None, summarize=None, message=None, 

1613 name=None): 

1614 """Assert tensor shapes and dimension size relationships between tensors. 

1615 

1616 This Op checks that a collection of tensors shape relationships 

1617 satisfies given constraints. 

1618 

1619 Example: 

1620 

1621 >>> n = 10 

1622 >>> q = 3 

1623 >>> d = 7 

1624 >>> x = tf.zeros([n,q]) 

1625 >>> y = tf.ones([n,d]) 

1626 >>> param = tf.Variable([1.0, 2.0, 3.0]) 

1627 >>> scalar = 1.0 

1628 >>> tf.debugging.assert_shapes([ 

1629 ... (x, ('N', 'Q')), 

1630 ... (y, ('N', 'D')), 

1631 ... (param, ('Q',)), 

1632 ... (scalar, ()), 

1633 ... ]) 

1634 

1635 >>> tf.debugging.assert_shapes([ 

1636 ... (x, ('N', 'D')), 

1637 ... (y, ('N', 'D')) 

1638 ... ]) 

1639 Traceback (most recent call last): 

1640 ... 

1641 ValueError: ... 

1642 

1643 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 

1644 all specified constraints, `message`, as well as the first `summarize` entries 

1645 of the first encountered violating tensor are printed, and 

1646 `InvalidArgumentError` is raised. 

1647 

1648 Size entries in the specified shapes are checked against other entries by 

1649 their __hash__, except: 

1650 - a size entry is interpreted as an explicit size if it can be parsed as an 

1651 integer primitive. 

1652 - a size entry is interpreted as *any* size if it is None or '.'. 

1653 

1654 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 

1655 a variable number of outer dimensions of unspecified size, i.e. the constraint 

1656 applies to the inner-most dimensions only. 

1657 

1658 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 

1659 prefix) are both treated as having a single dimension of size one. 

1660 

1661 Args: 

1662 shapes: dictionary with (`Tensor` to shape) items, or a list of 

1663 (`Tensor`, shape) tuples. A shape must be an iterable. 

1664 data: The tensors to print out if the condition is False. Defaults to error 

1665 message and first few entries of the violating tensor. 

1666 summarize: Print this many entries of the tensor. 

1667 message: A string to prefix to the default message. 

1668 name: A name for this operation (optional). Defaults to "assert_shapes". 

1669 

1670 Raises: 

1671 ValueError: If static checks determine any shape constraint is violated. 

1672 """ 

1673 assert_shapes( 

1674 shapes, data=data, summarize=summarize, message=message, name=name) 

1675 

1676 

1677@tf_export(v1=['debugging.assert_shapes']) 

1678@dispatch.add_dispatch_support 

1679def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 

1680 """Assert tensor shapes and dimension size relationships between tensors. 

1681 

1682 This Op checks that a collection of tensors shape relationships 

1683 satisfies given constraints. 

1684 

1685 Example: 

1686 

1687 >>> n = 10 

1688 >>> q = 3 

1689 >>> d = 7 

1690 >>> x = tf.zeros([n,q]) 

1691 >>> y = tf.ones([n,d]) 

1692 >>> param = tf.Variable([1.0, 2.0, 3.0]) 

1693 >>> scalar = 1.0 

1694 >>> tf.debugging.assert_shapes([ 

1695 ... (x, ('N', 'Q')), 

1696 ... (y, ('N', 'D')), 

1697 ... (param, ('Q',)), 

1698 ... (scalar, ()), 

1699 ... ]) 

1700 

1701 >>> tf.debugging.assert_shapes([ 

1702 ... (x, ('N', 'D')), 

1703 ... (y, ('N', 'D')) 

1704 ... ]) 

1705 Traceback (most recent call last): 

1706 ... 

1707 ValueError: ... 

1708 

1709 Example of adding a dependency to an operation: 

1710 

1711 ```python 

1712 with tf.control_dependencies([tf.assert_shapes(shapes)]): 

1713 output = tf.matmul(x, y, transpose_a=True) 

1714 ``` 

1715 

1716 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 

1717 all specified constraints, `message`, as well as the first `summarize` entries 

1718 of the first encountered violating tensor are printed, and 

1719 `InvalidArgumentError` is raised. 

1720 

1721 Size entries in the specified shapes are checked against other entries by 

1722 their __hash__, except: 

1723 - a size entry is interpreted as an explicit size if it can be parsed as an 

1724 integer primitive. 

1725 - a size entry is interpreted as *any* size if it is None or '.'. 

1726 

1727 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 

1728 a variable number of outer dimensions of unspecified size, i.e. the constraint 

1729 applies to the inner-most dimensions only. 

1730 

1731 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 

1732 prefix) are both treated as having a single dimension of size one. 

1733 

1734 Args: 

1735 shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the 

1736 expected shape of `Tensor`. See the example code above. The `shape` must 

1737 be an iterable. Each element of the iterable can be either a concrete 

1738 integer value or a string that abstractly represents the dimension. 

1739 For example, 

1740 - `('N', 'Q')` specifies a 2D shape wherein the first and second 

1741 dimensions of shape may or may not be equal. 

1742 - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second 

1743 dimensions are equal. 

1744 - `(1, 'N')` specifies a 2D shape wherein the first dimension is 

1745 exactly 1 and the second dimension can be any value. 

1746 Note that the abstract dimension letters take effect across different 

1747 tuple elements of the list. For example, 

1748 `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts 

1749 that both `x` and `y` are rank-2 tensors and their first dimensions are 

1750 equal (`N`). 

1751 `shape` can also be a `tf.TensorShape`. 

1752 data: The tensors to print out if the condition is False. Defaults to error 

1753 message and first few entries of the violating tensor. 

1754 summarize: Print this many entries of the tensor. 

1755 message: A string to prefix to the default message. 

1756 name: A name for this operation (optional). Defaults to "assert_shapes". 

1757 

1758 Returns: 

1759 Op raising `InvalidArgumentError` unless all shape constraints are 

1760 satisfied. 

1761 If static checks determine all constraints are satisfied, a `no_op` is 

1762 returned. 

1763 

1764 Raises: 

1765 ValueError: If static checks determine any shape constraint is violated. 

1766 """ 

1767 # If the user manages to assemble a dict containing tensors (possible in 

1768 # Graph mode only), make sure we still accept that. 

1769 if isinstance(shapes, dict): 

1770 shapes = shapes.items() 

1771 

1772 message_prefix = _message_prefix(message) 

1773 with ops.name_scope(name, 'assert_shapes', [shapes, data]): 

1774 # Shape specified as None implies no constraint 

1775 shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else 

1776 ops.convert_to_tensor(x), s) 

1777 for x, s in shapes if s is not None] 

1778 

1779 executing_eagerly = context.executing_eagerly() 

1780 

1781 def tensor_name(x): 

1782 if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): 

1783 return _shape_and_dtype_str(x) 

1784 return x.name 

1785 

1786 tensor_dim_sizes = [] 

1787 for tensor, symbolic_shape in shape_constraints: 

1788 is_iterable = ( 

1789 hasattr(symbolic_shape, '__iter__') or 

1790 hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. 

1791 ) 

1792 if not is_iterable: 

1793 raise ValueError( 

1794 '%s' 

1795 'Tensor %s. Specified shape must be an iterable. ' 

1796 'An iterable has the attribute `__iter__` or `__getitem__`. ' 

1797 'Received specified shape: %s' % 

1798 (message_prefix, tensor_name(tensor), symbolic_shape)) 

1799 

1800 # We convert this into a tuple to handle strings, lists and numpy arrays 

1801 symbolic_shape_tuple = tuple(symbolic_shape) 

1802 

1803 tensors_specified_innermost = False 

1804 for i, symbol in enumerate(symbolic_shape_tuple): 

1805 if symbol not in [Ellipsis, '*']: 

1806 continue 

1807 

1808 if i != 0: 

1809 raise ValueError( 

1810 '%s' 

1811 'Tensor %s specified shape index %d. ' 

1812 'Symbol `...` or `*` for a variable number of ' 

1813 'unspecified dimensions is only allowed as the first entry' % 

1814 (message_prefix, tensor_name(tensor), i)) 

1815 

1816 tensors_specified_innermost = True 

1817 

1818 # Only include the size of the specified dimensions since the 0th symbol 

1819 # is either ellipsis or * 

1820 tensor_dim_sizes.append( 

1821 _TensorDimSizes( 

1822 tensor, tensors_specified_innermost, _dimension_sizes(tensor), 

1823 _symbolic_dimension_sizes( 

1824 symbolic_shape_tuple[1:] 

1825 if tensors_specified_innermost else symbolic_shape_tuple))) 

1826 

1827 rank_assertions = [] 

1828 for sizes in tensor_dim_sizes: 

1829 rank = len(sizes.symbolic_sizes) 

1830 rank_zero_or_one = rank in [0, 1] 

1831 if sizes.unspecified_dim: 

1832 if rank_zero_or_one: 

1833 # No assertion of rank needed as `x` only need to have rank at least 

1834 # 0. See elif rank_zero_or_one case comment. 

1835 continue 

1836 assertion = assert_rank_at_least( 

1837 x=sizes.x, 

1838 rank=rank, 

1839 data=data, 

1840 summarize=summarize, 

1841 message=message, 

1842 name=name) 

1843 elif rank_zero_or_one: 

1844 # Rank 0 is treated as rank 1 size 1, i.e. there is 

1845 # no distinction between the two in terms of rank. 

1846 # See _dimension_sizes. 

1847 assertion = assert_rank_in( 

1848 x=sizes.x, 

1849 ranks=[0, 1], 

1850 data=data, 

1851 summarize=summarize, 

1852 message=message, 

1853 name=name) 

1854 else: 

1855 assertion = assert_rank( 

1856 x=sizes.x, 

1857 rank=rank, 

1858 data=data, 

1859 summarize=summarize, 

1860 message=message, 

1861 name=name) 

1862 rank_assertions.append(assertion) 

1863 

1864 size_assertions = [] 

1865 size_specifications = {} 

1866 for sizes in tensor_dim_sizes: 

1867 for i, size_symbol in enumerate(sizes.symbolic_sizes): 

1868 

1869 if _is_symbol_for_any_size(size_symbol): 

1870 # Size specified as any implies no constraint 

1871 continue 

1872 

1873 if sizes.unspecified_dim: 

1874 tensor_dim = i - len(sizes.symbolic_sizes) 

1875 else: 

1876 tensor_dim = i 

1877 

1878 if size_symbol in size_specifications or _has_known_value(size_symbol): 

1879 if _has_known_value(size_symbol): 

1880 specified_size = int(size_symbol) 

1881 size_check_message = 'Specified explicitly' 

1882 else: 

1883 specified_size, specified_by_y, specified_at_dim = ( 

1884 size_specifications[size_symbol]) 

1885 size_check_message = ( 

1886 'Specified by tensor %s dimension %d' % 

1887 (tensor_name(specified_by_y), specified_at_dim)) 

1888 

1889 # This is extremely subtle. If actual_sizes is dynamic, we must 

1890 # make sure a control dependency is inserted here so that this slice 

1891 # can not execute until the rank is asserted to be enough for the 

1892 # slice to not fail. 

1893 with ops.control_dependencies(rank_assertions): 

1894 actual_size = sizes.actual_sizes[tensor_dim] 

1895 if _has_known_value(actual_size) and _has_known_value(specified_size): 

1896 if int(actual_size) != int(specified_size): 

1897 raise ValueError( 

1898 '%s%s. Tensor %s dimension %s must have size %d. ' 

1899 'Received size %d, shape %s' % 

1900 (message_prefix, size_check_message, tensor_name(sizes.x), 

1901 tensor_dim, specified_size, actual_size, 

1902 sizes.x.get_shape())) 

1903 # No dynamic assertion needed 

1904 continue 

1905 

1906 condition = math_ops.equal( 

1907 ops.convert_to_tensor(actual_size), 

1908 ops.convert_to_tensor(specified_size)) 

1909 data_ = data 

1910 if data is None: 

1911 data_ = [ 

1912 message_prefix, size_check_message, 

1913 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 

1914 'must have size', specified_size, 'Received shape: ', 

1915 array_ops.shape(sizes.x) 

1916 ] 

1917 size_assertions.append( 

1918 control_flow_assert.Assert(condition, data_, summarize=summarize)) 

1919 else: 

1920 # Not sure if actual_sizes is a constant, but for safety, guard 

1921 # on rank. See explanation above about actual_sizes need for safety. 

1922 with ops.control_dependencies(rank_assertions): 

1923 size = sizes.actual_sizes[tensor_dim] 

1924 size_specifications[size_symbol] = (size, sizes.x, tensor_dim) 

1925 

1926 # Ensure both assertions actually occur. 

1927 with ops.control_dependencies(rank_assertions): 

1928 shapes_assertion = control_flow_ops.group(size_assertions) 

1929 

1930 return shapes_assertion 

1931 

1932 

1933# pylint: disable=line-too-long 

1934def _get_diff_for_monotonic_comparison(x): 

1935 """Gets the difference x[1:] - x[:-1].""" 

1936 x = array_ops.reshape(x, [-1]) 

1937 if not is_numeric_tensor(x): 

1938 raise TypeError('Expected x to be numeric, instead found: %s' % x) 

1939 

1940 # If x has less than 2 elements, there is nothing to compare. So return []. 

1941 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 

1942 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 

1943 

1944 # With 2 or more elements, return x[1:] - x[:-1] 

1945 s_len = array_ops.shape(x) - 1 

1946 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 

1947 return cond.cond(is_shorter_than_two, short_result, diff) 

1948 

1949 

1950@tf_export( 

1951 'debugging.is_numeric_tensor', 

1952 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) 

1953@deprecation.deprecated_endpoints('is_numeric_tensor') 

1954def is_numeric_tensor(tensor): 

1955 """Returns `True` if the elements of `tensor` are numbers. 

1956 

1957 Specifically, returns `True` if the dtype of `tensor` is one of the following: 

1958 

1959 * `tf.float16` 

1960 * `tf.float32` 

1961 * `tf.float64` 

1962 * `tf.int8` 

1963 * `tf.int16` 

1964 * `tf.int32` 

1965 * `tf.int64` 

1966 * `tf.uint8` 

1967 * `tf.uint16` 

1968 * `tf.uint32` 

1969 * `tf.uint64` 

1970 * `tf.qint8` 

1971 * `tf.qint16` 

1972 * `tf.qint32` 

1973 * `tf.quint8` 

1974 * `tf.quint16` 

1975 * `tf.complex64` 

1976 * `tf.complex128` 

1977 * `tf.bfloat16` 

1978 

1979 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not 

1980 a `tf.Tensor` object. 

1981 """ 

1982 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 

1983 

1984 

1985@tf_export( 

1986 'math.is_non_decreasing', 

1987 v1=[ 

1988 'math.is_non_decreasing', 'debugging.is_non_decreasing', 

1989 'is_non_decreasing' 

1990 ]) 

1991@dispatch.add_dispatch_support 

1992@deprecation.deprecated_endpoints('debugging.is_non_decreasing', 

1993 'is_non_decreasing') 

1994def is_non_decreasing(x, name=None): 

1995 """Returns `True` if `x` is non-decreasing. 

1996 

1997 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 

1998 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 

1999 If `x` has less than two elements, it is trivially non-decreasing. 

2000 

2001 See also: `is_strictly_increasing` 

2002 

2003 >>> x1 = tf.constant([1.0, 1.0, 3.0]) 

2004 >>> tf.math.is_non_decreasing(x1) 

2005 <tf.Tensor: shape=(), dtype=bool, numpy=True> 

2006 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 

2007 >>> tf.math.is_non_decreasing(x2) 

2008 <tf.Tensor: shape=(), dtype=bool, numpy=False> 

2009 

2010 Args: 

2011 x: Numeric `Tensor`. 

2012 name: A name for this operation (optional). Defaults to "is_non_decreasing" 

2013 

2014 Returns: 

2015 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 

2016 

2017 Raises: 

2018 TypeError: if `x` is not a numeric tensor. 

2019 """ 

2020 with ops.name_scope(name, 'is_non_decreasing', [x]): 

2021 diff = _get_diff_for_monotonic_comparison(x) 

2022 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 

2023 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 

2024 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 

2025 

2026 

2027@tf_export( 

2028 'math.is_strictly_increasing', 

2029 v1=[ 

2030 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 

2031 'is_strictly_increasing' 

2032 ]) 

2033@dispatch.add_dispatch_support 

2034@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 

2035 'is_strictly_increasing') 

2036def is_strictly_increasing(x, name=None): 

2037 """Returns `True` if `x` is strictly increasing. 

2038 

2039 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 

2040 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 

2041 If `x` has less than two elements, it is trivially strictly increasing. 

2042 

2043 See also: `is_non_decreasing` 

2044 

2045 >>> x1 = tf.constant([1.0, 2.0, 3.0]) 

2046 >>> tf.math.is_strictly_increasing(x1) 

2047 <tf.Tensor: shape=(), dtype=bool, numpy=True> 

2048 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 

2049 >>> tf.math.is_strictly_increasing(x2) 

2050 <tf.Tensor: shape=(), dtype=bool, numpy=False> 

2051 

2052 Args: 

2053 x: Numeric `Tensor`. 

2054 name: A name for this operation (optional). 

2055 Defaults to "is_strictly_increasing" 

2056 

2057 Returns: 

2058 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 

2059 

2060 Raises: 

2061 TypeError: if `x` is not a numeric tensor. 

2062 """ 

2063 with ops.name_scope(name, 'is_strictly_increasing', [x]): 

2064 diff = _get_diff_for_monotonic_comparison(x) 

2065 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 

2066 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 

2067 return math_ops.reduce_all(math_ops.less(zero, diff)) 

2068 

2069 

2070def _assert_same_base_type(items, expected_type=None): 

2071 r"""Asserts all items are of the same base type. 

2072 

2073 Args: 

2074 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 

2075 `Operation`, or `IndexedSlices`). Can include `None` elements, which 

2076 will be ignored. 

2077 expected_type: Expected type. If not specified, assert all items are 

2078 of the same base type. 

2079 

2080 Returns: 

2081 Validated type, or none if neither expected_type nor items provided. 

2082 

2083 Raises: 

2084 ValueError: If any types do not match. 

2085 """ 

2086 original_expected_type = expected_type 

2087 mismatch = False 

2088 for item in items: 

2089 if item is not None: 

2090 item_type = item.dtype.base_dtype 

2091 if not expected_type: 

2092 expected_type = item_type 

2093 elif expected_type != item_type: 

2094 mismatch = True 

2095 break 

2096 if mismatch: 

2097 # Loop back through and build up an informative error message (this is very 

2098 # slow, so we don't do it unless we found an error above). 

2099 expected_type = original_expected_type 

2100 original_item_str = None 

2101 for item in items: 

2102 if item is not None: 

2103 item_type = item.dtype.base_dtype 

2104 if not expected_type: 

2105 expected_type = item_type 

2106 original_item_str = item.name if hasattr(item, 'name') else str(item) 

2107 elif expected_type != item_type: 

2108 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 

2109 item.name if hasattr(item, 'name') else str(item), 

2110 item_type, expected_type, 

2111 (' as %s' % original_item_str) if original_item_str else '')) 

2112 return expected_type # Should be unreachable 

2113 else: 

2114 return expected_type 

2115 

2116 

2117@tf_export( 

2118 'debugging.assert_same_float_dtype', 

2119 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) 

2120@dispatch.add_dispatch_support 

2121@deprecation.deprecated_endpoints('assert_same_float_dtype') 

2122def assert_same_float_dtype(tensors=None, dtype=None): 

2123 """Validate and return float type based on `tensors` and `dtype`. 

2124 

2125 For ops such as matrix multiplication, inputs and weights must be of the 

2126 same float type. This function validates that all `tensors` are the same type, 

2127 validates that type is `dtype` (if supplied), and returns the type. Type must 

2128 be a floating point type. If neither `tensors` nor `dtype` is supplied, 

2129 the function will return `dtypes.float32`. 

2130 

2131 Args: 

2132 tensors: Tensors of input values. Can include `None` elements, which will be 

2133 ignored. 

2134 dtype: Expected type. 

2135 

2136 Returns: 

2137 Validated type. 

2138 

2139 Raises: 

2140 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 

2141 float, or the common type of the inputs is not a floating point type. 

2142 """ 

2143 if tensors: 

2144 dtype = _assert_same_base_type(tensors, dtype) 

2145 if not dtype: 

2146 dtype = dtypes.float32 

2147 elif not dtype.is_floating: 

2148 raise ValueError('Expected floating point type, got %s.' % dtype) 

2149 return dtype 

2150 

2151 

2152@tf_export('debugging.assert_scalar', v1=[]) 

2153@dispatch.add_dispatch_support 

2154def assert_scalar_v2(tensor, message=None, name=None): 

2155 """Asserts that the given `tensor` is a scalar. 

2156 

2157 This function raises `ValueError` unless it can be certain that the given 

2158 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 

2159 unknown. 

2160 

2161 This is always checked statically, so this method returns nothing. 

2162 

2163 Args: 

2164 tensor: A `Tensor`. 

2165 message: A string to prefix to the default message. 

2166 name: A name for this operation. Defaults to "assert_scalar" 

2167 

2168 Raises: 

2169 ValueError: If the tensor is not scalar (rank 0), or if its shape is 

2170 unknown. 

2171 """ 

2172 assert_scalar(tensor=tensor, message=message, name=name) 

2173 

2174 

2175@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) 

2176@dispatch.add_dispatch_support 

2177@deprecation.deprecated_endpoints('assert_scalar') 

2178def assert_scalar(tensor, name=None, message=None): 

2179 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). 

2180 

2181 This function raises `ValueError` unless it can be certain that the given 

2182 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 

2183 unknown. 

2184 

2185 Args: 

2186 tensor: A `Tensor`. 

2187 name: A name for this operation. Defaults to "assert_scalar" 

2188 message: A string to prefix to the default message. 

2189 

2190 Returns: 

2191 The input tensor (potentially converted to a `Tensor`). 

2192 

2193 Raises: 

2194 ValueError: If the tensor is not scalar (rank 0), or if its shape is 

2195 unknown. 

2196 """ 

2197 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 

2198 tensor = ops.convert_to_tensor(tensor, name=name_scope) 

2199 shape = tensor.get_shape() 

2200 message = _message_prefix(message) 

2201 if shape.ndims != 0: 

2202 if context.executing_eagerly(): 

2203 raise ValueError('%sExpected scalar shape, saw shape: %s.' 

2204 % (message, shape,)) 

2205 else: 

2206 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' 

2207 % (message, tensor.name, shape)) 

2208 return tensor 

2209 

2210 

2211def _message_prefix(message): 

2212 if message: 

2213 return '%s. ' % message 

2214 return '' 

2215 

2216 

2217@tf_export('ensure_shape') 

2218@dispatch.add_dispatch_support 

2219def ensure_shape(x, shape, name=None): 

2220 """Updates the shape of a tensor and checks at runtime that the shape holds. 

2221 

2222 When executed, this operation asserts that the input tensor `x`'s shape 

2223 is compatible with the `shape` argument. 

2224 See `tf.TensorShape.is_compatible_with` for details. 

2225 

2226 >>> x = tf.constant([[1, 2, 3], 

2227 ... [4, 5, 6]]) 

2228 >>> x = tf.ensure_shape(x, [2, 3]) 

2229 

2230 Use `None` for unknown dimensions: 

2231 

2232 >>> x = tf.ensure_shape(x, [None, 3]) 

2233 >>> x = tf.ensure_shape(x, [2, None]) 

2234 

2235 If the tensor's shape is not compatible with the `shape` argument, an error 

2236 is raised: 

2237 

2238 >>> x = tf.ensure_shape(x, [5]) 

2239 Traceback (most recent call last): 

2240 ... 

2241 tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not 

2242 compatible with expected shape [5]. [Op:EnsureShape] 

2243 

2244 During graph construction (typically tracing a `tf.function`), 

2245 `tf.ensure_shape` updates the static-shape of the **result** tensor by 

2246 merging the two shapes. See `tf.TensorShape.merge_with` for details. 

2247 

2248 This is most useful when **you** know a shape that can't be determined 

2249 statically by TensorFlow. 

2250 

2251 The following trivial `tf.function` prints the input tensor's 

2252 static-shape before and after `ensure_shape` is applied. 

2253 

2254 >>> @tf.function 

2255 ... def f(tensor): 

2256 ... print("Static-shape before:", tensor.shape) 

2257 ... tensor = tf.ensure_shape(tensor, [None, 3]) 

2258 ... print("Static-shape after:", tensor.shape) 

2259 ... return tensor 

2260 

2261 This lets you see the effect of `tf.ensure_shape` when the function is traced: 

2262 >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) 

2263 Static-shape before: (None, None) 

2264 Static-shape after: (None, 3) 

2265 

2266 >>> cf(tf.zeros([3, 3])) # Passes 

2267 >>> cf(tf.constant([1, 2, 3])) # fails 

2268 Traceback (most recent call last): 

2269 ... 

2270 InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. 

2271 

2272 The above example raises `tf.errors.InvalidArgumentError`, because `x`'s 

2273 shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` 

2274 

2275 Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and 

2276 runtime shapes. This is stricter than `tf.Tensor.set_shape` which only 

2277 checks the buildtime shape. 

2278 

2279 Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape 

2280 of the resulting tensor and enforces it at runtime, raising an error if the 

2281 tensor's runtime shape is incompatible with the specified shape. 

2282 `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it 

2283 at runtime, which may result in inconsistencies between the statically-known 

2284 shape of tensors and the runtime value of tensors. 

2285 

2286 For example, of loading images of a known size: 

2287 

2288 >>> @tf.function 

2289 ... def decode_image(png): 

2290 ... image = tf.image.decode_png(png, channels=3) 

2291 ... # the `print` executes during tracing. 

2292 ... print("Initial shape: ", image.shape) 

2293 ... image = tf.ensure_shape(image,[28, 28, 3]) 

2294 ... print("Final shape: ", image.shape) 

2295 ... return image 

2296 

2297 When tracing a function, no ops are being executed, shapes may be unknown. 

2298 See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) 

2299 for details. 

2300 

2301 >>> concrete_decode = decode_image.get_concrete_function( 

2302 ... tf.TensorSpec([], dtype=tf.string)) 

2303 Initial shape: (None, None, 3) 

2304 Final shape: (28, 28, 3) 

2305 

2306 >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) 

2307 >>> image = tf.cast(image,tf.uint8) 

2308 >>> png = tf.image.encode_png(image) 

2309 >>> image2 = concrete_decode(png) 

2310 >>> print(image2.shape) 

2311 (28, 28, 3) 

2312 

2313 >>> image = tf.concat([image,image], axis=0) 

2314 >>> print(image.shape) 

2315 (56, 28, 3) 

2316 >>> png = tf.image.encode_png(image) 

2317 >>> image2 = concrete_decode(png) 

2318 Traceback (most recent call last): 

2319 ... 

2320 tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not 

2321 compatible with expected shape [28,28,3]. 

2322 

2323 Caution: if you don't use the result of `tf.ensure_shape` the check may not 

2324 run. 

2325 

2326 >>> @tf.function 

2327 ... def bad_decode_image(png): 

2328 ... image = tf.image.decode_png(png, channels=3) 

2329 ... # the `print` executes during tracing. 

2330 ... print("Initial shape: ", image.shape) 

2331 ... # BAD: forgot to use the returned tensor. 

2332 ... tf.ensure_shape(image,[28, 28, 3]) 

2333 ... print("Final shape: ", image.shape) 

2334 ... return image 

2335 

2336 >>> image = bad_decode_image(png) 

2337 Initial shape: (None, None, 3) 

2338 Final shape: (None, None, 3) 

2339 >>> print(image.shape) 

2340 (56, 28, 3) 

2341 

2342 Args: 

2343 x: A `Tensor`. 

2344 shape: A `TensorShape` representing the shape of this tensor, a 

2345 `TensorShapeProto`, a list, a tuple, or None. 

2346 name: A name for this operation (optional). Defaults to "EnsureShape". 

2347 

2348 Returns: 

2349 A `Tensor`. Has the same type and contents as `x`. 

2350 

2351 Raises: 

2352 tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape 

2353 of `x`. 

2354 """ 

2355 if not isinstance(shape, tensor_shape.TensorShape): 

2356 shape = tensor_shape.TensorShape(shape) 

2357 

2358 return array_ops.ensure_shape(x, shape, name=name) 

2359 

2360 

2361@ops.RegisterGradient('EnsureShape') 

2362def _ensure_shape_grad(op, grad): 

2363 del op # Unused. 

2364 return grad