Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_utils.py: 20%

82 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training-related utilities.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.keras.utils import generic_utils 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.util import nest 

24 

25 

26def slice_arrays(arrays, indices, contiguous=True): 

27 """Slices batches out of provided arrays (workaround for eager tensors). 

28 

29 Unfortunately eager tensors don't have the same slicing behavior as 

30 Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), 

31 hence we cannot use `generic_utils.slice_arrays` directly 

32 and we have to implement this workaround based on `concat`. This has a 

33 performance cost. 

34 

35 Args: 

36 arrays: Single array or list of arrays. 

37 indices: List of indices in the array that should be included in the output 

38 batch. 

39 contiguous: Boolean flag indicating whether the indices are contiguous. 

40 

41 Returns: 

42 Slice of data (either single array or list of arrays). 

43 """ 

44 converted_to_list = False 

45 if not isinstance(arrays, list): 

46 converted_to_list = True 

47 arrays = [arrays] 

48 if any(tensor_util.is_tf_type(x) for x in arrays): 

49 if not contiguous: 

50 entries = [[x[i:i + 1] for i in indices] for x in arrays] 

51 slices = [array_ops.concat(x, axis=0) for x in entries] 

52 else: 

53 slices = [x[indices[0]:indices[-1] + 1] for x in arrays] 

54 else: 

55 slices = generic_utils.slice_arrays(arrays, indices) 

56 

57 if converted_to_list: 

58 slices = slices[0] 

59 return slices 

60 

61 

62def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes, 

63 check_all_flat=False): 

64 """Adds 1.0 as sample weights for the outputs for which there is no weight. 

65 

66 Args: 

67 outputs: List of model outputs. 

68 sample_weights: List of sample weight inputs. 

69 sample_weight_modes: List of sample weight modes or None. 

70 check_all_flat: Ensure that inputs are not nested structures. This is not 

71 a free check, so we may not want to run it eagerly every iteration. 

72 

73 Returns: 

74 Tuple of sample weights, one sample weight for every output, and booleans 

75 describing the raw sample weights. 

76 """ 

77 any_sample_weight = sample_weights is not None and any( 

78 w is not None for w in sample_weights) 

79 partial_sample_weight = any_sample_weight and any( 

80 w is None for w in sample_weights) 

81 

82 if not any_sample_weight: 

83 return None, any_sample_weight, partial_sample_weight 

84 

85 if not partial_sample_weight: 

86 return sample_weights, any_sample_weight, partial_sample_weight 

87 

88 if check_all_flat: 

89 nest.assert_same_structure( 

90 list_to_tuple(sample_weights), 

91 list_to_tuple(nest.flatten(sample_weights))) 

92 nest.assert_same_structure( 

93 list_to_tuple(outputs), 

94 list_to_tuple(nest.flatten(outputs))) 

95 if sample_weight_modes is not None: 

96 nest.assert_same_structure( 

97 sample_weight_modes, nest.flatten(sample_weight_modes)) 

98 

99 new_sample_weights = [] 

100 for i, sw in enumerate(sample_weights): 

101 if sw is None: 

102 as_numpy = isinstance(outputs[i], np.ndarray) 

103 output = outputs[i] 

104 output_shape = output.shape if as_numpy else array_ops.shape(output) 

105 

106 is_temporal = ( 

107 sample_weight_modes is not None and 

108 sample_weight_modes[i] == 'temporal') 

109 sw_shape = (output_shape[0], 

110 output_shape[1]) if is_temporal else (output_shape[0],) 

111 

112 new_sample_weights.append( 

113 np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape)) 

114 

115 else: 

116 new_sample_weights.append(sw) 

117 return (list_to_tuple(new_sample_weights), 

118 any_sample_weight, partial_sample_weight) 

119 

120 

121class RespectCompiledTrainableState(object): 

122 """Set and restore trainable state if it has changed since compile. 

123 

124 The keras API guarantees that the value of each Layer's `trainable` property 

125 at `Model.compile` time will be used when training that model. In order to 

126 respect this requirement, it may be necessary to set the trainable value of 

127 layers to their compile time values before beginning a training endpoint and 

128 restore the values before returing from said endpoint. This scope checks if 

129 any layer's trainable state has changed since Model compile, and performs this 

130 set and un-set bookkeeping. 

131 

132 However, the trainable state of a layer changes quite infrequently, if ever, 

133 for many kinds of workflows. Moreover, updating every layer in a model is an 

134 expensive operation. As a result, we will only explicitly set and unset the 

135 trainable state of a model if a trainable value has changed since compile. 

136 """ 

137 

138 def __init__(self, model): 

139 self._model = model 

140 self._current_trainable_state = None 

141 self._compiled_trainable_state = None 

142 self._should_set_trainable = False 

143 

144 def __enter__(self): 

145 self._current_trainable_state = self._model._get_trainable_state() # pylint: disable=protected-access 

146 self._compiled_trainable_state = self._model._compiled_trainable_state # pylint: disable=protected-access 

147 

148 # Check to see if any layer's trainable state has changed since `compile`. 

149 for layer, trainable in self._compiled_trainable_state.items(): 

150 if (layer in self._current_trainable_state and 

151 trainable != self._current_trainable_state[layer]): 

152 self._should_set_trainable = True 

153 break 

154 

155 # If so, restore the model to its compiled state. 

156 if self._should_set_trainable: 

157 self._model._set_trainable_state(self._compiled_trainable_state) # pylint: disable=protected-access 

158 

159 def __exit__(self, type_arg, value_arg, traceback_arg): 

160 # If we set the values to their compiled state in __enter__, we need to 

161 # restore the original values before leaving the scope. 

162 if self._should_set_trainable: 

163 self._model._set_trainable_state(self._current_trainable_state) # pylint: disable=protected-access 

164 return False # False values do not suppress exceptions 

165 

166 

167# Allow use of methods not exposed to the user. 

168# pylint: disable=protected-access 

169def get_input_shape_and_dtype(layer): 

170 """Retrieves input shape and input dtype of layer if applicable. 

171 

172 Args: 

173 layer: Layer (or model) instance. 

174 

175 Returns: 

176 Tuple (input_shape, input_dtype). Both could be None if the layer 

177 does not have a defined input shape. 

178 

179 Raises: 

180 ValueError: in case an empty Sequential or Functional model is passed. 

181 """ 

182 

183 def _is_graph_model(layer): 

184 return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or 

185 layer.__class__.__name__ == 'Sequential') 

186 

187 # In case of nested models: recover the first layer 

188 # of the deepest model to infer input shape and dtype. 

189 # Subclassed Models may not have been built so can't be checked. 

190 while _is_graph_model(layer): 

191 if not layer.layers: 

192 raise ValueError('An empty Model cannot be used as a Layer.') 

193 layer = layer.layers[0] 

194 

195 if getattr(layer, '_batch_input_shape', None): 

196 return layer._batch_input_shape, layer.dtype 

197 return None, None 

198 

199 

200# pylint: enable=protected-access 

201 

202 

203def get_static_batch_size(layer): 

204 """Gets the static batch size of a Layer. 

205 

206 Args: 

207 layer: a `Layer` instance. 

208 

209 Returns: 

210 The static batch size of a Layer. 

211 """ 

212 batch_input_shape, _ = get_input_shape_and_dtype(layer) 

213 if batch_input_shape is not None: 

214 return tensor_shape.Dimension(batch_input_shape[0]).value 

215 return None 

216 

217 

218def list_to_tuple(maybe_list): 

219 """Datasets will stack the list of tensor, so switch them to tuples.""" 

220 if isinstance(maybe_list, list): 

221 return tuple(maybe_list) 

222 return maybe_list