Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/check_ops.py: 35%
666 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-short-docstring-punctuation
16"""Asserts and Boolean Checks."""
18import collections
20import numpy as np
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import cond
31from tensorflow.python.ops import control_flow_assert
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.util import compat
35from tensorflow.python.util import deprecation
36from tensorflow.python.util import dispatch
37from tensorflow.python.util.tf_export import tf_export
39NUMERIC_TYPES = frozenset([
40 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16,
41 dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32,
42 dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8,
43 dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16
44])
46__all__ = [
47 'assert_negative',
48 'assert_positive',
49 'assert_proper_iterable',
50 'assert_non_negative',
51 'assert_non_positive',
52 'assert_equal',
53 'assert_none_equal',
54 'assert_near',
55 'assert_integer',
56 'assert_less',
57 'assert_less_equal',
58 'assert_greater',
59 'assert_greater_equal',
60 'assert_rank',
61 'assert_rank_at_least',
62 'assert_rank_in',
63 'assert_same_float_dtype',
64 'assert_scalar',
65 'assert_type',
66 'assert_shapes',
67 'is_non_decreasing',
68 'is_numeric_tensor',
69 'is_strictly_increasing',
70]
73def _maybe_constant_value_string(t):
74 if not isinstance(t, ops.Tensor):
75 return str(t)
76 const_t = tensor_util.constant_value(t)
77 if const_t is not None:
78 return str(const_t)
79 return t
82def _assert_static(condition, data):
83 """Raises a InvalidArgumentError with as much information as possible."""
84 if not condition:
85 data_static = [_maybe_constant_value_string(x) for x in data]
86 raise errors.InvalidArgumentError(node_def=None, op=None,
87 message='\n'.join(data_static))
90def _shape_and_dtype_str(tensor):
91 """Returns a string containing tensor's shape and dtype."""
92 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
95def _unary_assert_doc(sym, sym_name):
96 """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
98 Args:
99 sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
100 sym_name: English-language name for the op described by sym
102 Returns:
103 Decorator that adds the appropriate docstring to the function for symbol
104 `sym`.
105 """
107 def _decorator(func):
108 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
110 Args:
111 func: Function for a TensorFlow op
113 Returns:
114 Version of `func` with documentation attached.
115 """
116 opname = func.__name__
117 cap_sym_name = sym_name.capitalize()
119 func.__doc__ = """
120 Assert the condition `x {sym}` holds element-wise.
122 When running in graph mode, you should add a dependency on this operation
123 to ensure that it runs. Example of adding a dependency to an operation:
125 ```python
126 with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
127 output = tf.reduce_sum(x)
128 ```
130 {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
131 If `x` is empty this is trivially satisfied.
133 Args:
134 x: Numeric `Tensor`.
135 data: The tensors to print out if the condition is False. Defaults to
136 error message and first few entries of `x`.
137 summarize: Print this many entries of each tensor.
138 message: A string to prefix to the default message.
139 name: A name for this operation (optional). Defaults to "{opname}".
141 Returns:
142 Op that raises `InvalidArgumentError` if `x {sym}` is False.
143 @compatibility(eager)
144 returns None
145 @end_compatibility
147 Raises:
148 InvalidArgumentError: if the check can be performed immediately and
149 `x {sym}` is False. The check can be performed immediately during
150 eager execution or if `x` is statically known.
151 """.format(
152 sym=sym, sym_name=cap_sym_name, opname=opname)
153 return func
155 return _decorator
158def _binary_assert_doc(sym, test_var):
159 """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
161 Args:
162 sym: Binary operation symbol, i.e. "=="
163 test_var: a string that represents the variable in the right-hand side of
164 binary operator of the test case
166 Returns:
167 Decorator that adds the appropriate docstring to the function for
168 symbol `sym`.
169 """
171 def _decorator(func):
172 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
174 Args:
175 func: Function for a TensorFlow op
177 Returns:
178 A version of `func` with documentation attached.
179 """
180 opname = func.__name__
182 func.__doc__ = """
183 Assert the condition `x {sym} y` holds element-wise.
185 This condition holds if for every pair of (possibly broadcast) elements
186 `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
187 If both `x` and `y` are empty, this is trivially satisfied.
189 When running in graph mode, you should add a dependency on this operation
190 to ensure that it runs. Example of adding a dependency to an operation:
192 ```python
193 with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
194 output = tf.reduce_sum(x)
195 ```
197 Args:
198 x: Numeric `Tensor`.
199 y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
200 data: The tensors to print out if the condition is False. Defaults to
201 error message and first few entries of `x`, `y`.
202 summarize: Print this many entries of each tensor.
203 message: A string to prefix to the default message.
204 name: A name for this operation (optional). Defaults to "{opname}".
206 Returns:
207 Op that raises `InvalidArgumentError` if `x {sym} y` is False.
209 Raises:
210 InvalidArgumentError: if the check can be performed immediately and
211 `x {sym} y` is False. The check can be performed immediately during
212 eager execution or if `x` and `y` are statically known.
214 @compatibility(TF2)
215 `tf.compat.v1.{opname}` is compatible with eager execution and
216 `tf.function`.
217 Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart
218 from `data`, all arguments are supported with the same argument name.
220 If you want to ensure the assert statements run before the
221 potentially-invalid computation, please use `tf.control_dependencies`,
222 as tf.function auto-control dependencies are insufficient for assert
223 statements.
225 #### Structural Mapping to Native TF2
227 Before:
229 ```python
230 tf.compat.v1.{opname}(
231 x=x, y=y, data=data, summarize=summarize,
232 message=message, name=name)
233 ```
235 After:
237 ```python
238 tf.debugging.{opname}(
239 x=x, y=y, message=message,
240 summarize=summarize, name=name)
241 ```
243 #### TF1 & TF2 Usage Example
245 TF1:
247 >>> g = tf.Graph()
248 >>> with g.as_default():
249 ... a = tf.compat.v1.placeholder(tf.float32, [2])
250 ... b = tf.compat.v1.placeholder(tf.float32, [2])
251 ... result = tf.compat.v1.{opname}(a, b,
252 ... message='"a {sym} b" does not hold for the given inputs')
253 ... with tf.compat.v1.control_dependencies([result]):
254 ... sum_node = a + b
255 >>> sess = tf.compat.v1.Session(graph=g)
256 >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}})
259 TF2:
261 >>> a = tf.Variable([1, 2], dtype=tf.float32)
262 >>> b = tf.Variable({test_var}, dtype=tf.float32)
263 >>> assert_op = tf.debugging.{opname}(a, b, message=
264 ... '"a {sym} b" does not hold for the given inputs')
265 >>> # When working with tf.control_dependencies
266 >>> with tf.control_dependencies([assert_op]):
267 ... val = a + b
269 @end_compatibility
270 """.format(
271 sym=sym, opname=opname, test_var=test_var)
272 return func
274 return _decorator
277def _binary_assert_doc_v2(sym, opname, test_var):
278 """Common docstring for v2 assert_* ops that compare two tensors element-wise.
280 Args:
281 sym: Binary operation symbol, i.e. "=="
282 opname: Name for the symbol, i.e. "assert_equal"
283 test_var: A number used in the docstring example
285 Returns:
286 Decorator that adds the appropriate docstring to the function for
287 symbol `sym`.
288 """
290 def _decorator(func):
291 """Decorator that adds docstring to the function for symbol `sym`.
293 Args:
294 func: Function for a TensorFlow op
296 Returns:
297 A version of `func` with documentation attached.
298 """
300 func.__doc__ = """
301 Assert the condition `x {sym} y` holds element-wise.
303 This Op checks that `x[i] {sym} y[i]` holds for every pair of (possibly
304 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
305 trivially satisfied.
307 If `x` {sym} `y` does not hold, `message`, as well as the first `summarize`
308 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
310 When using inside `tf.function`, this API takes effects during execution.
311 It's recommended to use this API with `tf.control_dependencies` to
312 ensure the correct execution order.
314 In the following example, without `tf.control_dependencies`, errors may
315 not be raised at all.
316 Check `tf.control_dependencies` for more details.
318 >>> def check_size(x):
319 ... with tf.control_dependencies([
320 ... tf.debugging.{opname}(tf.size(x), {test_var},
321 ... message='Bad tensor size')]):
322 ... return x
324 >>> check_size(tf.ones([2, 3], tf.float32))
325 Traceback (most recent call last):
326 ...
327 InvalidArgumentError: ...
329 Args:
330 x: Numeric `Tensor`.
331 y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
332 message: A string to prefix to the default message. (optional)
333 summarize: Print this many entries of each tensor. (optional)
334 name: A name for this operation (optional). Defaults to "{opname}".
336 Returns:
337 Op that raises `InvalidArgumentError` if `x {sym} y` is False. This can
338 be used with `tf.control_dependencies` inside of `tf.function`s to
339 block followup computation until the check has executed.
340 @compatibility(eager)
341 returns None
342 @end_compatibility
344 Raises:
345 InvalidArgumentError: if the check can be performed immediately and
346 `x == y` is False. The check can be performed immediately during eager
347 execution or if `x` and `y` are statically known.
348 """.format(
349 sym=sym, opname=opname, test_var=test_var)
350 return func
352 return _decorator
355def _make_assert_msg_data(sym, x, y, summarize, test_op):
356 """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
358 Args:
359 sym: Mathematical symbol for the test to apply to pairs of tensor elements,
360 i.e. "=="
361 x: First input to the assertion after applying `convert_to_tensor()`
362 y: Second input to the assertion
363 summarize: Value of the "summarize" parameter to the original assert_* call;
364 tells how many elements of each tensor to print.
365 test_op: TensorFlow op that returns a Boolean tensor with True in each
366 position where the assertion is satisfied.
368 Returns:
369 List of tensors and scalars that, when stringified and concatenated,
370 will produce the error message string.
371 """
372 # Prepare a message with first elements of x and y.
373 data = []
375 data.append('Condition x %s y did not hold.' % sym)
377 if summarize > 0:
378 if x.shape == y.shape and x.shape.as_list():
379 # If the shapes of x and y are the same (and not scalars),
380 # Get the values that actually differed and their indices.
381 # If shapes are different this information is more confusing
382 # than useful.
383 mask = math_ops.logical_not(test_op)
384 indices = array_ops.where(mask)
385 indices_np = indices.numpy()
386 x_vals = array_ops.boolean_mask(x, mask)
387 y_vals = array_ops.boolean_mask(y, mask)
388 num_vals = min(summarize, indices_np.shape[0])
389 data.append('Indices of first %d different values:' % num_vals)
390 data.append(indices_np[:num_vals])
391 data.append('Corresponding x values:')
392 data.append(x_vals.numpy().reshape((-1,))[:num_vals])
393 data.append('Corresponding y values:')
394 data.append(y_vals.numpy().reshape((-1,))[:num_vals])
396 # reshape((-1,)) is the fastest way to get a flat array view.
397 x_np = x.numpy().reshape((-1,))
398 y_np = y.numpy().reshape((-1,))
399 x_sum = min(x_np.size, summarize)
400 y_sum = min(y_np.size, summarize)
401 data.append('First %d elements of x:' % x_sum)
402 data.append(x_np[:x_sum])
403 data.append('First %d elements of y:' % y_sum)
404 data.append(y_np[:y_sum])
406 return data
409def _pretty_print(data_item, summarize):
410 """Format a data item for use in an error message in eager mode.
412 Args:
413 data_item: One of the items in the "data" argument to an assert_* function.
414 Can be a Tensor or a scalar value.
415 summarize: How many elements to retain of each tensor-valued entry in data.
417 Returns:
418 An appropriate string representation of data_item
419 """
420 if isinstance(data_item, ops.Tensor):
421 arr = data_item.numpy()
422 if np.isscalar(arr):
423 # Tensor.numpy() returns a scalar for zero-dimensional tensors
424 return str(arr)
425 else:
426 flat = arr.reshape((-1,))
427 lst = [str(x) for x in flat[:summarize]]
428 if len(lst) < flat.size:
429 lst.append('...')
430 return str(lst)
431 else:
432 return str(data_item)
435def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
436 message, name):
437 """Generic binary elementwise assertion.
439 Implements the behavior described in _binary_assert_doc() above.
440 Args:
441 sym: Mathematical symbol for the test to apply to pairs of tensor elements,
442 i.e. "=="
443 opname: Name of the assert op in the public API, i.e. "assert_equal"
444 op_func: Function that, if passed the two Tensor inputs to the assertion (x
445 and y), will return the test to be passed to reduce_all() i.e.
446 static_func: Function that, if passed numpy ndarray versions of the two
447 inputs to the assertion, will return a Boolean ndarray with containing
448 True in all positions where the assertion PASSES.
449 i.e. np.equal for assert_equal()
450 x: Numeric `Tensor`.
451 y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
452 data: The tensors to print out if the condition is False. Defaults to
453 error message and first few entries of `x`, `y`.
454 summarize: Print this many entries of each tensor.
455 message: A string to prefix to the default message.
456 name: A name for this operation (optional). Defaults to the value of
457 `opname`.
459 Returns:
460 See docstring template in _binary_assert_doc().
461 """
462 with ops.name_scope(name, opname, [x, y, data]):
463 x = ops.convert_to_tensor(x, name='x')
464 y = ops.convert_to_tensor(y, name='y')
466 if context.executing_eagerly():
467 test_op = op_func(x, y)
468 condition = math_ops.reduce_all(test_op)
469 if condition:
470 return
472 # If we get here, the assertion has failed.
473 # Default to printing 3 elements like control_flow_ops.Assert (used
474 # by graph mode) does. Also treat negative values as "print
475 # everything" for consistency with Tensor::SummarizeValue().
476 if summarize is None:
477 summarize = 3
478 elif summarize < 0:
479 summarize = 1e9 # Code below will find exact size of x and y.
481 if data is None:
482 data = _make_assert_msg_data(sym, x, y, summarize, test_op)
484 if message is not None:
485 data = [message] + list(data)
487 raise errors.InvalidArgumentError(
488 node_def=None,
489 op=None,
490 message=('\n'.join(_pretty_print(d, summarize) for d in data)))
492 else: # not context.executing_eagerly()
493 if data is None:
494 data = [
495 'Condition x %s y did not hold element-wise:' % sym,
496 'x (%s) = ' % x.name, x,
497 'y (%s) = ' % y.name, y
498 ]
499 if message is not None:
500 data = [message] + list(data)
501 condition = math_ops.reduce_all(op_func(x, y))
502 x_static = tensor_util.constant_value(x)
503 y_static = tensor_util.constant_value(y)
504 if x_static is not None and y_static is not None:
505 condition_static = np.all(static_func(x_static, y_static))
506 _assert_static(condition_static, data)
507 return control_flow_assert.Assert(condition, data, summarize=summarize)
510@tf_export(
511 'debugging.assert_proper_iterable',
512 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
513@dispatch.add_dispatch_support
514@deprecation.deprecated_endpoints('assert_proper_iterable')
515def assert_proper_iterable(values):
516 """Static assert that values is a "proper" iterable.
518 `Ops` that expect iterables of `Tensor` can call this to validate input.
519 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
521 Args:
522 values: Object to be checked.
524 Raises:
525 TypeError: If `values` is not iterable or is one of
526 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
527 """
528 unintentional_iterables = (
529 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
530 + compat.bytes_or_text_types
531 )
532 if isinstance(values, unintentional_iterables):
533 raise TypeError(
534 'Expected argument "values" to be a "proper" iterable. Found: %s' %
535 type(values))
537 if not hasattr(values, '__iter__'):
538 raise TypeError(
539 'Expected argument "values" to be iterable. Found: %s' % type(values))
542@tf_export('debugging.assert_negative', v1=[])
543@dispatch.add_dispatch_support
544def assert_negative_v2(x, message=None, summarize=None, name=None):
545 """Assert the condition `x < 0` holds element-wise.
547 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
548 empty, this is trivially satisfied.
550 If `x` is not negative everywhere, `message`, as well as the first `summarize`
551 entries of `x` are printed, and `InvalidArgumentError` is raised.
553 Args:
554 x: Numeric `Tensor`.
555 message: A string to prefix to the default message.
556 summarize: Print this many entries of each tensor.
557 name: A name for this operation (optional). Defaults to "assert_negative".
559 Returns:
560 Op raising `InvalidArgumentError` unless `x` is all negative. This can be
561 used with `tf.control_dependencies` inside of `tf.function`s to block
562 followup computation until the check has executed.
563 @compatibility(eager)
564 returns None
565 @end_compatibility
567 Raises:
568 InvalidArgumentError: if the check can be performed immediately and
569 `x[i] < 0` is False. The check can be performed immediately during eager
570 execution or if `x` is statically known.
571 """
572 return assert_negative(x=x, message=message, summarize=summarize, name=name)
575@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
576@dispatch.add_dispatch_support
577@deprecation.deprecated_endpoints('assert_negative')
578@_unary_assert_doc('< 0', 'negative')
579def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
580 message = _message_prefix(message)
581 with ops.name_scope(name, 'assert_negative', [x, data]):
582 x = ops.convert_to_tensor(x, name='x')
583 if data is None:
584 if context.executing_eagerly():
585 name = _shape_and_dtype_str(x)
586 else:
587 name = x.name
588 data = [
589 message,
590 'Condition x < 0 did not hold element-wise:',
591 'x (%s) = ' % name, x]
592 zero = ops.convert_to_tensor(0, dtype=x.dtype)
593 return assert_less(x, zero, data=data, summarize=summarize)
596@tf_export('debugging.assert_positive', v1=[])
597@dispatch.add_dispatch_support
598def assert_positive_v2(x, message=None, summarize=None, name=None):
599 """Assert the condition `x > 0` holds element-wise.
601 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
602 empty, this is trivially satisfied.
604 If `x` is not positive everywhere, `message`, as well as the first `summarize`
605 entries of `x` are printed, and `InvalidArgumentError` is raised.
607 Args:
608 x: Numeric `Tensor`.
609 message: A string to prefix to the default message.
610 summarize: Print this many entries of each tensor.
611 name: A name for this operation (optional). Defaults to "assert_positive".
613 Returns:
614 Op raising `InvalidArgumentError` unless `x` is all positive. This can be
615 used with `tf.control_dependencies` inside of `tf.function`s to block
616 followup computation until the check has executed.
617 @compatibility(eager)
618 returns None
619 @end_compatibility
621 Raises:
622 InvalidArgumentError: if the check can be performed immediately and
623 `x[i] > 0` is False. The check can be performed immediately during eager
624 execution or if `x` is statically known.
625 """
626 return assert_positive(x=x, summarize=summarize, message=message, name=name)
629@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
630@dispatch.add_dispatch_support
631@deprecation.deprecated_endpoints('assert_positive')
632@_unary_assert_doc('> 0', 'positive')
633def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
634 message = _message_prefix(message)
635 with ops.name_scope(name, 'assert_positive', [x, data]):
636 x = ops.convert_to_tensor(x, name='x')
637 if data is None:
638 if context.executing_eagerly():
639 name = _shape_and_dtype_str(x)
640 else:
641 name = x.name
642 data = [
643 message, 'Condition x > 0 did not hold element-wise:',
644 'x (%s) = ' % name, x]
645 zero = ops.convert_to_tensor(0, dtype=x.dtype)
646 return assert_less(zero, x, data=data, summarize=summarize)
649@tf_export('debugging.assert_non_negative', v1=[])
650@dispatch.add_dispatch_support
651def assert_non_negative_v2(x, message=None, summarize=None, name=None):
652 """Assert the condition `x >= 0` holds element-wise.
654 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
655 empty, this is trivially satisfied.
657 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
658 entries of `x` are printed, and `InvalidArgumentError` is raised.
660 Args:
661 x: Numeric `Tensor`.
662 message: A string to prefix to the default message.
663 summarize: Print this many entries of each tensor.
664 name: A name for this operation (optional). Defaults to
665 "assert_non_negative".
667 Returns:
668 Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
669 be used with `tf.control_dependencies` inside of `tf.function`s to block
670 followup computation until the check has executed.
671 @compatibility(eager)
672 returns None
673 @end_compatibility
675 Raises:
676 InvalidArgumentError: if the check can be performed immediately and
677 `x[i] >= 0` is False. The check can be performed immediately during eager
678 execution or if `x` is statically known.
679 """
680 return assert_non_negative(x=x, summarize=summarize, message=message,
681 name=name)
684@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
685@dispatch.add_dispatch_support
686@deprecation.deprecated_endpoints('assert_non_negative')
687@_unary_assert_doc('>= 0', 'non-negative')
688def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
689 message = _message_prefix(message)
690 with ops.name_scope(name, 'assert_non_negative', [x, data]):
691 x = ops.convert_to_tensor(x, name='x')
692 if data is None:
693 if context.executing_eagerly():
694 name = _shape_and_dtype_str(x)
695 else:
696 name = x.name
697 data = [
698 message,
699 'Condition x >= 0 did not hold element-wise:',
700 'x (%s) = ' % name, x]
701 zero = ops.convert_to_tensor(0, dtype=x.dtype)
702 return assert_less_equal(zero, x, data=data, summarize=summarize)
705@tf_export('debugging.assert_non_positive', v1=[])
706@dispatch.add_dispatch_support
707def assert_non_positive_v2(x, message=None, summarize=None, name=None):
708 """Assert the condition `x <= 0` holds element-wise.
710 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
711 empty, this is trivially satisfied.
713 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
714 entries of `x` are printed, and `InvalidArgumentError` is raised.
716 Args:
717 x: Numeric `Tensor`.
718 message: A string to prefix to the default message.
719 summarize: Print this many entries of each tensor.
720 name: A name for this operation (optional). Defaults to
721 "assert_non_positive".
723 Returns:
724 Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
725 be used with `tf.control_dependencies` inside of `tf.function`s to block
726 followup computation until the check has executed.
727 @compatibility(eager)
728 returns None
729 @end_compatibility
731 Raises:
732 InvalidArgumentError: if the check can be performed immediately and
733 `x[i] <= 0` is False. The check can be performed immediately during eager
734 execution or if `x` is statically known.
735 """
736 return assert_non_positive(x=x, summarize=summarize, message=message,
737 name=name)
740@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
741@dispatch.add_dispatch_support
742@deprecation.deprecated_endpoints('assert_non_positive')
743@_unary_assert_doc('<= 0', 'non-positive')
744def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
745 message = _message_prefix(message)
746 with ops.name_scope(name, 'assert_non_positive', [x, data]):
747 x = ops.convert_to_tensor(x, name='x')
748 if data is None:
749 if context.executing_eagerly():
750 name = _shape_and_dtype_str(x)
751 else:
752 name = x.name
753 data = [
754 message,
755 'Condition x <= 0 did not hold element-wise:'
756 'x (%s) = ' % name, x]
757 zero = ops.convert_to_tensor(0, dtype=x.dtype)
758 return assert_less_equal(x, zero, data=data, summarize=summarize)
761@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
762@dispatch.register_binary_elementwise_assert_api
763@dispatch.add_dispatch_support
764@_binary_assert_doc_v2('==', 'assert_equal', 3)
765def assert_equal_v2(x, y, message=None, summarize=None, name=None):
766 return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
769@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
770@dispatch.register_binary_elementwise_assert_api
771@dispatch.add_dispatch_support
772@_binary_assert_doc('==', '[1, 2]')
773def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
774 with ops.name_scope(name, 'assert_equal', [x, y, data]):
775 # Short-circuit if x and y are the same tensor.
776 if x is y:
777 return None if context.executing_eagerly() else control_flow_ops.no_op()
778 return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
779 data, summarize, message, name)
782@tf_export('debugging.assert_none_equal', v1=[])
783@dispatch.register_binary_elementwise_assert_api
784@dispatch.add_dispatch_support
785@_binary_assert_doc_v2('!=', 'assert_none_equal', 6)
786def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
787 return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
788 name=name)
791@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
792@dispatch.register_binary_elementwise_assert_api
793@dispatch.add_dispatch_support
794@deprecation.deprecated_endpoints('assert_none_equal')
795@_binary_assert_doc('!=', '[2, 1]')
796def assert_none_equal(
797 x, y, data=None, summarize=None, message=None, name=None):
798 return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
799 np.not_equal, x, y, data, summarize, message, name)
802@tf_export('debugging.assert_near', v1=[])
803@dispatch.register_binary_elementwise_assert_api
804@dispatch.add_dispatch_support
805def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
806 name=None):
807 """Assert the condition `x` and `y` are close element-wise.
809 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
810 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
811 empty, this is trivially satisfied.
813 If any elements of `x` and `y` are not close, `message`, as well as the first
814 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
815 is raised.
817 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
818 representable positive number such that `1 + eps != 1`. This is about
819 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
820 See `numpy.finfo`.
822 Args:
823 x: Float or complex `Tensor`.
824 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
825 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
826 The relative tolerance. Default is `10 * eps`.
827 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
828 The absolute tolerance. Default is `10 * eps`.
829 message: A string to prefix to the default message.
830 summarize: Print this many entries of each tensor.
831 name: A name for this operation (optional). Defaults to "assert_near".
833 Returns:
834 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
835 This can be used with `tf.control_dependencies` inside of `tf.function`s
836 to block followup computation until the check has executed.
837 @compatibility(eager)
838 returns None
839 @end_compatibility
841 Raises:
842 InvalidArgumentError: if the check can be performed immediately and
843 `x != y` is False for any pair of elements in `x` and `y`. The check can
844 be performed immediately during eager execution or if `x` and `y` are
845 statically known.
847 @compatibility(numpy)
848 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
849 type. This is due to the fact that `TensorFlow` is often used with `32bit`,
850 `64bit`, and even `16bit` data.
851 @end_compatibility
852 """
853 return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
854 message=message, name=name)
857@tf_export(v1=['debugging.assert_near', 'assert_near'])
858@dispatch.register_binary_elementwise_assert_api
859@dispatch.add_dispatch_support
860@deprecation.deprecated_endpoints('assert_near')
861def assert_near(
862 x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
863 name=None):
864 """Assert the condition `x` and `y` are close element-wise.
866 Example of adding a dependency to an operation:
868 ```python
869 with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
870 output = tf.reduce_sum(x)
871 ```
873 This condition holds if for every pair of (possibly broadcast) elements
874 `x[i]`, `y[i]`, we have
876 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
878 If both `x` and `y` are empty, this is trivially satisfied.
880 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
881 representable positive number such that `1 + eps != 1`. This is about
882 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
883 See `numpy.finfo`.
885 Args:
886 x: Float or complex `Tensor`.
887 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
888 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
889 The relative tolerance. Default is `10 * eps`.
890 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`.
891 The absolute tolerance. Default is `10 * eps`.
892 data: The tensors to print out if the condition is False. Defaults to
893 error message and first few entries of `x`, `y`.
894 summarize: Print this many entries of each tensor.
895 message: A string to prefix to the default message.
896 name: A name for this operation (optional). Defaults to "assert_near".
898 Returns:
899 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
901 @compatibility(numpy)
902 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
903 type. This is due to the fact that `TensorFlow` is often used with `32bit`,
904 `64bit`, and even `16bit` data.
905 @end_compatibility
906 """
907 message = _message_prefix(message)
908 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
909 x = ops.convert_to_tensor(x, name='x')
910 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
912 dtype = x.dtype
913 if dtype.is_complex:
914 dtype = dtype.real_dtype
915 eps = np.finfo(dtype.as_numpy_dtype).eps
916 rtol = 10 * eps if rtol is None else rtol
917 atol = 10 * eps if atol is None else atol
919 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
920 atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
922 if context.executing_eagerly():
923 x_name = _shape_and_dtype_str(x)
924 y_name = _shape_and_dtype_str(y)
925 else:
926 x_name = x.name
927 y_name = y.name
929 if data is None:
930 data = [
931 message,
932 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
933 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
934 ]
935 tol = atol + rtol * math_ops.abs(y)
936 diff = math_ops.abs(x - y)
937 condition = math_ops.reduce_all(math_ops.less(diff, tol))
938 return control_flow_assert.Assert(condition, data, summarize=summarize)
941@tf_export('debugging.assert_less', 'assert_less', v1=[])
942@dispatch.register_binary_elementwise_assert_api
943@dispatch.add_dispatch_support
944@_binary_assert_doc_v2('<', 'assert_less', 3)
945def assert_less_v2(x, y, message=None, summarize=None, name=None):
946 return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
949@tf_export(v1=['debugging.assert_less', 'assert_less'])
950@dispatch.register_binary_elementwise_assert_api
951@dispatch.add_dispatch_support
952@_binary_assert_doc('<', '[2, 3]')
953def assert_less(x, y, data=None, summarize=None, message=None, name=None):
954 return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
955 summarize, message, name)
958@tf_export('debugging.assert_less_equal', v1=[])
959@dispatch.register_binary_elementwise_assert_api
960@dispatch.add_dispatch_support
961@_binary_assert_doc_v2('<=', 'assert_less_equal', 3)
962def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
963 return assert_less_equal(x=x, y=y,
964 summarize=summarize, message=message, name=name)
967@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
968@dispatch.register_binary_elementwise_assert_api
969@dispatch.add_dispatch_support
970@deprecation.deprecated_endpoints('assert_less_equal')
971@_binary_assert_doc('<=', '[1, 3]')
972def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
973 return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
974 np.less_equal, x, y, data, summarize, message, name)
977@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
978@dispatch.register_binary_elementwise_assert_api
979@dispatch.add_dispatch_support
980@_binary_assert_doc_v2('>', 'assert_greater', 9)
981def assert_greater_v2(x, y, message=None, summarize=None, name=None):
982 return assert_greater(x=x, y=y, summarize=summarize, message=message,
983 name=name)
986@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
987@dispatch.register_binary_elementwise_assert_api
988@dispatch.add_dispatch_support
989@_binary_assert_doc('>', '[0, 1]')
990def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
991 return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
992 y, data, summarize, message, name)
995@tf_export('debugging.assert_greater_equal', v1=[])
996@dispatch.register_binary_elementwise_assert_api
997@dispatch.add_dispatch_support
998@_binary_assert_doc_v2('>=', 'assert_greater_equal', 9)
999def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
1000 return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
1001 name=name)
1004@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
1005@dispatch.register_binary_elementwise_assert_api
1006@dispatch.add_dispatch_support
1007@deprecation.deprecated_endpoints('assert_greater_equal')
1008@_binary_assert_doc('>=', '[1, 0]')
1009def assert_greater_equal(x, y, data=None, summarize=None, message=None,
1010 name=None):
1011 return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
1012 np.greater_equal, x, y, data, summarize, message, name)
1015def _assert_rank_condition(
1016 x, rank, static_condition, dynamic_condition, data, summarize):
1017 """Assert `x` has a rank that satisfies a given condition.
1019 Args:
1020 x: Numeric `Tensor`.
1021 rank: Scalar `Tensor`.
1022 static_condition: A python function that takes `[actual_rank, given_rank]`
1023 and returns `True` if the condition is satisfied, `False` otherwise.
1024 dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
1025 `True` if the condition is satisfied, `False` otherwise.
1026 data: The tensors to print out if the condition is false. Defaults to
1027 error message and first few entries of `x`.
1028 summarize: Print this many entries of each tensor.
1030 Returns:
1031 Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1033 Raises:
1034 ValueError: If static checks determine `x` fails static_condition.
1035 """
1036 assert_type(rank, dtypes.int32)
1038 # Attempt to statically defined rank.
1039 rank_static = tensor_util.constant_value(rank)
1040 if rank_static is not None:
1041 if rank_static.ndim != 0:
1042 raise ValueError('Rank must be a scalar.')
1044 x_rank_static = x.get_shape().ndims
1045 if x_rank_static is not None:
1046 if not static_condition(x_rank_static, rank_static):
1047 raise ValueError(
1048 'Static rank condition failed', x_rank_static, rank_static)
1049 return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1051 condition = dynamic_condition(array_ops.rank(x), rank)
1053 # Add the condition that `rank` must have rank zero. Prevents the bug where
1054 # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1055 if rank_static is None:
1056 this_data = ['Rank must be a scalar. Received rank: ', rank]
1057 rank_check = assert_rank(rank, 0, data=this_data)
1058 condition = control_flow_ops.with_dependencies([rank_check], condition)
1060 return control_flow_assert.Assert(condition, data, summarize=summarize)
1063@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
1064@dispatch.add_dispatch_support
1065def assert_rank_v2(x, rank, message=None, name=None):
1066 """Assert that `x` has rank equal to `rank`.
1068 This Op checks that the rank of `x` is equal to `rank`.
1070 If `x` has a different rank, `message`, as well as the shape of `x` are
1071 printed, and `InvalidArgumentError` is raised.
1073 Args:
1074 x: `Tensor`.
1075 rank: Scalar integer `Tensor`.
1076 message: A string to prefix to the default message.
1077 name: A name for this operation (optional). Defaults to
1078 "assert_rank".
1080 Returns:
1081 Op raising `InvalidArgumentError` unless `x` has specified rank.
1082 If static checks determine `x` has correct rank, a `no_op` is returned.
1083 This can be used with `tf.control_dependencies` inside of `tf.function`s
1084 to block followup computation until the check has executed.
1085 @compatibility(eager)
1086 returns None
1087 @end_compatibility
1089 Raises:
1090 InvalidArgumentError: if the check can be performed immediately and
1091 `x` does not have rank `rank`. The check can be performed immediately
1092 during eager execution or if the shape of `x` is statically known.
1093 """
1094 return assert_rank(x=x, rank=rank, message=message, name=name)
1097@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
1098@dispatch.add_dispatch_support
1099def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
1100 """Assert `x` has rank equal to `rank`.
1102 Example of adding a dependency to an operation:
1104 ```python
1105 with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
1106 output = tf.reduce_sum(x)
1107 ```
1109 Args:
1110 x: Numeric `Tensor`.
1111 rank: Scalar integer `Tensor`.
1112 data: The tensors to print out if the condition is False. Defaults to
1113 error message and the shape of `x`.
1114 summarize: Print this many entries of each tensor.
1115 message: A string to prefix to the default message.
1116 name: A name for this operation (optional). Defaults to "assert_rank".
1118 Returns:
1119 Op raising `InvalidArgumentError` unless `x` has specified rank.
1120 If static checks determine `x` has correct rank, a `no_op` is returned.
1122 Raises:
1123 ValueError: If static checks determine `x` has wrong rank.
1124 """
1125 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
1126 if not isinstance(x, sparse_tensor.SparseTensor):
1127 x = ops.convert_to_tensor(x, name='x')
1128 rank = ops.convert_to_tensor(rank, name='rank')
1129 message = _message_prefix(message)
1131 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
1132 dynamic_condition = math_ops.equal
1134 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1135 name = ''
1136 else:
1137 name = x.name
1139 if data is None:
1140 data = [
1141 message,
1142 'Tensor %s must have rank' % name, rank, 'Received shape: ',
1143 array_ops.shape(x)
1144 ]
1146 try:
1147 assert_op = _assert_rank_condition(x, rank, static_condition,
1148 dynamic_condition, data, summarize)
1150 except ValueError as e:
1151 if e.args[0] == 'Static rank condition failed':
1152 raise ValueError(
1153 '%sTensor %s must have rank %d. Received rank %d, shape %s' %
1154 (message, name, e.args[2], e.args[1], x.get_shape()))
1155 else:
1156 raise ValueError(e.args[0])
1158 return assert_op
1161@tf_export('debugging.assert_rank_at_least', v1=[])
1162@dispatch.add_dispatch_support
1163def assert_rank_at_least_v2(x, rank, message=None, name=None):
1164 """Assert that `x` has rank of at least `rank`.
1166 This Op checks that the rank of `x` is greater or equal to `rank`.
1168 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
1169 are printed, and `InvalidArgumentError` is raised.
1171 Args:
1172 x: `Tensor`.
1173 rank: Scalar integer `Tensor`.
1174 message: A string to prefix to the default message.
1175 name: A name for this operation (optional). Defaults to
1176 "assert_rank_at_least".
1178 Returns:
1179 Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1180 If static checks determine `x` has correct rank, a `no_op` is returned.
1181 This can be used with `tf.control_dependencies` inside of `tf.function`s
1182 to block followup computation until the check has executed.
1183 @compatibility(eager)
1184 returns None
1185 @end_compatibility
1187 Raises:
1188 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
1189 cannot be statically determined.
1190 ValueError: If static checks determine `x` has mismatched rank.
1191 """
1192 return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
1195@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
1196@dispatch.add_dispatch_support
1197@deprecation.deprecated_endpoints('assert_rank_at_least')
1198def assert_rank_at_least(
1199 x, rank, data=None, summarize=None, message=None, name=None):
1200 """Assert `x` has rank equal to `rank` or higher.
1202 Example of adding a dependency to an operation:
1204 ```python
1205 with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
1206 output = tf.reduce_sum(x)
1207 ```
1209 Args:
1210 x: Numeric `Tensor`.
1211 rank: Scalar `Tensor`.
1212 data: The tensors to print out if the condition is False. Defaults to
1213 error message and first few entries of `x`.
1214 summarize: Print this many entries of each tensor.
1215 message: A string to prefix to the default message.
1216 name: A name for this operation (optional).
1217 Defaults to "assert_rank_at_least".
1219 Returns:
1220 Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1221 If static checks determine `x` has correct rank, a `no_op` is returned.
1223 Raises:
1224 ValueError: If static checks determine `x` has wrong rank.
1225 """
1226 with ops.name_scope(
1227 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
1228 x = ops.convert_to_tensor(x, name='x')
1229 rank = ops.convert_to_tensor(rank, name='rank')
1230 message = _message_prefix(message)
1232 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
1233 dynamic_condition = math_ops.greater_equal
1235 if context.executing_eagerly():
1236 name = ''
1237 else:
1238 name = x.name
1240 if data is None:
1241 data = [
1242 message,
1243 'Tensor %s must have rank at least' % name, rank,
1244 'Received shape: ', array_ops.shape(x)
1245 ]
1247 try:
1248 assert_op = _assert_rank_condition(x, rank, static_condition,
1249 dynamic_condition, data, summarize)
1251 except ValueError as e:
1252 if e.args[0] == 'Static rank condition failed':
1253 raise ValueError(
1254 '%sTensor %s must have rank at least %d. Received rank %d, '
1255 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1256 else:
1257 raise
1259 return assert_op
1262def _static_rank_in(actual_rank, given_ranks):
1263 return actual_rank in given_ranks
1266def _dynamic_rank_in(actual_rank, given_ranks):
1267 if len(given_ranks) < 1:
1268 return ops.convert_to_tensor(False)
1269 result = math_ops.equal(given_ranks[0], actual_rank)
1270 for given_rank in given_ranks[1:]:
1271 result = math_ops.logical_or(
1272 result, math_ops.equal(given_rank, actual_rank))
1273 return result
1276def _assert_ranks_condition(
1277 x, ranks, static_condition, dynamic_condition, data, summarize):
1278 """Assert `x` has a rank that satisfies a given condition.
1280 Args:
1281 x: Numeric `Tensor`.
1282 ranks: Scalar `Tensor`.
1283 static_condition: A python function that takes
1284 `[actual_rank, given_ranks]` and returns `True` if the condition is
1285 satisfied, `False` otherwise.
1286 dynamic_condition: An `op` that takes [actual_rank, given_ranks]
1287 and return `True` if the condition is satisfied, `False` otherwise.
1288 data: The tensors to print out if the condition is false. Defaults to
1289 error message and first few entries of `x`.
1290 summarize: Print this many entries of each tensor.
1292 Returns:
1293 Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1295 Raises:
1296 ValueError: If static checks determine `x` fails static_condition.
1297 """
1298 for rank in ranks:
1299 assert_type(rank, dtypes.int32)
1301 # Attempt to statically defined rank.
1302 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
1303 if not any(r is None for r in ranks_static):
1304 for rank_static in ranks_static:
1305 if rank_static.ndim != 0:
1306 raise ValueError('Rank must be a scalar.')
1308 x_rank_static = x.get_shape().ndims
1309 if x_rank_static is not None:
1310 if not static_condition(x_rank_static, ranks_static):
1311 raise ValueError(
1312 'Static rank condition failed', x_rank_static, ranks_static)
1313 return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1315 condition = dynamic_condition(array_ops.rank(x), ranks)
1317 # Add the condition that `rank` must have rank zero. Prevents the bug where
1318 # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1319 for rank, rank_static in zip(ranks, ranks_static):
1320 if rank_static is None:
1321 this_data = ['Rank must be a scalar. Received rank: ', rank]
1322 rank_check = assert_rank(rank, 0, data=this_data)
1323 condition = control_flow_ops.with_dependencies([rank_check], condition)
1325 return control_flow_assert.Assert(condition, data, summarize=summarize)
1328@tf_export('debugging.assert_rank_in', v1=[])
1329@dispatch.add_dispatch_support
1330def assert_rank_in_v2(x, ranks, message=None, name=None):
1331 """Assert that `x` has a rank in `ranks`.
1333 This Op checks that the rank of `x` is in `ranks`.
1335 If `x` has a different rank, `message`, as well as the shape of `x` are
1336 printed, and `InvalidArgumentError` is raised.
1338 Args:
1339 x: `Tensor`.
1340 ranks: `Iterable` of scalar `Tensor` objects.
1341 message: A string to prefix to the default message.
1342 name: A name for this operation (optional). Defaults to "assert_rank_in".
1344 Returns:
1345 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1346 If static checks determine `x` has matching rank, a `no_op` is returned.
1347 This can be used with `tf.control_dependencies` inside of `tf.function`s
1348 to block followup computation until the check has executed.
1349 @compatibility(eager)
1350 returns None
1351 @end_compatibility
1353 Raises:
1354 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
1355 be statically determined.
1356 ValueError: If static checks determine `x` has mismatched rank.
1357 """
1358 return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
1361@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
1362@dispatch.add_dispatch_support
1363@deprecation.deprecated_endpoints('assert_rank_in')
1364def assert_rank_in(
1365 x, ranks, data=None, summarize=None, message=None, name=None):
1366 """Assert `x` has rank in `ranks`.
1368 Example of adding a dependency to an operation:
1370 ```python
1371 with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
1372 output = tf.reduce_sum(x)
1373 ```
1375 Args:
1376 x: Numeric `Tensor`.
1377 ranks: Iterable of scalar `Tensor` objects.
1378 data: The tensors to print out if the condition is False. Defaults to
1379 error message and first few entries of `x`.
1380 summarize: Print this many entries of each tensor.
1381 message: A string to prefix to the default message.
1382 name: A name for this operation (optional).
1383 Defaults to "assert_rank_in".
1385 Returns:
1386 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1387 If static checks determine `x` has matching rank, a `no_op` is returned.
1389 Raises:
1390 ValueError: If static checks determine `x` has mismatched rank.
1391 """
1392 with ops.name_scope(
1393 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
1394 if not isinstance(x, sparse_tensor.SparseTensor):
1395 x = ops.convert_to_tensor(x, name='x')
1396 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
1397 message = _message_prefix(message)
1399 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1400 name = ''
1401 else:
1402 name = x.name
1404 if data is None:
1405 data = [
1406 message, 'Tensor %s must have rank in' % name
1407 ] + list(ranks) + [
1408 'Received shape: ', array_ops.shape(x)
1409 ]
1411 try:
1412 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
1413 _dynamic_rank_in, data, summarize)
1415 except ValueError as e:
1416 if e.args[0] == 'Static rank condition failed':
1417 raise ValueError(
1418 '%sTensor %s must have rank in %s. Received rank %d, '
1419 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1420 else:
1421 raise
1423 return assert_op
1426@tf_export('debugging.assert_integer', v1=[])
1427@dispatch.add_dispatch_support
1428def assert_integer_v2(x, message=None, name=None):
1429 """Assert that `x` is of integer dtype.
1431 If `x` has a non-integer type, `message`, as well as the dtype of `x` are
1432 printed, and `InvalidArgumentError` is raised.
1434 This can always be checked statically, so this method returns nothing.
1436 Args:
1437 x: A `Tensor`.
1438 message: A string to prefix to the default message.
1439 name: A name for this operation (optional). Defaults to "assert_integer".
1441 Raises:
1442 TypeError: If `x.dtype` is not a non-quantized integer type.
1443 """
1444 assert_integer(x=x, message=message, name=name)
1447@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
1448@dispatch.add_dispatch_support
1449@deprecation.deprecated_endpoints('assert_integer')
1450def assert_integer(x, message=None, name=None):
1451 """Assert that `x` is of integer dtype.
1453 Example of adding a dependency to an operation:
1455 ```python
1456 with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
1457 output = tf.reduce_sum(x)
1458 ```
1460 Args:
1461 x: `Tensor` whose basetype is integer and is not quantized.
1462 message: A string to prefix to the default message.
1463 name: A name for this operation (optional). Defaults to "assert_integer".
1465 Raises:
1466 TypeError: If `x.dtype` is anything other than non-quantized integer.
1468 Returns:
1469 A `no_op` that does nothing. Type can be determined statically.
1470 """
1471 with ops.name_scope(name, 'assert_integer', [x]):
1472 x = ops.convert_to_tensor(x, name='x')
1473 if not x.dtype.is_integer:
1474 if context.executing_eagerly():
1475 name = 'tensor'
1476 else:
1477 name = x.name
1478 err_msg = (
1479 '%sExpected "x" to be integer type. Found: %s of dtype %s'
1480 % (_message_prefix(message), name, x.dtype))
1481 raise TypeError(err_msg)
1483 return control_flow_ops.no_op('statically_determined_was_integer')
1486@tf_export('debugging.assert_type', v1=[])
1487@dispatch.add_dispatch_support
1488def assert_type_v2(tensor, tf_type, message=None, name=None):
1489 """Asserts that the given `Tensor` is of the specified type.
1491 This can always be checked statically, so this method returns nothing.
1493 Example:
1495 >>> a = tf.Variable(1.0)
1496 >>> tf.debugging.assert_type(a, tf_type= tf.float32)
1498 >>> b = tf.constant(21)
1499 >>> tf.debugging.assert_type(b, tf_type=tf.bool)
1500 Traceback (most recent call last):
1501 ...
1502 TypeError: ...
1504 >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2],
1505 ... dense_shape=[3, 4])
1506 >>> tf.debugging.assert_type(c, tf_type= tf.int32)
1508 Args:
1509 tensor: A `Tensor`, `SparseTensor` or `tf.Variable` .
1510 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1511 etc).
1512 message: A string to prefix to the default message.
1513 name: A name for this operation. Defaults to "assert_type"
1515 Raises:
1516 TypeError: If the tensor's data type doesn't match `tf_type`.
1517 """
1518 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
1521@tf_export(v1=['debugging.assert_type', 'assert_type'])
1522@dispatch.add_dispatch_support
1523@deprecation.deprecated_endpoints('assert_type')
1524def assert_type(tensor, tf_type, message=None, name=None):
1525 """Statically asserts that the given `Tensor` is of the specified type.
1527 Args:
1528 tensor: A `Tensor` or `SparseTensor`.
1529 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1530 etc).
1531 message: A string to prefix to the default message.
1532 name: A name to give this `Op`. Defaults to "assert_type"
1534 Raises:
1535 TypeError: If the tensors data type doesn't match `tf_type`.
1537 Returns:
1538 A `no_op` that does nothing. Type can be determined statically.
1539 """
1540 tf_type = dtypes.as_dtype(tf_type)
1541 with ops.name_scope(name, 'assert_type', [tensor]):
1542 if not isinstance(tensor, sparse_tensor.SparseTensor):
1543 tensor = ops.convert_to_tensor(tensor, name='tensor')
1544 if tensor.dtype != tf_type:
1545 raise TypeError(
1546 f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}'
1547 f' must be of type {tf_type!r}; got {tensor.dtype!r}')
1549 return control_flow_ops.no_op('statically_determined_correct_type')
1552def _dimension_sizes(x):
1553 """Gets the dimension sizes of a tensor `x`.
1555 If a size can be determined statically it is returned as an integer,
1556 otherwise as a tensor.
1558 If `x` is a scalar it is treated as rank 1 size 1.
1560 Args:
1561 x: A `Tensor`.
1563 Returns:
1564 Dimension sizes.
1565 """
1566 dynamic_shape = array_ops.shape(x)
1567 rank = x.get_shape().rank
1568 rank_is_known = rank is not None
1569 if rank_is_known and rank == 0:
1570 return (1,)
1571 if rank_is_known and rank > 0:
1572 static_shape = x.get_shape().as_list()
1573 sizes = [
1574 int(size) if size is not None else dynamic_shape[i]
1575 for i, size in enumerate(static_shape)
1576 ]
1577 return sizes
1578 has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
1579 return cond.cond(
1580 has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
1583def _symbolic_dimension_sizes(symbolic_shape):
1584 # If len(symbolic_shape) == 0 construct a tuple
1585 if not symbolic_shape:
1586 return tuple([1])
1588 return symbolic_shape
1591def _has_known_value(dimension_size):
1592 not_none = dimension_size is not None
1593 try:
1594 int(dimension_size)
1595 can_be_parsed_as_int = True
1596 except (ValueError, TypeError):
1597 can_be_parsed_as_int = False
1598 return not_none and can_be_parsed_as_int
1601def _is_symbol_for_any_size(symbol):
1602 return symbol in [None, '.']
1605_TensorDimSizes = collections.namedtuple(
1606 '_TensorDimSizes',
1607 ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
1610@tf_export('debugging.assert_shapes', v1=[])
1611@dispatch.add_dispatch_support
1612def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
1613 name=None):
1614 """Assert tensor shapes and dimension size relationships between tensors.
1616 This Op checks that a collection of tensors shape relationships
1617 satisfies given constraints.
1619 Example:
1621 >>> n = 10
1622 >>> q = 3
1623 >>> d = 7
1624 >>> x = tf.zeros([n,q])
1625 >>> y = tf.ones([n,d])
1626 >>> param = tf.Variable([1.0, 2.0, 3.0])
1627 >>> scalar = 1.0
1628 >>> tf.debugging.assert_shapes([
1629 ... (x, ('N', 'Q')),
1630 ... (y, ('N', 'D')),
1631 ... (param, ('Q',)),
1632 ... (scalar, ()),
1633 ... ])
1635 >>> tf.debugging.assert_shapes([
1636 ... (x, ('N', 'D')),
1637 ... (y, ('N', 'D'))
1638 ... ])
1639 Traceback (most recent call last):
1640 ...
1641 ValueError: ...
1643 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1644 all specified constraints, `message`, as well as the first `summarize` entries
1645 of the first encountered violating tensor are printed, and
1646 `InvalidArgumentError` is raised.
1648 Size entries in the specified shapes are checked against other entries by
1649 their __hash__, except:
1650 - a size entry is interpreted as an explicit size if it can be parsed as an
1651 integer primitive.
1652 - a size entry is interpreted as *any* size if it is None or '.'.
1654 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1655 a variable number of outer dimensions of unspecified size, i.e. the constraint
1656 applies to the inner-most dimensions only.
1658 Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1659 prefix) are both treated as having a single dimension of size one.
1661 Args:
1662 shapes: dictionary with (`Tensor` to shape) items, or a list of
1663 (`Tensor`, shape) tuples. A shape must be an iterable.
1664 data: The tensors to print out if the condition is False. Defaults to error
1665 message and first few entries of the violating tensor.
1666 summarize: Print this many entries of the tensor.
1667 message: A string to prefix to the default message.
1668 name: A name for this operation (optional). Defaults to "assert_shapes".
1670 Raises:
1671 ValueError: If static checks determine any shape constraint is violated.
1672 """
1673 assert_shapes(
1674 shapes, data=data, summarize=summarize, message=message, name=name)
1677@tf_export(v1=['debugging.assert_shapes'])
1678@dispatch.add_dispatch_support
1679def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
1680 """Assert tensor shapes and dimension size relationships between tensors.
1682 This Op checks that a collection of tensors shape relationships
1683 satisfies given constraints.
1685 Example:
1687 >>> n = 10
1688 >>> q = 3
1689 >>> d = 7
1690 >>> x = tf.zeros([n,q])
1691 >>> y = tf.ones([n,d])
1692 >>> param = tf.Variable([1.0, 2.0, 3.0])
1693 >>> scalar = 1.0
1694 >>> tf.debugging.assert_shapes([
1695 ... (x, ('N', 'Q')),
1696 ... (y, ('N', 'D')),
1697 ... (param, ('Q',)),
1698 ... (scalar, ()),
1699 ... ])
1701 >>> tf.debugging.assert_shapes([
1702 ... (x, ('N', 'D')),
1703 ... (y, ('N', 'D'))
1704 ... ])
1705 Traceback (most recent call last):
1706 ...
1707 ValueError: ...
1709 Example of adding a dependency to an operation:
1711 ```python
1712 with tf.control_dependencies([tf.assert_shapes(shapes)]):
1713 output = tf.matmul(x, y, transpose_a=True)
1714 ```
1716 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1717 all specified constraints, `message`, as well as the first `summarize` entries
1718 of the first encountered violating tensor are printed, and
1719 `InvalidArgumentError` is raised.
1721 Size entries in the specified shapes are checked against other entries by
1722 their __hash__, except:
1723 - a size entry is interpreted as an explicit size if it can be parsed as an
1724 integer primitive.
1725 - a size entry is interpreted as *any* size if it is None or '.'.
1727 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1728 a variable number of outer dimensions of unspecified size, i.e. the constraint
1729 applies to the inner-most dimensions only.
1731 Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1732 prefix) are both treated as having a single dimension of size one.
1734 Args:
1735 shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
1736 expected shape of `Tensor`. See the example code above. The `shape` must
1737 be an iterable. Each element of the iterable can be either a concrete
1738 integer value or a string that abstractly represents the dimension.
1739 For example,
1740 - `('N', 'Q')` specifies a 2D shape wherein the first and second
1741 dimensions of shape may or may not be equal.
1742 - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
1743 dimensions are equal.
1744 - `(1, 'N')` specifies a 2D shape wherein the first dimension is
1745 exactly 1 and the second dimension can be any value.
1746 Note that the abstract dimension letters take effect across different
1747 tuple elements of the list. For example,
1748 `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
1749 that both `x` and `y` are rank-2 tensors and their first dimensions are
1750 equal (`N`).
1751 `shape` can also be a `tf.TensorShape`.
1752 data: The tensors to print out if the condition is False. Defaults to error
1753 message and first few entries of the violating tensor.
1754 summarize: Print this many entries of the tensor.
1755 message: A string to prefix to the default message.
1756 name: A name for this operation (optional). Defaults to "assert_shapes".
1758 Returns:
1759 Op raising `InvalidArgumentError` unless all shape constraints are
1760 satisfied.
1761 If static checks determine all constraints are satisfied, a `no_op` is
1762 returned.
1764 Raises:
1765 ValueError: If static checks determine any shape constraint is violated.
1766 """
1767 # If the user manages to assemble a dict containing tensors (possible in
1768 # Graph mode only), make sure we still accept that.
1769 if isinstance(shapes, dict):
1770 shapes = shapes.items()
1772 message_prefix = _message_prefix(message)
1773 with ops.name_scope(name, 'assert_shapes', [shapes, data]):
1774 # Shape specified as None implies no constraint
1775 shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
1776 ops.convert_to_tensor(x), s)
1777 for x, s in shapes if s is not None]
1779 executing_eagerly = context.executing_eagerly()
1781 def tensor_name(x):
1782 if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
1783 return _shape_and_dtype_str(x)
1784 return x.name
1786 tensor_dim_sizes = []
1787 for tensor, symbolic_shape in shape_constraints:
1788 is_iterable = (
1789 hasattr(symbolic_shape, '__iter__') or
1790 hasattr(symbolic_shape, '__getitem__') # For Python 2 compat.
1791 )
1792 if not is_iterable:
1793 raise ValueError(
1794 '%s'
1795 'Tensor %s. Specified shape must be an iterable. '
1796 'An iterable has the attribute `__iter__` or `__getitem__`. '
1797 'Received specified shape: %s' %
1798 (message_prefix, tensor_name(tensor), symbolic_shape))
1800 # We convert this into a tuple to handle strings, lists and numpy arrays
1801 symbolic_shape_tuple = tuple(symbolic_shape)
1803 tensors_specified_innermost = False
1804 for i, symbol in enumerate(symbolic_shape_tuple):
1805 if symbol not in [Ellipsis, '*']:
1806 continue
1808 if i != 0:
1809 raise ValueError(
1810 '%s'
1811 'Tensor %s specified shape index %d. '
1812 'Symbol `...` or `*` for a variable number of '
1813 'unspecified dimensions is only allowed as the first entry' %
1814 (message_prefix, tensor_name(tensor), i))
1816 tensors_specified_innermost = True
1818 # Only include the size of the specified dimensions since the 0th symbol
1819 # is either ellipsis or *
1820 tensor_dim_sizes.append(
1821 _TensorDimSizes(
1822 tensor, tensors_specified_innermost, _dimension_sizes(tensor),
1823 _symbolic_dimension_sizes(
1824 symbolic_shape_tuple[1:]
1825 if tensors_specified_innermost else symbolic_shape_tuple)))
1827 rank_assertions = []
1828 for sizes in tensor_dim_sizes:
1829 rank = len(sizes.symbolic_sizes)
1830 rank_zero_or_one = rank in [0, 1]
1831 if sizes.unspecified_dim:
1832 if rank_zero_or_one:
1833 # No assertion of rank needed as `x` only need to have rank at least
1834 # 0. See elif rank_zero_or_one case comment.
1835 continue
1836 assertion = assert_rank_at_least(
1837 x=sizes.x,
1838 rank=rank,
1839 data=data,
1840 summarize=summarize,
1841 message=message,
1842 name=name)
1843 elif rank_zero_or_one:
1844 # Rank 0 is treated as rank 1 size 1, i.e. there is
1845 # no distinction between the two in terms of rank.
1846 # See _dimension_sizes.
1847 assertion = assert_rank_in(
1848 x=sizes.x,
1849 ranks=[0, 1],
1850 data=data,
1851 summarize=summarize,
1852 message=message,
1853 name=name)
1854 else:
1855 assertion = assert_rank(
1856 x=sizes.x,
1857 rank=rank,
1858 data=data,
1859 summarize=summarize,
1860 message=message,
1861 name=name)
1862 rank_assertions.append(assertion)
1864 size_assertions = []
1865 size_specifications = {}
1866 for sizes in tensor_dim_sizes:
1867 for i, size_symbol in enumerate(sizes.symbolic_sizes):
1869 if _is_symbol_for_any_size(size_symbol):
1870 # Size specified as any implies no constraint
1871 continue
1873 if sizes.unspecified_dim:
1874 tensor_dim = i - len(sizes.symbolic_sizes)
1875 else:
1876 tensor_dim = i
1878 if size_symbol in size_specifications or _has_known_value(size_symbol):
1879 if _has_known_value(size_symbol):
1880 specified_size = int(size_symbol)
1881 size_check_message = 'Specified explicitly'
1882 else:
1883 specified_size, specified_by_y, specified_at_dim = (
1884 size_specifications[size_symbol])
1885 size_check_message = (
1886 'Specified by tensor %s dimension %d' %
1887 (tensor_name(specified_by_y), specified_at_dim))
1889 # This is extremely subtle. If actual_sizes is dynamic, we must
1890 # make sure a control dependency is inserted here so that this slice
1891 # can not execute until the rank is asserted to be enough for the
1892 # slice to not fail.
1893 with ops.control_dependencies(rank_assertions):
1894 actual_size = sizes.actual_sizes[tensor_dim]
1895 if _has_known_value(actual_size) and _has_known_value(specified_size):
1896 if int(actual_size) != int(specified_size):
1897 raise ValueError(
1898 '%s%s. Tensor %s dimension %s must have size %d. '
1899 'Received size %d, shape %s' %
1900 (message_prefix, size_check_message, tensor_name(sizes.x),
1901 tensor_dim, specified_size, actual_size,
1902 sizes.x.get_shape()))
1903 # No dynamic assertion needed
1904 continue
1906 condition = math_ops.equal(
1907 ops.convert_to_tensor(actual_size),
1908 ops.convert_to_tensor(specified_size))
1909 data_ = data
1910 if data is None:
1911 data_ = [
1912 message_prefix, size_check_message,
1913 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
1914 'must have size', specified_size, 'Received shape: ',
1915 array_ops.shape(sizes.x)
1916 ]
1917 size_assertions.append(
1918 control_flow_assert.Assert(condition, data_, summarize=summarize))
1919 else:
1920 # Not sure if actual_sizes is a constant, but for safety, guard
1921 # on rank. See explanation above about actual_sizes need for safety.
1922 with ops.control_dependencies(rank_assertions):
1923 size = sizes.actual_sizes[tensor_dim]
1924 size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
1926 # Ensure both assertions actually occur.
1927 with ops.control_dependencies(rank_assertions):
1928 shapes_assertion = control_flow_ops.group(size_assertions)
1930 return shapes_assertion
1933# pylint: disable=line-too-long
1934def _get_diff_for_monotonic_comparison(x):
1935 """Gets the difference x[1:] - x[:-1]."""
1936 x = array_ops.reshape(x, [-1])
1937 if not is_numeric_tensor(x):
1938 raise TypeError('Expected x to be numeric, instead found: %s' % x)
1940 # If x has less than 2 elements, there is nothing to compare. So return [].
1941 is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
1942 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
1944 # With 2 or more elements, return x[1:] - x[:-1]
1945 s_len = array_ops.shape(x) - 1
1946 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
1947 return cond.cond(is_shorter_than_two, short_result, diff)
1950@tf_export(
1951 'debugging.is_numeric_tensor',
1952 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
1953@deprecation.deprecated_endpoints('is_numeric_tensor')
1954def is_numeric_tensor(tensor):
1955 """Returns `True` if the elements of `tensor` are numbers.
1957 Specifically, returns `True` if the dtype of `tensor` is one of the following:
1959 * `tf.float16`
1960 * `tf.float32`
1961 * `tf.float64`
1962 * `tf.int8`
1963 * `tf.int16`
1964 * `tf.int32`
1965 * `tf.int64`
1966 * `tf.uint8`
1967 * `tf.uint16`
1968 * `tf.uint32`
1969 * `tf.uint64`
1970 * `tf.qint8`
1971 * `tf.qint16`
1972 * `tf.qint32`
1973 * `tf.quint8`
1974 * `tf.quint16`
1975 * `tf.complex64`
1976 * `tf.complex128`
1977 * `tf.bfloat16`
1979 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
1980 a `tf.Tensor` object.
1981 """
1982 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
1985@tf_export(
1986 'math.is_non_decreasing',
1987 v1=[
1988 'math.is_non_decreasing', 'debugging.is_non_decreasing',
1989 'is_non_decreasing'
1990 ])
1991@dispatch.add_dispatch_support
1992@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
1993 'is_non_decreasing')
1994def is_non_decreasing(x, name=None):
1995 """Returns `True` if `x` is non-decreasing.
1997 Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
1998 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
1999 If `x` has less than two elements, it is trivially non-decreasing.
2001 See also: `is_strictly_increasing`
2003 >>> x1 = tf.constant([1.0, 1.0, 3.0])
2004 >>> tf.math.is_non_decreasing(x1)
2005 <tf.Tensor: shape=(), dtype=bool, numpy=True>
2006 >>> x2 = tf.constant([3.0, 1.0, 2.0])
2007 >>> tf.math.is_non_decreasing(x2)
2008 <tf.Tensor: shape=(), dtype=bool, numpy=False>
2010 Args:
2011 x: Numeric `Tensor`.
2012 name: A name for this operation (optional). Defaults to "is_non_decreasing"
2014 Returns:
2015 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
2017 Raises:
2018 TypeError: if `x` is not a numeric tensor.
2019 """
2020 with ops.name_scope(name, 'is_non_decreasing', [x]):
2021 diff = _get_diff_for_monotonic_comparison(x)
2022 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
2023 zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2024 return math_ops.reduce_all(math_ops.less_equal(zero, diff))
2027@tf_export(
2028 'math.is_strictly_increasing',
2029 v1=[
2030 'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
2031 'is_strictly_increasing'
2032 ])
2033@dispatch.add_dispatch_support
2034@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
2035 'is_strictly_increasing')
2036def is_strictly_increasing(x, name=None):
2037 """Returns `True` if `x` is strictly increasing.
2039 Elements of `x` are compared in row-major order. The tensor `[x[0],...]`
2040 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
2041 If `x` has less than two elements, it is trivially strictly increasing.
2043 See also: `is_non_decreasing`
2045 >>> x1 = tf.constant([1.0, 2.0, 3.0])
2046 >>> tf.math.is_strictly_increasing(x1)
2047 <tf.Tensor: shape=(), dtype=bool, numpy=True>
2048 >>> x2 = tf.constant([3.0, 1.0, 2.0])
2049 >>> tf.math.is_strictly_increasing(x2)
2050 <tf.Tensor: shape=(), dtype=bool, numpy=False>
2052 Args:
2053 x: Numeric `Tensor`.
2054 name: A name for this operation (optional).
2055 Defaults to "is_strictly_increasing"
2057 Returns:
2058 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
2060 Raises:
2061 TypeError: if `x` is not a numeric tensor.
2062 """
2063 with ops.name_scope(name, 'is_strictly_increasing', [x]):
2064 diff = _get_diff_for_monotonic_comparison(x)
2065 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
2066 zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2067 return math_ops.reduce_all(math_ops.less(zero, diff))
2070def _assert_same_base_type(items, expected_type=None):
2071 r"""Asserts all items are of the same base type.
2073 Args:
2074 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
2075 `Operation`, or `IndexedSlices`). Can include `None` elements, which
2076 will be ignored.
2077 expected_type: Expected type. If not specified, assert all items are
2078 of the same base type.
2080 Returns:
2081 Validated type, or none if neither expected_type nor items provided.
2083 Raises:
2084 ValueError: If any types do not match.
2085 """
2086 original_expected_type = expected_type
2087 mismatch = False
2088 for item in items:
2089 if item is not None:
2090 item_type = item.dtype.base_dtype
2091 if not expected_type:
2092 expected_type = item_type
2093 elif expected_type != item_type:
2094 mismatch = True
2095 break
2096 if mismatch:
2097 # Loop back through and build up an informative error message (this is very
2098 # slow, so we don't do it unless we found an error above).
2099 expected_type = original_expected_type
2100 original_item_str = None
2101 for item in items:
2102 if item is not None:
2103 item_type = item.dtype.base_dtype
2104 if not expected_type:
2105 expected_type = item_type
2106 original_item_str = item.name if hasattr(item, 'name') else str(item)
2107 elif expected_type != item_type:
2108 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
2109 item.name if hasattr(item, 'name') else str(item),
2110 item_type, expected_type,
2111 (' as %s' % original_item_str) if original_item_str else ''))
2112 return expected_type # Should be unreachable
2113 else:
2114 return expected_type
2117@tf_export(
2118 'debugging.assert_same_float_dtype',
2119 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
2120@dispatch.add_dispatch_support
2121@deprecation.deprecated_endpoints('assert_same_float_dtype')
2122def assert_same_float_dtype(tensors=None, dtype=None):
2123 """Validate and return float type based on `tensors` and `dtype`.
2125 For ops such as matrix multiplication, inputs and weights must be of the
2126 same float type. This function validates that all `tensors` are the same type,
2127 validates that type is `dtype` (if supplied), and returns the type. Type must
2128 be a floating point type. If neither `tensors` nor `dtype` is supplied,
2129 the function will return `dtypes.float32`.
2131 Args:
2132 tensors: Tensors of input values. Can include `None` elements, which will be
2133 ignored.
2134 dtype: Expected type.
2136 Returns:
2137 Validated type.
2139 Raises:
2140 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
2141 float, or the common type of the inputs is not a floating point type.
2142 """
2143 if tensors:
2144 dtype = _assert_same_base_type(tensors, dtype)
2145 if not dtype:
2146 dtype = dtypes.float32
2147 elif not dtype.is_floating:
2148 raise ValueError('Expected floating point type, got %s.' % dtype)
2149 return dtype
2152@tf_export('debugging.assert_scalar', v1=[])
2153@dispatch.add_dispatch_support
2154def assert_scalar_v2(tensor, message=None, name=None):
2155 """Asserts that the given `tensor` is a scalar.
2157 This function raises `ValueError` unless it can be certain that the given
2158 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2159 unknown.
2161 This is always checked statically, so this method returns nothing.
2163 Args:
2164 tensor: A `Tensor`.
2165 message: A string to prefix to the default message.
2166 name: A name for this operation. Defaults to "assert_scalar"
2168 Raises:
2169 ValueError: If the tensor is not scalar (rank 0), or if its shape is
2170 unknown.
2171 """
2172 assert_scalar(tensor=tensor, message=message, name=name)
2175@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
2176@dispatch.add_dispatch_support
2177@deprecation.deprecated_endpoints('assert_scalar')
2178def assert_scalar(tensor, name=None, message=None):
2179 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
2181 This function raises `ValueError` unless it can be certain that the given
2182 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2183 unknown.
2185 Args:
2186 tensor: A `Tensor`.
2187 name: A name for this operation. Defaults to "assert_scalar"
2188 message: A string to prefix to the default message.
2190 Returns:
2191 The input tensor (potentially converted to a `Tensor`).
2193 Raises:
2194 ValueError: If the tensor is not scalar (rank 0), or if its shape is
2195 unknown.
2196 """
2197 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
2198 tensor = ops.convert_to_tensor(tensor, name=name_scope)
2199 shape = tensor.get_shape()
2200 message = _message_prefix(message)
2201 if shape.ndims != 0:
2202 if context.executing_eagerly():
2203 raise ValueError('%sExpected scalar shape, saw shape: %s.'
2204 % (message, shape,))
2205 else:
2206 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
2207 % (message, tensor.name, shape))
2208 return tensor
2211def _message_prefix(message):
2212 if message:
2213 return '%s. ' % message
2214 return ''
2217@tf_export('ensure_shape')
2218@dispatch.add_dispatch_support
2219def ensure_shape(x, shape, name=None):
2220 """Updates the shape of a tensor and checks at runtime that the shape holds.
2222 When executed, this operation asserts that the input tensor `x`'s shape
2223 is compatible with the `shape` argument.
2224 See `tf.TensorShape.is_compatible_with` for details.
2226 >>> x = tf.constant([[1, 2, 3],
2227 ... [4, 5, 6]])
2228 >>> x = tf.ensure_shape(x, [2, 3])
2230 Use `None` for unknown dimensions:
2232 >>> x = tf.ensure_shape(x, [None, 3])
2233 >>> x = tf.ensure_shape(x, [2, None])
2235 If the tensor's shape is not compatible with the `shape` argument, an error
2236 is raised:
2238 >>> x = tf.ensure_shape(x, [5])
2239 Traceback (most recent call last):
2240 ...
2241 tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
2242 compatible with expected shape [5]. [Op:EnsureShape]
2244 During graph construction (typically tracing a `tf.function`),
2245 `tf.ensure_shape` updates the static-shape of the **result** tensor by
2246 merging the two shapes. See `tf.TensorShape.merge_with` for details.
2248 This is most useful when **you** know a shape that can't be determined
2249 statically by TensorFlow.
2251 The following trivial `tf.function` prints the input tensor's
2252 static-shape before and after `ensure_shape` is applied.
2254 >>> @tf.function
2255 ... def f(tensor):
2256 ... print("Static-shape before:", tensor.shape)
2257 ... tensor = tf.ensure_shape(tensor, [None, 3])
2258 ... print("Static-shape after:", tensor.shape)
2259 ... return tensor
2261 This lets you see the effect of `tf.ensure_shape` when the function is traced:
2262 >>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
2263 Static-shape before: (None, None)
2264 Static-shape after: (None, 3)
2266 >>> cf(tf.zeros([3, 3])) # Passes
2267 >>> cf(tf.constant([1, 2, 3])) # fails
2268 Traceback (most recent call last):
2269 ...
2270 InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].
2272 The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
2273 shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
2275 Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
2276 runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
2277 checks the buildtime shape.
2279 Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape
2280 of the resulting tensor and enforces it at runtime, raising an error if the
2281 tensor's runtime shape is incompatible with the specified shape.
2282 `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it
2283 at runtime, which may result in inconsistencies between the statically-known
2284 shape of tensors and the runtime value of tensors.
2286 For example, of loading images of a known size:
2288 >>> @tf.function
2289 ... def decode_image(png):
2290 ... image = tf.image.decode_png(png, channels=3)
2291 ... # the `print` executes during tracing.
2292 ... print("Initial shape: ", image.shape)
2293 ... image = tf.ensure_shape(image,[28, 28, 3])
2294 ... print("Final shape: ", image.shape)
2295 ... return image
2297 When tracing a function, no ops are being executed, shapes may be unknown.
2298 See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)
2299 for details.
2301 >>> concrete_decode = decode_image.get_concrete_function(
2302 ... tf.TensorSpec([], dtype=tf.string))
2303 Initial shape: (None, None, 3)
2304 Final shape: (28, 28, 3)
2306 >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
2307 >>> image = tf.cast(image,tf.uint8)
2308 >>> png = tf.image.encode_png(image)
2309 >>> image2 = concrete_decode(png)
2310 >>> print(image2.shape)
2311 (28, 28, 3)
2313 >>> image = tf.concat([image,image], axis=0)
2314 >>> print(image.shape)
2315 (56, 28, 3)
2316 >>> png = tf.image.encode_png(image)
2317 >>> image2 = concrete_decode(png)
2318 Traceback (most recent call last):
2319 ...
2320 tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not
2321 compatible with expected shape [28,28,3].
2323 Caution: if you don't use the result of `tf.ensure_shape` the check may not
2324 run.
2326 >>> @tf.function
2327 ... def bad_decode_image(png):
2328 ... image = tf.image.decode_png(png, channels=3)
2329 ... # the `print` executes during tracing.
2330 ... print("Initial shape: ", image.shape)
2331 ... # BAD: forgot to use the returned tensor.
2332 ... tf.ensure_shape(image,[28, 28, 3])
2333 ... print("Final shape: ", image.shape)
2334 ... return image
2336 >>> image = bad_decode_image(png)
2337 Initial shape: (None, None, 3)
2338 Final shape: (None, None, 3)
2339 >>> print(image.shape)
2340 (56, 28, 3)
2342 Args:
2343 x: A `Tensor`.
2344 shape: A `TensorShape` representing the shape of this tensor, a
2345 `TensorShapeProto`, a list, a tuple, or None.
2346 name: A name for this operation (optional). Defaults to "EnsureShape".
2348 Returns:
2349 A `Tensor`. Has the same type and contents as `x`.
2351 Raises:
2352 tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape
2353 of `x`.
2354 """
2355 if not isinstance(shape, tensor_shape.TensorShape):
2356 shape = tensor_shape.TensorShape(shape)
2358 return array_ops.ensure_shape(x, shape, name=name)
2361@ops.RegisterGradient('EnsureShape')
2362def _ensure_shape_grad(op, grad):
2363 del op # Unused.
2364 return grad