{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0DH9bjZD_Cfi" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "JO1GUwC1_T2x" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "M4xOsFiu-1-c" }, "source": [ "# Generate music with an RNN" ] }, { "cell_type": "markdown", "metadata": { "id": "OyzAxV7Vu_9Y" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "hr78EkAY-FFg" }, "source": [ "This tutorial shows you how to generate musical notes using a simple recurrent neural network (RNN). You will train a model using a collection of piano MIDI files from the [MAESTRO dataset](https://magenta.tensorflow.org/datasets/maestro). Given a sequence of notes, your model will learn to predict the next note in the sequence. You can generate longer sequences of notes by calling the model repeatedly.\n", "\n", "This tutorial contains complete code to parse and create MIDI files. You can learn more about how RNNs work by visiting the [Text generation with an RNN](https://www.tensorflow.org/text/tutorials/text_generation) tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "4ZniYb7Y_0Ey" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "3ks8__E_WUGt" }, "source": [ "This tutorial uses the [`pretty_midi`](https://github.com/craffel/pretty-midi) library to create and parse MIDI files, and [`pyfluidsynth`](https://github.com/nwhitehead/pyfluidsynth) for generating audio playback in Colab." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kahm6Z8v_TqC" }, "outputs": [], "source": [ "!sudo apt install -y fluidsynth" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M0lAReB7_Vqb" }, "outputs": [], "source": [ "!pip install --upgrade pyfluidsynth" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G46kKoQZmIa8" }, "outputs": [], "source": [ "!pip install pretty_midi" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GsLFq7nsiqcq" }, "outputs": [], "source": [ "import collections\n", "import datetime\n", "import fluidsynth\n", "import glob\n", "import numpy as np\n", "import pathlib\n", "import pandas as pd\n", "import pretty_midi\n", "import seaborn as sns\n", "import tensorflow as tf\n", "\n", "from IPython import display\n", "from matplotlib import pyplot as plt\n", "from typing import Optional" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Efja_OtJNzAM" }, "outputs": [], "source": [ "seed = 42\n", "tf.random.set_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# Sampling rate for audio playback\n", "_SAMPLING_RATE = 16000" ] }, { "cell_type": "markdown", "metadata": { "id": "FzIbfb-Ikgg7" }, "source": [ "## Download the Maestro dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mwja4SWmibrL" }, "outputs": [], "source": [ "data_dir = pathlib.Path('data/maestro-v2.0.0')\n", "if not data_dir.exists():\n", " tf.keras.utils.get_file(\n", " 'maestro-v2.0.0-midi.zip',\n", " origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',\n", " extract=True,\n", " cache_dir='.', cache_subdir='data',\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "k7UYBSxcINqJ" }, "source": [ "The dataset contains about 1,200 MIDI files." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "72iFI1bPB9o1" }, "outputs": [], "source": [ "filenames = glob.glob(str(data_dir/'**/*.mid*'))\n", "print('Number of files:', len(filenames))" ] }, { "cell_type": "markdown", "metadata": { "id": "8BlRafYDIRgA" }, "source": [ "## Process a MIDI file" ] }, { "cell_type": "markdown", "metadata": { "id": "oFsmG87gXSbh" }, "source": [ "First, use ```pretty_midi``` to parse a single MIDI file and inspect the format of the notes. If you would like to download the MIDI file below to play on your computer, you can do so in colab by writing ```files.download(sample_file)```.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6oSCbHvJNbci" }, "outputs": [], "source": [ "sample_file = filenames[1]\n", "print(sample_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "A48VdGEpXnLp" }, "source": [ "Generate a `PrettyMIDI` object for the sample MIDI file." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1YSQ5DjRI2md" }, "outputs": [], "source": [ "pm = pretty_midi.PrettyMIDI(sample_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "FZNVsZuA_lef" }, "source": [ "Play the sample file. The playback widget may take several seconds to load." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vzoHAaVY_kyY" }, "outputs": [], "source": [ "def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):\n", " waveform = pm.fluidsynth(fs=_SAMPLING_RATE)\n", " # Take a sample of the generated waveform to mitigate kernel resets\n", " waveform_short = waveform[:seconds*_SAMPLING_RATE]\n", " return display.Audio(waveform_short, rate=_SAMPLING_RATE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GOe-3AAi_sRw" }, "outputs": [], "source": [ "display_audio(pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Lqe7nOsIyh1" }, "source": [ "Do some inspection on the MIDI file. What kinds of instruments are used?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SIGHYQPZQnRo" }, "outputs": [], "source": [ "print('Number of instruments:', len(pm.instruments))\n", "instrument = pm.instruments[0]\n", "instrument_name = pretty_midi.program_to_instrument_name(instrument.program)\n", "print('Instrument name:', instrument_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "KVQfV2hVKB28" }, "source": [ "## Extract notes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nYZm_VehYOTZ" }, "outputs": [], "source": [ "for i, note in enumerate(instrument.notes[:10]):\n", " note_name = pretty_midi.note_number_to_name(note.pitch)\n", " duration = note.end - note.start\n", " print(f'{i}: pitch={note.pitch}, note_name={note_name},'\n", " f' duration={duration:.4f}')" ] }, { "cell_type": "markdown", "metadata": { "id": "jutzynyqX_GC" }, "source": [ "You will use three variables to represent a note when training the model: `pitch`, `step` and `duration`. The pitch is the perceptual quality of the sound as a MIDI note number. \n", "The `step` is the time elapsed from the previous note or start of the track.\n", "The `duration` is how long the note will be playing in seconds and is the difference between the note end and note start times. \n" ] }, { "cell_type": "markdown", "metadata": { "id": "KGn7Juv_PTi6" }, "source": [ "Extract the notes from the sample MIDI file." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wyp_wdcEPWby" }, "outputs": [], "source": [ "def midi_to_notes(midi_file: str) -> pd.DataFrame:\n", " pm = pretty_midi.PrettyMIDI(midi_file)\n", " instrument = pm.instruments[0]\n", " notes = collections.defaultdict(list)\n", "\n", " # Sort the notes by start time\n", " sorted_notes = sorted(instrument.notes, key=lambda note: note.start)\n", " prev_start = sorted_notes[0].start\n", "\n", " for note in sorted_notes:\n", " start = note.start\n", " end = note.end\n", " notes['pitch'].append(note.pitch)\n", " notes['start'].append(start)\n", " notes['end'].append(end)\n", " notes['step'].append(start - prev_start)\n", " notes['duration'].append(end - start)\n", " prev_start = start\n", "\n", " return pd.DataFrame({name: np.array(value) for name, value in notes.items()})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X0kPjLBlcnY6" }, "outputs": [], "source": [ "raw_notes = midi_to_notes(sample_file)\n", "raw_notes.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "-71LPvjubOSO" }, "source": [ "It may be easier to interpret the note names rather than the pitches, so you can use the function below to convert from the numeric pitch values to note names. \n", "The note name shows the type of note, accidental and octave number\n", "(e.g. C#4). " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WE9YXrGZbY2X" }, "outputs": [], "source": [ "get_note_names = np.vectorize(pretty_midi.note_number_to_name)\n", "sample_note_names = get_note_names(raw_notes['pitch'])\n", "sample_note_names[:10]" ] }, { "cell_type": "markdown", "metadata": { "id": "Q7sjqbp1e_f-" }, "source": [ "To visualize the musical piece, plot the note pitch, start and end across the length of the track (i.e. piano roll). Start with the first 100 notes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "liD2N7x_WOTp" }, "outputs": [], "source": [ "def plot_piano_roll(notes: pd.DataFrame, count: Optional[int] = None):\n", " if count:\n", " title = f'First {count} notes'\n", " else:\n", " title = f'Whole track'\n", " count = len(notes['pitch'])\n", " plt.figure(figsize=(20, 4))\n", " plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)\n", " plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)\n", " plt.plot(\n", " plot_start_stop[:, :count], plot_pitch[:, :count], color=\"b\", marker=\".\")\n", " plt.xlabel('Time [s]')\n", " plt.ylabel('Pitch')\n", " _ = plt.title(title)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vWeUbqmAXjOs" }, "outputs": [], "source": [ "plot_piano_roll(raw_notes, count=100)" ] }, { "cell_type": "markdown", "metadata": { "id": "gcUyCXYhXeVA" }, "source": [ "Plot the notes for the entire track." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G7l76hEDZX8Z" }, "outputs": [], "source": [ "plot_piano_roll(raw_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "5GM1bi3aX8rd" }, "source": [ "Check the distribution of each note variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pq9C9XBBaK7W" }, "outputs": [], "source": [ "def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):\n", " plt.figure(figsize=[15, 5])\n", " plt.subplot(1, 3, 1)\n", " sns.histplot(notes, x=\"pitch\", bins=20)\n", "\n", " plt.subplot(1, 3, 2)\n", " max_step = np.percentile(notes['step'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"step\", bins=np.linspace(0, max_step, 21))\n", " \n", " plt.subplot(1, 3, 3)\n", " max_duration = np.percentile(notes['duration'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"duration\", bins=np.linspace(0, max_duration, 21))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-Nu2Pw24acFD" }, "outputs": [], "source": [ "plot_distributions(raw_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "poIivompcfS4" }, "source": [ "## Create a MIDI file\n", "\n", "You can generate your own MIDI file from a list of notes using the function below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BD5rsMRARYoV" }, "outputs": [], "source": [ "def notes_to_midi(\n", " notes: pd.DataFrame,\n", " out_file: str, \n", " instrument_name: str,\n", " velocity: int = 100, # note loudness\n", ") -> pretty_midi.PrettyMIDI:\n", "\n", " pm = pretty_midi.PrettyMIDI()\n", " instrument = pretty_midi.Instrument(\n", " program=pretty_midi.instrument_name_to_program(\n", " instrument_name))\n", "\n", " prev_start = 0\n", " for i, note in notes.iterrows():\n", " start = float(prev_start + note['step'])\n", " end = float(start + note['duration'])\n", " note = pretty_midi.Note(\n", " velocity=velocity,\n", " pitch=int(note['pitch']),\n", " start=start,\n", " end=end,\n", " )\n", " instrument.notes.append(note)\n", " prev_start = start\n", "\n", " pm.instruments.append(instrument)\n", " pm.write(out_file)\n", " return pm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTazLbuWPIPF" }, "outputs": [], "source": [ "example_file = 'example.midi'\n", "example_pm = notes_to_midi(\n", " raw_notes, out_file=example_file, instrument_name=instrument_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "XG0N9zZV_4Gp" }, "source": [ "Play the generated MIDI file and see if there is any difference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fGRLs-eR_4uK" }, "outputs": [], "source": [ "display_audio(example_pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "CLrUscjhBzYc" }, "source": [ "As before, you can write ```files.download(example_file)``` to download and play this file." ] }, { "cell_type": "markdown", "metadata": { "id": "pfRNk9tEScuf" }, "source": [ "## Create the training dataset\n" ] }, { "cell_type": "markdown", "metadata": { "id": "b77zHR1udDrK" }, "source": [ "Create the training dataset by extracting notes from the MIDI files. You can start by using a small number of files, and experiment later with more. This may take a couple minutes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GiaQiTnXSW-T" }, "outputs": [], "source": [ "num_files = 5\n", "all_notes = []\n", "for f in filenames[:num_files]:\n", " notes = midi_to_notes(f)\n", " all_notes.append(notes)\n", "\n", "all_notes = pd.concat(all_notes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F4bMDeRvgWqx" }, "outputs": [], "source": [ "n_notes = len(all_notes)\n", "print('Number of notes parsed:', n_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "xIBLvj-cODWS" }, "source": [ "Next, create a `tf.data.Dataset` from the parsed notes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mvNHCHZdXG2P" }, "outputs": [], "source": [ "key_order = ['pitch', 'step', 'duration']\n", "train_notes = np.stack([all_notes[key] for key in key_order], axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PLC_19tshyFk" }, "outputs": [], "source": [ "notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)\n", "notes_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "Sj9SXRCjt3I7" }, "source": [ "You will train the model on batches of sequences of notes. Each example will consist of a sequence of notes as the input features, and the next note as the label. In this way, the model will be trained to predict the next note in a sequence. You can find a diagram describing this process (and more details) in [Text classification with an RNN](https://www.tensorflow.org/text/tutorials/text_generation).\n", "\n", "You can use the handy [window](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#window) function with size `seq_length` to create the features and labels in this format." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZkEC-5s6wJJV" }, "outputs": [], "source": [ "def create_sequences(\n", " dataset: tf.data.Dataset, \n", " seq_length: int,\n", " vocab_size = 128,\n", ") -> tf.data.Dataset:\n", " \"\"\"Returns TF Dataset of sequence and label examples.\"\"\"\n", " seq_length = seq_length+1\n", "\n", " # Take 1 extra for the labels\n", " windows = dataset.window(seq_length, shift=1, stride=1,\n", " drop_remainder=True)\n", "\n", " # `flat_map` flattens the\" dataset of datasets\" into a dataset of tensors\n", " flatten = lambda x: x.batch(seq_length, drop_remainder=True)\n", " sequences = windows.flat_map(flatten)\n", " \n", " # Normalize note pitch\n", " def scale_pitch(x):\n", " x = x/[vocab_size,1.0,1.0]\n", " return x\n", "\n", " # Split the labels\n", " def split_labels(sequences):\n", " inputs = sequences[:-1]\n", " labels_dense = sequences[-1]\n", " labels = {key:labels_dense[i] for i,key in enumerate(key_order)}\n", "\n", " return scale_pitch(inputs), labels\n", "\n", " return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "2xDX5pVkegrv" }, "source": [ "Set the sequence length for each example. Experiment with different lengths (e.g. 50, 100, 150) to see which one works best for the data, or use [hyperparameter tuning](https://www.tensorflow.org/tutorials/keras/keras_tuner). The size of the vocabulary (`vocab_size`) is set to 128 representing all the pitches supported by `pretty_midi`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fGA3VxcFXZ4T" }, "outputs": [], "source": [ "seq_length = 25\n", "vocab_size = 128\n", "seq_ds = create_sequences(notes_ds, seq_length, vocab_size)\n", "seq_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "AX9nKmSYetGo" }, "source": [ "The shape of the dataset is ```(100,1)```, meaning that the model will take 100 notes as input, and learn to predict the following note as output." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ESK9cL7__TF3" }, "outputs": [], "source": [ "for seq, target in seq_ds.take(1):\n", " print('sequence shape:', seq.shape)\n", " print('sequence elements (first 10):', seq[0: 10])\n", " print()\n", " print('target:', target)" ] }, { "cell_type": "markdown", "metadata": { "id": "kR3TVZZGk5Qq" }, "source": [ "Batch the examples, and configure the dataset for performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fTpFoiM_AV_Y" }, "outputs": [], "source": [ "batch_size = 64\n", "buffer_size = n_notes - seq_length # the number of items in the dataset\n", "train_ds = (seq_ds\n", " .shuffle(buffer_size)\n", " .batch(batch_size, drop_remainder=True)\n", " .cache()\n", " .prefetch(tf.data.experimental.AUTOTUNE))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LySbjV0GzXQu" }, "outputs": [], "source": [ "train_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "cWZmfkshqP8G" }, "source": [ "## Create and train the model" ] }, { "cell_type": "markdown", "metadata": { "id": "iGQn32q-hdK2" }, "source": [ "The model will have three outputs, one for each note variable. For `step` and `duration`, you will use a custom loss function based on mean squared error that encourages the model to output non-negative values." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "erxLOif08e8v" }, "outputs": [], "source": [ "def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):\n", " mse = (y_true - y_pred) ** 2\n", " positive_pressure = 10 * tf.maximum(-y_pred, 0.0)\n", " return tf.reduce_mean(mse + positive_pressure)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kNaVWcCzAm5V" }, "outputs": [], "source": [ "input_shape = (seq_length, 3)\n", "learning_rate = 0.005\n", "\n", "inputs = tf.keras.Input(input_shape)\n", "x = tf.keras.layers.LSTM(128)(inputs)\n", "\n", "outputs = {\n", " 'pitch': tf.keras.layers.Dense(128, name='pitch')(x),\n", " 'step': tf.keras.layers.Dense(1, name='step')(x),\n", " 'duration': tf.keras.layers.Dense(1, name='duration')(x),\n", "}\n", "\n", "model = tf.keras.Model(inputs, outputs)\n", "\n", "loss = {\n", " 'pitch': tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True),\n", " 'step': mse_with_positive_pressure,\n", " 'duration': mse_with_positive_pressure,\n", "}\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "model.compile(loss=loss, optimizer=optimizer)\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "VDL0Jypt3eU5" }, "source": [ "Testing the `model.evaluate` function, you can see that the `pitch` loss is significantly greater than the `step` and `duration` losses. \n", "Note that `loss` is the total loss computed by summing all the other losses and is currently dominated by the `pitch` loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BlATt7Rl0XJl" }, "outputs": [], "source": [ "losses = model.evaluate(train_ds, return_dict=True)\n", "losses" ] }, { "cell_type": "markdown", "metadata": { "id": "KLvNLvtR3W59" }, "source": [ "One way balance this is to use the `loss_weights` argument to compile:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9fQB5SiN3ufX" }, "outputs": [], "source": [ "model.compile(\n", " loss=loss,\n", " loss_weights={\n", " 'pitch': 0.05,\n", " 'step': 1.0,\n", " 'duration':1.0,\n", " },\n", " optimizer=optimizer,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "nPMUnIMelHgR" }, "source": [ "The `loss` then becomes the weighted sum of the individual losses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T7CzWmFR38ut" }, "outputs": [], "source": [ "model.evaluate(train_ds, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "SJbn7HZgfosr" }, "source": [ "Train the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uQA_rwKEgPjp" }, "outputs": [], "source": [ "callbacks = [\n", " tf.keras.callbacks.ModelCheckpoint(\n", " filepath='./training_checkpoints/ckpt_{epoch}',\n", " save_weights_only=True),\n", " tf.keras.callbacks.EarlyStopping(\n", " monitor='loss',\n", " patience=5,\n", " verbose=1,\n", " restore_best_weights=True),\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aLoYY8-XaPFN" }, "outputs": [], "source": [ "%%time\n", "epochs = 50\n", "\n", "history = model.fit(\n", " train_ds,\n", " epochs=epochs,\n", " callbacks=callbacks,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PYBSjgDWiUfT" }, "outputs": [], "source": [ "plt.plot(history.epoch, history.history['loss'], label='total loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "aPWI94lQ8uQA" }, "source": [ "## Generate notes" ] }, { "cell_type": "markdown", "metadata": { "id": "Wbaoiy4Hf-n5" }, "source": [ "To use the model to generate notes, you will first need to provide a starting sequence of notes. The function below generates one note from a sequence of notes. \n", "\n", "For note pitch, it draws a sample from the softmax distribution of notes produced by the model, and does not simply pick the note with the highest probability.\n", "Always picking the note with the highest probability would lead to repetitive sequences of notes being generated.\n", "\n", "The `temperature` parameter can be used to control the randomness of notes generated. You can find more details on temperature in [Text generation with an RNN](https://www.tensorflow.org/text/tutorials/text_generation)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1mil8ZyJNe1w" }, "outputs": [], "source": [ "def predict_next_note(\n", " notes: np.ndarray, \n", " model: tf.keras.Model, \n", " temperature: float = 1.0) -> tuple[int, float, float]:\n", " \"\"\"Generates a note as a tuple of (pitch, step, duration), using a trained sequence model.\"\"\"\n", "\n", " assert temperature > 0\n", "\n", " # Add batch dimension\n", " inputs = tf.expand_dims(notes, 0)\n", "\n", " predictions = model.predict(inputs)\n", " pitch_logits = predictions['pitch']\n", " step = predictions['step']\n", " duration = predictions['duration']\n", " \n", " pitch_logits /= temperature\n", " pitch = tf.random.categorical(pitch_logits, num_samples=1)\n", " pitch = tf.squeeze(pitch, axis=-1)\n", " duration = tf.squeeze(duration, axis=-1)\n", " step = tf.squeeze(step, axis=-1)\n", "\n", " # `step` and `duration` values should be non-negative\n", " step = tf.maximum(0, step)\n", " duration = tf.maximum(0, duration)\n", "\n", " return int(pitch), float(step), float(duration)" ] }, { "cell_type": "markdown", "metadata": { "id": "W64K-EX3hxU_" }, "source": [ "Now generate some notes. You can play around with temperature and the starting sequence in `next_notes` and see what happens." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "87fPl4auPdR3" }, "outputs": [], "source": [ "temperature = 2.0\n", "num_predictions = 120\n", "\n", "sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)\n", "\n", "# The initial sequence of notes; pitch is normalized similar to training\n", "# sequences\n", "input_notes = (\n", " sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))\n", "\n", "generated_notes = []\n", "prev_start = 0\n", "for _ in range(num_predictions):\n", " pitch, step, duration = predict_next_note(input_notes, model, temperature)\n", " start = prev_start + step\n", " end = start + duration\n", " input_note = (pitch, step, duration)\n", " generated_notes.append((*input_note, start, end))\n", " input_notes = np.delete(input_notes, 0, axis=0)\n", " input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)\n", " prev_start = start\n", "\n", "generated_notes = pd.DataFrame(\n", " generated_notes, columns=(*key_order, 'start', 'end'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0MK7HmqLuqka" }, "outputs": [], "source": [ "generated_notes.head(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e9K9KHPaTNnK" }, "outputs": [], "source": [ "out_file = 'output.mid'\n", "out_pm = notes_to_midi(\n", " generated_notes, out_file=out_file, instrument_name=instrument_name)\n", "display_audio(out_pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "u4N9_Y03Kw-3" }, "source": [ "You can also download the audio file by adding the two lines below:\n", "\n", "```\n", "from google.colab import files\n", "files.download(out_file)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "trp82gTqskPR" }, "source": [ "Visualize the generated notes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NlNsxcnhvbcK" }, "outputs": [], "source": [ "plot_piano_roll(generated_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "p5_yA9lvvitC" }, "source": [ "Check the distributions of `pitch`, `step` and `duration`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j5bco2WVRkAa" }, "outputs": [], "source": [ "plot_distributions(generated_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "iAyxR7Itw3Wh" }, "source": [ "In the above plots, you will notice the change in distribution of the note variables.\n", "Since there is a feedback loop between the model's outputs and inputs, the model tends to generate similar sequences of outputs to reduce the loss. \n", "This is particularly relevant for `step` and `duration`, which uses the MSE loss.\n", "For `pitch`, you can increase the randomness by increasing the `temperature` in `predict_next_note`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Bkfe3GYZEu4l" }, "source": [ "## Next steps\n", "\n", "This tutorial demonstrated the mechanics of using an RNN to generate sequences of notes from a dataset of MIDI files. To learn more, you can visit the closely related [Text generation with an RNN](https://www.tensorflow.org/text/tutorials/text_generation) tutorial, which contains additional diagrams and explanations. \n", "\n", "One of the alternatives to using RNNs for music generation is using GANs. Rather than generating audio, a GAN-based approach can generate an entire sequence in parallel. The Magenta team has done impressive work on this approach with [GANSynth](https://magenta.tensorflow.org/gansynth). You can also find many wonderful music and art projects and open-source code on [Magenta project website](https://magenta.tensorflow.org/)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "music_generation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }