Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_assert.py: 38%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Assert functions for Control Flow Operations.""" 

16 

17from tensorflow.python.eager import context 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import errors 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import cond 

23from tensorflow.python.ops import gen_control_flow_ops 

24from tensorflow.python.ops import gen_logging_ops 

25from tensorflow.python.ops import gen_math_ops 

26from tensorflow.python.util import dispatch 

27from tensorflow.python.util import tf_should_use 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31def _summarize_eager(tensor, summarize=None): 

32 """Returns a summarized string representation of eager `tensor`. 

33 

34 Args: 

35 tensor: EagerTensor to summarize 

36 summarize: Include these many first elements of `array` 

37 """ 

38 # Emulate the behavior of Tensor::SummarizeValue() 

39 if summarize is None: 

40 summarize = 3 

41 elif summarize < 0: 

42 summarize = array_ops.size(tensor) 

43 

44 # reshape((-1,)) is the fastest way to get a flat array view 

45 if tensor._rank(): # pylint: disable=protected-access 

46 flat = tensor.numpy().reshape((-1,)) 

47 lst = [str(x) for x in flat[:summarize]] 

48 if len(lst) < flat.size: 

49 lst.append("...") 

50 else: 

51 # tensor.numpy() returns a scalar for zero dimensional arrays 

52 if gen_math_ops.not_equal(summarize, 0): 

53 lst = [str(tensor.numpy())] 

54 else: 

55 lst = [] 

56 

57 return ", ".join(lst) 

58 

59 

60# Assert and Print are special symbols in python, so we must 

61# use an upper-case version of them. 

62@tf_export("debugging.Assert", "Assert") 

63@dispatch.add_dispatch_support 

64@tf_should_use.should_use_result 

65def Assert(condition, data, summarize=None, name=None): 

66 """Asserts that the given condition is true. 

67 

68 If `condition` evaluates to false, print the list of tensors in `data`. 

69 `summarize` determines how many entries of the tensors to print. 

70 

71 Args: 

72 condition: The condition to evaluate. 

73 data: The tensors to print out when condition is false. 

74 summarize: Print this many entries of each tensor. 

75 name: A name for this operation (optional). 

76 

77 Returns: 

78 assert_op: An `Operation` that, when executed, raises a 

79 `tf.errors.InvalidArgumentError` if `condition` is not true. 

80 @compatibility(eager) 

81 returns None 

82 @end_compatibility 

83 

84 Raises: 

85 @compatibility(TF1) 

86 When in TF V1 mode (that is, outside `tf.function`) Assert needs a control 

87 dependency on the output to ensure the assertion executes: 

88 

89 ```python 

90 # Ensure maximum element of x is smaller or equal to 1 

91 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 

92 with tf.control_dependencies([assert_op]): 

93 ... code using x ... 

94 ``` 

95 

96 @end_compatibility 

97 """ 

98 if context.executing_eagerly(): 

99 if not condition: 

100 xs = ops.convert_n_to_tensor(data) 

101 data_str = [_summarize_eager(x, summarize) for x in xs] 

102 raise errors.InvalidArgumentError( 

103 node_def=None, 

104 op=None, 

105 message="Expected '%s' to be true. Summarized data: %s" % 

106 (condition, "\n".join(data_str))) 

107 return 

108 

109 with ops.name_scope(name, "Assert", [condition, data]) as name: 

110 xs = ops.convert_n_to_tensor(data) 

111 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 

112 # As a simple heuristic, we assume that string and int32 are 

113 # on host to avoid the need to use cond. If it is not case, 

114 # we will pay the price copying the tensor to host memory. 

115 return gen_logging_ops._assert(condition, data, summarize, name="Assert") # pylint: disable=protected-access 

116 else: 

117 condition = ops.convert_to_tensor(condition, name="Condition") 

118 

119 def true_assert(): 

120 return gen_logging_ops._assert( # pylint: disable=protected-access 

121 condition, data, summarize, name="Assert") 

122 

123 guarded_assert = cond.cond( 

124 condition, 

125 gen_control_flow_ops.no_op, 

126 true_assert, 

127 name="AssertGuard") 

128 if context.executing_eagerly(): 

129 return 

130 return guarded_assert.op