Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numerics.py: 46%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Connects all half, float and double tensors to CheckNumericsOp."""
18from tensorflow.python.eager import context
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.util import deprecation
24from tensorflow.python.util import dispatch
25from tensorflow.python.util.tf_export import tf_export
28@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"])
29@dispatch.add_dispatch_support
30@deprecation.deprecated_endpoints("verify_tensor_all_finite")
31def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
32 """Assert that the tensor does not contain any NaN's or Inf's.
34 Args:
35 t: Tensor to check.
36 msg: Message to log on failure.
37 name: A name for this operation (optional).
38 x: Alias for t.
39 message: Alias for msg.
41 Returns:
42 Same tensor as `t`.
43 """
44 x = deprecation.deprecated_argument_lookup("x", x, "t", t)
45 message = deprecation.deprecated_argument_lookup(
46 "message", message, "msg", msg)
47 return verify_tensor_all_finite_v2(x, message, name)
50@tf_export("debugging.assert_all_finite", v1=[])
51@dispatch.add_dispatch_support
52def verify_tensor_all_finite_v2(x, message, name=None):
53 """Assert that the tensor does not contain any NaN's or Inf's.
55 >>> @tf.function
56 ... def f(x):
57 ... x = tf.debugging.assert_all_finite(x, 'Input x must be all finite')
58 ... return x + 1
60 >>> f(tf.constant([np.inf, 1, 2]))
61 Traceback (most recent call last):
62 ...
63 InvalidArgumentError: ...
65 Args:
66 x: Tensor to check.
67 message: Message to log on failure.
68 name: A name for this operation (optional).
70 Returns:
71 Same tensor as `x`.
72 """
73 with ops.name_scope(name, "VerifyFinite", [x]) as name:
74 x = ops.convert_to_tensor(x, name="x")
75 with ops.colocate_with(x):
76 verify_input = array_ops.check_numerics(x, message=message)
77 out = control_flow_ops.with_dependencies([verify_input], x)
78 return out
81@tf_export(v1=["add_check_numerics_ops"])
82def add_check_numerics_ops():
83 """Connect a `tf.debugging.check_numerics` to every floating point tensor.
85 `check_numerics` operations themselves are added for each `half`, `float`,
86 or `double` tensor in the current default graph. For all ops in the graph, the
87 `check_numerics` op for all of its (`half`, `float`, or `double`) inputs
88 is guaranteed to run before the `check_numerics` op on any of its outputs.
90 Note: This API is not compatible with the use of `tf.cond` or
91 `tf.while_loop`, and will raise a `ValueError` if you attempt to call it
92 in such a graph.
94 Returns:
95 A `group` op depending on all `check_numerics` ops added.
97 Raises:
98 ValueError: If the graph contains any numeric operations in a control flow
99 structure.
100 RuntimeError: If called with eager execution enabled.
102 @compatibility(eager)
103 Not compatible with eager execution. To check for `Inf`s and `NaN`s under
104 eager execution, call `tf.debugging.enable_check_numerics()` once before
105 executing the checked operations.
106 @end_compatibility
107 """
108 if context.executing_eagerly():
109 raise RuntimeError(
110 "add_check_numerics_ops() is not compatible with eager execution. "
111 "To check for Inf's and NaN's under eager execution, call "
112 "tf.debugging.enable_check_numerics() once before executing the "
113 "checked operations.")
115 check_op = []
116 # This code relies on the ordering of ops in get_operations().
117 # The producer of a tensor always comes before that tensor's consumer in
118 # this list. This is true because get_operations() returns ops in the order
119 # added, and an op can only be added after its inputs are added.
120 for op in ops.get_default_graph().get_operations():
121 for output in op.outputs:
122 if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
123 if op._get_control_flow_context() is not None: # pylint: disable=protected-access
124 raise ValueError("`tf.add_check_numerics_ops() is not compatible "
125 "with TensorFlow control flow operations such as "
126 "`tf.cond()` or `tf.while_loop()`.")
128 message = op.name + ":" + str(output.value_index)
129 with ops.control_dependencies(check_op):
130 check_op = [array_ops.check_numerics(output, message=message)]
131 return control_flow_ops.group(*check_op)