{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "yCs7P9JTMlzV" }, "source": [ "##### Copyright 2021 The TensorFlow Hub Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jqn-HYw-Mkea" }, "outputs": [], "source": [ "#@title Copyright 2021 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "stRetE8gMlmZ" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub model\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "ndG8MjmJeicp" }, "source": [ "# Fine-tuning Wav2Vec2 with an LM head\n", "\n", "In this notebook, we will load the pre-trained wav2vec2 model from [TFHub](https://tfhub.dev) and will fine-tune it on [LibriSpeech dataset](https://huggingface.co/datasets/librispeech_asr) by appending Language Modeling head (LM) over the top of our pre-trained model. The underlying task is to build a model for **Automatic Speech Recognition** i.e. given some speech, the model should be able to transcribe it into text." ] }, { "cell_type": "markdown", "metadata": { "id": "rWk8nL6Ui-_0" }, "source": [ "## Setting Up\n", "\n", "Before running this notebook, please ensure that you are on GPU runtime (`Runtime` > `Change runtime type` > `GPU`). The following cell will install [`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2) package & its dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "seqTlMyeZvM4" }, "outputs": [], "source": [ "!pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main\n", "!sudo apt-get install -y libsndfile1-dev\n", "!pip3 install -q SoundFile" ] }, { "cell_type": "markdown", "metadata": { "id": "wvuJL8-f0zn5" }, "source": [ "## Model setup using `TFHub`\n", "\n", "We will start by importing some libraries/modules." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M3_fgx4eZvM7" }, "outputs": [], "source": [ "import os\n", "\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "from wav2vec2 import Wav2Vec2Config\n", "\n", "config = Wav2Vec2Config()\n", "\n", "print(\"TF version:\", tf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "y0rVUxyWsS5f" }, "source": [ "First, we will download our model from TFHub & will wrap our model signature with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer) to be able to use this model like any other Keras layer. Fortunately, `hub.KerasLayer` can do both in just 1 line.\n", "\n", "**Note:** When loading model with `hub.KerasLayer`, model becomes a bit opaque but sometimes we need finer controls over the model, then we can load the model with `tf.keras.models.load_model(...)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NO6QRC7KZvM9" }, "outputs": [], "source": [ "pretrained_layer = hub.KerasLayer(\"https://tfhub.dev/vasudevgupta7/wav2vec2/1\", trainable=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "pCputyVBv2e9" }, "source": [ "You can refer to this [script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/export2hub.py) in case you are interested in the model exporting script. Object `pretrained_layer` is the freezed version of [`Wav2Vec2Model`](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/wav2vec2/modeling.py). These pre-trained weights were converted from HuggingFace PyTorch [pre-trained weights](https://huggingface.co/facebook/wav2vec2-base) using [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/convert_torch_to_tf.py).\n", "\n", "Originally, wav2vec2 was pre-trained with a masked language modelling approach with the objective to identify the true quantized latent speech representation for a masked time step. You can read more about the training objective in the paper- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477)." ] }, { "cell_type": "markdown", "metadata": { "id": "SseDnCr7hyhC" }, "source": [ "Now, we will be defining a few constants and hyper-parameters which will be useful in the next few cells. `AUDIO_MAXLEN` is intentionally set to `246000` as the model signature only accepts static sequence length of `246000`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eiILuMBERxlO" }, "outputs": [], "source": [ "AUDIO_MAXLEN = 246000\n", "LABEL_MAXLEN = 256\n", "BATCH_SIZE = 2" ] }, { "cell_type": "markdown", "metadata": { "id": "1V4gTgGLgXvO" }, "source": [ "In the following cell, we will wrap `pretrained_layer` & a dense layer (LM head) with the [Keras's Functional API](https://www.tensorflow.org/guide/keras/functional)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a3CUN1KEB10Q" }, "outputs": [], "source": [ "inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))\n", "hidden_states = pretrained_layer(inputs)\n", "outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)\n", "\n", "model = tf.keras.Model(inputs=inputs, outputs=outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "5zDXuoMXhDMo" }, "source": [ "The dense layer (defined above) is having an output dimension of `vocab_size` as we want to predict probabilities of each token in the vocabulary at each time step." ] }, { "cell_type": "markdown", "metadata": { "id": "oPp18ZHRtnq-" }, "source": [ "## Setting up training state" ] }, { "cell_type": "markdown", "metadata": { "id": "ATQy1ZK3vFr7" }, "source": [ "In TensorFlow, model weights are built only when `model.call` or `model.build` is called for the first time, so the following cell will build the model weights for us. Further, we will be running `model.summary()` to check the total number of trainable parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZgL5wyaXZvM-" }, "outputs": [], "source": [ "model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "EQxxA4Fevp7m" }, "source": [ "Now, we need to define the `loss_fn` and optimizer to be able to train the model. The following cell will do that for us. We will be using the `Adam` optimizer for simplicity. `CTCLoss` is a common loss type that is used for tasks (like `ASR`) where input sub-parts can't be easily aligned with output sub-parts. You can read more about CTC-loss from this amazing [blog post](https://distill.pub/2017/ctc/).\n", "\n", "\n", "`CTCLoss` (from [`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2) package) accepts 3 arguments: `config`, `model_input_shape` & `division_factor`. If `division_factor=1`, then loss will simply get summed, so pass `division_factor` accordingly to get mean over batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "glDepVEHZvM_" }, "outputs": [], "source": [ "from wav2vec2 import CTCLoss\n", "\n", "LEARNING_RATE = 5e-5\n", "\n", "loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)\n", "optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)" ] }, { "cell_type": "markdown", "metadata": { "id": "1mvTuOXpwsQe" }, "source": [ "## Loading & Pre-processing data\n", "\n", "Let's now download the LibriSpeech dataset from the [official website](http://www.openslr.org/12) and set it up." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4kIEC77cBCM" }, "outputs": [], "source": [ "!wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/\n", "!tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/" ] }, { "cell_type": "markdown", "metadata": { "id": "LsQpmpn6jrMI" }, "source": [ "**Note:** We are using `dev-clean` configuration as this notebook is just for demonstration purposes, so we need a small amount of data. Complete training data can be easily downloaded from [LibriSpeech website](http://www.openslr.org/12)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ynxAjtGHGFpM" }, "outputs": [], "source": [ "ls ./data/train/" ] }, { "cell_type": "markdown", "metadata": { "id": "yBMiORo0xJD0" }, "source": [ "Our dataset lies in the LibriSpeech directory. Let's explore these files." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jkIu_Wt4ZvNA" }, "outputs": [], "source": [ "data_dir = \"./data/train/LibriSpeech/dev-clean/2428/83705/\"\n", "all_files = os.listdir(data_dir)\n", "\n", "flac_files = [f for f in all_files if f.endswith(\".flac\")]\n", "txt_files = [f for f in all_files if f.endswith(\".txt\")]\n", "\n", "print(\"Transcription files:\", txt_files, \"\\nSound files:\", flac_files)" ] }, { "cell_type": "markdown", "metadata": { "id": "XEObi_Apk3ZD" }, "source": [ "Alright, so each sub-directory has many `.flac` files and a `.txt` file. The `.txt` file contains text transcriptions for all the speech samples (i.e. `.flac` files) present in that sub-directory." ] }, { "cell_type": "markdown", "metadata": { "id": "WYW6WKJflO2e" }, "source": [ "We can load this text data as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cEBKxQblHPwq" }, "outputs": [], "source": [ "def read_txt_file(f):\n", " with open(f, \"r\") as f:\n", " samples = f.read().split(\"\\n\")\n", " samples = {s.split()[0]: \" \".join(s.split()[1:]) for s in samples if len(s.split()) > 2}\n", " return samples" ] }, { "cell_type": "markdown", "metadata": { "id": "Ldkf_ceb0_YW" }, "source": [ "Similarly, we will define a function for loading a speech sample from a `.flac` file.\n", "\n", "`REQUIRED_SAMPLE_RATE` is set to `16000` as wav2vec2 was pre-trained with `16K` frequency and it's recommended to fine-tune it without any major change in data distribution due to frequency." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YOJ3OzPsTyXv" }, "outputs": [], "source": [ "import soundfile as sf\n", "\n", "REQUIRED_SAMPLE_RATE = 16000\n", "\n", "def read_flac_file(file_path):\n", " with open(file_path, \"rb\") as f:\n", " audio, sample_rate = sf.read(f)\n", " if sample_rate != REQUIRED_SAMPLE_RATE:\n", " raise ValueError(\n", " f\"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}\"\n", " )\n", " file_id = os.path.split(file_path)[-1][:-len(\".flac\")]\n", " return {file_id: audio}" ] }, { "cell_type": "markdown", "metadata": { "id": "2sxDN8P4nWkW" }, "source": [ "Now, we will pick some random samples & will try to visualize them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HI5J-2Dfm_wT" }, "outputs": [], "source": [ "from IPython.display import Audio\n", "import random\n", "\n", "file_id = random.choice([f[:-len(\".flac\")] for f in flac_files])\n", "flac_file_path, txt_file_path = os.path.join(data_dir, f\"{file_id}.flac\"), os.path.join(data_dir, \"2428-83705.trans.txt\")\n", "\n", "print(\"Text Transcription:\", read_txt_file(txt_file_path)[file_id], \"\\nAudio:\")\n", "Audio(filename=flac_file_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "M8jJ7Ed81p_A" }, "source": [ "Now, we will combine all the speech & text samples and will define the function (in next cell) for that purpose." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MI-5YCzaTsei" }, "outputs": [], "source": [ "def fetch_sound_text_mapping(data_dir):\n", " all_files = os.listdir(data_dir)\n", "\n", " flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(\".flac\")]\n", " txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(\".txt\")]\n", "\n", " txt_samples = {}\n", " for f in txt_files:\n", " txt_samples.update(read_txt_file(f))\n", "\n", " speech_samples = {}\n", " for f in flac_files:\n", " speech_samples.update(read_flac_file(f))\n", "\n", " assert len(txt_samples) == len(speech_samples)\n", "\n", " samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]\n", " return samples" ] }, { "cell_type": "markdown", "metadata": { "id": "mx95Lxvu0nT4" }, "source": [ "It's time to have a look at a few samples ..." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Ls7X_jqIz4R" }, "outputs": [], "source": [ "samples = fetch_sound_text_mapping(data_dir)\n", "samples[:5]" ] }, { "cell_type": "markdown", "metadata": { "id": "TUjhSWfsnlCL" }, "source": [ "Note: We are loading this data into memory as we working with a small amount of dataset in this notebook. But for training on the complete dataset (~300 GBs), you will have to load data lazily. You can refer to [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/data_utils.py) to know more on that." ] }, { "cell_type": "markdown", "metadata": { "id": "xg8Zia1kzw0J" }, "source": [ "Let's pre-process the data now !!!\n", "\n", "We will first define the tokenizer & processor using `gsoc-wav2vec2` package. Then, we will do very simple pre-processing. `processor` will normalize raw speech w.r.to frames axis and `tokenizer` will convert our model outputs into the string (using the defined vocabulary) & will take care of the removal of special tokens (depending on your tokenizer configuration)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gaat_hMLNVHF" }, "outputs": [], "source": [ "from wav2vec2 import Wav2Vec2Processor\n", "tokenizer = Wav2Vec2Processor(is_tokenizer=True)\n", "processor = Wav2Vec2Processor(is_tokenizer=False)\n", "\n", "def preprocess_text(text):\n", " label = tokenizer(text)\n", " return tf.constant(label, dtype=tf.int32)\n", "\n", "def preprocess_speech(audio):\n", " audio = tf.constant(audio, dtype=tf.float32)\n", " return processor(tf.transpose(audio))" ] }, { "cell_type": "markdown", "metadata": { "id": "GyKl8QP-zRFC" }, "source": [ "Now, we will define the python generator to call the preprocessing functions we defined in above cells." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PoQrRalwMpQ6" }, "outputs": [], "source": [ "def inputs_generator():\n", " for speech, text in samples:\n", " yield preprocess_speech(speech), preprocess_text(text)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Vlm3ySFULsG" }, "source": [ "## Setting up `tf.data.Dataset`\n", "\n", "Following cell will setup `tf.data.Dataset` object using its `.from_generator(...)` method. We will be using the `generator` object, we defined in the above cell.\n", "\n", "**Note:** For distributed training (especially on TPUs), `.from_generator(...)` doesn't work currently and it is recommended to train on data stored in `.tfrecord` format (Note: The TFRecords should ideally be stored inside a GCS Bucket in order for the TPUs to work to the fullest extent).\n", "\n", "You can refer to [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/make_tfrecords.py) for more details on how to convert LibriSpeech data into tfrecords." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LbQ_dMwGO62h" }, "outputs": [], "source": [ "output_signature = (\n", " tf.TensorSpec(shape=(None), dtype=tf.float32),\n", " tf.TensorSpec(shape=(None), dtype=tf.int32),\n", ")\n", "\n", "dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HXBbNsRyPyw3" }, "outputs": [], "source": [ "BUFFER_SIZE = len(flac_files)\n", "SEED = 42\n", "\n", "dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)" ] }, { "cell_type": "markdown", "metadata": { "id": "9DAUmns3pXfr" }, "source": [ "We will pass the dataset into multiple batches, so let's prepare batches in the following cell. Now, all the sequences in a batch should be padded to a constant length. We will use the`.padded_batch(...)` method for that purpose." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Okhko1IWRida" }, "outputs": [], "source": [ "dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))" ] }, { "cell_type": "markdown", "metadata": { "id": "A45CjQG5qSbV" }, "source": [ "Accelerators (like GPUs/TPUs) are very fast and often data-loading (& pre-processing) becomes the bottleneck during training as the data-loading part happens on CPUs. This can increase the training time significantly especially when there is a lot of online pre-processing involved or data is streamed online from GCS buckets. To handle those issues, `tf.data.Dataset` offers the `.prefetch(...)` method. This method helps in preparing the next few batches in parallel (on CPUs) while the model is making predictions (on GPUs/TPUs) on the current batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f-bKu2YjRior" }, "outputs": [], "source": [ "dataset = dataset.prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "Lqk2cs6LxVIh" }, "source": [ "Since this notebook is made for demonstration purposes, we will be taking first `num_train_batches` and will perform training over only that. You are encouraged to train on the whole dataset though. Similarly, we will evaluate only `num_val_batches`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z6GO5oYUxXtz" }, "outputs": [], "source": [ "num_train_batches = 10\n", "num_val_batches = 4\n", "\n", "train_dataset = dataset.take(num_train_batches)\n", "val_dataset = dataset.skip(num_train_batches).take(num_val_batches)" ] }, { "cell_type": "markdown", "metadata": { "id": "CzAOI78tky08" }, "source": [ "## Model training\n", "\n", "For training our model, we will be directly calling `.fit(...)` method after compiling our model with `.compile(...)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vuBY2sZElgwg" }, "outputs": [], "source": [ "model.compile(optimizer, loss=loss_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "qswxafSl0HjO" }, "source": [ "The above cell will set up our training state. Now we can initiate training with the `.fit(...)` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vtuSfnj1l-I_" }, "outputs": [], "source": [ "history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)\n", "history.history" ] }, { "cell_type": "markdown", "metadata": { "id": "ySvp8r2E1q_V" }, "source": [ "Let's save our model with `.save(...)` method to be able to perform inference later. You can also export this SavedModel to TFHub by following [TFHub documentation](https://www.tensorflow.org/hub/publish)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C0KEYcwydwjF" }, "outputs": [], "source": [ "save_dir = \"finetuned-wav2vec2\"\n", "model.save(save_dir, include_optimizer=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "MkOpp9rZ211t" }, "source": [ "Note: We are setting `include_optimizer=False` as we want to use this model for inference only." ] }, { "cell_type": "markdown", "metadata": { "id": "SJfPlTgezD0i" }, "source": [ "## Evaluation\n", "\n", "Now we will be computing Word Error Rate over the validation dataset\n", "\n", "**Word error rate** (WER) is a common metric for measuring the performance of an automatic speech recognition system. The WER is derived from the Levenshtein distance, working at the word level. Word error rate can then be computed as: WER = (S + D + I) / N = (S + D + I) / (S + D + C) where S is the number of substitutions, D is the number of deletions, I is the number of insertions, C is the number of correct words, N is the number of words in the reference (N=S+D+C). This value indicates the percentage of words that were incorrectly predicted. \n", "\n", "You can refer to [this paper](https://www.isca-speech.org/archive_v0/interspeech_2004/i04_2765.html) to learn more about WER." ] }, { "cell_type": "markdown", "metadata": { "id": "Io_91Y7-r3xu" }, "source": [ "We will use `load_metric(...)` function from [HuggingFace datasets](https://huggingface.co/docs/datasets/) library. Let's first install the `datasets` library using `pip` and then define the `metric` object." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GW9F_oVDU1TZ" }, "outputs": [], "source": [ "!pip3 install -q datasets\n", "\n", "from datasets import load_metric\n", "metric = load_metric(\"wer\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ssWXWc7CZvNB" }, "outputs": [], "source": [ "@tf.function(jit_compile=True)\n", "def eval_fwd(batch):\n", " logits = model(batch, training=False)\n", " return tf.argmax(logits, axis=-1)" ] }, { "cell_type": "markdown", "metadata": { "id": "NFh1myg1x4ua" }, "source": [ "It's time to run the evaluation on validation data now." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EQTFVjZghckJ" }, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "for speech, labels in tqdm(val_dataset, total=num_val_batches):\n", " predictions = eval_fwd(speech)\n", " predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]\n", " references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]\n", " metric.add_batch(references=references, predictions=predictions)" ] }, { "cell_type": "markdown", "metadata": { "id": "WWCc8qBesv3e" }, "source": [ "We are using the `tokenizer.decode(...)` method for decoding our predictions and labels back into the text and will add them to the metric for `WER` computation later." ] }, { "cell_type": "markdown", "metadata": { "id": "XI_URj8Wtb2g" }, "source": [ "Now, let's calculate the metric value in following cell:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a83wekLgWMod" }, "outputs": [], "source": [ "metric.compute()" ] }, { "cell_type": "markdown", "metadata": { "id": "c_cD1OgVEjl4" }, "source": [ "**Note:** Here metric value doesn't make any sense as the model is trained on very small data and ASR-like tasks often require a large amount of data to learn a mapping from speech to text. You should probably train on large data to get some good results. This notebook gives you a template to fine-tune a pre-trained speech model." ] }, { "cell_type": "markdown", "metadata": { "id": "G14o706kdTE1" }, "source": [ "## Inference\n", "\n", "Now that we are satisfied with the training process & have saved the model in `save_dir`, we will see how this model can be used for inference.\n", "\n", "First, we will load our model using `tf.keras.models.load_model(...)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wrTrExiUdaED" }, "outputs": [], "source": [ "finetuned_model = tf.keras.models.load_model(save_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "luodSroz20SR" }, "source": [ "Let's download some speech samples for performing inference. You can replace the following sample with your speech sample also." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HUE0shded6Ej" }, "outputs": [], "source": [ "!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav" ] }, { "cell_type": "markdown", "metadata": { "id": "ycBjU_U53FjL" }, "source": [ "Now, we will read the speech sample using `soundfile.read(...)` and pad it to `AUDIO_MAXLEN` to satisfy the model signature. Then we will normalize that speech sample using the `Wav2Vec2Processor` instance & will feed it into the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z7CARje4d5_H" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "speech, _ = sf.read(\"SA2.wav\")\n", "speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))\n", "speech = tf.expand_dims(processor(tf.constant(speech)), 0)\n", "\n", "outputs = finetuned_model(speech)\n", "outputs" ] }, { "cell_type": "markdown", "metadata": { "id": "lUSttSPa30qP" }, "source": [ "Let's decode numbers back into text sequence using the `Wav2Vec2tokenizer` instance, we defined above." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RYdJqxQ4llgI" }, "outputs": [], "source": [ "predictions = tf.argmax(outputs, axis=-1)\n", "predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]\n", "predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "7DXC757bztJc" }, "source": [ "This prediction is quite random as the model was never trained on large data in this notebook (as this notebook is not meant for doing complete training). You will get good predictions if you train this model on complete LibriSpeech dataset.\n", "\n", "Finally, we have reached an end to this notebook. But it's not an end of learning TensorFlow for speech-related tasks, this [repository](https://github.com/tulasiram58827/TTS_TFLite) contains some more amazing tutorials. In case you encountered any bug in this notebook, please create an issue [here](https://github.com/vasudevgupta7/gsoc-wav2vec2/issues)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "rWk8nL6Ui-_0", "wvuJL8-f0zn5", "oPp18ZHRtnq-", "1mvTuOXpwsQe", "7Vlm3ySFULsG", "CzAOI78tky08", "SJfPlTgezD0i", "G14o706kdTE1" ], "name": "wav2vec2_saved_model_finetuning.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }