Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/input_spec.py: 17%

119 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16# pylint: disable=g-classes-have-attributes 

17"""Contains the InputSpec class.""" 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.framework import tensor_spec 

22from tensorflow.python.keras import backend 

23from tensorflow.python.util import nest 

24from tensorflow.python.util.tf_export import keras_export 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@keras_export('keras.layers.InputSpec') 

29@tf_export(v1=['layers.InputSpec']) 

30class InputSpec(object): 

31 """Specifies the rank, dtype and shape of every input to a layer. 

32 

33 Layers can expose (if appropriate) an `input_spec` attribute: 

34 an instance of `InputSpec`, or a nested structure of `InputSpec` instances 

35 (one per input tensor). These objects enable the layer to run input 

36 compatibility checks for input structure, input rank, input shape, and 

37 input dtype. 

38 

39 A None entry in a shape is compatible with any dimension, 

40 a None shape is compatible with any shape. 

41 

42 Args: 

43 dtype: Expected DataType of the input. 

44 shape: Shape tuple, expected shape of the input 

45 (may include None for unchecked axes). Includes the batch size. 

46 ndim: Integer, expected rank of the input. 

47 max_ndim: Integer, maximum rank of the input. 

48 min_ndim: Integer, minimum rank of the input. 

49 axes: Dictionary mapping integer axes to 

50 a specific dimension value. 

51 allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long 

52 as the last axis of the input is 1, as well as inputs of rank N-1 

53 as long as the last axis of the spec is 1. 

54 name: Expected key corresponding to this input when passing data as 

55 a dictionary. 

56 

57 Example: 

58 

59 ```python 

60 class MyLayer(Layer): 

61 def __init__(self): 

62 super(MyLayer, self).__init__() 

63 # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1) 

64 # and raise an appropriate error message otherwise. 

65 self.input_spec = InputSpec( 

66 shape=(None, 28, 28, 1), 

67 allow_last_axis_squeeze=True) 

68 ``` 

69 """ 

70 

71 def __init__(self, 

72 dtype=None, 

73 shape=None, 

74 ndim=None, 

75 max_ndim=None, 

76 min_ndim=None, 

77 axes=None, 

78 allow_last_axis_squeeze=False, 

79 name=None): 

80 self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None 

81 shape = tensor_shape.TensorShape(shape) 

82 if shape.rank is None: 

83 shape = None 

84 else: 

85 shape = tuple(shape.as_list()) 

86 if shape is not None: 

87 self.ndim = len(shape) 

88 self.shape = shape 

89 else: 

90 self.ndim = ndim 

91 self.shape = None 

92 self.max_ndim = max_ndim 

93 self.min_ndim = min_ndim 

94 self.name = name 

95 self.allow_last_axis_squeeze = allow_last_axis_squeeze 

96 try: 

97 axes = axes or {} 

98 self.axes = {int(k): axes[k] for k in axes} 

99 except (ValueError, TypeError): 

100 raise TypeError('The keys in axes must be integers.') 

101 

102 if self.axes and (self.ndim is not None or self.max_ndim is not None): 

103 max_dim = (self.ndim if self.ndim else self.max_ndim) - 1 

104 max_axis = max(self.axes) 

105 if max_axis > max_dim: 

106 raise ValueError('Axis {} is greater than the maximum allowed value: {}' 

107 .format(max_axis, max_dim)) 

108 

109 def __repr__(self): 

110 spec = [('dtype=' + str(self.dtype)) if self.dtype else '', 

111 ('shape=' + str(self.shape)) if self.shape else '', 

112 ('ndim=' + str(self.ndim)) if self.ndim else '', 

113 ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '', 

114 ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '', 

115 ('axes=' + str(self.axes)) if self.axes else ''] 

116 return 'InputSpec(%s)' % ', '.join(x for x in spec if x) 

117 

118 def get_config(self): 

119 return { 

120 'dtype': self.dtype, 

121 'shape': self.shape, 

122 'ndim': self.ndim, 

123 'max_ndim': self.max_ndim, 

124 'min_ndim': self.min_ndim, 

125 'axes': self.axes} 

126 

127 @classmethod 

128 def from_config(cls, config): 

129 return cls(**config) 

130 

131 

132def to_tensor_shape(spec): 

133 """Returns a tf.TensorShape object that matches the shape specifications. 

134 

135 If the InputSpec's shape or ndim is defined, this method will return a fully 

136 or partially-known shape. Otherwise, the returned TensorShape is None. 

137 

138 Args: 

139 spec: an InputSpec object. 

140 

141 Returns: 

142 a tf.TensorShape object 

143 """ 

144 if spec.ndim is None and spec.shape is None: 

145 return tensor_shape.TensorShape(None) 

146 elif spec.shape is not None: 

147 return tensor_shape.TensorShape(spec.shape) 

148 else: 

149 shape = [None] * spec.ndim 

150 for a in spec.axes: 

151 shape[a] = spec.axes[a] # Assume that axes is defined 

152 return tensor_shape.TensorShape(shape) 

153 

154 

155def assert_input_compatibility(input_spec, inputs, layer_name): 

156 """Checks compatibility between the layer and provided inputs. 

157 

158 This checks that the tensor(s) `inputs` verify the input assumptions 

159 of a layer (if any). If not, a clear and actional exception gets raised. 

160 

161 Args: 

162 input_spec: An InputSpec instance, list of InputSpec instances, a nested 

163 structure of InputSpec instances, or None. 

164 inputs: Input tensor, list of input tensors, or a nested structure of 

165 input tensors. 

166 layer_name: String, name of the layer (for error message formatting). 

167 

168 Raises: 

169 ValueError: in case of mismatch between 

170 the provided inputs and the expectations of the layer. 

171 """ 

172 if not input_spec: 

173 return 

174 

175 input_spec = nest.flatten(input_spec) 

176 if isinstance(inputs, dict): 

177 # Flatten `inputs` by reference order if input spec names are provided 

178 names = [spec.name for spec in input_spec] 

179 if all(names): 

180 list_inputs = [] 

181 for name in names: 

182 if name not in inputs: 

183 raise ValueError('Missing data for input "%s". ' 

184 'You passed a data dictionary with keys %s. ' 

185 'Expected the following keys: %s' % 

186 (name, list(inputs.keys()), names)) 

187 list_inputs.append(inputs[name]) 

188 inputs = list_inputs 

189 

190 inputs = nest.flatten(inputs) 

191 for x in inputs: 

192 # Having a shape/dtype is the only commonality of the various tensor-like 

193 # objects that may be passed. The most common kind of invalid type we are 

194 # guarding for is a Layer instance (Functional API), which does not 

195 # have a `shape` attribute. 

196 if not hasattr(x, 'shape'): 

197 raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,)) 

198 

199 if len(inputs) != len(input_spec): 

200 raise ValueError('Layer ' + layer_name + ' expects ' + 

201 str(len(input_spec)) + ' input(s), ' 

202 'but it received ' + str(len(inputs)) + 

203 ' input tensors. Inputs received: ' + str(inputs)) 

204 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): 

205 if spec is None: 

206 continue 

207 

208 shape = tensor_shape.TensorShape(x.shape) 

209 if shape.rank is None: 

210 return 

211 # Check ndim. 

212 if spec.ndim is not None and not spec.allow_last_axis_squeeze: 

213 ndim = shape.rank 

214 if ndim != spec.ndim: 

215 raise ValueError('Input ' + str(input_index) + ' of layer ' + 

216 layer_name + ' is incompatible with the layer: ' 

217 'expected ndim=' + str(spec.ndim) + ', found ndim=' + 

218 str(ndim) + '. Full shape received: ' + 

219 str(tuple(shape))) 

220 if spec.max_ndim is not None: 

221 ndim = x.shape.rank 

222 if ndim is not None and ndim > spec.max_ndim: 

223 raise ValueError('Input ' + str(input_index) + ' of layer ' + 

224 layer_name + ' is incompatible with the layer: ' 

225 'expected max_ndim=' + str(spec.max_ndim) + 

226 ', found ndim=' + str(ndim)) 

227 if spec.min_ndim is not None: 

228 ndim = x.shape.rank 

229 if ndim is not None and ndim < spec.min_ndim: 

230 raise ValueError('Input ' + str(input_index) + ' of layer ' + 

231 layer_name + ' is incompatible with the layer: ' 

232 ': expected min_ndim=' + str(spec.min_ndim) + 

233 ', found ndim=' + str(ndim) + 

234 '. Full shape received: ' + 

235 str(tuple(shape))) 

236 # Check dtype. 

237 if spec.dtype is not None: 

238 if x.dtype.name != spec.dtype: 

239 raise ValueError('Input ' + str(input_index) + ' of layer ' + 

240 layer_name + ' is incompatible with the layer: ' 

241 'expected dtype=' + str(spec.dtype) + 

242 ', found dtype=' + str(x.dtype)) 

243 

244 # Check specific shape axes. 

245 shape_as_list = shape.as_list() 

246 if spec.axes: 

247 for axis, value in spec.axes.items(): 

248 if hasattr(value, 'value'): 

249 value = value.value 

250 if value is not None and shape_as_list[int(axis)] not in {value, None}: 

251 raise ValueError( 

252 'Input ' + str(input_index) + ' of layer ' + layer_name + ' is' 

253 ' incompatible with the layer: expected axis ' + str(axis) + 

254 ' of input shape to have value ' + str(value) + 

255 ' but received input with shape ' + display_shape(x.shape)) 

256 # Check shape. 

257 if spec.shape is not None and shape.rank is not None: 

258 spec_shape = spec.shape 

259 if spec.allow_last_axis_squeeze: 

260 if shape_as_list and shape_as_list[-1] == 1: 

261 shape_as_list = shape_as_list[:-1] 

262 if spec_shape and spec_shape[-1] == 1: 

263 spec_shape = spec_shape[:-1] 

264 for spec_dim, dim in zip(spec_shape, shape_as_list): 

265 if spec_dim is not None and dim is not None: 

266 if spec_dim != dim: 

267 raise ValueError('Input ' + str(input_index) + 

268 ' is incompatible with layer ' + layer_name + 

269 ': expected shape=' + str(spec.shape) + 

270 ', found shape=' + display_shape(x.shape)) 

271 

272 

273def display_shape(shape): 

274 return str(tuple(shape.as_list())) 

275 

276 

277def to_tensor_spec(input_spec, default_dtype=None): 

278 """Converts a Keras InputSpec object to a TensorSpec.""" 

279 default_dtype = default_dtype or backend.floatx() 

280 if isinstance(input_spec, InputSpec): 

281 dtype = input_spec.dtype or default_dtype 

282 return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype) 

283 return tensor_spec.TensorSpec(None, default_dtype)