Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/audio_dataset.py: 15%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras audio dataset loading utilities.""" 

16 

17import numpy as np 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src.utils import dataset_utils 

21 

22# isort: off 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26tfio = None # Import as-needed. 

27 

28ALLOWED_FORMATS = (".wav",) 

29 

30 

31@keras_export("keras.utils.audio_dataset_from_directory", v1=[]) 

32def audio_dataset_from_directory( 

33 directory, 

34 labels="inferred", 

35 label_mode="int", 

36 class_names=None, 

37 batch_size=32, 

38 sampling_rate=None, 

39 output_sequence_length=None, 

40 ragged=False, 

41 shuffle=True, 

42 seed=None, 

43 validation_split=None, 

44 subset=None, 

45 follow_links=False, 

46): 

47 """Generates a `tf.data.Dataset` from audio files in a directory. 

48 

49 If your directory structure is: 

50 

51 ``` 

52 main_directory/ 

53 ...class_a/ 

54 ......a_audio_1.wav 

55 ......a_audio_2.wav 

56 ...class_b/ 

57 ......b_audio_1.wav 

58 ......b_audio_2.wav 

59 ``` 

60 

61 Then calling `audio_dataset_from_directory(main_directory, 

62 labels='inferred')` 

63 will return a `tf.data.Dataset` that yields batches of audio files from 

64 the subdirectories `class_a` and `class_b`, together with labels 

65 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). 

66 

67 Only `.wav` files are supported at this time. 

68 

69 Args: 

70 directory: Directory where the data is located. If `labels` is "inferred", 

71 it should contain subdirectories, each containing audio files for a 

72 class. Otherwise, the directory structure is ignored. 

73 labels: Either "inferred" (labels are generated from the directory 

74 structure), None (no labels), or a list/tuple of integer labels of the 

75 same size as the number of audio files found in the directory. Labels 

76 should be sorted according to the alphanumeric order of the audio file 

77 paths (obtained via `os.walk(directory)` in Python). 

78 label_mode: String describing the encoding of `labels`. Options are: 

79 - 'int': means that the labels are encoded as integers (e.g. for 

80 `sparse_categorical_crossentropy` loss). - 'categorical' means that 

81 the labels are encoded as a categorical vector (e.g. for 

82 `categorical_crossentropy` loss). - 'binary' means that the labels 

83 (there can be only 2) are encoded as `float32` scalars with values 0 

84 or 1 (e.g. for `binary_crossentropy`). - None (no labels). 

85 class_names: Only valid if "labels" is "inferred". This is the explicit 

86 list of class names (must match names of subdirectories). Used to 

87 control the order of the classes (otherwise alphanumerical order is 

88 used). 

89 batch_size: Size of the batches of data. Default: 32. If `None`, the data 

90 will not be batched (the dataset will yield individual samples). 

91 sampling_rate: Audio sampling rate (in samples per second). 

92 output_sequence_length: Maximum length of an audio sequence. Audio files 

93 longer than this will be truncated to `output_sequence_length`. If set 

94 to `None`, then all sequences in the same batch will be padded to the 

95 length of the longest sequence in the batch. 

96 ragged: Whether to return a Ragged dataset (where each sequence has its 

97 own length). Default: False. 

98 shuffle: Whether to shuffle the data. Default: True. If set to False, 

99 sorts the data in alphanumeric order. 

100 seed: Optional random seed for shuffling and transformations. 

101 validation_split: Optional float between 0 and 1, fraction of data to 

102 reserve for validation. 

103 subset: Subset of the data to return. One of "training", "validation" or 

104 "both". Only used if `validation_split` is set. 

105 follow_links: Whether to visits subdirectories pointed to by symlinks. 

106 Defaults to False. 

107 

108 Returns: 

109 A `tf.data.Dataset` object. 

110 - If `label_mode` is None, it yields `string` tensors of shape 

111 `(batch_size,)`, containing the contents of a batch of audio files. 

112 - Otherwise, it yields a tuple `(audio, labels)`, where `audio` 

113 has shape `(batch_size, sequence_length, num_channels)` and `labels` 

114 follows the format described 

115 below. 

116 

117 Rules regarding labels format: 

118 - if `label_mode` is `int`, the labels are an `int32` tensor of shape 

119 `(batch_size,)`. 

120 - if `label_mode` is `binary`, the labels are a `float32` tensor of 

121 1s and 0s of shape `(batch_size, 1)`. 

122 - if `label_mode` is `categorical`, the labels are a `float32` tensor 

123 of shape `(batch_size, num_classes)`, representing a one-hot 

124 encoding of the class index. 

125 """ 

126 if labels not in ("inferred", None): 

127 if not isinstance(labels, (list, tuple)): 

128 raise ValueError( 

129 "The `labels` argument should be a list/tuple of integer " 

130 "labels, of the same size as the number of audio files in " 

131 "the target directory. If you wish to infer the labels from " 

132 "the subdirectory names in the target directory," 

133 ' pass `labels="inferred"`. ' 

134 "If you wish to get a dataset that only contains audio samples " 

135 f"(no labels), pass `labels=None`. Received: labels={labels}" 

136 ) 

137 if class_names: 

138 raise ValueError( 

139 "You can only pass `class_names` if " 

140 f'`labels="inferred"`. Received: labels={labels}, and ' 

141 f"class_names={class_names}" 

142 ) 

143 if label_mode not in {"int", "categorical", "binary", None}: 

144 raise ValueError( 

145 '`label_mode` argument must be one of "int", "categorical", ' 

146 '"binary", ' 

147 f"or None. Received: label_mode={label_mode}" 

148 ) 

149 

150 if ragged and output_sequence_length is not None: 

151 raise ValueError( 

152 "Cannot set both `ragged` and `output_sequence_length`" 

153 ) 

154 

155 if sampling_rate is not None: 

156 if not isinstance(sampling_rate, int): 

157 raise ValueError( 

158 "`sampling_rate` should have an integer value. " 

159 f"Received: sampling_rate={sampling_rate}" 

160 ) 

161 

162 if sampling_rate <= 0: 

163 raise ValueError( 

164 "`sampling_rate` should be higher than 0. " 

165 f"Received: sampling_rate={sampling_rate}" 

166 ) 

167 

168 global tfio 

169 if tfio is None: 

170 try: 

171 import tensorflow_io as tfio 

172 except ImportError: 

173 raise ImportError( 

174 "To use the argument `sampling_rate`, you should install " 

175 "tensorflow_io. You can install it via `pip install " 

176 "tensorflow-io`." 

177 ) 

178 

179 if labels is None or label_mode is None: 

180 labels = None 

181 label_mode = None 

182 

183 dataset_utils.check_validation_split_arg( 

184 validation_split, subset, shuffle, seed 

185 ) 

186 

187 if seed is None: 

188 seed = np.random.randint(1e6) 

189 

190 file_paths, labels, class_names = dataset_utils.index_directory( 

191 directory, 

192 labels, 

193 formats=ALLOWED_FORMATS, 

194 class_names=class_names, 

195 shuffle=shuffle, 

196 seed=seed, 

197 follow_links=follow_links, 

198 ) 

199 

200 if label_mode == "binary" and len(class_names) != 2: 

201 raise ValueError( 

202 'When passing `label_mode="binary"`, there must be exactly 2 ' 

203 f"class_names. Received: class_names={class_names}" 

204 ) 

205 

206 if subset == "both": 

207 train_dataset, val_dataset = get_training_and_validation_dataset( 

208 file_paths=file_paths, 

209 labels=labels, 

210 validation_split=validation_split, 

211 directory=directory, 

212 label_mode=label_mode, 

213 class_names=class_names, 

214 sampling_rate=sampling_rate, 

215 output_sequence_length=output_sequence_length, 

216 ragged=ragged, 

217 ) 

218 

219 train_dataset = prepare_dataset( 

220 dataset=train_dataset, 

221 batch_size=batch_size, 

222 shuffle=shuffle, 

223 seed=seed, 

224 class_names=class_names, 

225 output_sequence_length=output_sequence_length, 

226 ragged=ragged, 

227 ) 

228 val_dataset = prepare_dataset( 

229 dataset=val_dataset, 

230 batch_size=batch_size, 

231 shuffle=False, 

232 seed=seed, 

233 class_names=class_names, 

234 output_sequence_length=output_sequence_length, 

235 ragged=ragged, 

236 ) 

237 return train_dataset, val_dataset 

238 

239 else: 

240 dataset = get_dataset( 

241 file_paths=file_paths, 

242 labels=labels, 

243 directory=directory, 

244 validation_split=validation_split, 

245 subset=subset, 

246 label_mode=label_mode, 

247 class_names=class_names, 

248 sampling_rate=sampling_rate, 

249 output_sequence_length=output_sequence_length, 

250 ragged=ragged, 

251 ) 

252 

253 dataset = prepare_dataset( 

254 dataset=dataset, 

255 batch_size=batch_size, 

256 shuffle=shuffle, 

257 seed=seed, 

258 class_names=class_names, 

259 output_sequence_length=output_sequence_length, 

260 ragged=ragged, 

261 ) 

262 return dataset 

263 

264 

265def prepare_dataset( 

266 dataset, 

267 batch_size, 

268 shuffle, 

269 seed, 

270 class_names, 

271 output_sequence_length, 

272 ragged, 

273): 

274 dataset = dataset.prefetch(tf.data.AUTOTUNE) 

275 if batch_size is not None: 

276 if shuffle: 

277 dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) 

278 

279 if output_sequence_length is None and not ragged: 

280 dataset = dataset.padded_batch( 

281 batch_size, padded_shapes=([None, None], []) 

282 ) 

283 else: 

284 dataset = dataset.batch(batch_size) 

285 else: 

286 if shuffle: 

287 dataset = dataset.shuffle(buffer_size=1024, seed=seed) 

288 

289 # Users may need to reference `class_names`. 

290 dataset.class_names = class_names 

291 return dataset 

292 

293 

294def get_training_and_validation_dataset( 

295 file_paths, 

296 labels, 

297 validation_split, 

298 directory, 

299 label_mode, 

300 class_names, 

301 sampling_rate, 

302 output_sequence_length, 

303 ragged, 

304): 

305 ( 

306 file_paths_train, 

307 labels_train, 

308 ) = dataset_utils.get_training_or_validation_split( 

309 file_paths, labels, validation_split, "training" 

310 ) 

311 if not file_paths_train: 

312 raise ValueError( 

313 f"No training audio files found in directory {directory}. " 

314 f"Allowed format(s): {ALLOWED_FORMATS}" 

315 ) 

316 

317 file_paths_val, labels_val = dataset_utils.get_training_or_validation_split( 

318 file_paths, labels, validation_split, "validation" 

319 ) 

320 if not file_paths_val: 

321 raise ValueError( 

322 f"No validation audio files found in directory {directory}. " 

323 f"Allowed format(s): {ALLOWED_FORMATS}" 

324 ) 

325 

326 train_dataset = paths_and_labels_to_dataset( 

327 file_paths=file_paths_train, 

328 labels=labels_train, 

329 label_mode=label_mode, 

330 num_classes=len(class_names), 

331 sampling_rate=sampling_rate, 

332 output_sequence_length=output_sequence_length, 

333 ragged=ragged, 

334 ) 

335 

336 val_dataset = paths_and_labels_to_dataset( 

337 file_paths=file_paths_val, 

338 labels=labels_val, 

339 label_mode=label_mode, 

340 num_classes=len(class_names), 

341 sampling_rate=sampling_rate, 

342 output_sequence_length=output_sequence_length, 

343 ragged=ragged, 

344 ) 

345 

346 return train_dataset, val_dataset 

347 

348 

349def get_dataset( 

350 file_paths, 

351 labels, 

352 directory, 

353 validation_split, 

354 subset, 

355 label_mode, 

356 class_names, 

357 sampling_rate, 

358 output_sequence_length, 

359 ragged, 

360): 

361 file_paths, labels = dataset_utils.get_training_or_validation_split( 

362 file_paths, labels, validation_split, subset 

363 ) 

364 if not file_paths: 

365 raise ValueError( 

366 f"No audio files found in directory {directory}. " 

367 f"Allowed format(s): {ALLOWED_FORMATS}" 

368 ) 

369 

370 dataset = paths_and_labels_to_dataset( 

371 file_paths=file_paths, 

372 labels=labels, 

373 label_mode=label_mode, 

374 num_classes=len(class_names), 

375 sampling_rate=sampling_rate, 

376 output_sequence_length=output_sequence_length, 

377 ragged=ragged, 

378 ) 

379 

380 return dataset 

381 

382 

383def read_and_decode_audio( 

384 path, sampling_rate=None, output_sequence_length=None 

385): 

386 """Reads and decodes audio file.""" 

387 audio = tf.io.read_file(path) 

388 

389 if output_sequence_length is None: 

390 output_sequence_length = -1 

391 

392 audio, default_audio_rate = tf.audio.decode_wav( 

393 contents=audio, desired_samples=output_sequence_length 

394 ) 

395 if sampling_rate is not None: 

396 # default_audio_rate should have dtype=int64 

397 default_audio_rate = tf.cast(default_audio_rate, tf.int64) 

398 audio = tfio.audio.resample( 

399 input=audio, rate_in=default_audio_rate, rate_out=sampling_rate 

400 ) 

401 return audio 

402 

403 

404def paths_and_labels_to_dataset( 

405 file_paths, 

406 labels, 

407 label_mode, 

408 num_classes, 

409 sampling_rate, 

410 output_sequence_length, 

411 ragged, 

412): 

413 """Constructs a fixed-size dataset of audio and labels.""" 

414 path_ds = tf.data.Dataset.from_tensor_slices(file_paths) 

415 audio_ds = path_ds.map( 

416 lambda x: read_and_decode_audio( 

417 x, sampling_rate, output_sequence_length 

418 ), 

419 num_parallel_calls=tf.data.AUTOTUNE, 

420 ) 

421 

422 if ragged: 

423 audio_ds = audio_ds.map( 

424 lambda x: tf.RaggedTensor.from_tensor(x), 

425 num_parallel_calls=tf.data.AUTOTUNE, 

426 ) 

427 

428 if label_mode: 

429 label_ds = dataset_utils.labels_to_dataset( 

430 labels, label_mode, num_classes 

431 ) 

432 audio_ds = tf.data.Dataset.zip((audio_ds, label_ds)) 

433 return audio_ds 

434