Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_utils.py: 16%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training-related utilities.""" 

16 

17import numpy as np 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.utils import generic_utils 

21 

22 

23def slice_arrays(arrays, indices, contiguous=True): 

24 """Slices batches out of provided arrays (workaround for eager tensors). 

25 

26 Unfortunately eager tensors don't have the same slicing behavior as 

27 Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), 

28 hence we cannot use `generic_utils.slice_arrays` directly 

29 and we have to implement this workaround based on `concat`. This has a 

30 performance cost. 

31 

32 Args: 

33 arrays: Single array or list of arrays. 

34 indices: List of indices in the array that should be included in the 

35 output batch. 

36 contiguous: Boolean flag indicating whether the indices are contiguous. 

37 

38 Returns: 

39 Slice of data (either single array or list of arrays). 

40 """ 

41 converted_to_list = False 

42 if not isinstance(arrays, list): 

43 converted_to_list = True 

44 arrays = [arrays] 

45 if any(tf.is_tensor(x) for x in arrays): 

46 if not contiguous: 

47 entries = [[x[i : i + 1] for i in indices] for x in arrays] 

48 slices = [tf.concat(x, axis=0) for x in entries] 

49 else: 

50 slices = [x[indices[0] : indices[-1] + 1] for x in arrays] 

51 else: 

52 slices = generic_utils.slice_arrays(arrays, indices) 

53 

54 if converted_to_list: 

55 slices = slices[0] 

56 return slices 

57 

58 

59def handle_partial_sample_weights( 

60 outputs, sample_weights, sample_weight_modes, check_all_flat=False 

61): 

62 """Adds 1.0 as sample weights for the outputs for which there is no weight. 

63 

64 Args: 

65 outputs: List of model outputs. 

66 sample_weights: List of sample weight inputs. 

67 sample_weight_modes: List of sample weight modes or None. 

68 check_all_flat: Ensure that inputs are not nested structures. This is not 

69 a free check, so we may not want to run it eagerly every iteration. 

70 

71 Returns: 

72 Tuple of sample weights, one sample weight for every output, and booleans 

73 describing the raw sample weights. 

74 """ 

75 if not isinstance(sample_weights, (list, tuple)): 

76 any_sample_weight = sample_weights is not None 

77 partial_sample_weight = any_sample_weight and sample_weights is None 

78 else: 

79 any_sample_weight = sample_weights is not None and any( 

80 w is not None for w in sample_weights 

81 ) 

82 partial_sample_weight = any_sample_weight and any( 

83 w is None for w in sample_weights 

84 ) 

85 

86 if not any_sample_weight: 

87 return None, any_sample_weight, partial_sample_weight 

88 

89 if not partial_sample_weight: 

90 return sample_weights, any_sample_weight, partial_sample_weight 

91 

92 if check_all_flat: 

93 tf.nest.assert_same_structure( 

94 list_to_tuple(sample_weights), 

95 list_to_tuple(tf.nest.flatten(sample_weights)), 

96 ) 

97 tf.nest.assert_same_structure( 

98 list_to_tuple(outputs), list_to_tuple(tf.nest.flatten(outputs)) 

99 ) 

100 if sample_weight_modes is not None: 

101 tf.nest.assert_same_structure( 

102 sample_weight_modes, tf.nest.flatten(sample_weight_modes) 

103 ) 

104 

105 new_sample_weights = [] 

106 for i, sw in enumerate(sample_weights): 

107 if sw is None: 

108 as_numpy = isinstance(outputs[i], np.ndarray) 

109 output = outputs[i] 

110 output_shape = output.shape if as_numpy else tf.shape(output) 

111 

112 is_temporal = ( 

113 sample_weight_modes is not None 

114 and sample_weight_modes[i] == "temporal" 

115 ) 

116 sw_shape = ( 

117 (output_shape[0], output_shape[1]) 

118 if is_temporal 

119 else (output_shape[0],) 

120 ) 

121 

122 new_sample_weights.append( 

123 np.ones(sw_shape) if as_numpy else tf.ones(sw_shape) 

124 ) 

125 

126 else: 

127 new_sample_weights.append(sw) 

128 return ( 

129 list_to_tuple(new_sample_weights), 

130 any_sample_weight, 

131 partial_sample_weight, 

132 ) 

133 

134 

135class RespectCompiledTrainableState: 

136 """Set and restore trainable state if it has changed since compile. 

137 

138 The keras API guarantees that the value of each Layer's `trainable` property 

139 at `Model.compile` time will be used when training that model. In order to 

140 respect this requirement, it may be necessary to set the trainable value of 

141 layers to their compile time values before beginning a training endpoint and 

142 restore the values before returning from said endpoint. This scope checks if 

143 any layer's trainable state has changed since Model compile, and performs 

144 this set and un-set bookkeeping. 

145 

146 However, the trainable state of a layer changes quite infrequently, if ever, 

147 for many kinds of workflows. Moreover, updating every layer in a model is an 

148 expensive operation. As a result, we will only explicitly set and unset the 

149 trainable state of a model if a trainable value has changed since compile. 

150 """ 

151 

152 def __init__(self, model): 

153 self._model = model 

154 self._current_trainable_state = None 

155 self._compiled_trainable_state = None 

156 self._should_set_trainable = False 

157 

158 def __enter__(self): 

159 self._current_trainable_state = self._model._get_trainable_state() 

160 self._compiled_trainable_state = self._model._compiled_trainable_state 

161 

162 # Check to see if any layer's trainable state has changed since 

163 # `compile`. 

164 for layer, trainable in self._compiled_trainable_state.items(): 

165 if ( 

166 layer in self._current_trainable_state 

167 and trainable != self._current_trainable_state[layer] 

168 ): 

169 self._should_set_trainable = True 

170 break 

171 

172 # If so, restore the model to its compiled state. 

173 if self._should_set_trainable: 

174 self._model._set_trainable_state(self._compiled_trainable_state) 

175 

176 def __exit__(self, type_arg, value_arg, traceback_arg): 

177 # If we set the values to their compiled state in __enter__, we need to 

178 # restore the original values before leaving the scope. 

179 if self._should_set_trainable: 

180 self._model._set_trainable_state(self._current_trainable_state) 

181 return False # False values do not suppress exceptions 

182 

183 

184# Allow use of methods not exposed to the user. 

185 

186 

187def get_input_shape_and_dtype(layer): 

188 """Retrieves input shape and input dtype of layer if applicable. 

189 

190 Args: 

191 layer: Layer (or model) instance. 

192 

193 Returns: 

194 Tuple (input_shape, input_dtype). Both could be None if the layer 

195 does not have a defined input shape. 

196 

197 Raises: 

198 ValueError: in case an empty Sequential or Functional model is passed. 

199 """ 

200 

201 def _is_graph_model(layer): 

202 return ( 

203 hasattr(layer, "_is_graph_network") and layer._is_graph_network 

204 ) or layer.__class__.__name__ == "Sequential" 

205 

206 # In case of nested models: recover the first layer 

207 # of the deepest model to infer input shape and dtype. 

208 # Subclassed Models may not have been built so can't be checked. 

209 while _is_graph_model(layer): 

210 if not layer.layers: 

211 raise ValueError("An empty Model cannot be used as a Layer.") 

212 layer = layer.layers[0] 

213 

214 if getattr(layer, "_batch_input_shape", None): 

215 return layer._batch_input_shape, layer.dtype 

216 return None, None 

217 

218 

219def get_static_batch_size(layer): 

220 """Gets the static batch size of a Layer. 

221 

222 Args: 

223 layer: a `Layer` instance. 

224 

225 Returns: 

226 The static batch size of a Layer. 

227 """ 

228 batch_input_shape, _ = get_input_shape_and_dtype(layer) 

229 if batch_input_shape is not None: 

230 return tf.compat.v1.Dimension(batch_input_shape[0]).value 

231 return None 

232 

233 

234def list_to_tuple(maybe_list): 

235 """Datasets will stack the list of tensor, so switch them to tuples.""" 

236 if isinstance(maybe_list, list): 

237 return tuple(maybe_list) 

238 return maybe_list 

239