Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/numerics.py: 46%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Connects all half, float and double tensors to CheckNumericsOp.""" 

17 

18from tensorflow.python.eager import context 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.util import deprecation 

24from tensorflow.python.util import dispatch 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"]) 

29@dispatch.add_dispatch_support 

30@deprecation.deprecated_endpoints("verify_tensor_all_finite") 

31def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None): 

32 """Assert that the tensor does not contain any NaN's or Inf's. 

33 

34 Args: 

35 t: Tensor to check. 

36 msg: Message to log on failure. 

37 name: A name for this operation (optional). 

38 x: Alias for t. 

39 message: Alias for msg. 

40 

41 Returns: 

42 Same tensor as `t`. 

43 """ 

44 x = deprecation.deprecated_argument_lookup("x", x, "t", t) 

45 message = deprecation.deprecated_argument_lookup( 

46 "message", message, "msg", msg) 

47 return verify_tensor_all_finite_v2(x, message, name) 

48 

49 

50@tf_export("debugging.assert_all_finite", v1=[]) 

51@dispatch.add_dispatch_support 

52def verify_tensor_all_finite_v2(x, message, name=None): 

53 """Assert that the tensor does not contain any NaN's or Inf's. 

54 

55 >>> @tf.function 

56 ... def f(x): 

57 ... x = tf.debugging.assert_all_finite(x, 'Input x must be all finite') 

58 ... return x + 1 

59 

60 >>> f(tf.constant([np.inf, 1, 2])) 

61 Traceback (most recent call last): 

62 ... 

63 InvalidArgumentError: ... 

64 

65 Args: 

66 x: Tensor to check. 

67 message: Message to log on failure. 

68 name: A name for this operation (optional). 

69 

70 Returns: 

71 Same tensor as `x`. 

72 """ 

73 with ops.name_scope(name, "VerifyFinite", [x]) as name: 

74 x = ops.convert_to_tensor(x, name="x") 

75 with ops.colocate_with(x): 

76 verify_input = array_ops.check_numerics(x, message=message) 

77 out = control_flow_ops.with_dependencies([verify_input], x) 

78 return out 

79 

80 

81@tf_export(v1=["add_check_numerics_ops"]) 

82def add_check_numerics_ops(): 

83 """Connect a `tf.debugging.check_numerics` to every floating point tensor. 

84 

85 `check_numerics` operations themselves are added for each `half`, `float`, 

86 or `double` tensor in the current default graph. For all ops in the graph, the 

87 `check_numerics` op for all of its (`half`, `float`, or `double`) inputs 

88 is guaranteed to run before the `check_numerics` op on any of its outputs. 

89 

90 Note: This API is not compatible with the use of `tf.cond` or 

91 `tf.while_loop`, and will raise a `ValueError` if you attempt to call it 

92 in such a graph. 

93 

94 Returns: 

95 A `group` op depending on all `check_numerics` ops added. 

96 

97 Raises: 

98 ValueError: If the graph contains any numeric operations in a control flow 

99 structure. 

100 RuntimeError: If called with eager execution enabled. 

101 

102 @compatibility(eager) 

103 Not compatible with eager execution. To check for `Inf`s and `NaN`s under 

104 eager execution, call `tf.debugging.enable_check_numerics()` once before 

105 executing the checked operations. 

106 @end_compatibility 

107 """ 

108 if context.executing_eagerly(): 

109 raise RuntimeError( 

110 "add_check_numerics_ops() is not compatible with eager execution. " 

111 "To check for Inf's and NaN's under eager execution, call " 

112 "tf.debugging.enable_check_numerics() once before executing the " 

113 "checked operations.") 

114 

115 check_op = [] 

116 # This code relies on the ordering of ops in get_operations(). 

117 # The producer of a tensor always comes before that tensor's consumer in 

118 # this list. This is true because get_operations() returns ops in the order 

119 # added, and an op can only be added after its inputs are added. 

120 for op in ops.get_default_graph().get_operations(): 

121 for output in op.outputs: 

122 if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 

123 if op._get_control_flow_context() is not None: # pylint: disable=protected-access 

124 raise ValueError("`tf.add_check_numerics_ops() is not compatible " 

125 "with TensorFlow control flow operations such as " 

126 "`tf.cond()` or `tf.while_loop()`.") 

127 

128 message = op.name + ":" + str(output.value_index) 

129 with ops.control_dependencies(check_op): 

130 check_op = [array_ops.check_numerics(output, message=message)] 

131 return control_flow_ops.group(*check_op)