Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/layers/utils.py: 14%

92 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16"""Contains layer utilities for input validation and format conversion.""" 

17from tensorflow.python.framework import smart_cond as smart_module 

18from tensorflow.python.ops import cond 

19from tensorflow.python.ops import variables 

20 

21 

22def convert_data_format(data_format, ndim): 

23 if data_format == 'channels_last': 

24 if ndim == 3: 

25 return 'NWC' 

26 elif ndim == 4: 

27 return 'NHWC' 

28 elif ndim == 5: 

29 return 'NDHWC' 

30 else: 

31 raise ValueError(f'Input rank: {ndim} not supported. We only support ' 

32 'input rank 3, 4 or 5.') 

33 elif data_format == 'channels_first': 

34 if ndim == 3: 

35 return 'NCW' 

36 elif ndim == 4: 

37 return 'NCHW' 

38 elif ndim == 5: 

39 return 'NCDHW' 

40 else: 

41 raise ValueError(f'Input rank: {ndim} not supported. We only support ' 

42 'input rank 3, 4 or 5.') 

43 else: 

44 raise ValueError(f'Invalid data_format: {data_format}. We only support ' 

45 '"channels_first" or "channels_last"') 

46 

47 

48def normalize_tuple(value, n, name): 

49 """Transforms a single integer or iterable of integers into an integer tuple. 

50 

51 Args: 

52 value: The value to validate and convert. Could an int, or any iterable 

53 of ints. 

54 n: The size of the tuple to be returned. 

55 name: The name of the argument being validated, e.g. "strides" or 

56 "kernel_size". This is only used to format error messages. 

57 

58 Returns: 

59 A tuple of n integers. 

60 

61 Raises: 

62 ValueError: If something else than an int/long or iterable thereof was 

63 passed. 

64 """ 

65 if isinstance(value, int): 

66 return (value,) * n 

67 else: 

68 try: 

69 value_tuple = tuple(value) 

70 except TypeError: 

71 raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' 

72 f'integers. Received: {str(value)}') 

73 if len(value_tuple) != n: 

74 raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' 

75 f'integers. Received: {str(value)}') 

76 for single_value in value_tuple: 

77 try: 

78 int(single_value) 

79 except (ValueError, TypeError): 

80 raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' 

81 f'integers. Received: {str(value)} including element ' 

82 f'{str(single_value)} of type ' 

83 f'{str(type(single_value))}') 

84 return value_tuple 

85 

86 

87def normalize_data_format(value): 

88 data_format = value.lower() 

89 if data_format not in {'channels_first', 'channels_last'}: 

90 raise ValueError('The `data_format` argument must be one of ' 

91 '"channels_first", "channels_last". Received: ' 

92 f'{str(value)}.') 

93 return data_format 

94 

95 

96def normalize_padding(value): 

97 padding = value.lower() 

98 if padding not in {'valid', 'same'}: 

99 raise ValueError('The `padding` argument must be one of "valid", "same". ' 

100 f'Received: {str(padding)}.') 

101 return padding 

102 

103 

104def conv_output_length(input_length, filter_size, padding, stride, dilation=1): 

105 """Determines output length of a convolution given input length. 

106 

107 Args: 

108 input_length: integer. 

109 filter_size: integer. 

110 padding: one of "same", "valid", "full". 

111 stride: integer. 

112 dilation: dilation rate, integer. 

113 

114 Returns: 

115 The output length (integer). 

116 """ 

117 if input_length is None: 

118 return None 

119 assert padding in {'same', 'valid', 'full'} 

120 dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) 

121 if padding == 'same': 

122 output_length = input_length 

123 elif padding == 'valid': 

124 output_length = input_length - dilated_filter_size + 1 

125 elif padding == 'full': 

126 output_length = input_length + dilated_filter_size - 1 

127 return (output_length + stride - 1) // stride 

128 

129 

130def conv_input_length(output_length, filter_size, padding, stride): 

131 """Determines input length of a convolution given output length. 

132 

133 Args: 

134 output_length: integer. 

135 filter_size: integer. 

136 padding: one of "same", "valid", "full". 

137 stride: integer. 

138 

139 Returns: 

140 The input length (integer). 

141 """ 

142 if output_length is None: 

143 return None 

144 assert padding in {'same', 'valid', 'full'} 

145 if padding == 'same': 

146 pad = filter_size // 2 

147 elif padding == 'valid': 

148 pad = 0 

149 elif padding == 'full': 

150 pad = filter_size - 1 

151 return (output_length - 1) * stride - 2 * pad + filter_size 

152 

153 

154def deconv_output_length(input_length, filter_size, padding, stride): 

155 """Determines output length of a transposed convolution given input length. 

156 

157 Args: 

158 input_length: integer. 

159 filter_size: integer. 

160 padding: one of "same", "valid", "full". 

161 stride: integer. 

162 

163 Returns: 

164 The output length (integer). 

165 """ 

166 if input_length is None: 

167 return None 

168 input_length *= stride 

169 if padding == 'valid': 

170 input_length += max(filter_size - stride, 0) 

171 elif padding == 'full': 

172 input_length -= (stride + filter_size - 2) 

173 return input_length 

174 

175 

176def smart_cond(pred, true_fn=None, false_fn=None, name=None): 

177 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 

178 

179 If `pred` is a bool or has a constant value, we return either `true_fn()` 

180 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 

181 

182 Args: 

183 pred: A scalar determining whether to return the result of `true_fn` or 

184 `false_fn`. 

185 true_fn: The callable to be performed if pred is true. 

186 false_fn: The callable to be performed if pred is false. 

187 name: Optional name prefix when using `tf.cond`. 

188 

189 Returns: 

190 Tensors returned by the call to either `true_fn` or `false_fn`. 

191 

192 Raises: 

193 TypeError: If `true_fn` or `false_fn` is not callable. 

194 """ 

195 if isinstance(pred, variables.Variable): 

196 return cond.cond( 

197 pred, true_fn=true_fn, false_fn=false_fn, name=name) 

198 return smart_module.smart_cond( 

199 pred, true_fn=true_fn, false_fn=false_fn, name=name) 

200 

201 

202def constant_value(pred): 

203 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 

204 

205 Args: 

206 pred: A scalar, either a Python bool or a TensorFlow boolean variable 

207 or tensor, or the Python integer 1 or 0. 

208 

209 Returns: 

210 True or False if `pred` has a constant boolean value, None otherwise. 

211 

212 Raises: 

213 TypeError: If `pred` is not a Variable, Tensor or bool, or Python 

214 integer 1 or 0. 

215 """ 

216 # Allow integer booleans. 

217 if isinstance(pred, int): 

218 if pred == 1: 

219 pred = True 

220 elif pred == 0: 

221 pred = False 

222 

223 if isinstance(pred, variables.Variable): 

224 return None 

225 return smart_module.smart_constant_value(pred)