{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "OoasdhSAp0zJ" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T21:53:57.110235Z", "iopub.status.busy": "2024-01-11T21:53:57.110023Z", "iopub.status.idle": "2024-01-11T21:53:57.113971Z", "shell.execute_reply": "2024-01-11T21:53:57.113278Z" }, "id": "cIrwotvGqsYh" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "C81KT2D_j-xR" }, "source": [ "# Estimator で線形モデルを構築する\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "JOccPOFMm5Tc" }, "source": [ "> 警告: 新しいコードには Estimators は推奨されません。Estimators は `v1.Session` スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、[互換性保証](https://tensorflow.org/guide/versions)の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、[移行ガイド](https://tensorflow.org/guide/migrate)を参照してください。" ] }, { "cell_type": "markdown", "metadata": { "id": "tUP8LMdYtWPz" }, "source": [ "## 概要\n", "\n", "このエンドツーエンドのウォークスルーでは、`tf.estimator` API を使用してロジスティック回帰モデルをトレーニングします。このモデルはほかのより複雑なアルゴリズムの基準としてよく使用されます。\n", "\n", "注意: Keras によるロジスティック回帰の例は[こちら](https://tensorflow.org/guide/migrate/tutorials/keras/regression)からご覧いただけます。これは、本チュートリアルよりも推奨されます。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vkC_j6VpqrDw" }, "source": [ "## セットアップ" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:53:57.117673Z", "iopub.status.busy": "2024-01-11T21:53:57.117423Z", "iopub.status.idle": "2024-01-11T21:53:57.934590Z", "shell.execute_reply": "2024-01-11T21:53:57.933800Z" }, "id": "rutbJGmpqvm3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting sklearn\r\n", " Using cached sklearn-0.0.post12.tar.gz (2.6 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l-\b \berror\r\n", " \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\r\n", " \r\n", " \u001b[31m×\u001b[0m \u001b[32mpython setup.py egg_info\u001b[0m did not run successfully.\r\n", " \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\r\n", " \u001b[31m╰─>\u001b[0m \u001b[31m[15 lines of output]\u001b[0m\r\n", " \u001b[31m \u001b[0m The 'sklearn' PyPI package is deprecated, use 'scikit-learn'\r\n", " \u001b[31m \u001b[0m rather than 'sklearn' for pip commands.\r\n", " \u001b[31m \u001b[0m \r\n", " \u001b[31m \u001b[0m Here is how to fix this error in the main use cases:\r\n", " \u001b[31m \u001b[0m - use 'pip install scikit-learn' rather than 'pip install sklearn'\r\n", " \u001b[31m \u001b[0m - replace 'sklearn' by 'scikit-learn' in your pip requirements files\r\n", " \u001b[31m \u001b[0m (requirements.txt, setup.py, setup.cfg, Pipfile, etc ...)\r\n", " \u001b[31m \u001b[0m - if the 'sklearn' package is used by one of your dependencies,\r\n", " \u001b[31m \u001b[0m it would be great if you take some time to track which package uses\r\n", " \u001b[31m \u001b[0m 'sklearn' instead of 'scikit-learn' and report it to their issue tracker\r\n", " \u001b[31m \u001b[0m - as a last resort, set the environment variable\r\n", " \u001b[31m \u001b[0m SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True to avoid this error\r\n", " \u001b[31m \u001b[0m \r\n", " \u001b[31m \u001b[0m More information is available at\r\n", " \u001b[31m \u001b[0m https://github.com/scikit-learn/sklearn-pypi-package\r\n", " \u001b[31m \u001b[0m \u001b[31m[end of output]\u001b[0m\r\n", " \r\n", " \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\r\n", "\u001b[1;31merror\u001b[0m: \u001b[1mmetadata-generation-failed\u001b[0m\r\n", "\r\n", "\u001b[31m×\u001b[0m Encountered error while generating package metadata.\r\n", "\u001b[31m╰─>\u001b[0m See above for output.\r\n", "\r\n", "\u001b[1;35mnote\u001b[0m: This is an issue with the package mentioned above, not pip.\r\n", "\u001b[1;36mhint\u001b[0m: See above for details.\r\n", "\u001b[?25h" ] } ], "source": [ "!pip install sklearn\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:53:57.938673Z", "iopub.status.busy": "2024-01-11T21:53:57.938384Z", "iopub.status.idle": "2024-01-11T21:53:58.586191Z", "shell.execute_reply": "2024-01-11T21:53:58.585529Z" }, "id": "54mb4J9PqqDh" }, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "from six.moves import urllib" ] }, { "cell_type": "markdown", "metadata": { "id": "fsjkwfsGOBMT" }, "source": [ "## Titanic データセットを読み込む\n", "\n", "Titanic データセットを使用して、性別、年齢、船室クラスなどの特性に基づき、(やや悪趣味ではありますが)乗船者の生存を予測することを目標とします。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:53:58.590601Z", "iopub.status.busy": "2024-01-11T21:53:58.590089Z", "iopub.status.idle": "2024-01-11T21:54:00.532818Z", "shell.execute_reply": "2024-01-11T21:54:00.531876Z" }, "id": "bNiwh-APcRVD" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 21:53:58.890761: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 21:53:58.890806: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 21:53:58.892379: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow.compat.v2.feature_column as fc\n", "\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.536979Z", "iopub.status.busy": "2024-01-11T21:54:00.536521Z", "iopub.status.idle": "2024-01-11T21:54:00.640919Z", "shell.execute_reply": "2024-01-11T21:54:00.640309Z" }, "id": "DSeMKcx03d5R" }, "outputs": [], "source": [ "# Load dataset.\n", "dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n", "dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n", "y_train = dftrain.pop('survived')\n", "y_eval = dfeval.pop('survived')" ] }, { "cell_type": "markdown", "metadata": { "id": "jjm4Qj0u7_cp" }, "source": [ "## データを確認する" ] }, { "cell_type": "markdown", "metadata": { "id": "UrQzxKKh4d6u" }, "source": [ "データセットには、次の特徴量が含まれます。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.644566Z", "iopub.status.busy": "2024-01-11T21:54:00.644098Z", "iopub.status.idle": "2024-01-11T21:54:00.657256Z", "shell.execute_reply": "2024-01-11T21:54:00.656626Z" }, "id": "rTjugo3n308g" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sexagen_siblings_spousesparchfareclassdeckembark_townalone
0male22.0107.2500ThirdunknownSouthamptonn
1female38.01071.2833FirstCCherbourgn
2female26.0007.9250ThirdunknownSouthamptony
3female35.01053.1000FirstCSouthamptonn
4male28.0008.4583ThirdunknownQueenstowny
\n", "
" ], "text/plain": [ " sex age n_siblings_spouses parch fare class deck \\\n", "0 male 22.0 1 0 7.2500 Third unknown \n", "1 female 38.0 1 0 71.2833 First C \n", "2 female 26.0 0 0 7.9250 Third unknown \n", "3 female 35.0 1 0 53.1000 First C \n", "4 male 28.0 0 0 8.4583 Third unknown \n", "\n", " embark_town alone \n", "0 Southampton n \n", "1 Cherbourg n \n", "2 Southampton y \n", "3 Southampton n \n", "4 Queenstown y " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dftrain.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.660299Z", "iopub.status.busy": "2024-01-11T21:54:00.659922Z", "iopub.status.idle": "2024-01-11T21:54:00.674346Z", "shell.execute_reply": "2024-01-11T21:54:00.673716Z" }, "id": "y86q1fj44lZs" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agen_siblings_spousesparchfare
count627.000000627.000000627.000000627.000000
mean29.6313080.5454550.37958534.385399
std12.5118181.1510900.79299954.597730
min0.7500000.0000000.0000000.000000
25%23.0000000.0000000.0000007.895800
50%28.0000000.0000000.00000015.045800
75%35.0000001.0000000.00000031.387500
max80.0000008.0000005.000000512.329200
\n", "
" ], "text/plain": [ " age n_siblings_spouses parch fare\n", "count 627.000000 627.000000 627.000000 627.000000\n", "mean 29.631308 0.545455 0.379585 34.385399\n", "std 12.511818 1.151090 0.792999 54.597730\n", "min 0.750000 0.000000 0.000000 0.000000\n", "25% 23.000000 0.000000 0.000000 7.895800\n", "50% 28.000000 0.000000 0.000000 15.045800\n", "75% 35.000000 1.000000 0.000000 31.387500\n", "max 80.000000 8.000000 5.000000 512.329200" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dftrain.describe()" ] }, { "cell_type": "markdown", "metadata": { "id": "8JSa_duD4tFZ" }, "source": [ "トレーニングセットと評価セットには、それぞれ 627 個と 264 個の例があります。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.677689Z", "iopub.status.busy": "2024-01-11T21:54:00.677466Z", "iopub.status.idle": "2024-01-11T21:54:00.681720Z", "shell.execute_reply": "2024-01-11T21:54:00.681117Z" }, "id": "Fs3Nu5pV4v5J" }, "outputs": [ { "data": { "text/plain": [ "(627, 264)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dftrain.shape[0], dfeval.shape[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "RxCA4Nr45AfF" }, "source": [ "乗船者の大半は 20 代から 30 代です。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.684877Z", "iopub.status.busy": "2024-01-11T21:54:00.684507Z", "iopub.status.idle": "2024-01-11T21:54:00.867254Z", "shell.execute_reply": "2024-01-11T21:54:00.866677Z" }, "id": "RYeCMm7K40ZN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dftrain.age.hist(bins=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "DItSwJ_B5B0f" }, "source": [ "男性の乗船者数は女性の乗船者数の約 2 倍です。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.870721Z", "iopub.status.busy": "2024-01-11T21:54:00.870316Z", "iopub.status.idle": "2024-01-11T21:54:00.984998Z", "shell.execute_reply": "2024-01-11T21:54:00.984417Z" }, "id": "b03dVV9q5Dv2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dftrain.sex.value_counts().plot(kind='barh')" ] }, { "cell_type": "markdown", "metadata": { "id": "rK6WQ29q5Jf5" }, "source": [ "乗船者の大半は「3 等」の船室クラスを利用していました。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:00.988232Z", "iopub.status.busy": "2024-01-11T21:54:00.987764Z", "iopub.status.idle": "2024-01-11T21:54:01.099370Z", "shell.execute_reply": "2024-01-11T21:54:01.098688Z" }, "id": "dgpJVeCq5Fgd" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dftrain['class'].value_counts().plot(kind='barh')" ] }, { "cell_type": "markdown", "metadata": { "id": "FXJhGGL85TLp" }, "source": [ "女性は男性よりも生存する確率がはるかに高く、これは明らかにモデルの予測特徴量です。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:01.102868Z", "iopub.status.busy": "2024-01-11T21:54:01.102232Z", "iopub.status.idle": "2024-01-11T21:54:01.220936Z", "shell.execute_reply": "2024-01-11T21:54:01.220365Z" }, "id": "lSZYa7c45Ttt" }, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, '% survive')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pd.concat([dftrain, y_train], axis=1).groupby('sex').survived.mean().plot(kind='barh').set_xlabel('% survive')" ] }, { "cell_type": "markdown", "metadata": { "id": "qCHvgeorEsHa" }, "source": [ "## モデルの特徴量エンジニアリング" ] }, { "cell_type": "markdown", "metadata": { "id": "Dhcq8Ds4mCtm" }, "source": [ "> 警告: このチュートリアルで説明されている tf.feature_columns モジュールは、新しいコードにはお勧めしません。Keras 前処理レイヤーがこの機能をカバーしています。移行手順については、[特徴量カラムの移行](https://www.tensorflow.org/guide/migrate/migrating_feature_columns)ガイドをご覧ください。tf.feature_columns モジュールは、TF1 Estimators で使用するために設計されました。[互換性保証](https://tensorflow.org/guide/versions)の対象となりますが、セキュリティの脆弱性以外の修正は行われません。" ] }, { "cell_type": "markdown", "metadata": { "id": "VqDKQLZn8L-B" }, "source": [ "Estimator は、モデルがどのように各行の入力特徴量を解釈すべきかを説明する[特徴量カラム](https://www.tensorflow.org/tutorials/structured_data/feature_columns)というシステムを使用しています。Estimator は数値入力のベクトルを期待しており、*特徴量カラム*にはモデルがどのように各特徴量を変換すべきかが記述されています。\n", "\n", "効率的なモデルを学習するには、適切な特徴カラムの選択と作成が鍵となります。特徴量カラムは、元の特徴量 `dict` の生の入力の 1 つ(*基本特徴量カラム*)または 1 つ以上の基本カラムに定義された変換を使って作成された新規カラム(*派生特徴量カラム*)のいずれかです。\n", "\n", "線形 Estimator は、数値特徴量とカテゴリカル特徴量の両方を利用します。特徴量カラムは TensorFlow Estimator と機能し、その目的はモデリングに使用される特徴量を定義することにあります。さらに、One-Hot エンコーディング、正規化、およびバケット化などのいくつかの特徴量エンジニアリング機能を提供します。" ] }, { "cell_type": "markdown", "metadata": { "id": "puZFOhTDkblt" }, "source": [ "### 基本特徴量カラム" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:01.224755Z", "iopub.status.busy": "2024-01-11T21:54:01.224137Z", "iopub.status.idle": "2024-01-11T21:54:01.231109Z", "shell.execute_reply": "2024-01-11T21:54:01.230549Z" }, "id": "GpveXYSsADS6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_978618/567449645.py:8: categorical_column_with_vocabulary_list (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_978618/567449645.py:11: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] } ], "source": [ "CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',\n", " 'embark_town', 'alone']\n", "NUMERIC_COLUMNS = ['age', 'fare']\n", "\n", "feature_columns = []\n", "for feature_name in CATEGORICAL_COLUMNS:\n", " vocabulary = dftrain[feature_name].unique()\n", " feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))\n", "\n", "for feature_name in NUMERIC_COLUMNS:\n", " feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))" ] }, { "cell_type": "markdown", "metadata": { "id": "Gt8HMtwOh9lJ" }, "source": [ "`input_function` は、入力パイプラインをストリーミングの手法でフィードする `tf.data.Dataset` にデータを変換する方法を指定します。`tf.data.Dataset` は、dataframe や csv 形式ファイルなど、複数のソースを取ることができます。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:01.234094Z", "iopub.status.busy": "2024-01-11T21:54:01.233664Z", "iopub.status.idle": "2024-01-11T21:54:01.238237Z", "shell.execute_reply": "2024-01-11T21:54:01.237629Z" }, "id": "qVtrIHFnAe7w" }, "outputs": [], "source": [ "def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):\n", " def input_function():\n", " ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))\n", " if shuffle:\n", " ds = ds.shuffle(1000)\n", " ds = ds.batch(batch_size).repeat(num_epochs)\n", " return ds\n", " return input_function\n", "\n", "train_input_fn = make_input_fn(dftrain, y_train)\n", "eval_input_fn = make_input_fn(dfeval, y_eval, num_epochs=1, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "P7UMVkQnkrgb" }, "source": [ "次のようにデータセットを検査できます。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:01.241676Z", "iopub.status.busy": "2024-01-11T21:54:01.241074Z", "iopub.status.idle": "2024-01-11T21:54:03.529403Z", "shell.execute_reply": "2024-01-11T21:54:03.528563Z" }, "id": "8ZcG_3KiCb1M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Some feature keys: ['sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']\n", "\n", "A batch of class: [b'Second' b'Third' b'Third' b'Third' b'First' b'Third' b'Second' b'Third'\n", " b'Third' b'First']\n", "\n", "A batch of Labels: [1 0 0 0 1 1 0 0 0 1]\n" ] } ], "source": [ "ds = make_input_fn(dftrain, y_train, batch_size=10)()\n", "for feature_batch, label_batch in ds.take(1):\n", " print('Some feature keys:', list(feature_batch.keys()))\n", " print()\n", " print('A batch of class:', feature_batch['class'].numpy())\n", " print()\n", " print('A batch of Labels:', label_batch.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "lMNBMyodjlW3" }, "source": [ "また、`tf.keras.layers.DenseFeatures` レイヤーを使用して、特定の特徴量カラムの結果を検査することもできます。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:03.533142Z", "iopub.status.busy": "2024-01-11T21:54:03.532472Z", "iopub.status.idle": "2024-01-11T21:54:03.570649Z", "shell.execute_reply": "2024-01-11T21:54:03.570043Z" }, "id": "IMjlmbPlDmkB" }, "outputs": [ { "data": { "text/plain": [ "array([[ 3.],\n", " [21.],\n", " [32.],\n", " [59.],\n", " [44.],\n", " [ 4.],\n", " [24.],\n", " [10.],\n", " [28.],\n", " [28.]], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "age_column = feature_columns[7]\n", "tf.keras.layers.DenseFeatures([age_column])(feature_batch).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "f4zrAdCIjr3s" }, "source": [ "`DenseFeatures` は密なテンソルのみを受け入れ、それを最初にインジケータカラムに変換する必要のあるカテゴリカラムを検査します。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:03.574161Z", "iopub.status.busy": "2024-01-11T21:54:03.573676Z", "iopub.status.idle": "2024-01-11T21:54:03.610497Z", "shell.execute_reply": "2024-01-11T21:54:03.609903Z" }, "id": "1VXmXFTSFEvv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_978618/1523458592.py:2: indicator_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "array([[1., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.],\n", " [1., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.]], dtype=float32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gender_column = feature_columns[0]\n", "tf.keras.layers.DenseFeatures([tf.feature_column.indicator_column(gender_column)])(feature_batch).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "MEp59g5UkHYY" }, "source": [ "すべての基本特徴量をモデルに追加したら、モデルをトレーニングすることにしましょう。モデルのトレーニングは、 `tf.estimator` API を使ってコマンド 1 つで行います。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:03.613898Z", "iopub.status.busy": "2024-01-11T21:54:03.613326Z", "iopub.status.idle": "2024-01-11T21:54:12.597772Z", "shell.execute_reply": "2024-01-11T21:54:12.597117Z" }, "id": "aGXjdnqqdgIs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'accuracy': 0.7386364, 'accuracy_baseline': 0.625, 'auc': 0.8361494, 'auc_precision_recall': 0.78821784, 'average_loss': 0.47708422, 'label/mean': 0.375, 'loss': 0.47104964, 'precision': 0.64705884, 'prediction/mean': 0.3915048, 'recall': 0.6666667, 'global_step': 200}\n" ] } ], "source": [ "linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)\n", "linear_est.train(train_input_fn)\n", "result = linear_est.evaluate(eval_input_fn)\n", "\n", "clear_output()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "3tOan4hDsG6d" }, "source": [ "### 派生特徴量カラム" ] }, { "cell_type": "markdown", "metadata": { "id": "NOG2FSTHlAMu" }, "source": [ "精度が 75% に達しました。それぞれの基本特徴量カラムを個別に使用しても、データを説明するには不足している場合があります。たとえば、年齢とラベルの間の相関関係は、性別が変われば異なることがあります。そのため、`gender=\"Male\"` と `gender=\"Female\"` で単一モデルの重みのみを把握しただけでは、すべての年齢と性別の組み合わせをキャプチャすることはできません(`gender=\"Male\"` と `age=\"30\"` と `gender=\"Male\"` と `age=\"40\"` を区別するなど)。\n", "\n", "さまざまな特徴量の組み合わせの間の違いを把握するには、*相互特徴量カラム*をモデルに追加できます(また、相互カラムの前に年齢カラムをバケット化できます)。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:12.600998Z", "iopub.status.busy": "2024-01-11T21:54:12.600757Z", "iopub.status.idle": "2024-01-11T21:54:12.604858Z", "shell.execute_reply": "2024-01-11T21:54:12.604270Z" }, "id": "AM-RsDzNfGlu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_978618/476100734.py:1: crossed_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `tf.keras.layers.experimental.preprocessing.HashedCrossing` instead for feature crossing when preprocessing data to train a Keras model.\n" ] } ], "source": [ "age_x_gender = tf.feature_column.crossed_column(['age', 'sex'], hash_bucket_size=100)" ] }, { "cell_type": "markdown", "metadata": { "id": "DqDFyPKQmGTN" }, "source": [ "組み合わせた特徴量をモデルに追加したら、モデルをもう一度トレーニングしましょう。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:12.607958Z", "iopub.status.busy": "2024-01-11T21:54:12.607532Z", "iopub.status.idle": "2024-01-11T21:54:21.647807Z", "shell.execute_reply": "2024-01-11T21:54:21.647133Z" }, "id": "s8FV9oPQfS-g" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'accuracy': 0.7462121, 'accuracy_baseline': 0.625, 'auc': 0.8426691, 'auc_precision_recall': 0.79368746, 'average_loss': 0.47478184, 'label/mean': 0.375, 'loss': 0.46697915, 'precision': 0.6666667, 'prediction/mean': 0.40963757, 'recall': 0.64646465, 'global_step': 200}\n" ] } ], "source": [ "derived_feature_columns = [age_x_gender]\n", "linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns+derived_feature_columns)\n", "linear_est.train(train_input_fn)\n", "result = linear_est.evaluate(eval_input_fn)\n", "\n", "clear_output()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "rwfdZj7ImLwb" }, "source": [ "これで、77.6% の精度に達しました。基本特徴量のみでトレーニングした場合よりわずかに改善されています。ほかの特徴量と変換を使用して、さらに改善されるか確認してみましょう!" ] }, { "cell_type": "markdown", "metadata": { "id": "8_eyb9d-ncjH" }, "source": [ "このトレーニングモデルを使用して、評価セットからある乗船者に予測を立てることができます。TensorFlow モデルは、バッチ、コレクション、または例に対してまとめて予測を立てられるように最適化されています。以前は、`eval_input_fn` は評価セット全体を使って定義されていました。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:21.651282Z", "iopub.status.busy": "2024-01-11T21:54:21.650993Z", "iopub.status.idle": "2024-01-11T21:54:23.004845Z", "shell.execute_reply": "2024-01-11T21:54:23.004168Z" }, "id": "wiScyBcef6Dq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/binary_class_head.py:561: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/binary_class_head.py:563: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpsa7vq6kw/model.ckpt-200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pred_dicts = list(linear_est.predict(eval_input_fn))\n", "probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])\n", "\n", "probs.plot(kind='hist', bins=20, title='predicted probabilities')" ] }, { "cell_type": "markdown", "metadata": { "id": "UEHRCd4sqrLs" }, "source": [ "最後に、結果の受信者操作特性(ROC)を見てみましょう。真陽性率と偽陽性率間のトレードオフに関し、より明確な予想を得ることができます。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T21:54:23.008201Z", "iopub.status.busy": "2024-01-11T21:54:23.007952Z", "iopub.status.idle": "2024-01-11T21:54:23.566207Z", "shell.execute_reply": "2024-01-11T21:54:23.565429Z" }, "id": "kqEjsezIokIe" }, "outputs": [ { "data": { "text/plain": [ "(0.0, 1.05)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve\n", "from matplotlib import pyplot as plt\n", "\n", "fpr, tpr, _ = roc_curve(y_eval, probs)\n", "plt.plot(fpr, tpr)\n", "plt.title('ROC curve')\n", "plt.xlabel('false positive rate')\n", "plt.ylabel('true positive rate')\n", "plt.xlim(0,)\n", "plt.ylim(0,)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "linear.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }