Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/losses/util.py: 31%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for manipulating the loss collections.""" 

16 

17from tensorflow.python.eager import context 

18from tensorflow.python.framework import constant_op 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import check_ops 

23from tensorflow.python.ops import cond 

24from tensorflow.python.ops import confusion_matrix 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.util import tf_contextlib 

27from tensorflow.python.util.tf_export import tf_export 

28 

29 

30def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): 

31 """Squeeze or expand last dimension if needed. 

32 

33 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 

34 (using `confusion_matrix.remove_squeezable_dimensions`). 

35 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 

36 from the new rank of `y_pred`. 

37 If `sample_weight` is scalar, it is kept scalar. 

38 

39 This will use static shape if available. Otherwise, it will add graph 

40 operations, which could result in a performance hit. 

41 

42 Args: 

43 y_pred: Predicted values, a `Tensor` of arbitrary dimensions. 

44 y_true: Optional label `Tensor` whose dimensions match `y_pred`. 

45 sample_weight: Optional weight scalar or `Tensor` whose dimensions match 

46 `y_pred`. 

47 

48 Returns: 

49 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has 

50 the last dimension squeezed, 

51 `sample_weight` could be extended by one dimension. 

52 If `sample_weight` is None, (y_pred, y_true) is returned. 

53 """ 

54 y_pred_shape = y_pred.shape 

55 y_pred_rank = y_pred_shape.ndims 

56 if y_true is not None: 

57 

58 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred` 

59 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), 

60 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) 

61 # In this case, we should not try to remove squeezable dimension. 

62 y_true_shape = y_true.shape 

63 y_true_rank = y_true_shape.ndims 

64 if (y_true_rank is not None) and (y_pred_rank is not None): 

65 # Use static rank for `y_true` and `y_pred`. 

66 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: 

67 y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( 

68 y_true, y_pred) 

69 else: 

70 # Use dynamic rank. 

71 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true) 

72 squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions( # pylint: disable=g-long-lambda 

73 y_true, y_pred) 

74 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1]) 

75 maybe_squeeze_dims = lambda: cond.cond( # pylint: disable=g-long-lambda 

76 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)) 

77 y_true, y_pred = cond.cond( 

78 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims) 

79 

80 if sample_weight is None: 

81 return y_pred, y_true 

82 

83 weights_shape = sample_weight.shape 

84 weights_rank = weights_shape.ndims 

85 if weights_rank == 0: # If weights is scalar, do nothing. 

86 return y_pred, y_true, sample_weight 

87 

88 if (y_pred_rank is not None) and (weights_rank is not None): 

89 # Use static rank. 

90 if weights_rank - y_pred_rank == 1: 

91 sample_weight = array_ops.squeeze(sample_weight, [-1]) 

92 elif y_pred_rank - weights_rank == 1: 

93 sample_weight = array_ops.expand_dims(sample_weight, [-1]) 

94 return y_pred, y_true, sample_weight 

95 

96 # Use dynamic rank. 

97 weights_rank_tensor = array_ops.rank(sample_weight) 

98 rank_diff = weights_rank_tensor - array_ops.rank(y_pred) 

99 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) 

100 

101 def _maybe_expand_weights(): 

102 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1]) 

103 return cond.cond( 

104 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight) 

105 

106 def _maybe_adjust_weights(): 

107 return cond.cond( 

108 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 

109 _maybe_expand_weights) 

110 

111 # squeeze or expand last dim of `sample_weight` if its rank differs by 1 

112 # from the new rank of `y_pred`. 

113 sample_weight = cond.cond( 

114 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, 

115 _maybe_adjust_weights) 

116 return y_pred, y_true, sample_weight 

117 

118 

119def scale_losses_by_sample_weight(losses, sample_weight): 

120 """Scales loss values by the given sample weights. 

121 

122 `sample_weight` dimensions are updated to match with the dimension of `losses` 

123 if possible by using squeeze/expand/broadcast. 

124 

125 Args: 

126 losses: Loss tensor. 

127 sample_weight: Sample weights tensor. 

128 

129 Returns: 

130 `losses` scaled by `sample_weight` with dtype float32. 

131 """ 

132 # TODO(psv): Handle the casting here in a better way, eg. if losses is float64 

133 # we do not want to lose precision. 

134 losses = math_ops.cast(losses, dtypes.float32) 

135 sample_weight = math_ops.cast(sample_weight, dtypes.float32) 

136 

137 # Update dimensions of `sample_weight` to match with `losses` if possible. 

138 losses, _, sample_weight = squeeze_or_expand_dimensions( 

139 losses, None, sample_weight) 

140 return math_ops.multiply(losses, sample_weight) 

141 

142 

143@tf_contextlib.contextmanager 

144def check_per_example_loss_rank(per_example_loss): 

145 """Context manager that checks that the rank of per_example_loss is at least 1. 

146 

147 Args: 

148 per_example_loss: Per example loss tensor. 

149 

150 Yields: 

151 A context manager. 

152 """ 

153 loss_rank = per_example_loss.shape.rank 

154 if loss_rank is not None: 

155 # Handle static rank. 

156 if loss_rank == 0: 

157 raise ValueError( 

158 "Invalid value passed for `per_example_loss`. Expected a tensor with " 

159 f"at least rank 1. Received per_example_loss={per_example_loss} with " 

160 f"rank {loss_rank}") 

161 yield 

162 else: 

163 # Handle dynamic rank. 

164 with ops.control_dependencies([ 

165 check_ops.assert_greater_equal( 

166 array_ops.rank(per_example_loss), 

167 math_ops.cast(1, dtype=dtypes.int32), 

168 message="Invalid value passed for `per_example_loss`. Expected a " 

169 "tensor with at least rank 1.") 

170 ]): 

171 yield 

172 

173 

174@tf_export(v1=["losses.add_loss"]) 

175def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): 

176 """Adds a externally defined loss to the collection of losses. 

177 

178 Args: 

179 loss: A loss `Tensor`. 

180 loss_collection: Optional collection to add the loss to. 

181 """ 

182 # Since we have no way of figuring out when a training iteration starts or 

183 # ends, holding on to a loss when executing eagerly is indistinguishable from 

184 # leaking memory. We instead leave the collection empty. 

185 if loss_collection and not context.executing_eagerly(): 

186 ops.add_to_collection(loss_collection, loss) 

187 

188 

189@tf_export(v1=["losses.get_losses"]) 

190def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): 

191 """Gets the list of losses from the loss_collection. 

192 

193 Args: 

194 scope: An optional scope name for filtering the losses to return. 

195 loss_collection: Optional losses collection. 

196 

197 Returns: 

198 a list of loss tensors. 

199 """ 

200 return ops.get_collection(loss_collection, scope) 

201 

202 

203@tf_export(v1=["losses.get_regularization_losses"]) 

204def get_regularization_losses(scope=None): 

205 """Gets the list of regularization losses. 

206 

207 Args: 

208 scope: An optional scope name for filtering the losses to return. 

209 

210 Returns: 

211 A list of regularization losses as Tensors. 

212 """ 

213 return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) 

214 

215 

216@tf_export(v1=["losses.get_regularization_loss"]) 

217def get_regularization_loss(scope=None, name="total_regularization_loss"): 

218 """Gets the total regularization loss. 

219 

220 Args: 

221 scope: An optional scope name for filtering the losses to return. 

222 name: The name of the returned tensor. 

223 

224 Returns: 

225 A scalar regularization loss. 

226 """ 

227 losses = get_regularization_losses(scope) 

228 if losses: 

229 return math_ops.add_n(losses, name=name) 

230 else: 

231 return constant_op.constant(0.0) 

232 

233 

234@tf_export(v1=["losses.get_total_loss"]) 

235def get_total_loss(add_regularization_losses=True, 

236 name="total_loss", 

237 scope=None): 

238 """Returns a tensor whose value represents the total loss. 

239 

240 In particular, this adds any losses you have added with `tf.add_loss()` to 

241 any regularization losses that have been added by regularization parameters 

242 on layers constructors e.g. `tf.layers`. Be very sure to use this if you 

243 are constructing a loss_op manually. Otherwise regularization arguments 

244 on `tf.layers` methods will not function. 

245 

246 Args: 

247 add_regularization_losses: A boolean indicating whether or not to use the 

248 regularization losses in the sum. 

249 name: The name of the returned tensor. 

250 scope: An optional scope name for filtering the losses to return. Note that 

251 this filters the losses added with `tf.add_loss()` as well as the 

252 regularization losses to that scope. 

253 

254 Returns: 

255 A `Tensor` whose value represents the total loss. 

256 

257 Raises: 

258 ValueError: if `losses` is not iterable. 

259 """ 

260 losses = get_losses(scope=scope) 

261 if add_regularization_losses: 

262 losses += get_regularization_losses(scope=scope) 

263 return math_ops.add_n(losses, name=name)