Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/utils/keras_utils.py: 21%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for tf.keras.""" 

16 

17import tensorflow as tf 

18 

19 

20def is_tensor_or_variable(x): 

21 return tf.is_tensor(x) or isinstance(x, tf.Variable) 

22 

23 

24class LossFunctionWrapper(tf.keras.losses.Loss): 

25 """Wraps a loss function in the `Loss` class.""" 

26 

27 def __init__( 

28 self, fn, reduction=tf.keras.losses.Reduction.AUTO, name=None, **kwargs 

29 ): 

30 """Initializes `LossFunctionWrapper` class. 

31 

32 Args: 

33 fn: The loss function to wrap, with signature `fn(y_true, y_pred, 

34 **kwargs)`. 

35 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 

36 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 

37 option will be determined by the usage context. For almost all cases 

38 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 

39 `tf.distribute.Strategy`, outside of built-in training loops such as 

40 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 

41 will raise an error. Please see this custom training [tutorial]( 

42 https://www.tensorflow.org/tutorials/distribute/custom_training) 

43 for more details. 

44 name: (Optional) name for the loss. 

45 **kwargs: The keyword arguments that are passed on to `fn`. 

46 """ 

47 super().__init__(reduction=reduction, name=name) 

48 self.fn = fn 

49 self._fn_kwargs = kwargs 

50 

51 def call(self, y_true, y_pred): 

52 """Invokes the `LossFunctionWrapper` instance. 

53 

54 Args: 

55 y_true: Ground truth values. 

56 y_pred: The predicted values. 

57 

58 Returns: 

59 Loss values per sample. 

60 """ 

61 return self.fn(y_true, y_pred, **self._fn_kwargs) 

62 

63 def get_config(self): 

64 config = {} 

65 for k, v in iter(self._fn_kwargs.items()): 

66 config[k] = tf.keras.backend.eval(v) if is_tensor_or_variable(v) else v 

67 base_config = super().get_config() 

68 return {**base_config, **config} 

69 

70 

71def normalize_data_format(value): 

72 if value is None: 

73 value = tf.keras.backend.image_data_format() 

74 data_format = value.lower() 

75 if data_format not in {"channels_first", "channels_last"}: 

76 raise ValueError( 

77 "The `data_format` argument must be one of " 

78 '"channels_first", "channels_last". Received: ' + str(value) 

79 ) 

80 return data_format 

81 

82 

83def normalize_tuple(value, n, name): 

84 """Transforms an integer or iterable of integers into an integer tuple. 

85 

86 A copy of tensorflow.python.keras.util. 

87 

88 Args: 

89 value: The value to validate and convert. Could an int, or any iterable 

90 of ints. 

91 n: The size of the tuple to be returned. 

92 name: The name of the argument being validated, e.g. "strides" or 

93 "kernel_size". This is only used to format error messages. 

94 

95 Returns: 

96 A tuple of n integers. 

97 

98 Raises: 

99 ValueError: If something else than an int/long or iterable thereof was 

100 passed. 

101 """ 

102 if isinstance(value, int): 

103 return (value,) * n 

104 else: 

105 try: 

106 value_tuple = tuple(value) 

107 except TypeError: 

108 raise TypeError( 

109 "The `" 

110 + name 

111 + "` argument must be a tuple of " 

112 + str(n) 

113 + " integers. Received: " 

114 + str(value) 

115 ) 

116 if len(value_tuple) != n: 

117 raise ValueError( 

118 "The `" 

119 + name 

120 + "` argument must be a tuple of " 

121 + str(n) 

122 + " integers. Received: " 

123 + str(value) 

124 ) 

125 for single_value in value_tuple: 

126 try: 

127 int(single_value) 

128 except (ValueError, TypeError): 

129 raise ValueError( 

130 "The `" 

131 + name 

132 + "` argument must be a tuple of " 

133 + str(n) 

134 + " integers. Received: " 

135 + str(value) 

136 + " " 

137 "including element " 

138 + str(single_value) 

139 + " of type" 

140 + " " 

141 + str(type(single_value)) 

142 ) 

143 return value_tuple 

144 

145 

146def _hasattr(obj, attr_name): 

147 # If possible, avoid retrieving the attribute as the object might run some 

148 # lazy computation in it. 

149 if attr_name in dir(obj): 

150 return True 

151 try: 

152 getattr(obj, attr_name) 

153 except AttributeError: 

154 return False 

155 else: 

156 return True 

157 

158 

159def assert_like_rnncell(cell_name, cell): 

160 """Raises a TypeError if cell is not like a 

161 tf.keras.layers.AbstractRNNCell. 

162 

163 Args: 

164 cell_name: A string to give a meaningful error referencing to the name 

165 of the function argument. 

166 cell: The object which should behave like a 

167 tf.keras.layers.AbstractRNNCell. 

168 

169 Raises: 

170 TypeError: A human-friendly exception. 

171 """ 

172 conditions = [ 

173 _hasattr(cell, "output_size"), 

174 _hasattr(cell, "state_size"), 

175 _hasattr(cell, "get_initial_state"), 

176 callable(cell), 

177 ] 

178 

179 errors = [ 

180 "'output_size' property is missing", 

181 "'state_size' property is missing", 

182 "'get_initial_state' method is required", 

183 "is not callable", 

184 ] 

185 

186 if not all(conditions): 

187 errors = [error for error, cond in zip(errors, conditions) if not cond] 

188 raise TypeError( 

189 "The argument {!r} ({}) is not an RNNCell: {}.".format( 

190 cell_name, cell, ", ".join(errors) 

191 ) 

192 )