Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/audio_dataset.py: 15%
91 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras audio dataset loading utilities."""
17import numpy as np
18import tensorflow.compat.v2 as tf
20from keras.src.utils import dataset_utils
22# isort: off
23from tensorflow.python.util.tf_export import keras_export
26tfio = None # Import as-needed.
28ALLOWED_FORMATS = (".wav",)
31@keras_export("keras.utils.audio_dataset_from_directory", v1=[])
32def audio_dataset_from_directory(
33 directory,
34 labels="inferred",
35 label_mode="int",
36 class_names=None,
37 batch_size=32,
38 sampling_rate=None,
39 output_sequence_length=None,
40 ragged=False,
41 shuffle=True,
42 seed=None,
43 validation_split=None,
44 subset=None,
45 follow_links=False,
46):
47 """Generates a `tf.data.Dataset` from audio files in a directory.
49 If your directory structure is:
51 ```
52 main_directory/
53 ...class_a/
54 ......a_audio_1.wav
55 ......a_audio_2.wav
56 ...class_b/
57 ......b_audio_1.wav
58 ......b_audio_2.wav
59 ```
61 Then calling `audio_dataset_from_directory(main_directory,
62 labels='inferred')`
63 will return a `tf.data.Dataset` that yields batches of audio files from
64 the subdirectories `class_a` and `class_b`, together with labels
65 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
67 Only `.wav` files are supported at this time.
69 Args:
70 directory: Directory where the data is located. If `labels` is "inferred",
71 it should contain subdirectories, each containing audio files for a
72 class. Otherwise, the directory structure is ignored.
73 labels: Either "inferred" (labels are generated from the directory
74 structure), None (no labels), or a list/tuple of integer labels of the
75 same size as the number of audio files found in the directory. Labels
76 should be sorted according to the alphanumeric order of the audio file
77 paths (obtained via `os.walk(directory)` in Python).
78 label_mode: String describing the encoding of `labels`. Options are:
79 - 'int': means that the labels are encoded as integers (e.g. for
80 `sparse_categorical_crossentropy` loss). - 'categorical' means that
81 the labels are encoded as a categorical vector (e.g. for
82 `categorical_crossentropy` loss). - 'binary' means that the labels
83 (there can be only 2) are encoded as `float32` scalars with values 0
84 or 1 (e.g. for `binary_crossentropy`). - None (no labels).
85 class_names: Only valid if "labels" is "inferred". This is the explicit
86 list of class names (must match names of subdirectories). Used to
87 control the order of the classes (otherwise alphanumerical order is
88 used).
89 batch_size: Size of the batches of data. Default: 32. If `None`, the data
90 will not be batched (the dataset will yield individual samples).
91 sampling_rate: Audio sampling rate (in samples per second).
92 output_sequence_length: Maximum length of an audio sequence. Audio files
93 longer than this will be truncated to `output_sequence_length`. If set
94 to `None`, then all sequences in the same batch will be padded to the
95 length of the longest sequence in the batch.
96 ragged: Whether to return a Ragged dataset (where each sequence has its
97 own length). Default: False.
98 shuffle: Whether to shuffle the data. Default: True. If set to False,
99 sorts the data in alphanumeric order.
100 seed: Optional random seed for shuffling and transformations.
101 validation_split: Optional float between 0 and 1, fraction of data to
102 reserve for validation.
103 subset: Subset of the data to return. One of "training", "validation" or
104 "both". Only used if `validation_split` is set.
105 follow_links: Whether to visits subdirectories pointed to by symlinks.
106 Defaults to False.
108 Returns:
109 A `tf.data.Dataset` object.
110 - If `label_mode` is None, it yields `string` tensors of shape
111 `(batch_size,)`, containing the contents of a batch of audio files.
112 - Otherwise, it yields a tuple `(audio, labels)`, where `audio`
113 has shape `(batch_size, sequence_length, num_channels)` and `labels`
114 follows the format described
115 below.
117 Rules regarding labels format:
118 - if `label_mode` is `int`, the labels are an `int32` tensor of shape
119 `(batch_size,)`.
120 - if `label_mode` is `binary`, the labels are a `float32` tensor of
121 1s and 0s of shape `(batch_size, 1)`.
122 - if `label_mode` is `categorical`, the labels are a `float32` tensor
123 of shape `(batch_size, num_classes)`, representing a one-hot
124 encoding of the class index.
125 """
126 if labels not in ("inferred", None):
127 if not isinstance(labels, (list, tuple)):
128 raise ValueError(
129 "The `labels` argument should be a list/tuple of integer "
130 "labels, of the same size as the number of audio files in "
131 "the target directory. If you wish to infer the labels from "
132 "the subdirectory names in the target directory,"
133 ' pass `labels="inferred"`. '
134 "If you wish to get a dataset that only contains audio samples "
135 f"(no labels), pass `labels=None`. Received: labels={labels}"
136 )
137 if class_names:
138 raise ValueError(
139 "You can only pass `class_names` if "
140 f'`labels="inferred"`. Received: labels={labels}, and '
141 f"class_names={class_names}"
142 )
143 if label_mode not in {"int", "categorical", "binary", None}:
144 raise ValueError(
145 '`label_mode` argument must be one of "int", "categorical", '
146 '"binary", '
147 f"or None. Received: label_mode={label_mode}"
148 )
150 if ragged and output_sequence_length is not None:
151 raise ValueError(
152 "Cannot set both `ragged` and `output_sequence_length`"
153 )
155 if sampling_rate is not None:
156 if not isinstance(sampling_rate, int):
157 raise ValueError(
158 "`sampling_rate` should have an integer value. "
159 f"Received: sampling_rate={sampling_rate}"
160 )
162 if sampling_rate <= 0:
163 raise ValueError(
164 "`sampling_rate` should be higher than 0. "
165 f"Received: sampling_rate={sampling_rate}"
166 )
168 global tfio
169 if tfio is None:
170 try:
171 import tensorflow_io as tfio
172 except ImportError:
173 raise ImportError(
174 "To use the argument `sampling_rate`, you should install "
175 "tensorflow_io. You can install it via `pip install "
176 "tensorflow-io`."
177 )
179 if labels is None or label_mode is None:
180 labels = None
181 label_mode = None
183 dataset_utils.check_validation_split_arg(
184 validation_split, subset, shuffle, seed
185 )
187 if seed is None:
188 seed = np.random.randint(1e6)
190 file_paths, labels, class_names = dataset_utils.index_directory(
191 directory,
192 labels,
193 formats=ALLOWED_FORMATS,
194 class_names=class_names,
195 shuffle=shuffle,
196 seed=seed,
197 follow_links=follow_links,
198 )
200 if label_mode == "binary" and len(class_names) != 2:
201 raise ValueError(
202 'When passing `label_mode="binary"`, there must be exactly 2 '
203 f"class_names. Received: class_names={class_names}"
204 )
206 if subset == "both":
207 train_dataset, val_dataset = get_training_and_validation_dataset(
208 file_paths=file_paths,
209 labels=labels,
210 validation_split=validation_split,
211 directory=directory,
212 label_mode=label_mode,
213 class_names=class_names,
214 sampling_rate=sampling_rate,
215 output_sequence_length=output_sequence_length,
216 ragged=ragged,
217 )
219 train_dataset = prepare_dataset(
220 dataset=train_dataset,
221 batch_size=batch_size,
222 shuffle=shuffle,
223 seed=seed,
224 class_names=class_names,
225 output_sequence_length=output_sequence_length,
226 ragged=ragged,
227 )
228 val_dataset = prepare_dataset(
229 dataset=val_dataset,
230 batch_size=batch_size,
231 shuffle=False,
232 seed=seed,
233 class_names=class_names,
234 output_sequence_length=output_sequence_length,
235 ragged=ragged,
236 )
237 return train_dataset, val_dataset
239 else:
240 dataset = get_dataset(
241 file_paths=file_paths,
242 labels=labels,
243 directory=directory,
244 validation_split=validation_split,
245 subset=subset,
246 label_mode=label_mode,
247 class_names=class_names,
248 sampling_rate=sampling_rate,
249 output_sequence_length=output_sequence_length,
250 ragged=ragged,
251 )
253 dataset = prepare_dataset(
254 dataset=dataset,
255 batch_size=batch_size,
256 shuffle=shuffle,
257 seed=seed,
258 class_names=class_names,
259 output_sequence_length=output_sequence_length,
260 ragged=ragged,
261 )
262 return dataset
265def prepare_dataset(
266 dataset,
267 batch_size,
268 shuffle,
269 seed,
270 class_names,
271 output_sequence_length,
272 ragged,
273):
274 dataset = dataset.prefetch(tf.data.AUTOTUNE)
275 if batch_size is not None:
276 if shuffle:
277 dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
279 if output_sequence_length is None and not ragged:
280 dataset = dataset.padded_batch(
281 batch_size, padded_shapes=([None, None], [])
282 )
283 else:
284 dataset = dataset.batch(batch_size)
285 else:
286 if shuffle:
287 dataset = dataset.shuffle(buffer_size=1024, seed=seed)
289 # Users may need to reference `class_names`.
290 dataset.class_names = class_names
291 return dataset
294def get_training_and_validation_dataset(
295 file_paths,
296 labels,
297 validation_split,
298 directory,
299 label_mode,
300 class_names,
301 sampling_rate,
302 output_sequence_length,
303 ragged,
304):
305 (
306 file_paths_train,
307 labels_train,
308 ) = dataset_utils.get_training_or_validation_split(
309 file_paths, labels, validation_split, "training"
310 )
311 if not file_paths_train:
312 raise ValueError(
313 f"No training audio files found in directory {directory}. "
314 f"Allowed format(s): {ALLOWED_FORMATS}"
315 )
317 file_paths_val, labels_val = dataset_utils.get_training_or_validation_split(
318 file_paths, labels, validation_split, "validation"
319 )
320 if not file_paths_val:
321 raise ValueError(
322 f"No validation audio files found in directory {directory}. "
323 f"Allowed format(s): {ALLOWED_FORMATS}"
324 )
326 train_dataset = paths_and_labels_to_dataset(
327 file_paths=file_paths_train,
328 labels=labels_train,
329 label_mode=label_mode,
330 num_classes=len(class_names),
331 sampling_rate=sampling_rate,
332 output_sequence_length=output_sequence_length,
333 ragged=ragged,
334 )
336 val_dataset = paths_and_labels_to_dataset(
337 file_paths=file_paths_val,
338 labels=labels_val,
339 label_mode=label_mode,
340 num_classes=len(class_names),
341 sampling_rate=sampling_rate,
342 output_sequence_length=output_sequence_length,
343 ragged=ragged,
344 )
346 return train_dataset, val_dataset
349def get_dataset(
350 file_paths,
351 labels,
352 directory,
353 validation_split,
354 subset,
355 label_mode,
356 class_names,
357 sampling_rate,
358 output_sequence_length,
359 ragged,
360):
361 file_paths, labels = dataset_utils.get_training_or_validation_split(
362 file_paths, labels, validation_split, subset
363 )
364 if not file_paths:
365 raise ValueError(
366 f"No audio files found in directory {directory}. "
367 f"Allowed format(s): {ALLOWED_FORMATS}"
368 )
370 dataset = paths_and_labels_to_dataset(
371 file_paths=file_paths,
372 labels=labels,
373 label_mode=label_mode,
374 num_classes=len(class_names),
375 sampling_rate=sampling_rate,
376 output_sequence_length=output_sequence_length,
377 ragged=ragged,
378 )
380 return dataset
383def read_and_decode_audio(
384 path, sampling_rate=None, output_sequence_length=None
385):
386 """Reads and decodes audio file."""
387 audio = tf.io.read_file(path)
389 if output_sequence_length is None:
390 output_sequence_length = -1
392 audio, default_audio_rate = tf.audio.decode_wav(
393 contents=audio, desired_samples=output_sequence_length
394 )
395 if sampling_rate is not None:
396 # default_audio_rate should have dtype=int64
397 default_audio_rate = tf.cast(default_audio_rate, tf.int64)
398 audio = tfio.audio.resample(
399 input=audio, rate_in=default_audio_rate, rate_out=sampling_rate
400 )
401 return audio
404def paths_and_labels_to_dataset(
405 file_paths,
406 labels,
407 label_mode,
408 num_classes,
409 sampling_rate,
410 output_sequence_length,
411 ragged,
412):
413 """Constructs a fixed-size dataset of audio and labels."""
414 path_ds = tf.data.Dataset.from_tensor_slices(file_paths)
415 audio_ds = path_ds.map(
416 lambda x: read_and_decode_audio(
417 x, sampling_rate, output_sequence_length
418 ),
419 num_parallel_calls=tf.data.AUTOTUNE,
420 )
422 if ragged:
423 audio_ds = audio_ds.map(
424 lambda x: tf.RaggedTensor.from_tensor(x),
425 num_parallel_calls=tf.data.AUTOTUNE,
426 )
428 if label_mode:
429 label_ds = dataset_utils.labels_to_dataset(
430 labels, label_mode, num_classes
431 )
432 audio_ds = tf.data.Dataset.zip((audio_ds, label_ds))
433 return audio_ds