Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_assert.py: 38%
48 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Assert functions for Control Flow Operations."""
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import errors
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import cond
23from tensorflow.python.ops import gen_control_flow_ops
24from tensorflow.python.ops import gen_logging_ops
25from tensorflow.python.ops import gen_math_ops
26from tensorflow.python.util import dispatch
27from tensorflow.python.util import tf_should_use
28from tensorflow.python.util.tf_export import tf_export
31def _summarize_eager(tensor, summarize=None):
32 """Returns a summarized string representation of eager `tensor`.
34 Args:
35 tensor: EagerTensor to summarize
36 summarize: Include these many first elements of `array`
37 """
38 # Emulate the behavior of Tensor::SummarizeValue()
39 if summarize is None:
40 summarize = 3
41 elif summarize < 0:
42 summarize = array_ops.size(tensor)
44 # reshape((-1,)) is the fastest way to get a flat array view
45 if tensor._rank(): # pylint: disable=protected-access
46 flat = tensor.numpy().reshape((-1,))
47 lst = [str(x) for x in flat[:summarize]]
48 if len(lst) < flat.size:
49 lst.append("...")
50 else:
51 # tensor.numpy() returns a scalar for zero dimensional arrays
52 if gen_math_ops.not_equal(summarize, 0):
53 lst = [str(tensor.numpy())]
54 else:
55 lst = []
57 return ", ".join(lst)
60# Assert and Print are special symbols in python, so we must
61# use an upper-case version of them.
62@tf_export("debugging.Assert", "Assert")
63@dispatch.add_dispatch_support
64@tf_should_use.should_use_result
65def Assert(condition, data, summarize=None, name=None):
66 """Asserts that the given condition is true.
68 If `condition` evaluates to false, print the list of tensors in `data`.
69 `summarize` determines how many entries of the tensors to print.
71 Args:
72 condition: The condition to evaluate.
73 data: The tensors to print out when condition is false.
74 summarize: Print this many entries of each tensor.
75 name: A name for this operation (optional).
77 Returns:
78 assert_op: An `Operation` that, when executed, raises a
79 `tf.errors.InvalidArgumentError` if `condition` is not true.
80 @compatibility(eager)
81 returns None
82 @end_compatibility
84 Raises:
85 @compatibility(TF1)
86 When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
87 dependency on the output to ensure the assertion executes:
89 ```python
90 # Ensure maximum element of x is smaller or equal to 1
91 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
92 with tf.control_dependencies([assert_op]):
93 ... code using x ...
94 ```
96 @end_compatibility
97 """
98 if context.executing_eagerly():
99 if not condition:
100 xs = ops.convert_n_to_tensor(data)
101 data_str = [_summarize_eager(x, summarize) for x in xs]
102 raise errors.InvalidArgumentError(
103 node_def=None,
104 op=None,
105 message="Expected '%s' to be true. Summarized data: %s" %
106 (condition, "\n".join(data_str)))
107 return
109 with ops.name_scope(name, "Assert", [condition, data]) as name:
110 xs = ops.convert_n_to_tensor(data)
111 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
112 # As a simple heuristic, we assume that string and int32 are
113 # on host to avoid the need to use cond. If it is not case,
114 # we will pay the price copying the tensor to host memory.
115 return gen_logging_ops._assert(condition, data, summarize, name="Assert") # pylint: disable=protected-access
116 else:
117 condition = ops.convert_to_tensor(condition, name="Condition")
119 def true_assert():
120 return gen_logging_ops._assert( # pylint: disable=protected-access
121 condition, data, summarize, name="Assert")
123 guarded_assert = cond.cond(
124 condition,
125 gen_control_flow_ops.no_op,
126 true_assert,
127 name="AssertGuard")
128 if context.executing_eagerly():
129 return
130 return guarded_assert.op