{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0DH9bjZD_Cfi" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "JO1GUwC1_T2x" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "M4xOsFiu-1-c" }, "source": [ "# 使用 RNN 生成音乐" ] }, { "cell_type": "markdown", "metadata": { "id": "OyzAxV7Vu_9Y" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "hr78EkAY-FFg" }, "source": [ "本教程向您展示如何使用简单的循环神经网络 (RNN) 生成音符。您将使用来自 [MAESTRO 数据集](https://magenta.tensorflow.org/datasets/maestro)的钢琴 MIDI 文件集合来训练模型。给定一系列音符,您的模型将学习预测序列中的下一个音符。可以通过重复调用模型来生成更长的音符序列。\n", "\n", "本教程包含解析和创建 MIDI 文件的完整代码。可以通过访问[使用 RNN 的文本生成](https://tensorflow.google.cn/text/tutorials/text_generation)教程来详细了解 RNN 的运作方式。" ] }, { "cell_type": "markdown", "metadata": { "id": "4ZniYb7Y_0Ey" }, "source": [ "## 安装" ] }, { "cell_type": "markdown", "metadata": { "id": "3ks8__E_WUGt" }, "source": [ "本教程使用 [`pretty_midi`](https://github.com/craffel/pretty-midi) 库创建和解析 MIDI 文件,并使用 [`pyfluidsynth`](https://github.com/nwhitehead/pyfluidsynth) 在 Colab 中生成音频播放。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kahm6Z8v_TqC" }, "outputs": [], "source": [ "!sudo apt install -y fluidsynth" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M0lAReB7_Vqb" }, "outputs": [], "source": [ "!pip install --upgrade pyfluidsynth" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G46kKoQZmIa8" }, "outputs": [], "source": [ "!pip install pretty_midi" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GsLFq7nsiqcq" }, "outputs": [], "source": [ "import collections\n", "import datetime\n", "import fluidsynth\n", "import glob\n", "import numpy as np\n", "import pathlib\n", "import pandas as pd\n", "import pretty_midi\n", "import seaborn as sns\n", "import tensorflow as tf\n", "\n", "from IPython import display\n", "from matplotlib import pyplot as plt\n", "from typing import Optional" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Efja_OtJNzAM" }, "outputs": [], "source": [ "seed = 42\n", "tf.random.set_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# Sampling rate for audio playback\n", "_SAMPLING_RATE = 16000" ] }, { "cell_type": "markdown", "metadata": { "id": "FzIbfb-Ikgg7" }, "source": [ "## 下载 Maestro 数据集" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mwja4SWmibrL" }, "outputs": [], "source": [ "data_dir = pathlib.Path('data/maestro-v2.0.0')\n", "if not data_dir.exists():\n", " tf.keras.utils.get_file(\n", " 'maestro-v2.0.0-midi.zip',\n", " origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',\n", " extract=True,\n", " cache_dir='.', cache_subdir='data',\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "k7UYBSxcINqJ" }, "source": [ "该数据集包含约 1,200 个 MIDI 文件。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "72iFI1bPB9o1" }, "outputs": [], "source": [ "filenames = glob.glob(str(data_dir/'**/*.mid*'))\n", "print('Number of files:', len(filenames))" ] }, { "cell_type": "markdown", "metadata": { "id": "8BlRafYDIRgA" }, "source": [ "## 处理 MIDI 文件" ] }, { "cell_type": "markdown", "metadata": { "id": "oFsmG87gXSbh" }, "source": [ "首先,使用 `pretty_midi` 解析单个 MIDI 文件并检查音符的格式。如果想下载下面的 MIDI 文件以在计算机上播放,则可以在 Colab 中通过编写 `files.download(sample_file)` 来实现。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6oSCbHvJNbci" }, "outputs": [], "source": [ "sample_file = filenames[1]\n", "print(sample_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "A48VdGEpXnLp" }, "source": [ "为示例 MIDI 文件生成 `PrettyMIDI` 对象。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1YSQ5DjRI2md" }, "outputs": [], "source": [ "pm = pretty_midi.PrettyMIDI(sample_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "FZNVsZuA_lef" }, "source": [ "播放示例文件。播放微件可能需要几秒钟来加载。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vzoHAaVY_kyY" }, "outputs": [], "source": [ "def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):\n", " waveform = pm.fluidsynth(fs=_SAMPLING_RATE)\n", " # Take a sample of the generated waveform to mitigate kernel resets\n", " waveform_short = waveform[:seconds*_SAMPLING_RATE]\n", " return display.Audio(waveform_short, rate=_SAMPLING_RATE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GOe-3AAi_sRw" }, "outputs": [], "source": [ "display_audio(pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Lqe7nOsIyh1" }, "source": [ "对 MIDI 文件进行一些检查。使用什么样的工具?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SIGHYQPZQnRo" }, "outputs": [], "source": [ "print('Number of instruments:', len(pm.instruments))\n", "instrument = pm.instruments[0]\n", "instrument_name = pretty_midi.program_to_instrument_name(instrument.program)\n", "print('Instrument name:', instrument_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "KVQfV2hVKB28" }, "source": [ "## 提取音符" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nYZm_VehYOTZ" }, "outputs": [], "source": [ "for i, note in enumerate(instrument.notes[:10]):\n", " note_name = pretty_midi.note_number_to_name(note.pitch)\n", " duration = note.end - note.start\n", " print(f'{i}: pitch={note.pitch}, note_name={note_name},'\n", " f' duration={duration:.4f}')" ] }, { "cell_type": "markdown", "metadata": { "id": "jutzynyqX_GC" }, "source": [ "在训练模型时,将使用三个变量来表示音符:`pitch`、`step` 和 `duration`。pitch 是声音的感知质量,作为 MIDI 音符编号。`step` 是从前一个音符或曲目开始所经过的时间。`duration` 是音符将播放多长时间(以秒为单位),是音符结束时间和音符开始时间之间的差值。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "KGn7Juv_PTi6" }, "source": [ "从示例 MIDI 文件中提取音符。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wyp_wdcEPWby" }, "outputs": [], "source": [ "def midi_to_notes(midi_file: str) -> pd.DataFrame:\n", " pm = pretty_midi.PrettyMIDI(midi_file)\n", " instrument = pm.instruments[0]\n", " notes = collections.defaultdict(list)\n", "\n", " # Sort the notes by start time\n", " sorted_notes = sorted(instrument.notes, key=lambda note: note.start)\n", " prev_start = sorted_notes[0].start\n", "\n", " for note in sorted_notes:\n", " start = note.start\n", " end = note.end\n", " notes['pitch'].append(note.pitch)\n", " notes['start'].append(start)\n", " notes['end'].append(end)\n", " notes['step'].append(start - prev_start)\n", " notes['duration'].append(end - start)\n", " prev_start = start\n", "\n", " return pd.DataFrame({name: np.array(value) for name, value in notes.items()})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X0kPjLBlcnY6" }, "outputs": [], "source": [ "raw_notes = midi_to_notes(sample_file)\n", "raw_notes.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "-71LPvjubOSO" }, "source": [ "解释音符名称可能比解释音高更容易,因此您可以使用下面的函数将数字音高值转换为音符名称。音符名称显示了音符类型、变音记号和八度数(例如 C#4)。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WE9YXrGZbY2X" }, "outputs": [], "source": [ "get_note_names = np.vectorize(pretty_midi.note_number_to_name)\n", "sample_note_names = get_note_names(raw_notes['pitch'])\n", "sample_note_names[:10]" ] }, { "cell_type": "markdown", "metadata": { "id": "Q7sjqbp1e_f-" }, "source": [ "要呈现乐曲,请绘制音高、音轨(即钢琴卷帘)时长的开始点和结束点。从前 100 个音符开始" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "liD2N7x_WOTp" }, "outputs": [], "source": [ "def plot_piano_roll(notes: pd.DataFrame, count: Optional[int] = None):\n", " if count:\n", " title = f'First {count} notes'\n", " else:\n", " title = f'Whole track'\n", " count = len(notes['pitch'])\n", " plt.figure(figsize=(20, 4))\n", " plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)\n", " plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)\n", " plt.plot(\n", " plot_start_stop[:, :count], plot_pitch[:, :count], color=\"b\", marker=\".\")\n", " plt.xlabel('Time [s]')\n", " plt.ylabel('Pitch')\n", " _ = plt.title(title)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vWeUbqmAXjOs" }, "outputs": [], "source": [ "plot_piano_roll(raw_notes, count=100)" ] }, { "cell_type": "markdown", "metadata": { "id": "gcUyCXYhXeVA" }, "source": [ "绘制整个音轨的音符。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G7l76hEDZX8Z" }, "outputs": [], "source": [ "plot_piano_roll(raw_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "5GM1bi3aX8rd" }, "source": [ "检查每个音符变量的分布。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pq9C9XBBaK7W" }, "outputs": [], "source": [ "def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):\n", " plt.figure(figsize=[15, 5])\n", " plt.subplot(1, 3, 1)\n", " sns.histplot(notes, x=\"pitch\", bins=20)\n", "\n", " plt.subplot(1, 3, 2)\n", " max_step = np.percentile(notes['step'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"step\", bins=np.linspace(0, max_step, 21))\n", " \n", " plt.subplot(1, 3, 3)\n", " max_duration = np.percentile(notes['duration'], 100 - drop_percentile)\n", " sns.histplot(notes, x=\"duration\", bins=np.linspace(0, max_duration, 21))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-Nu2Pw24acFD" }, "outputs": [], "source": [ "plot_distributions(raw_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "poIivompcfS4" }, "source": [ "## 创建 MIDI 文件\n", "\n", "可以使用以下函数从音符列表中生成自己的 MIDI 文件。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BD5rsMRARYoV" }, "outputs": [], "source": [ "def notes_to_midi(\n", " notes: pd.DataFrame,\n", " out_file: str, \n", " instrument_name: str,\n", " velocity: int = 100, # note loudness\n", ") -> pretty_midi.PrettyMIDI:\n", "\n", " pm = pretty_midi.PrettyMIDI()\n", " instrument = pretty_midi.Instrument(\n", " program=pretty_midi.instrument_name_to_program(\n", " instrument_name))\n", "\n", " prev_start = 0\n", " for i, note in notes.iterrows():\n", " start = float(prev_start + note['step'])\n", " end = float(start + note['duration'])\n", " note = pretty_midi.Note(\n", " velocity=velocity,\n", " pitch=int(note['pitch']),\n", " start=start,\n", " end=end,\n", " )\n", " instrument.notes.append(note)\n", " prev_start = start\n", "\n", " pm.instruments.append(instrument)\n", " pm.write(out_file)\n", " return pm" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTazLbuWPIPF" }, "outputs": [], "source": [ "example_file = 'example.midi'\n", "example_pm = notes_to_midi(\n", " raw_notes, out_file=example_file, instrument_name=instrument_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "XG0N9zZV_4Gp" }, "source": [ "播放生成的 MIDI 文件,看看有什么区别。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fGRLs-eR_4uK" }, "outputs": [], "source": [ "display_audio(example_pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "CLrUscjhBzYc" }, "source": [ "和以前一样,可以编写 `files.download(example_file)` 来下载和播放此文件。" ] }, { "cell_type": "markdown", "metadata": { "id": "pfRNk9tEScuf" }, "source": [ "## 创建训练数据集\n" ] }, { "cell_type": "markdown", "metadata": { "id": "b77zHR1udDrK" }, "source": [ "通过从 MIDI 文件中提取音符来创建训练数据集。可以先使用少量文件,然后再尝试更多文件。这可能需要几分钟。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GiaQiTnXSW-T" }, "outputs": [], "source": [ "num_files = 5\n", "all_notes = []\n", "for f in filenames[:num_files]:\n", " notes = midi_to_notes(f)\n", " all_notes.append(notes)\n", "\n", "all_notes = pd.concat(all_notes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F4bMDeRvgWqx" }, "outputs": [], "source": [ "n_notes = len(all_notes)\n", "print('Number of notes parsed:', n_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "xIBLvj-cODWS" }, "source": [ "接下来,从已解析的音符创建 [tf.data.Dataset](https://tensorflow.google.cn/datasets)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mvNHCHZdXG2P" }, "outputs": [], "source": [ "key_order = ['pitch', 'step', 'duration']\n", "train_notes = np.stack([all_notes[key] for key in key_order], axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PLC_19tshyFk" }, "outputs": [], "source": [ "notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)\n", "notes_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "Sj9SXRCjt3I7" }, "source": [ "然后,针对成批的音符序列训练模型。每个样本将包含一系列音符作为输入特征,下一个音符作为标签。通过这种方式,模型将被训练来预测序列中的下一个音符。可以在[使用 RNN 的文本分类](https://tensorflow.google.cn/text/tutorials/text_generation)中找到说明此过程的图表(以及更多详细信息)。\n", "\n", "可以使用大小为 `seq_length` 的方便 [window](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#window) 函数来创建这种格式的特征和标签。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZkEC-5s6wJJV" }, "outputs": [], "source": [ "def create_sequences(\n", " dataset: tf.data.Dataset, \n", " seq_length: int,\n", " vocab_size = 128,\n", ") -> tf.data.Dataset:\n", " \"\"\"Returns TF Dataset of sequence and label examples.\"\"\"\n", " seq_length = seq_length+1\n", "\n", " # Take 1 extra for the labels\n", " windows = dataset.window(seq_length, shift=1, stride=1,\n", " drop_remainder=True)\n", "\n", " # `flat_map` flattens the\" dataset of datasets\" into a dataset of tensors\n", " flatten = lambda x: x.batch(seq_length, drop_remainder=True)\n", " sequences = windows.flat_map(flatten)\n", " \n", " # Normalize note pitch\n", " def scale_pitch(x):\n", " x = x/[vocab_size,1.0,1.0]\n", " return x\n", "\n", " # Split the labels\n", " def split_labels(sequences):\n", " inputs = sequences[:-1]\n", " labels_dense = sequences[-1]\n", " labels = {key:labels_dense[i] for i,key in enumerate(key_order)}\n", "\n", " return scale_pitch(inputs), labels\n", "\n", " return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "2xDX5pVkegrv" }, "source": [ "为每个样本设置序列长度。尝试不同的长度(例如 50、100、150),看看哪一个最适合数据,或者使用[超参数调优](https://tensorflow.google.cn/tutorials/keras/keras_tuner)。词汇表的大小 (`vocab_size`) 设置为 128,表示 `pretty_midi` 支持的所有音高。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fGA3VxcFXZ4T" }, "outputs": [], "source": [ "seq_length = 25\n", "vocab_size = 128\n", "seq_ds = create_sequences(notes_ds, seq_length, vocab_size)\n", "seq_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "AX9nKmSYetGo" }, "source": [ "数据集的形状为 `(100,1)`,表示模型将以 100 个音符作为输入,并学习预测以下音符作为输出。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ESK9cL7__TF3" }, "outputs": [], "source": [ "for seq, target in seq_ds.take(1):\n", " print('sequence shape:', seq.shape)\n", " print('sequence elements (first 10):', seq[0: 10])\n", " print()\n", " print('target:', target)" ] }, { "cell_type": "markdown", "metadata": { "id": "kR3TVZZGk5Qq" }, "source": [ "对样本进行批处理,并配置数据集以提高性能。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fTpFoiM_AV_Y" }, "outputs": [], "source": [ "batch_size = 64\n", "buffer_size = n_notes - seq_length # the number of items in the dataset\n", "train_ds = (seq_ds\n", " .shuffle(buffer_size)\n", " .batch(batch_size, drop_remainder=True)\n", " .cache()\n", " .prefetch(tf.data.experimental.AUTOTUNE))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LySbjV0GzXQu" }, "outputs": [], "source": [ "train_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "cWZmfkshqP8G" }, "source": [ "## 创建并训练模型" ] }, { "cell_type": "markdown", "metadata": { "id": "iGQn32q-hdK2" }, "source": [ "该模型将具有三个输出,每个音符变量使用一个输出。对于 `step` 和 `duration`,将使用基于均方误差的自定义损失函数,以鼓励模型输出非负值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "erxLOif08e8v" }, "outputs": [], "source": [ "def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):\n", " mse = (y_true - y_pred) ** 2\n", " positive_pressure = 10 * tf.maximum(-y_pred, 0.0)\n", " return tf.reduce_mean(mse + positive_pressure)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kNaVWcCzAm5V" }, "outputs": [], "source": [ "input_shape = (seq_length, 3)\n", "learning_rate = 0.005\n", "\n", "inputs = tf.keras.Input(input_shape)\n", "x = tf.keras.layers.LSTM(128)(inputs)\n", "\n", "outputs = {\n", " 'pitch': tf.keras.layers.Dense(128, name='pitch')(x),\n", " 'step': tf.keras.layers.Dense(1, name='step')(x),\n", " 'duration': tf.keras.layers.Dense(1, name='duration')(x),\n", "}\n", "\n", "model = tf.keras.Model(inputs, outputs)\n", "\n", "loss = {\n", " 'pitch': tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True),\n", " 'step': mse_with_positive_pressure,\n", " 'duration': mse_with_positive_pressure,\n", "}\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n", "\n", "model.compile(loss=loss, optimizer=optimizer)\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "VDL0Jypt3eU5" }, "source": [ "测试 `model.evaluate` 函数,可以看到 `pitch` 损失明显大于 `step` 和 `duration` 损失。请注意,`loss` 是通过对所有其他损失求和计算得出的总损失,目前主要由 `pitch` 损失决定。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BlATt7Rl0XJl" }, "outputs": [], "source": [ "losses = model.evaluate(train_ds, return_dict=True)\n", "losses" ] }, { "cell_type": "markdown", "metadata": { "id": "KLvNLvtR3W59" }, "source": [ "平衡这种情况的一种方式是使用 `loss_weights` 参数进行编译:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9fQB5SiN3ufX" }, "outputs": [], "source": [ "model.compile(\n", " loss=loss,\n", " loss_weights={\n", " 'pitch': 0.05,\n", " 'step': 1.0,\n", " 'duration':1.0,\n", " },\n", " optimizer=optimizer,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "nPMUnIMelHgR" }, "source": [ "然后,`loss` 成为各个损失的加权总和。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T7CzWmFR38ut" }, "outputs": [], "source": [ "model.evaluate(train_ds, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "SJbn7HZgfosr" }, "source": [ "训练模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uQA_rwKEgPjp" }, "outputs": [], "source": [ "callbacks = [\n", " tf.keras.callbacks.ModelCheckpoint(\n", " filepath='./training_checkpoints/ckpt_{epoch}',\n", " save_weights_only=True),\n", " tf.keras.callbacks.EarlyStopping(\n", " monitor='loss',\n", " patience=5,\n", " verbose=1,\n", " restore_best_weights=True),\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aLoYY8-XaPFN" }, "outputs": [], "source": [ "%%time\n", "epochs = 50\n", "\n", "history = model.fit(\n", " train_ds,\n", " epochs=epochs,\n", " callbacks=callbacks,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PYBSjgDWiUfT" }, "outputs": [], "source": [ "plt.plot(history.epoch, history.history['loss'], label='total loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "aPWI94lQ8uQA" }, "source": [ "## 生成音符" ] }, { "cell_type": "markdown", "metadata": { "id": "Wbaoiy4Hf-n5" }, "source": [ "要使用模型生成音符,首先需要提供音符的起始序列。下面的函数可以从一系列音符中生成一个音符。\n", "\n", "对于音符音高,它会从模型产生的音符的 softmax 分布中抽取样本,而不是简单地选择概率最高的音符。始终选择概率最高的音符会导致生成重复的音符序列。\n", "\n", "`temperature` 参数可用于控制所生成音符的随机性。可以在[使用 RNN 的文本生成](https://tensorflow.google.cn/text/tutorials/text_generation)中详细了解温度。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1mil8ZyJNe1w" }, "outputs": [], "source": [ "def predict_next_note(\n", " notes: np.ndarray, \n", " keras_model: tf.keras.Model, \n", " temperature: float = 1.0) -> tuple[int, float, float]:\n", " \"\"\"Generates a note as a tuple of (pitch, step, duration), using a trained sequence model.\"\"\"\n", "\n", " assert temperature > 0\n", "\n", " # Add batch dimension\n", " inputs = tf.expand_dims(notes, 0)\n", "\n", " predictions = model.predict(inputs)\n", " pitch_logits = predictions['pitch']\n", " step = predictions['step']\n", " duration = predictions['duration']\n", " \n", " pitch_logits /= temperature\n", " pitch = tf.random.categorical(pitch_logits, num_samples=1)\n", " pitch = tf.squeeze(pitch, axis=-1)\n", " duration = tf.squeeze(duration, axis=-1)\n", " step = tf.squeeze(step, axis=-1)\n", "\n", " # `step` and `duration` values should be non-negative\n", " step = tf.maximum(0, step)\n", " duration = tf.maximum(0, duration)\n", "\n", " return int(pitch), float(step), float(duration)" ] }, { "cell_type": "markdown", "metadata": { "id": "W64K-EX3hxU_" }, "source": [ "现在,生成一些音符。可以在 `next_notes` 中调整温度和起始序列,看看会发生什么。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "87fPl4auPdR3" }, "outputs": [], "source": [ "temperature = 2.0\n", "num_predictions = 120\n", "\n", "sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)\n", "\n", "# The initial sequence of notes; pitch is normalized similar to training\n", "# sequences\n", "input_notes = (\n", " sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))\n", "\n", "generated_notes = []\n", "prev_start = 0\n", "for _ in range(num_predictions):\n", " pitch, step, duration = predict_next_note(input_notes, model, temperature)\n", " start = prev_start + step\n", " end = start + duration\n", " input_note = (pitch, step, duration)\n", " generated_notes.append((*input_note, start, end))\n", " input_notes = np.delete(input_notes, 0, axis=0)\n", " input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)\n", " prev_start = start\n", "\n", "generated_notes = pd.DataFrame(\n", " generated_notes, columns=(*key_order, 'start', 'end'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0MK7HmqLuqka" }, "outputs": [], "source": [ "generated_notes.head(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e9K9KHPaTNnK" }, "outputs": [], "source": [ "out_file = 'output.mid'\n", "out_pm = notes_to_midi(\n", " generated_notes, out_file=out_file, instrument_name=instrument_name)\n", "display_audio(out_pm)" ] }, { "cell_type": "markdown", "metadata": { "id": "u4N9_Y03Kw-3" }, "source": [ "还可以通过添加以下两行来下载音频文件:\n", "\n", "```\n", "from google.colab import files\n", "files.download(out_file)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "trp82gTqskPR" }, "source": [ "呈现生成的音符。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NlNsxcnhvbcK" }, "outputs": [], "source": [ "plot_piano_roll(generated_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "p5_yA9lvvitC" }, "source": [ "检查 `pitch`、`step` 和 `duration` 的分布。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j5bco2WVRkAa" }, "outputs": [], "source": [ "plot_distributions(generated_notes)" ] }, { "cell_type": "markdown", "metadata": { "id": "iAyxR7Itw3Wh" }, "source": [ "在上面的图中,您会注意到音符变量分布的变化。由于模型的输出和输入之间存在反馈回路,模型倾向于生成相似的输出序列以减少损失。这与使用 MSE 损失的 `step` 和 `duration` 特别相关。对于 `pitch`,可以通过增大 `predict_next_note` 中的 `temperature` 来增加随机性。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Bkfe3GYZEu4l" }, "source": [ "## 后续步骤\n", "\n", "本教程演示了使用 RNN 从 MIDI 文件数据集中生成音符序列的机制。要了解详情,您可以访问密切相关的[使用 RNN 的文本生成](https://tensorflow.google.cn/text/tutorials/text_generation)教程,其中包含附加的图表和解释。\n", "\n", "使用 RNN 生成音乐的一种替代方式是使用 GAN。基于 GAN 的方式并非生成音频,而是可以并行生成整个序列。Magenta 团队使用 [GANSynth](https://magenta.tensorflow.org/gansynth) 在这种方式上完成了非凡的工作。此外,还可以在 [Magenta 项目网站](https://magenta.tensorflow.org/)上找到许多精彩的音乐和艺术项目以及开源代码。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "music_generation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }