{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "t09eeeR5prIJ" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T23:58:14.802183Z", "iopub.status.busy": "2022-12-14T23:58:14.801923Z", "iopub.status.idle": "2022-12-14T23:58:14.806449Z", "shell.execute_reply": "2022-12-14T23:58:14.805800Z" }, "id": "GCCk8_dHpuNf" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "ovpZyIhNIgoq" }, "source": [ "# RNN によるテキスト生成" ] }, { "cell_type": "markdown", "metadata": { "id": "hcD2nPQvPOFM" }, "source": [ "
\n",
" ![]() | \n",
" \n",
" ![]() | \n",
" \n",
" ![]() | \n",
" \n",
" ![]() | \n",
"
\n", "QUEENE:\n", "I had thought thou hadst a Roman; for the oracle,\n", "Thus by All bids the man against the word,\n", "Which are so weak of care, by old care done;\n", "Your children were in your holy love,\n", "And the precipitation through the bleeding throne.\n", "\n", "BISHOP OF ELY:\n", "Marry, and will, my lord, to weep in such a one were prettiest;\n", "Yet now I was adopted heir\n", "Of the world's lamentable day,\n", "To watch the next way with his father with his face?\n", "\n", "ESCALUS:\n", "The cause why then we are all resolved more sons.\n", "\n", "VOLUMNIA:\n", "O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,\n", "And love and pale as any will to that word.\n", "\n", "QUEEN ELIZABETH:\n", "But how long have I heard the soul for this world,\n", "And show his hands of life be proved to stand.\n", "\n", "PETRUCHIO:\n", "I say he look'd on, if I must be content\n", "To stay him from the fatal of our country's bliss.\n", "His lordship pluck'd from this sentence then for prey,\n", "And then let us twain, being the moon,\n", "were she such a case as fills m\n", "\n", "\n", "いくつかは文法にあったものがある一方で、ほとんどは意味をなしていません。このモデルは、単語の意味を学習していませんが、次のことを考えてみてください。\n", "\n", "* このモデルは文字ベースです。訓練が始まった時に、モデルは英語の単語のスペルも知りませんし、単語がテキストの単位であることも知らないのです。\n", "\n", "* 出力の構造は戯曲に似ています。だいたいのばあい、データセットとおなじ大文字で書かれた話し手の名前で始まっています。\n", "\n", "* 以下に示すように、モデルはテキストの小さなバッチ(各100文字)で訓練されていますが、一貫した構造のより長いテキストのシーケンスを生成できます。" ] }, { "cell_type": "markdown", "metadata": { "id": "srXC6pLGLwS6" }, "source": [ "## 設定" ] }, { "cell_type": "markdown", "metadata": { "id": "WGyKZj3bzf9p" }, "source": [ "### TensorFlow 等のライブラリインポート" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:14.811023Z", "iopub.status.busy": "2022-12-14T23:58:14.810506Z", "iopub.status.idle": "2022-12-14T23:58:16.873320Z", "shell.execute_reply": "2022-12-14T23:58:16.872429Z" }, "id": "yG_n40gFzf9s" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 23:58:15.841952: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 23:58:15.842046: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 23:58:15.842056: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import numpy as np\n", "import os\n", "import time" ] }, { "cell_type": "markdown", "metadata": { "id": "EHDoRoc5PKWz" }, "source": [ "### シェイクスピアデータセットのダウンロード\n", "\n", "独自のデータで実行するためには下記の行を変更してください。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:16.877621Z", "iopub.status.busy": "2022-12-14T23:58:16.877206Z", "iopub.status.idle": "2022-12-14T23:58:16.881288Z", "shell.execute_reply": "2022-12-14T23:58:16.880673Z" }, "id": "pD_55cOxLkAb" }, "outputs": [], "source": [ "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" ] }, { "cell_type": "markdown", "metadata": { "id": "UHjdCjDuSvX_" }, "source": [ "### データの読み込み\n", "\n", "まずはテキストをのぞいてみましょう。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:16.884336Z", "iopub.status.busy": "2022-12-14T23:58:16.884084Z", "iopub.status.idle": "2022-12-14T23:58:16.889927Z", "shell.execute_reply": "2022-12-14T23:58:16.889296Z" }, "id": "aavnuByVymwK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Length of text: 1115394 characters\n" ] } ], "source": [ "# 読み込んだのち、Python 2 との互換性のためにデコード\n", "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", "# テキストの長さは含まれる文字数\n", "print ('Length of text: {} characters'.format(len(text)))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:16.893070Z", "iopub.status.busy": "2022-12-14T23:58:16.892551Z", "iopub.status.idle": "2022-12-14T23:58:16.896005Z", "shell.execute_reply": "2022-12-14T23:58:16.895329Z" }, "id": "Duhg9NrUymwO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "First Citizen:\n", "Before we proceed any further, hear me speak.\n", "\n", "All:\n", "Speak, speak.\n", "\n", "First Citizen:\n", "You are all resolved rather to die than to famish?\n", "\n", "All:\n", "Resolved. resolved.\n", "\n", "First Citizen:\n", "First, you know Caius Marcius is chief enemy to the people.\n", "\n" ] } ], "source": [ "# テキストの最初の 250文字を参照\n", "print(text[:250])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:16.899000Z", "iopub.status.busy": "2022-12-14T23:58:16.898502Z", "iopub.status.idle": "2022-12-14T23:58:16.912133Z", "shell.execute_reply": "2022-12-14T23:58:16.911501Z" }, "id": "IlCgQBRVymwR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "65 unique characters\n" ] } ], "source": [ "# ファイル中のユニークな文字の数\n", "vocab = sorted(set(text))\n", "print ('{} unique characters'.format(len(vocab)))" ] }, { "cell_type": "markdown", "metadata": { "id": "rNnrKn_lL-IJ" }, "source": [ "## テキストの処理" ] }, { "cell_type": "markdown", "metadata": { "id": "LFjSVAlWzf-N" }, "source": [ "### テキストのベクトル化\n", "\n", "訓練をする前に、文字列を数値表現に変換する必要があります。2つの参照テーブルを作成します。一つは文字を数字に変換するもの、もう一つは数字を文字に変換するものです。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:16.915619Z", "iopub.status.busy": "2022-12-14T23:58:16.915078Z", "iopub.status.idle": "2022-12-14T23:58:17.025025Z", "shell.execute_reply": "2022-12-14T23:58:17.024260Z" }, "id": "IalZLbvOzf-F" }, "outputs": [], "source": [ "# それぞれの文字からインデックスへの対応表を作成\n", "char2idx = {u:i for i, u in enumerate(vocab)}\n", "idx2char = np.array(vocab)\n", "\n", "text_as_int = np.array([char2idx[c] for c in text])" ] }, { "cell_type": "markdown", "metadata": { "id": "tZfqhkYCymwX" }, "source": [ "これで、それぞれの文字を整数で表現できました。文字を、0 から`len(unique)` までのインデックスに変換していることに注意してください。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:17.029444Z", "iopub.status.busy": "2022-12-14T23:58:17.028842Z", "iopub.status.idle": "2022-12-14T23:58:17.033228Z", "shell.execute_reply": "2022-12-14T23:58:17.032627Z" }, "id": "FYyNlCNXymwY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " '\\n': 0,\n", " ' ' : 1,\n", " '!' : 2,\n", " '$' : 3,\n", " '&' : 4,\n", " \"'\" : 5,\n", " ',' : 6,\n", " '-' : 7,\n", " '.' : 8,\n", " '3' : 9,\n", " ':' : 10,\n", " ';' : 11,\n", " '?' : 12,\n", " 'A' : 13,\n", " 'B' : 14,\n", " 'C' : 15,\n", " 'D' : 16,\n", " 'E' : 17,\n", " 'F' : 18,\n", " 'G' : 19,\n", " ...\n", "}\n" ] } ], "source": [ "print('{')\n", "for char,_ in zip(char2idx, range(20)):\n", " print(' {:4s}: {:3d},'.format(repr(char), char2idx[char]))\n", "print(' ...\\n}')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:17.036175Z", "iopub.status.busy": "2022-12-14T23:58:17.035925Z", "iopub.status.idle": "2022-12-14T23:58:17.040102Z", "shell.execute_reply": "2022-12-14T23:58:17.039518Z" }, "id": "l1VKcQHcymwb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n" ] } ], "source": [ "# テキストの最初の 13 文字がどのように整数に変換されるかを見てみる\n", "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))" ] }, { "cell_type": "markdown", "metadata": { "id": "bbmsf23Bymwe" }, "source": [ "### 予測タスク" ] }, { "cell_type": "markdown", "metadata": { "id": "wssHQ1oGymwe" }, "source": [ "ある文字、あるいは文字列が与えられたとき、もっともありそうな次の文字はなにか?これが、モデルを訓練してやらせたいタスクです。モデルへの入力は文字列であり、モデルが出力、つまりそれぞれの時点での次の文字を予測をするようにモデルを訓練します。\n", "\n", "RNN はすでに見た要素に基づく内部状態を保持しているため、この時点までに計算されたすべての文字を考えると、次の文字は何でしょうか?" ] }, { "cell_type": "markdown", "metadata": { "id": "hgsVvVxnymwf" }, "source": [ "### 訓練用サンプルとターゲットを作成\n", "\n", "つぎに、テキストをサンプルシーケンスに分割します。それぞれの入力シーケンスは、元のテキストからの `seq_length` 個の文字を含みます。\n", "\n", "入力シーケンスそれぞれに対して、対応するターゲットは同じ長さのテキストを含みますが、1文字ずつ右にシフトしたものです。\n", "\n", "そのため、テキストを `seq_length+1` のかたまりに分割します。たとえば、 `seq_length` が 4 で、テキストが \"Hello\" だとします。入力シーケンスは \"Hell\" で、ターゲットシーケンスは \"ello\" となります。\n", "\n", "これを行うために、最初に `tf.data.Dataset.from_tensor_slices` 関数を使ってテキストベクトルを文字インデックスの連続に変換します。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:17.043557Z", "iopub.status.busy": "2022-12-14T23:58:17.043029Z", "iopub.status.idle": "2022-12-14T23:58:20.723108Z", "shell.execute_reply": "2022-12-14T23:58:20.722128Z" }, "id": "0UHJDA39zf-O" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F\n", "i\n", "r\n", "s\n", "t\n" ] } ], "source": [ "# ひとつの入力としたいシーケンスの文字数としての最大の長さ\n", "seq_length = 100\n", "examples_per_epoch = len(text)//(seq_length+1)\n", "\n", "# 訓練用サンプルとターゲットを作る\n", "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n", "\n", "for i in char_dataset.take(5):\n", " print(idx2char[i.numpy()])" ] }, { "cell_type": "markdown", "metadata": { "id": "-ZSYAcQV8OGP" }, "source": [ "`batch` メソッドを使うと、個々の文字を求める長さのシーケンスに簡単に変換できます。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:20.727397Z", "iopub.status.busy": "2022-12-14T23:58:20.726704Z", "iopub.status.idle": "2022-12-14T23:58:20.746936Z", "shell.execute_reply": "2022-12-14T23:58:20.745990Z" }, "id": "l4hkDU3i7ozi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n", "'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n", "\"now Caius Marcius is chief enemy to the people.\\n\\nAll:\\nWe know't, we know't.\\n\\nFirst Citizen:\\nLet us ki\"\n", "\"ll him, and we'll have corn at our own price.\\nIs't a verdict?\\n\\nAll:\\nNo more talking on't; let it be d\"\n", "'one: away, away!\\n\\nSecond Citizen:\\nOne word, good citizens.\\n\\nFirst Citizen:\\nWe are accounted poor citi'\n" ] } ], "source": [ "sequences = char_dataset.batch(seq_length+1, drop_remainder=True)\n", "\n", "for item in sequences.take(5):\n", " print(repr(''.join(idx2char[item.numpy()])))" ] }, { "cell_type": "markdown", "metadata": { "id": "UbLcIPBj_mWZ" }, "source": [ "シーケンスそれぞれに対して、`map` メソッドを使って各バッチに単純な関数を適用することで、複製とシフトを行い、入力テキストとターゲットテキストを生成します。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:20.750738Z", "iopub.status.busy": "2022-12-14T23:58:20.750479Z", "iopub.status.idle": "2022-12-14T23:58:20.793037Z", "shell.execute_reply": "2022-12-14T23:58:20.792327Z" }, "id": "9NGu-FkO_kYU" }, "outputs": [], "source": [ "def split_input_target(chunk):\n", " input_text = chunk[:-1]\n", " target_text = chunk[1:]\n", " return input_text, target_text\n", "\n", "dataset = sequences.map(split_input_target)" ] }, { "cell_type": "markdown", "metadata": { "id": "hiCopyGZymwi" }, "source": [ "最初のサンプルの入力とターゲットを出力します。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:20.796680Z", "iopub.status.busy": "2022-12-14T23:58:20.796407Z", "iopub.status.idle": "2022-12-14T23:58:20.824572Z", "shell.execute_reply": "2022-12-14T23:58:20.823624Z" }, "id": "GNbw-iR0ymwj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input data: 'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'\n", "Target data: 'irst Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n" ] } ], "source": [ "for input_example, target_example in dataset.take(1):\n", " print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))\n", " print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))" ] }, { "cell_type": "markdown", "metadata": { "id": "_33OHL3b84i0" }, "source": [ "これらのベクトルのインデックスそれぞれが一つのタイムステップとして処理されます。タイムステップ 0 の入力として、モデルは \"F\" のインデックスを受け取り、次の文字として \"i\" のインデックスを予測しようとします。次のタイムステップでもおなじことをしますが、`RNN` は現在の入力文字に加えて、過去のステップのコンテキストも考慮します。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:20.828725Z", "iopub.status.busy": "2022-12-14T23:58:20.828137Z", "iopub.status.idle": "2022-12-14T23:58:20.845958Z", "shell.execute_reply": "2022-12-14T23:58:20.845083Z" }, "id": "0eBu9WZG84i0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0\n", " input: 18 ('F')\n", " expected output: 47 ('i')\n", "Step 1\n", " input: 47 ('i')\n", " expected output: 56 ('r')\n", "Step 2\n", " input: 56 ('r')\n", " expected output: 57 ('s')\n", "Step 3\n", " input: 57 ('s')\n", " expected output: 58 ('t')\n", "Step 4\n", " input: 58 ('t')\n", " expected output: 1 (' ')\n" ] } ], "source": [ "for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):\n", " print(\"Step {:4d}\".format(i))\n", " print(\" input: {} ({:s})\".format(input_idx, repr(idx2char[input_idx])))\n", " print(\" expected output: {} ({:s})\".format(target_idx, repr(idx2char[target_idx])))" ] }, { "cell_type": "markdown", "metadata": { "id": "MJdfPmdqzf-R" }, "source": [ "### 訓練用バッチの作成\n", "\n", "`tf.data` を使ってテキストを分割し、扱いやすいシーケンスにします。しかし、このデータをモデルに供給する前に、データをシャッフルしてバッチにまとめる必要があります。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:58:20.850008Z", "iopub.status.busy": "2022-12-14T23:58:20.849444Z", "iopub.status.idle": "2022-12-14T23:58:20.861858Z", "shell.execute_reply": "2022-12-14T23:58:20.861119Z" }, "id": "p2pGotuNzf-S" }, "outputs": [ { "data": { "text/plain": [ "