{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "s_qNSzzyaCbD" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T23:10:11.489585Z", "iopub.status.busy": "2022-12-14T23:10:11.489146Z", "iopub.status.idle": "2022-12-14T23:10:11.492857Z", "shell.execute_reply": "2022-12-14T23:10:11.492331Z" }, "id": "jmjh290raIky" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "J0Qjg6vuaHNt" }, "source": [ "# アテンションを用いたニューラル機械翻訳" ] }, { "cell_type": "markdown", "metadata": { "id": "tudzcncJXetB" }, "source": [ "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。" ] }, { "cell_type": "markdown", "metadata": { "id": "AOpGoE2T-YXS" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "CiwtNgENbx2g" }, "source": [ "このノートブックでは、スペイン語から英語への翻訳を行う Sequence to Sequence (seq2seq) モデルを訓練します。このチュートリアルは、 Sequence to Sequence モデルの知識があることを前提にした上級編のサンプルです。\n", "\n", "このノートブックのモデルを訓練すると、_\"¿todavia estan en casa?\"_ のようなスペイン語の文を入力して、英訳: _\"are you still at home?\"_ を得ることができます。\n", "\n", "この翻訳品質はおもちゃとしてはそれなりのものですが、生成されたアテンションの図表の方が面白いかもしれません。これは、翻訳時にモデルが入力文のどの部分に注目しているかを表しています。\n", "\n", "\"spanish-english\n", "\n", "Note: このサンプルは P100 GPU 1基で実行した場合に約 10 分かかります。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:11.496381Z", "iopub.status.busy": "2022-12-14T23:10:11.495946Z", "iopub.status.idle": "2022-12-14T23:10:13.961690Z", "shell.execute_reply": "2022-12-14T23:10:13.961006Z" }, "id": "tnxXKDjq3jEL" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 23:10:12.430372: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 23:10:12.430471: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 23:10:12.430481: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "from sklearn.model_selection import train_test_split\n", "\n", "import unicodedata\n", "import re\n", "import numpy as np\n", "import os\n", "import io\n", "import time" ] }, { "cell_type": "markdown", "metadata": { "id": "wfodePkj3jEa" }, "source": [ "## データセットのダウンロードと準備\n", "\n", "ここでは、http://www.manythings.org/anki/ で提供されている言語データセットを使用します。このデータセットには、次のような書式の言語翻訳ペアが含まれています。\n", "\n", "\n", "```\n", "May I borrow this book?\t¿Puedo tomar prestado este libro?\n", "```\n", "\n", "さまざまな言語が用意されていますが、ここでは英語ースペイン語のデータセットを使用します。利便性を考えてこのデータセットは Google Cloud 上に用意してありますが、ご自分でダウンロードすることも可能です。データセットをダウンロードしたあと、データを準備するために下記のようないくつかの手順を実行します。\n", "\n", "1. それぞれの文ごとに、*開始* と *終了* のトークンを付加する\n", "2. 特殊文字を除去して文をきれいにする\n", "3. 単語インデックスと逆単語インデックス(単語 → id と id → 単語のマッピングを行うディクショナリ)を作成する\n", "4. 最大長にあわせて各文をパディングする" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:13.966226Z", "iopub.status.busy": "2022-12-14T23:10:13.965458Z", "iopub.status.idle": "2022-12-14T23:10:14.080520Z", "shell.execute_reply": "2022-12-14T23:10:14.079926Z" }, "id": "kRVATYOgJs1b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/2638744 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2638744/2638744 [==============================] - 0s 0us/step\n" ] } ], "source": [ "# ファイルのダウンロード\n", "path_to_zip = tf.keras.utils.get_file(\n", " 'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',\n", " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:14.083878Z", "iopub.status.busy": "2022-12-14T23:10:14.083629Z", "iopub.status.idle": "2022-12-14T23:10:14.088456Z", "shell.execute_reply": "2022-12-14T23:10:14.087890Z" }, "id": "rd0jw-eC3jEh" }, "outputs": [], "source": [ "# ユニコードファイルを ascii に変換\n", "def unicode_to_ascii(s):\n", " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", " if unicodedata.category(c) != 'Mn')\n", "\n", "\n", "def preprocess_sentence(w):\n", " w = unicode_to_ascii(w.lower().strip())\n", "\n", " # 単語とそのあとの句読点の間にスペースを挿入\n", " # 例: \"he is a boy.\" => \"he is a boy .\"\n", " # 参照:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", "\n", " # (a-z, A-Z, \".\", \"?\", \"!\", \",\") 以外の全ての文字をスペースに置き換え\n", " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", "\n", " w = w.rstrip().strip()\n", "\n", " # 文の開始と終了のトークンを付加\n", " # モデルが予測をいつ開始し、いつ終了すれば良いかを知らせるため\n", " w = ' ' + w + ' '\n", " return w" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:14.091708Z", "iopub.status.busy": "2022-12-14T23:10:14.091191Z", "iopub.status.idle": "2022-12-14T23:10:14.095075Z", "shell.execute_reply": "2022-12-14T23:10:14.094499Z" }, "id": "opI2GzOt479E" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " may i borrow this book ? \n", "b' \\xc2\\xbf puedo tomar prestado este libro ? '\n" ] } ], "source": [ "en_sentence = u\"May I borrow this book?\"\n", "sp_sentence = u\"¿Puedo tomar prestado este libro?\"\n", "print(preprocess_sentence(en_sentence))\n", "print(preprocess_sentence(sp_sentence).encode('utf-8'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:14.098358Z", "iopub.status.busy": "2022-12-14T23:10:14.097911Z", "iopub.status.idle": "2022-12-14T23:10:14.101680Z", "shell.execute_reply": "2022-12-14T23:10:14.101127Z" }, "id": "OHn4Dct23jEm" }, "outputs": [], "source": [ "# 1. アクセント記号を除去\n", "# 2. 文をクリーニング\n", "# 3. [ENGLISH, SPANISH] の形で単語のペアを返す\n", "def create_dataset(path, num_examples):\n", " lines = io.open(path, encoding='UTF-8').read().strip().split('\\n')\n", "\n", " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", "\n", " return zip(*word_pairs)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:14.104736Z", "iopub.status.busy": "2022-12-14T23:10:14.104330Z", "iopub.status.idle": "2022-12-14T23:10:17.998563Z", "shell.execute_reply": "2022-12-14T23:10:17.997907Z" }, "id": "cTbSbBz55QtF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " if you want to sound like a native speaker , you must be willing to practice saying the same sentence over and over in the same way that banjo players practice the same phrase over and over until they can play it correctly and at the desired tempo . \n", " si quieres sonar como un hablante nativo , debes estar dispuesto a practicar diciendo la misma frase una y otra vez de la misma manera en que un musico de banjo practica el mismo fraseo una y otra vez hasta que lo puedan tocar correctamente y en el tiempo esperado . \n" ] } ], "source": [ "en, sp = create_dataset(path_to_file, None)\n", "print(en[-1])\n", "print(sp[-1])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:18.002280Z", "iopub.status.busy": "2022-12-14T23:10:18.001795Z", "iopub.status.idle": "2022-12-14T23:10:18.005098Z", "shell.execute_reply": "2022-12-14T23:10:18.004540Z" }, "id": "OmMZQpdO60dt" }, "outputs": [], "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:18.008332Z", "iopub.status.busy": "2022-12-14T23:10:18.007825Z", "iopub.status.idle": "2022-12-14T23:10:18.011579Z", "shell.execute_reply": "2022-12-14T23:10:18.011028Z" }, "id": "bIOn8RCNDJXG" }, "outputs": [], "source": [ "def tokenize(lang):\n", " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n", " filters='')\n", " lang_tokenizer.fit_on_texts(lang)\n", "\n", " tensor = lang_tokenizer.texts_to_sequences(lang)\n", "\n", " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n", " padding='post')\n", "\n", " return tensor, lang_tokenizer" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:18.014661Z", "iopub.status.busy": "2022-12-14T23:10:18.014217Z", "iopub.status.idle": "2022-12-14T23:10:18.017795Z", "shell.execute_reply": "2022-12-14T23:10:18.017249Z" }, "id": "eAY9k49G3jE_" }, "outputs": [], "source": [ "def load_dataset(path, num_examples=None):\n", " # クリーニングされた入力と出力のペアを生成\n", " targ_lang, inp_lang = create_dataset(path, num_examples)\n", "\n", " input_tensor, inp_lang_tokenizer = tokenize(inp_lang)\n", " target_tensor, targ_lang_tokenizer = tokenize(targ_lang)\n", "\n", " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" ] }, { "cell_type": "markdown", "metadata": { "id": "GOi42V79Ydlr" }, "source": [ "### 実験を速くするためデータセットのサイズを制限(オプション)\n", "\n", "100,000 を超える文のデータセットを使って訓練するには長い時間がかかります。訓練を速くするため、データセットのサイズを 30,000 に制限することができます(もちろん、データが少なければ翻訳の品質は低下します)。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:18.021090Z", "iopub.status.busy": "2022-12-14T23:10:18.020592Z", "iopub.status.idle": "2022-12-14T23:10:19.584221Z", "shell.execute_reply": "2022-12-14T23:10:19.583521Z" }, "id": "cnxC7q-j3jFD" }, "outputs": [], "source": [ "# このサイズのデータセットで実験\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)\n", "\n", "# ターゲットテンソルの最大長を計算\n", "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:19.588570Z", "iopub.status.busy": "2022-12-14T23:10:19.587903Z", "iopub.status.idle": "2022-12-14T23:10:19.596336Z", "shell.execute_reply": "2022-12-14T23:10:19.595698Z" }, "id": "4QILQkOs3jFG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "24000 24000 6000 6000\n" ] } ], "source": [ "# 80-20で分割を行い、訓練用と検証用のデータセットを作成\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# 長さを表示\n", "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:19.599578Z", "iopub.status.busy": "2022-12-14T23:10:19.599064Z", "iopub.status.idle": "2022-12-14T23:10:19.602302Z", "shell.execute_reply": "2022-12-14T23:10:19.601766Z" }, "id": "lJPmLZGMeD5q" }, "outputs": [], "source": [ "def convert(lang, tensor):\n", " for t in tensor:\n", " if t!=0:\n", " print (\"%d ----> %s\" % (t, lang.index_word[t]))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:19.605340Z", "iopub.status.busy": "2022-12-14T23:10:19.604868Z", "iopub.status.idle": "2022-12-14T23:10:19.608746Z", "shell.execute_reply": "2022-12-14T23:10:19.608167Z" }, "id": "VXukARTDd7MT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Language; index to word mapping\n", "1 ----> \n", "4 ----> tom\n", "7 ----> es\n", "36 ----> muy\n", "2138 ----> capaz\n", "3 ----> .\n", "2 ----> \n", "\n", "Target Language; index to word mapping\n", "1 ----> \n", "5 ----> tom\n", "8 ----> is\n", "48 ----> very\n", "4910 ----> competent\n", "3 ----> .\n", "2 ----> \n" ] } ], "source": [ "print (\"Input Language; index to word mapping\")\n", "convert(inp_lang, input_tensor_train[0])\n", "print ()\n", "print (\"Target Language; index to word mapping\")\n", "convert(targ_lang, target_tensor_train[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "rgCLkfv5uO3d" }, "source": [ "### tf.data データセットの作成" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:19.611971Z", "iopub.status.busy": "2022-12-14T23:10:19.611432Z", "iopub.status.idle": "2022-12-14T23:10:22.968347Z", "shell.execute_reply": "2022-12-14T23:10:22.967632Z" }, "id": "TqHsArVZ3jFS" }, "outputs": [], "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", "steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n", "embedding_dim = 256\n", "units = 1024\n", "vocab_inp_size = len(inp_lang.word_index)+1\n", "vocab_tar_size = len(targ_lang.word_index)+1\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:22.971669Z", "iopub.status.busy": "2022-12-14T23:10:22.971418Z", "iopub.status.idle": "2022-12-14T23:10:23.024947Z", "shell.execute_reply": "2022-12-14T23:10:23.024260Z" }, "id": "qc6-NK1GtWQt" }, "outputs": [ { "data": { "text/plain": [ "(TensorShape([64, 16]), TensorShape([64, 11]))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_input_batch, example_target_batch = next(iter(dataset))\n", "example_input_batch.shape, example_target_batch.shape" ] }, { "cell_type": "markdown", "metadata": { "id": "TNfHIF71ulLu" }, "source": [ "## エンコーダー・デコーダーモデルの記述\n", "\n", "TensorFlow の [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt) に記載されているアテンション付きのエンコーダー・デコーダーモデルを実装します。この例では、最新の API セットを使用します。このノートブックは、上記の seq2seq チュートリアルにある [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) を実装します。下図は、入力の単語ひとつひとつにアテンション機構によって重みが割り当てられ、それを使ってデコーダーが文中の次の単語を予測することを示しています。下記の図と式は [Luong の論文](https://arxiv.org/abs/1508.04025v5) にあるアテンション機構の例です。\n", "\n", "\"attention\n", "\n", "入力がエンコーダーを通過すると、shape が *(batch_size, max_length, hidden_size)* のエンコーダー出力と、shape が *(batch_size, hidden_size)* のエンコーダーの隠れ状態が得られます。\n", "\n", "下記に実装されている式を示します。\n", "\n", "\"attention\n", "\"attention\n", "\n", "このチュートリアルでは、エンコーダーでは [Bahdanau attention](https://arxiv.org/pdf/1409.0473.pdf) を使用します。簡略化した式を書く前に、表記方法を定めましょう。\n", "\n", "* FC = 全結合 (Dense) レイヤー\n", "* EO = エンコーダーの出力\n", "* H = 隠れ状態\n", "* X = デコーダーへの入力\n", "\n", "\n", "擬似コードは下記のとおりです。\n", "\n", "* `score = FC(tanh(FC(EO) + FC(H)))`\n", "* `attention weights = softmax(score, axis = 1)` softmax は既定では最後の軸に対して実行されますが、スコアの shape が *(batch_size, max_length, hidden_size)* であるため、*最初の軸* に適用します。`max_length` は入力の長さです。入力それぞれに重みを割り当てようとしているので、softmax はその軸に適用されなければなりません。\n", "* `context vector = sum(attention weights * EO, axis = 1)`. 上記と同様の理由で axis = 1 に設定しています。\n", "* `embedding output` = デコーダーへの入力 X は Embedding レイヤーを通して渡されます。\n", "* `merged vector = concat(embedding output, context vector)`\n", "* この結合されたベクトルがつぎに GRU に渡されます。\n", "\n", "それぞれのステップでのベクトルの shape は、コードのコメントに指定されています。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.028512Z", "iopub.status.busy": "2022-12-14T23:10:23.027970Z", "iopub.status.idle": "2022-12-14T23:10:23.033200Z", "shell.execute_reply": "2022-12-14T23:10:23.032611Z" }, "id": "nZ2rI24i3jFg" }, "outputs": [], "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", " super(Encoder, self).__init__()\n", " self.batch_sz = batch_sz\n", " self.enc_units = enc_units\n", " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", " self.gru = tf.keras.layers.GRU(self.enc_units,\n", " return_sequences=True,\n", " return_state=True,\n", " recurrent_initializer='glorot_uniform')\n", "\n", " def call(self, x, hidden):\n", " x = self.embedding(x)\n", " output, state = self.gru(x, initial_state = hidden)\n", " return output, state\n", "\n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.036483Z", "iopub.status.busy": "2022-12-14T23:10:23.035906Z", "iopub.status.idle": "2022-12-14T23:10:23.655432Z", "shell.execute_reply": "2022-12-14T23:10:23.654734Z" }, "id": "60gSVh05Jl6l" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoder output shape: (batch size, sequence length, units) (64, 16, 1024)\n", "Encoder Hidden state shape: (batch size, units) (64, 1024)\n" ] } ], "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "\n", "# サンプル入力\n", "sample_hidden = encoder.initialize_hidden_state()\n", "sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)\n", "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n", "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.659318Z", "iopub.status.busy": "2022-12-14T23:10:23.658725Z", "iopub.status.idle": "2022-12-14T23:10:23.664300Z", "shell.execute_reply": "2022-12-14T23:10:23.663742Z" }, "id": "umohpBN2OM94" }, "outputs": [], "source": [ "class BahdanauAttention(tf.keras.layers.Layer):\n", " def __init__(self, units):\n", " super(BahdanauAttention, self).__init__()\n", " self.W1 = tf.keras.layers.Dense(units)\n", " self.W2 = tf.keras.layers.Dense(units)\n", " self.V = tf.keras.layers.Dense(1)\n", "\n", " def call(self, query, values):\n", " # hidden shape == (batch_size, hidden size)\n", " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n", " # スコアを計算するためにこのように加算を実行する\n", " hidden_with_time_axis = tf.expand_dims(query, 1)\n", "\n", " # score shape == (batch_size, max_length, 1)\n", " # スコアを self.V に適用するために最後の軸は 1 となる\n", " # self.V に適用する前のテンソルの shape は (batch_size, max_length, units)\n", " score = self.V(tf.nn.tanh(\n", " self.W1(values) + self.W2(hidden_with_time_axis)))\n", "\n", " # attention_weights の shape == (batch_size, max_length, 1)\n", " attention_weights = tf.nn.softmax(score, axis=1)\n", "\n", " # context_vector の合計後の shape == (batch_size, hidden_size)\n", " context_vector = attention_weights * values\n", " context_vector = tf.reduce_sum(context_vector, axis=1)\n", "\n", " return context_vector, attention_weights" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.667603Z", "iopub.status.busy": "2022-12-14T23:10:23.667070Z", "iopub.status.idle": "2022-12-14T23:10:23.707933Z", "shell.execute_reply": "2022-12-14T23:10:23.707349Z" }, "id": "k534zTHiDjQU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Attention result shape: (batch size, units) (64, 1024)\n", "Attention weights shape: (batch_size, sequence_length, 1) (64, 16, 1)\n" ] } ], "source": [ "attention_layer = BahdanauAttention(10)\n", "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n", "\n", "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n", "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.711232Z", "iopub.status.busy": "2022-12-14T23:10:23.710686Z", "iopub.status.idle": "2022-12-14T23:10:23.716748Z", "shell.execute_reply": "2022-12-14T23:10:23.716130Z" }, "id": "yJ_B3mhW3jFk" }, "outputs": [], "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", " super(Decoder, self).__init__()\n", " self.batch_sz = batch_sz\n", " self.dec_units = dec_units\n", " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", " self.gru = tf.keras.layers.GRU(self.dec_units,\n", " return_sequences=True,\n", " return_state=True,\n", " recurrent_initializer='glorot_uniform')\n", " self.fc = tf.keras.layers.Dense(vocab_size)\n", "\n", " # アテンションのため\n", " self.attention = BahdanauAttention(self.dec_units)\n", "\n", " def call(self, x, hidden, enc_output):\n", " # enc_output の shape == (batch_size, max_length, hidden_size)\n", " context_vector, attention_weights = self.attention(hidden, enc_output)\n", "\n", " # 埋め込み層を通過したあとの x の shape == (batch_size, 1, embedding_dim)\n", " x = self.embedding(x)\n", "\n", " # 結合後の x の shape == (batch_size, 1, embedding_dim + hidden_size)\n", " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", "\n", " # 結合したベクトルを GRU 層に渡す\n", " output, state = self.gru(x)\n", "\n", " # output shape == (batch_size * 1, hidden_size)\n", " output = tf.reshape(output, (-1, output.shape[2]))\n", "\n", " # output shape == (batch_size, vocab)\n", " x = self.fc(output)\n", "\n", " return x, state, attention_weights" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.719802Z", "iopub.status.busy": "2022-12-14T23:10:23.719257Z", "iopub.status.idle": "2022-12-14T23:10:23.781576Z", "shell.execute_reply": "2022-12-14T23:10:23.781013Z" }, "id": "P5UY8wko3jFp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Decoder output shape: (batch_size, vocab size) (64, 4935)\n" ] } ], "source": [ "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n", "\n", "sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),\n", " sample_hidden, sample_output)\n", "\n", "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))" ] }, { "cell_type": "markdown", "metadata": { "id": "_ch_71VbIRfK" }, "source": [ "## オプティマイザと損失関数の定義" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.784627Z", "iopub.status.busy": "2022-12-14T23:10:23.784187Z", "iopub.status.idle": "2022-12-14T23:10:23.791050Z", "shell.execute_reply": "2022-12-14T23:10:23.790486Z" }, "id": "WmTHr5iV3jFr" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam()\n", "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", " from_logits=True, reduction='none')\n", "\n", "def loss_function(real, pred):\n", " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", " loss_ = loss_object(real, pred)\n", "\n", " mask = tf.cast(mask, dtype=loss_.dtype)\n", " loss_ *= mask\n", "\n", " return tf.reduce_mean(loss_)" ] }, { "cell_type": "markdown", "metadata": { "id": "DMVWzzsfNl4e" }, "source": [ "## チェックポイント(オブジェクトベースの保存)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.794522Z", "iopub.status.busy": "2022-12-14T23:10:23.793979Z", "iopub.status.idle": "2022-12-14T23:10:23.797472Z", "shell.execute_reply": "2022-12-14T23:10:23.796929Z" }, "id": "Zj8bXQTgNwrF" }, "outputs": [], "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" ] }, { "cell_type": "markdown", "metadata": { "id": "hpObfY22IddU" }, "source": [ "## 訓練\n", "\n", "1. *入力* を *エンコーダー* に通すと、*エンコーダー出力* と *エンコーダーの隠れ状態* が返される\n", "2. エンコーダーの出力とエンコーダーの隠れ状態、そしてデコーダーの入力(これが *開始トークン*)がデコーダーに渡される\n", "3. デコーダーは *予測値* と *デコーダーの隠れ状態* を返す\n", "4. つぎにデコーダーの隠れ状態がモデルに戻され、予測値が損失関数の計算に使用される\n", "5. デコーダーへの次の入力を決定するために *Teacher Forcing* が使用される\n", "6. *Teacher Forcing* は、*正解単語* をデコーダーの *次の入力* として使用するテクニックである\n", "7. 最後に勾配を計算し、それをオプティマイザに与えて誤差逆伝播を行う" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.800778Z", "iopub.status.busy": "2022-12-14T23:10:23.800335Z", "iopub.status.idle": "2022-12-14T23:10:23.805612Z", "shell.execute_reply": "2022-12-14T23:10:23.805020Z" }, "id": "sC9ArXSsVfqn" }, "outputs": [], "source": [ "@tf.function\n", "def train_step(inp, targ, enc_hidden):\n", " loss = 0\n", "\n", " with tf.GradientTape() as tape:\n", " enc_output, enc_hidden = encoder(inp, enc_hidden)\n", "\n", " dec_hidden = enc_hidden\n", "\n", " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n", "\n", " # Teacher Forcing - 正解値を次の入力として供給\n", " for t in range(1, targ.shape[1]):\n", " # passing enc_output to the decoder\n", " predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n", "\n", " loss += loss_function(targ[:, t], predictions)\n", "\n", " # Teacher Forcing を使用\n", " dec_input = tf.expand_dims(targ[:, t], 1)\n", "\n", " batch_loss = (loss / int(targ.shape[1]))\n", "\n", " variables = encoder.trainable_variables + decoder.trainable_variables\n", "\n", " gradients = tape.gradient(loss, variables)\n", "\n", " optimizer.apply_gradients(zip(gradients, variables))\n", "\n", " return batch_loss" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:10:23.808713Z", "iopub.status.busy": "2022-12-14T23:10:23.808272Z", "iopub.status.idle": "2022-12-14T23:14:07.201703Z", "shell.execute_reply": "2022-12-14T23:14:07.200849Z" }, "id": "ddefjBMa3jF0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 Batch 0 Loss 4.7353\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 Batch 100 Loss 2.3236\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 Batch 200 Loss 1.6937\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 Batch 300 Loss 1.7070\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 Loss 2.0369\n", "Time taken for 1 epoch 38.7444052696228 sec\n", "\n", "Epoch 2 Batch 0 Loss 1.6072\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 Batch 100 Loss 1.4702\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 Batch 200 Loss 1.4290\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 Batch 300 Loss 1.3025\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 Loss 1.4043\n", "Time taken for 1 epoch 21.323351621627808 sec\n", "\n", "Epoch 3 Batch 0 Loss 1.0264\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 Batch 100 Loss 1.0971\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 Batch 200 Loss 0.9568\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 Batch 300 Loss 0.8830\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 Loss 0.9820\n", "Time taken for 1 epoch 20.229315757751465 sec\n", "\n", "Epoch 4 Batch 0 Loss 0.7010\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 Batch 100 Loss 0.7372\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 Batch 200 Loss 0.6545\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 Batch 300 Loss 0.5471\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 Loss 0.6590\n", "Time taken for 1 epoch 20.86075496673584 sec\n", "\n", "Epoch 5 Batch 0 Loss 0.3788\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 Batch 100 Loss 0.4551\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 Batch 200 Loss 0.5846\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 Batch 300 Loss 0.4743\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 Loss 0.4479\n", "Time taken for 1 epoch 20.013705730438232 sec\n", "\n", "Epoch 6 Batch 0 Loss 0.3106\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 Batch 100 Loss 0.2062\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 Batch 200 Loss 0.3290\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 Batch 300 Loss 0.3195\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 Loss 0.3081\n", "Time taken for 1 epoch 20.81093120574951 sec\n", "\n", "Epoch 7 Batch 0 Loss 0.2261\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 Batch 100 Loss 0.3102\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 Batch 200 Loss 0.1855\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 Batch 300 Loss 0.2770\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 Loss 0.2200\n", "Time taken for 1 epoch 20.144649267196655 sec\n", "\n", "Epoch 8 Batch 0 Loss 0.1757\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 Batch 100 Loss 0.1824\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 Batch 200 Loss 0.1366\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 Batch 300 Loss 0.1332\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 Loss 0.1618\n", "Time taken for 1 epoch 20.600513696670532 sec\n", "\n", "Epoch 9 Batch 0 Loss 0.1050\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 Batch 100 Loss 0.1242\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 Batch 200 Loss 0.1232\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 Batch 300 Loss 0.1080\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 Loss 0.1271\n", "Time taken for 1 epoch 20.18126344680786 sec\n", "\n", "Epoch 10 Batch 0 Loss 0.0934\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 Batch 100 Loss 0.1031\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 Batch 200 Loss 0.1194\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 Batch 300 Loss 0.1105\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 Loss 0.1017\n", "Time taken for 1 epoch 20.478492975234985 sec\n", "\n" ] } ], "source": [ "EPOCHS = 10\n", "\n", "for epoch in range(EPOCHS):\n", " start = time.time()\n", "\n", " enc_hidden = encoder.initialize_hidden_state()\n", " total_loss = 0\n", "\n", " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n", " batch_loss = train_step(inp, targ, enc_hidden)\n", " total_loss += batch_loss\n", "\n", " if batch % 100 == 0:\n", " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", " batch,\n", " batch_loss.numpy()))\n", " # 2 エポックごとにモデル(のチェックポイント)を保存\n", " if (epoch + 1) % 2 == 0:\n", " checkpoint.save(file_prefix = checkpoint_prefix)\n", "\n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / steps_per_epoch))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" ] }, { "cell_type": "markdown", "metadata": { "id": "mU3Ce8M6I3rz" }, "source": [ "## 翻訳\n", "\n", "* 評価関数は、*Teacher Forcing* を使わないことを除いては、訓練ループと同様である。タイムステップごとのデコーダーへの入力は、過去の予測値に加えて、隠れ状態とエンコーダーのアウトプットである。\n", "* モデルが *終了トークン* を予測したら、予測を停止する。\n", "* また、*タイムステップごとのアテンションの重み* を保存する。\n", "\n", "Note: エンコーダーの出力は 1 つの入力に対して 1 回だけ計算されます。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.205521Z", "iopub.status.busy": "2022-12-14T23:14:07.204937Z", "iopub.status.idle": "2022-12-14T23:14:07.211834Z", "shell.execute_reply": "2022-12-14T23:14:07.211228Z" }, "id": "EbQpyYs13jF_" }, "outputs": [], "source": [ "def evaluate(sentence):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", "\n", " sentence = preprocess_sentence(sentence)\n", "\n", " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n", " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n", " maxlen=max_length_inp,\n", " padding='post')\n", " inputs = tf.convert_to_tensor(inputs)\n", "\n", " result = ''\n", "\n", " hidden = [tf.zeros((1, units))]\n", " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input,\n", " dec_hidden,\n", " enc_out)\n", "\n", " # 後ほどプロットするためにアテンションの重みを保存\n", " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.index_word[predicted_id] + ' '\n", "\n", " if targ_lang.index_word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", "\n", " # 予測された ID がモデルに戻される\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.214977Z", "iopub.status.busy": "2022-12-14T23:14:07.214410Z", "iopub.status.idle": "2022-12-14T23:14:07.219347Z", "shell.execute_reply": "2022-12-14T23:14:07.218805Z" }, "id": "s5hQWlbN3jGF" }, "outputs": [], "source": [ "# アテンションの重みをプロットする関数\n", "def plot_attention(attention, sentence, predicted_sentence):\n", " fig = plt.figure(figsize=(10,10))\n", " ax = fig.add_subplot(1, 1, 1)\n", " ax.matshow(attention, cmap='viridis')\n", "\n", " fontdict = {'fontsize': 14}\n", "\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.222394Z", "iopub.status.busy": "2022-12-14T23:14:07.221960Z", "iopub.status.idle": "2022-12-14T23:14:07.225818Z", "shell.execute_reply": "2022-12-14T23:14:07.225194Z" }, "id": "sl9zUHzg3jGI" }, "outputs": [], "source": [ "def translate(sentence):\n", " result, sentence, attention_plot = evaluate(sentence)\n", "\n", " print('Input: %s' % (sentence))\n", " print('Predicted translation: {}'.format(result))\n", "\n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" ] }, { "cell_type": "markdown", "metadata": { "id": "n250XbnjOaqP" }, "source": [ "## 最後のチェックポイントを復元しテストする" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.228821Z", "iopub.status.busy": "2022-12-14T23:14:07.228402Z", "iopub.status.idle": "2022-12-14T23:14:07.459975Z", "shell.execute_reply": "2022-12-14T23:14:07.459249Z" }, "id": "UJpT9D5_OgP6" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# checkpoint_dir の中の最後のチェックポイントを復元\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.463620Z", "iopub.status.busy": "2022-12-14T23:14:07.463027Z", "iopub.status.idle": "2022-12-14T23:14:07.758396Z", "shell.execute_reply": "2022-12-14T23:14:07.757563Z" }, "id": "WrAM0FDomq3E" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: hace mucho frio aqui . \n", "Predicted translation: it s very cold here . \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_146998/2662029192.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", "/tmpfs/tmp/ipykernel_146998/2662029192.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "translate(u'hace mucho frio aqui.')" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:07.761513Z", "iopub.status.busy": "2022-12-14T23:14:07.761280Z", "iopub.status.idle": "2022-12-14T23:14:08.033508Z", "shell.execute_reply": "2022-12-14T23:14:08.032669Z" }, "id": "zSx2iM36EZQZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: esta es mi vida . \n", "Predicted translation: this is my life . \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_146998/2662029192.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", "/tmpfs/tmp/ipykernel_146998/2662029192.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "translate(u'esta es mi vida.')" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:08.036843Z", "iopub.status.busy": "2022-12-14T23:14:08.036603Z", "iopub.status.idle": "2022-12-14T23:14:08.331874Z", "shell.execute_reply": "2022-12-14T23:14:08.331043Z" }, "id": "A3LLCx3ZE0Ls" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: ¿ todavia estan en casa ? \n", "Predicted translation: are you still at home ? \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_146998/2662029192.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", "/tmpfs/tmp/ipykernel_146998/2662029192.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "translate(u'¿todavia estan en casa?')" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T23:14:08.334981Z", "iopub.status.busy": "2022-12-14T23:14:08.334744Z", "iopub.status.idle": "2022-12-14T23:14:08.599511Z", "shell.execute_reply": "2022-12-14T23:14:08.598685Z" }, "id": "DUQVLVqUE1YW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input: trata de averiguarlo . \n", "Predicted translation: try to figure it out . \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_146998/2662029192.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", "/tmpfs/tmp/ipykernel_146998/2662029192.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 翻訳あやまりの例\n", "translate(u'trata de averiguarlo.')" ] }, { "cell_type": "markdown", "metadata": { "id": "RTe5P5ioMJwN" }, "source": [ "## 次のステップ\n", "\n", "* [異なるデータセットをダウンロード](http://www.manythings.org/anki/)して翻訳の実験を行ってみよう。たとえば英語からドイツ語や、英語からフランス語。\n", "* もっと大きなデータセットで訓練を行ったり、もっと多くのエポックで訓練を行ったりしてみよう。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "nmt_with_attention.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }