Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/input_spec.py: 15%

116 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Contains the InputSpec class.""" 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22 

23# isort: off 

24from tensorflow.python.util.tf_export import keras_export 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@keras_export( 

29 "keras.layers.InputSpec", 

30 v1=["keras.layers.InputSpec", "keras.__internal__.legacy.layers.InputSpec"], 

31) 

32@tf_export(v1=["layers.InputSpec"]) 

33class InputSpec: 

34 """Specifies the rank, dtype and shape of every input to a layer. 

35 

36 Layers can expose (if appropriate) an `input_spec` attribute: 

37 an instance of `InputSpec`, or a nested structure of `InputSpec` instances 

38 (one per input tensor). These objects enable the layer to run input 

39 compatibility checks for input structure, input rank, input shape, and 

40 input dtype. 

41 

42 A None entry in a shape is compatible with any dimension, 

43 a None shape is compatible with any shape. 

44 

45 Args: 

46 dtype: Expected DataType of the input. 

47 shape: Shape tuple, expected shape of the input 

48 (may include None for unchecked axes). Includes the batch size. 

49 ndim: Integer, expected rank of the input. 

50 max_ndim: Integer, maximum rank of the input. 

51 min_ndim: Integer, minimum rank of the input. 

52 axes: Dictionary mapping integer axes to 

53 a specific dimension value. 

54 allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long 

55 as the last axis of the input is 1, as well as inputs of rank N-1 

56 as long as the last axis of the spec is 1. 

57 name: Expected key corresponding to this input when passing data as 

58 a dictionary. 

59 

60 Example: 

61 

62 ```python 

63 class MyLayer(Layer): 

64 def __init__(self): 

65 super(MyLayer, self).__init__() 

66 # The layer will accept inputs with 

67 # shape (?, 28, 28) & (?, 28, 28, 1) 

68 # and raise an appropriate error message otherwise. 

69 self.input_spec = InputSpec( 

70 shape=(None, 28, 28, 1), 

71 allow_last_axis_squeeze=True) 

72 ``` 

73 """ 

74 

75 def __init__( 

76 self, 

77 dtype=None, 

78 shape=None, 

79 ndim=None, 

80 max_ndim=None, 

81 min_ndim=None, 

82 axes=None, 

83 allow_last_axis_squeeze=False, 

84 name=None, 

85 ): 

86 self.dtype = tf.as_dtype(dtype).name if dtype is not None else None 

87 shape = tf.TensorShape(shape) 

88 if shape.rank is None: 

89 shape = None 

90 else: 

91 shape = tuple(shape.as_list()) 

92 if shape is not None: 

93 self.ndim = len(shape) 

94 self.shape = shape 

95 else: 

96 self.ndim = ndim 

97 self.shape = None 

98 self.max_ndim = max_ndim 

99 self.min_ndim = min_ndim 

100 self.name = name 

101 self.allow_last_axis_squeeze = allow_last_axis_squeeze 

102 try: 

103 axes = axes or {} 

104 self.axes = {int(k): axes[k] for k in axes} 

105 except (ValueError, TypeError): 

106 raise TypeError( 

107 "Argument `axes` must be a dict with integer keys. " 

108 f"Received: axes={axes}" 

109 ) 

110 

111 if self.axes and (self.ndim is not None or self.max_ndim is not None): 

112 max_dim = (self.ndim if self.ndim else self.max_ndim) - 1 

113 max_axis = max(self.axes) 

114 if max_axis > max_dim: 

115 raise ValueError( 

116 "Axis {} is greater than the maximum " 

117 "allowed value: {}".format(max_axis, max_dim) 

118 ) 

119 

120 def __repr__(self): 

121 spec = [ 

122 ("dtype=" + str(self.dtype)) if self.dtype else "", 

123 ("shape=" + str(self.shape)) if self.shape else "", 

124 ("ndim=" + str(self.ndim)) if self.ndim else "", 

125 ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "", 

126 ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "", 

127 ("axes=" + str(self.axes)) if self.axes else "", 

128 ] 

129 return f"InputSpec({', '.join(x for x in spec if x)})" 

130 

131 def get_config(self): 

132 return { 

133 "dtype": self.dtype, 

134 "shape": self.shape, 

135 "ndim": self.ndim, 

136 "max_ndim": self.max_ndim, 

137 "min_ndim": self.min_ndim, 

138 "axes": self.axes, 

139 } 

140 

141 @classmethod 

142 def from_config(cls, config): 

143 return cls(**config) 

144 

145 

146def to_tensor_shape(spec): 

147 """Returns a tf.TensorShape object that matches the shape specifications. 

148 

149 If the InputSpec's shape or ndim is defined, this method will return a fully 

150 or partially-known shape. Otherwise, the returned TensorShape is None. 

151 

152 Args: 

153 spec: an InputSpec object. 

154 

155 Returns: 

156 a tf.TensorShape object 

157 """ 

158 if spec.ndim is None and spec.shape is None: 

159 return tf.TensorShape(None) 

160 elif spec.shape is not None: 

161 return tf.TensorShape(spec.shape) 

162 else: 

163 shape = [None] * spec.ndim 

164 for a in spec.axes: 

165 shape[a] = spec.axes[a] # Assume that axes is defined 

166 return tf.TensorShape(shape) 

167 

168 

169def assert_input_compatibility(input_spec, inputs, layer_name): 

170 """Checks compatibility between the layer and provided inputs. 

171 

172 This checks that the tensor(s) `inputs` verify the input assumptions 

173 of a layer (if any). If not, a clear and actional exception gets raised. 

174 

175 Args: 

176 input_spec: An InputSpec instance, list of InputSpec instances, a nested 

177 structure of InputSpec instances, or None. 

178 inputs: Input tensor, list of input tensors, or a nested structure of 

179 input tensors. 

180 layer_name: String, name of the layer (for error message formatting). 

181 

182 Raises: 

183 ValueError: in case of mismatch between 

184 the provided inputs and the expectations of the layer. 

185 """ 

186 if not input_spec: 

187 return 

188 

189 input_spec = tf.nest.flatten(input_spec) 

190 if isinstance(inputs, dict): 

191 # Flatten `inputs` by reference order if input spec names are provided 

192 names = [spec.name for spec in input_spec] 

193 if all(names): 

194 list_inputs = [] 

195 for name in names: 

196 if name not in inputs: 

197 raise ValueError( 

198 f'Missing data for input "{name}". ' 

199 "You passed a data dictionary with keys " 

200 f"{list(inputs.keys())}. " 

201 f"Expected the following keys: {names}" 

202 ) 

203 list_inputs.append(inputs[name]) 

204 inputs = list_inputs 

205 

206 inputs = tf.nest.flatten(inputs) 

207 for x in inputs: 

208 # Having a shape/dtype is the only commonality of the various 

209 # tensor-like objects that may be passed. The most common kind of 

210 # invalid type we are guarding for is a Layer instance (Functional API), 

211 # which does not have a `shape` attribute. 

212 if not hasattr(x, "shape"): 

213 raise TypeError( 

214 f"Inputs to a layer should be tensors. Got '{x}' " 

215 f"(of type {type(x)}) as input for layer '{layer_name}'." 

216 ) 

217 

218 if len(inputs) != len(input_spec): 

219 raise ValueError( 

220 f'Layer "{layer_name}" expects {len(input_spec)} input(s),' 

221 f" but it received {len(inputs)} input tensors. " 

222 f"Inputs received: {inputs}" 

223 ) 

224 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): 

225 if spec is None: 

226 continue 

227 

228 shape = tf.TensorShape(x.shape) 

229 if shape.rank is None: 

230 return 

231 # Check ndim. 

232 if spec.ndim is not None and not spec.allow_last_axis_squeeze: 

233 ndim = shape.rank 

234 if ndim != spec.ndim: 

235 raise ValueError( 

236 f'Input {input_index} of layer "{layer_name}" ' 

237 "is incompatible with the layer: " 

238 f"expected ndim={spec.ndim}, found ndim={ndim}. " 

239 f"Full shape received: {tuple(shape)}" 

240 ) 

241 if spec.max_ndim is not None: 

242 ndim = x.shape.rank 

243 if ndim is not None and ndim > spec.max_ndim: 

244 raise ValueError( 

245 f'Input {input_index} of layer "{layer_name}" ' 

246 "is incompatible with the layer: " 

247 f"expected max_ndim={spec.max_ndim}, " 

248 f"found ndim={ndim}" 

249 ) 

250 if spec.min_ndim is not None: 

251 ndim = x.shape.rank 

252 if ndim is not None and ndim < spec.min_ndim: 

253 raise ValueError( 

254 f'Input {input_index} of layer "{layer_name}" ' 

255 "is incompatible with the layer: " 

256 f"expected min_ndim={spec.min_ndim}, " 

257 f"found ndim={ndim}. " 

258 f"Full shape received: {tuple(shape)}" 

259 ) 

260 # Check dtype. 

261 if spec.dtype is not None: 

262 if x.dtype.name != spec.dtype: 

263 raise ValueError( 

264 f'Input {input_index} of layer "{layer_name}" ' 

265 "is incompatible with the layer: " 

266 f"expected dtype={spec.dtype}, " 

267 f"found dtype={x.dtype}" 

268 ) 

269 

270 # Check specific shape axes. 

271 shape_as_list = shape.as_list() 

272 if spec.axes: 

273 for axis, value in spec.axes.items(): 

274 if hasattr(value, "value"): 

275 value = value.value 

276 if value is not None and shape_as_list[int(axis)] not in { 

277 value, 

278 None, 

279 }: 

280 raise ValueError( 

281 f'Input {input_index} of layer "{layer_name}" is ' 

282 f"incompatible with the layer: expected axis {axis} " 

283 f"of input shape to have value {value}, " 

284 "but received input with " 

285 f"shape {display_shape(x.shape)}" 

286 ) 

287 # Check shape. 

288 if spec.shape is not None and shape.rank is not None: 

289 spec_shape = spec.shape 

290 if spec.allow_last_axis_squeeze: 

291 if shape_as_list and shape_as_list[-1] == 1: 

292 shape_as_list = shape_as_list[:-1] 

293 if spec_shape and spec_shape[-1] == 1: 

294 spec_shape = spec_shape[:-1] 

295 for spec_dim, dim in zip(spec_shape, shape_as_list): 

296 if spec_dim is not None and dim is not None: 

297 if spec_dim != dim: 

298 raise ValueError( 

299 f'Input {input_index} of layer "{layer_name}" is ' 

300 "incompatible with the layer: " 

301 f"expected shape={spec.shape}, " 

302 f"found shape={display_shape(x.shape)}" 

303 ) 

304 

305 

306def display_shape(shape): 

307 return str(tuple(shape.as_list())) 

308 

309 

310def to_tensor_spec(input_spec, default_dtype=None): 

311 """Converts a Keras InputSpec object to a TensorSpec.""" 

312 default_dtype = default_dtype or backend.floatx() 

313 if isinstance(input_spec, InputSpec): 

314 dtype = input_spec.dtype or default_dtype 

315 return tf.TensorSpec(to_tensor_shape(input_spec), dtype) 

316 return tf.TensorSpec(None, default_dtype) 

317