{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wJcYs_ERTnnI" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-09-16T01:21:19.973335Z", "iopub.status.busy": "2023-09-16T01:21:19.972749Z", "iopub.status.idle": "2023-09-16T01:21:19.976654Z", "shell.execute_reply": "2023-09-16T01:21:19.975963Z" }, "id": "HMUDt0CiUJk9" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# Migration examples: Canned Estimators\n", "\n", "\n", " \n", " \n", " \n", "
\n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "Canned (or Premade) Estimators have traditionally been used in TensorFlow 1 as quick and easy ways to train models for a variety of typical use cases. TensorFlow 2 provides straightforward approximate substitutes for a number of them by way of Keras models. For those canned estimators that do not have built-in TensorFlow 2 substitutes, you can still build your own replacement fairly easily.\n", "\n", "This guide will walk you through a few examples of direct equivalents and custom substitutions to demonstrate how TensorFlow 1's `tf.estimator`-derived models can be migrated to TensorFlow 2 with Keras.\n", "\n", "Namely, this guide includes examples for migrating:\n", "* From `tf.estimator`'s `LinearEstimator`, `Classifier` or `Regressor` in TensorFlow 1 to Keras `tf.compat.v1.keras.models.LinearModel` in TensorFlow 2\n", "* From `tf.estimator`'s `DNNEstimator`, `Classifier` or `Regressor` in TensorFlow 1 to a custom Keras DNN ModelKeras in TensorFlow 2\n", "* From `tf.estimator`'s `DNNLinearCombinedEstimator`, `Classifier` or `Regressor` in TensorFlow 1 to `tf.compat.v1.keras.models.WideDeepModel` in TensorFlow 2\n", "* From `tf.estimator`'s `BoostedTreesEstimator`, `Classifier` or `Regressor` in TensorFlow 1 to `tfdf.keras.GradientBoostedTreesModel` in TensorFlow 2\n", "\n", "A common precursor to the training of a model is feature preprocessing, which is done for TensorFlow 1 Estimator models with `tf.feature_column`. For more information on feature preprocessing in TensorFlow 2, see [this guide on migrating from feature columns to the Keras preprocessing layers API](migrating_feature_columns.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## Setup\n", "\n", "Start with a couple of necessary TensorFlow imports," ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:19.980441Z", "iopub.status.busy": "2023-09-16T01:21:19.979916Z", "iopub.status.idle": "2023-09-16T01:21:47.613414Z", "shell.execute_reply": "2023-09-16T01:21:47.612604Z" }, "id": "qsgZp0f-nu9s" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow_decision_forests\r\n", " Obtaining dependency information for tensorflow_decision_forests from https://files.pythonhosted.org/packages/67/84/dc181dc6d4ec2692432bb168119e932a3175ffcfddcca41bc8a1a6d5a8b9/tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.26.0rc1)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow~=2.13.0 (from tensorflow_decision_forests)\r\n", " Obtaining dependency information for tensorflow~=2.13.0 from https://files.pythonhosted.org/packages/d6/af/cb5ea6d1a9c83e715e29b45a598ebf542729cd216b43f5deefc27657bd38/tensorflow-2.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\r\n", " Using cached tensorflow-2.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.41.1)\r\n", "Collecting wurlitzer (from tensorflow_decision_forests)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.1.21 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.5.26)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting gast<=0.4.0,>=0.2.1 (from tensorflow~=2.13.0->tensorflow_decision_forests)\r\n", " Using cached gast-0.4.0-py3-none-any.whl (9.8 kB)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.58.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.9.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting keras<2.14,>=2.13.1 (from tensorflow~=2.13.0->tensorflow_decision_forests)\r\n", " Obtaining dependency information for keras<2.14,>=2.13.1 from https://files.pythonhosted.org/packages/2e/f3/19da7511b45e80216cbbd9467137b2d28919c58ba1ccb971435cb631e470/keras-2.13.1-py3-none-any.whl.metadata\r\n", " Using cached keras-2.13.1-py3-none-any.whl.metadata (2.4 kB)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (16.0.6)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting numpy (from tensorflow_decision_forests)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Using cached numpy-1.24.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.1)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (68.2.2)\r\n", "Collecting tensorboard<2.14,>=2.13 (from tensorflow~=2.13.0->tensorflow_decision_forests)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Using cached tensorboard-2.13.0-py3-none-any.whl (5.6 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-estimator<2.14,>=2.13.0 (from tensorflow~=2.13.0->tensorflow_decision_forests)\r\n", " Obtaining dependency information for tensorflow-estimator<2.14,>=2.13.0 from https://files.pythonhosted.org/packages/72/5c/c318268d96791c6222ad7df1651bbd1b2409139afeb6f468c0f327177016/tensorflow_estimator-2.13.0-py2.py3-none-any.whl.metadata\r\n", " Using cached tensorflow_estimator-2.13.0-py2.py3-none-any.whl.metadata (1.3 kB)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting typing-extensions<4.6.0,>=3.6.6 (from tensorflow~=2.13.0->tensorflow_decision_forests)\r\n", " Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.14.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.34.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3.post1)\r\n", "Requirement already satisfied: tzdata>=2022.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.23.0)\r\n", "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.0.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4.4)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.31.0)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.7.1)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.7)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (5.3.1)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (4.9)\r\n", "Requirement already satisfied: urllib3<2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.26.16)\r\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.3.1)\r\n", "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (6.8.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.0)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2023.7.22)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.1.3)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.16.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.5.0)\r\n", "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached tensorflow-2.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (524.1 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached keras-2.13.1-py3-none-any.whl (1.7 MB)\r\n", "Using cached tensorflow_estimator-2.13.0-py2.py3-none-any.whl (440 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: wurlitzer, typing-extensions, tensorflow-estimator, numpy, keras, gast, tensorboard, tensorflow, tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: typing-extensions\r\n", " Found existing installation: typing_extensions 4.8.0rc1\r\n", " Uninstalling typing_extensions-4.8.0rc1:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled typing_extensions-4.8.0rc1\r\n", " Attempting uninstall: tensorflow-estimator\r\n", " Found existing installation: tensorflow-estimator 2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling tensorflow-estimator-2.14.0:\r\n", " Successfully uninstalled tensorflow-estimator-2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: numpy\r\n", " Found existing installation: numpy 1.26.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling numpy-1.26.0rc1:\r\n", " Successfully uninstalled numpy-1.26.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: keras\r\n", " Found existing installation: keras 2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling keras-2.14.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled keras-2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: gast\r\n", " Found existing installation: gast 0.5.4\r\n", " Uninstalling gast-0.5.4:\r\n", " Successfully uninstalled gast-0.5.4\r\n", " Attempting uninstall: tensorboard\r\n", " Found existing installation: tensorboard 2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling tensorboard-2.14.0:\r\n", " Successfully uninstalled tensorboard-2.14.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: tensorflow\r\n", " Found existing installation: tensorflow 2.14.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling tensorflow-2.14.0rc1:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled tensorflow-2.14.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed gast-0.4.0 keras-2.13.1 numpy-1.24.3 tensorboard-2.13.0 tensorflow-2.13.0 tensorflow-estimator-2.13.0 tensorflow_decision_forests-1.5.0 typing-extensions-4.5.0 wurlitzer-3.0.3\r\n" ] } ], "source": [ "!pip install tensorflow_decision_forests" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:47.617447Z", "iopub.status.busy": "2023-09-16T01:21:47.617196Z", "iopub.status.idle": "2023-09-16T01:21:50.072540Z", "shell.execute_reply": "2023-09-16T01:21:50.071872Z" }, "id": "iE0vSfMXumKI" }, "outputs": [], "source": [ "import pandas as pd\n", "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "import tensorflow_decision_forests as tfdf\n", "from tensorflow import keras\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Jsm9Rxx7s1OZ" }, "source": [ "prepare some simple data for demonstration from the standard Titanic dataset," ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:50.077063Z", "iopub.status.busy": "2023-09-16T01:21:50.076480Z", "iopub.status.idle": "2023-09-16T01:21:50.215728Z", "shell.execute_reply": "2023-09-16T01:21:50.215076Z" }, "id": "wC6i_bEZPrPY" }, "outputs": [], "source": [ "x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n", "x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n", "x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)\n", "x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)\n", "\n", "x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)\n", "x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)\n", "\n", "x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)\n", "x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)\n", "\n", "x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)\n", "x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)\n", "\n", "y_train = x_train.pop('survived')\n", "y_eval = x_eval.pop('survived')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:50.219695Z", "iopub.status.busy": "2023-09-16T01:21:50.219081Z", "iopub.status.idle": "2023-09-16T01:21:50.224523Z", "shell.execute_reply": "2023-09-16T01:21:50.223949Z" }, "id": "lqe9obf7suIj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2801132002.py:16: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] } ], "source": [ "# Data setup for TensorFlow 1 with `tf.estimator`\n", "def _input_fn():\n", " return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)\n", "\n", "\n", "def _eval_input_fn():\n", " return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)\n", "\n", "\n", "FEATURE_NAMES = [\n", " 'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'\n", "]\n", "\n", "feature_columns = []\n", "for fn in FEATURE_NAMES:\n", " feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)\n", " feature_columns.append(feat_col)" ] }, { "cell_type": "markdown", "metadata": { "id": "bYSgoezeMrpI" }, "source": [ "and create a method to instantiate a simplistic sample optimizer to use with various TensorFlow 1 Estimator and TensorFlow 2 Keras models." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:50.228124Z", "iopub.status.busy": "2023-09-16T01:21:50.227540Z", "iopub.status.idle": "2023-09-16T01:21:50.232198Z", "shell.execute_reply": "2023-09-16T01:21:50.231629Z" }, "id": "YHB_nuzVLVLe" }, "outputs": [], "source": [ "def create_sample_optimizer(tf_version):\n", " if tf_version == 'tf1':\n", " optimizer = lambda: tf.keras.optimizers.legacy.Ftrl(\n", " l1_regularization_strength=0.001,\n", " learning_rate=tf1.train.exponential_decay(\n", " learning_rate=0.1,\n", " global_step=tf1.train.get_global_step(),\n", " decay_steps=10000,\n", " decay_rate=0.9))\n", " elif tf_version == 'tf2':\n", " optimizer = tf.keras.optimizers.legacy.Ftrl(\n", " l1_regularization_strength=0.001,\n", " learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))\n", " return optimizer" ] }, { "cell_type": "markdown", "metadata": { "id": "4uXff1BEssdE" }, "source": [ "## Example 1: Migrating from LinearEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "_O7fyhCnpvED" }, "source": [ "### TensorFlow 1: Using LinearEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "A9560BqEOTpb" }, "source": [ "In TensorFlow 1, you can use `tf.estimator.LinearEstimator` to create a baseline linear model for regression and classification problems." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:50.235642Z", "iopub.status.busy": "2023-09-16T01:21:50.235220Z", "iopub.status.idle": "2023-09-16T01:21:50.314104Z", "shell.execute_reply": "2023-09-16T01:21:50.313497Z" }, "id": "oWfh0QW4IXTn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2944250643.py:2: BinaryClassHead.__init__ (from tensorflow_estimator.python.estimator.head.binary_class_head) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/2944250643.py:1: LinearEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.linear) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1124: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpcvrw6s1d\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpcvrw6s1d', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] } ], "source": [ "linear_estimator = tf.estimator.LinearEstimator(\n", " head=tf.estimator.BinaryClassHead(),\n", " feature_columns=feature_columns,\n", " optimizer=create_sample_optimizer('tf1'))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:50.317334Z", "iopub.status.busy": "2023-09-16T01:21:50.316889Z", "iopub.status.idle": "2023-09-16T01:21:56.151959Z", "shell.execute_reply": "2023-09-16T01:21:56.151305Z" }, "id": "hi77Sg4k-0TR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/ftrl.py:173: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpcvrw6s1d/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6931472, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpcvrw6s1d/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.55268794.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-09-16T01:21:55\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpcvrw6s1d/model.ckpt-20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [2/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [3/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [4/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [5/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [6/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [7/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [8/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [9/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.58102s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-09-16-01:21:55\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704812, precision = 0.6388889, prediction/mean = 0.41331062, recall = 0.46464646\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpcvrw6s1d/model.ckpt-20\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.70075756,\n", " 'accuracy_baseline': 0.625,\n", " 'auc': 0.75472915,\n", " 'auc_precision_recall': 0.65362054,\n", " 'average_loss': 0.5759378,\n", " 'label/mean': 0.375,\n", " 'loss': 0.5704812,\n", " 'precision': 0.6388889,\n", " 'prediction/mean': 0.41331062,\n", " 'recall': 0.46464646,\n", " 'global_step': 20}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear_estimator.train(input_fn=_input_fn, steps=100)\n", "linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "KEmzBjfnsxwT" }, "source": [ "### TensorFlow 2: Using Keras LinearModel" ] }, { "cell_type": "markdown", "metadata": { "id": "fkgkGf_AOaRR" }, "source": [ "In TensorFlow 2, you can create an instance of the Keras `tf.compat.v1.keras.models.LinearModel` which is the substitute to the `tf.estimator.LinearEstimator`. The `tf.compat.v1.keras` path is used to signify that the pre-made model exists for compatibility." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:56.155381Z", "iopub.status.busy": "2023-09-16T01:21:56.155102Z", "iopub.status.idle": "2023-09-16T01:21:57.370719Z", "shell.execute_reply": "2023-09-16T01:21:57.369951Z" }, "id": "Kip65sYBlKiu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 6s - loss: 0.3438 - accuracy: 0.6562" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 3.6712 - accuracy: 0.6077\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2399 - accuracy: 0.6250" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2146 - accuracy: 0.6715\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2113 - accuracy: 0.6875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1980 - accuracy: 0.6874\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2068 - accuracy: 0.6250" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1880 - accuracy: 0.7129\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2236 - accuracy: 0.5000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/20 [==================>...........] - ETA: 0s - loss: 0.1866 - accuracy: 0.7091" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 9ms/step - loss: 0.1805 - accuracy: 0.7337\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2111 - accuracy: 0.6875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1807 - accuracy: 0.7624\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2339 - accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1692 - accuracy: 0.7783\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.1658 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1691 - accuracy: 0.7927\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2556 - accuracy: 0.6562" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1717 - accuracy: 0.7911\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2777 - accuracy: 0.6562" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1596 - accuracy: 0.7990\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/9 [==>...........................] - ETA: 1s - loss: 0.1653 - accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "9/9 [==============================] - 0s 2ms/step - loss: 0.1853 - accuracy: 0.7348\n" ] }, { "data": { "text/plain": [ "{'loss': 0.185323566198349, 'accuracy': 0.7348484992980957}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear_model = tf.compat.v1.keras.experimental.LinearModel()\n", "linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])\n", "linear_model.fit(x_train, y_train, epochs=10)\n", "linear_model.evaluate(x_eval, y_eval, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "RRrj78Lqplni" }, "source": [ "## Example 2: Migrating from DNNEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "YKl6XZ7Bp1t5" }, "source": [ "### TensorFlow 1: Using DNNEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "J7wJUmgypln8" }, "source": [ "In TensorFlow 1, you can use `tf.estimator.DNNEstimator` to create a baseline deep neural network (DNN) model for regression and classification problems." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:57.374172Z", "iopub.status.busy": "2023-09-16T01:21:57.373775Z", "iopub.status.idle": "2023-09-16T01:21:57.380981Z", "shell.execute_reply": "2023-09-16T01:21:57.380389Z" }, "id": "qHbgXCzfpln9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/1828606501.py:1: DNNEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpyih539cq\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpyih539cq', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] } ], "source": [ "dnn_estimator = tf.estimator.DNNEstimator(\n", " head=tf.estimator.BinaryClassHead(),\n", " feature_columns=feature_columns,\n", " hidden_units=[128],\n", " activation_fn=tf.nn.relu,\n", " optimizer=create_sample_optimizer('tf1'))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:21:57.384056Z", "iopub.status.busy": "2023-09-16T01:21:57.383569Z", "iopub.status.idle": "2023-09-16T01:22:00.019470Z", "shell.execute_reply": "2023-09-16T01:22:00.018769Z" }, "id": "6DTnXxU2pln-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-09-16 01:21:57.950353: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT64\n", " }\n", " }\n", "}\n", " is neither a subtype nor a supertype of the combined inputs preceding it:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT32\n", " }\n", " }\n", "}\n", "\n", "\tfor Tuple type infernce function 0\n", "\twhile inferring type of node 'dnn/zero_fraction/cond/output/_18'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpyih539cq/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.9991276, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpyih539cq/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.5818331.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-09-16T01:21:59\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpyih539cq/model.ckpt-20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [2/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [3/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [4/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [5/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [6/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [7/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [8/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [9/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.52606s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-09-16-01:21:59\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70454544, accuracy_baseline = 0.625, auc = 0.6964494, auc_precision_recall = 0.60180384, average_loss = 0.5988959, global_step = 20, label/mean = 0.375, loss = 0.59320897, precision = 0.6363636, prediction/mean = 0.38547936, recall = 0.4949495\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpyih539cq/model.ckpt-20\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.70454544,\n", " 'accuracy_baseline': 0.625,\n", " 'auc': 0.6964494,\n", " 'auc_precision_recall': 0.60180384,\n", " 'average_loss': 0.5988959,\n", " 'label/mean': 0.375,\n", " 'loss': 0.59320897,\n", " 'precision': 0.6363636,\n", " 'prediction/mean': 0.38547936,\n", " 'recall': 0.4949495,\n", " 'global_step': 20}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dnn_estimator.train(input_fn=_input_fn, steps=100)\n", "dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "6xJz6px6pln-" }, "source": [ "### TensorFlow 2: Using Keras to create a custom DNN model" ] }, { "cell_type": "markdown", "metadata": { "id": "7cgc72rzpln-" }, "source": [ "In TensorFlow 2, you can create a custom DNN model to substitute for one generated by `tf.estimator.DNNEstimator`, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).\n", "\n", "A similar workflow can be used to replace `tf.estimator.experimental.RNNEstimator` with a Keras recurrent neural network (RNN) model. Keras provides a number of built-in, customizable choices by way of `tf.keras.layers.RNN`, `tf.keras.layers.LSTM`, and `tf.keras.layers.GRU`. To learn more, check out the _Built-in RNN layers: a simple example_ section of [RNN with Keras guide](https://www.tensorflow.org/guide/keras/rnn)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:00.022693Z", "iopub.status.busy": "2023-09-16T01:22:00.022445Z", "iopub.status.idle": "2023-09-16T01:22:00.038869Z", "shell.execute_reply": "2023-09-16T01:22:00.038267Z" }, "id": "B5SdsjlL49RG" }, "outputs": [], "source": [ "dnn_model = tf.keras.models.Sequential(\n", " [tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(1)])\n", "\n", "dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:00.041779Z", "iopub.status.busy": "2023-09-16T01:22:00.041561Z", "iopub.status.idle": "2023-09-16T01:22:01.109413Z", "shell.execute_reply": "2023-09-16T01:22:01.108657Z" }, "id": "JQmRw9_Upln_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 6s - loss: 1.8427 - accuracy: 0.8438" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 581.7921 - accuracy: 0.4338\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 1.1434 - accuracy: 0.4062" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 1.2502 - accuracy: 0.4450\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.7506 - accuracy: 0.4688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.7286 - accuracy: 0.5183\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3871 - accuracy: 0.5938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.5274 - accuracy: 0.5486\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.6149 - accuracy: 0.4375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.4200 - accuracy: 0.5550\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3201 - accuracy: 0.5625" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.3405 - accuracy: 0.5821\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3388 - accuracy: 0.4375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2975 - accuracy: 0.6124\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3408 - accuracy: 0.5938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2628 - accuracy: 0.6507\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2660 - accuracy: 0.6250" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2382 - accuracy: 0.7002\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2402 - accuracy: 0.7188" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2218 - accuracy: 0.7273\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/9 [==>...........................] - ETA: 0s - loss: 0.1989 - accuracy: 0.7188" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "9/9 [==============================] - 0s 2ms/step - loss: 0.2414 - accuracy: 0.6894\n" ] }, { "data": { "text/plain": [ "{'loss': 0.2413530945777893, 'accuracy': 0.689393937587738}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dnn_model.fit(x_train, y_train, epochs=10)\n", "dnn_model.evaluate(x_eval, y_eval, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "UeBHZ0cd1Pl2" }, "source": [ "## Example 3: Migrating from DNNLinearCombinedEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "GfRaObf5g4TU" }, "source": [ "### TensorFlow 1: Using DNNLinearCombinedEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "2r13RMX-g4TV" }, "source": [ "In TensorFlow 1, you can use `tf.estimator.DNNLinearCombinedEstimator` to create a baseline combined model for regression and classification problems with customization capacity for both its linear and DNN components." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:01.113318Z", "iopub.status.busy": "2023-09-16T01:22:01.112687Z", "iopub.status.idle": "2023-09-16T01:22:01.120078Z", "shell.execute_reply": "2023-09-16T01:22:01.119439Z" }, "id": "OyyDCqc5j7rf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10080/1505653152.py:3: DNNLinearCombinedEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn_linear_combined) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpun15otq5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpun15otq5', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] } ], "source": [ "optimizer = create_sample_optimizer('tf1')\n", "\n", "combined_estimator = tf.estimator.DNNLinearCombinedEstimator(\n", " head=tf.estimator.BinaryClassHead(),\n", " # Wide settings\n", " linear_feature_columns=feature_columns,\n", " linear_optimizer=optimizer,\n", " # Deep settings\n", " dnn_feature_columns=feature_columns,\n", " dnn_hidden_units=[128],\n", " dnn_optimizer=optimizer)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:01.123149Z", "iopub.status.busy": "2023-09-16T01:22:01.122685Z", "iopub.status.idle": "2023-09-16T01:22:04.930300Z", "shell.execute_reply": "2023-09-16T01:22:04.929588Z" }, "id": "aXN-BxwzmRaf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpun15otq5/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 4.244113, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpun15otq5/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.5406369.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-09-16T01:22:04\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpun15otq5/model.ckpt-20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [2/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [3/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [4/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [5/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [6/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [7/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [8/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [9/10]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 0.58801s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-09-16-01:22:04\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 20: accuracy = 0.71590906, accuracy_baseline = 0.625, auc = 0.7440466, auc_precision_recall = 0.6447197, average_loss = 0.5923795, global_step = 20, label/mean = 0.375, loss = 0.5745624, precision = 0.65384614, prediction/mean = 0.3921669, recall = 0.5151515\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpun15otq5/model.ckpt-20\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.71590906,\n", " 'accuracy_baseline': 0.625,\n", " 'auc': 0.7440466,\n", " 'auc_precision_recall': 0.6447197,\n", " 'average_loss': 0.5923795,\n", " 'label/mean': 0.375,\n", " 'loss': 0.5745624,\n", " 'precision': 0.65384614,\n", " 'prediction/mean': 0.3921669,\n", " 'recall': 0.5151515,\n", " 'global_step': 20}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "combined_estimator.train(input_fn=_input_fn, steps=100)\n", "combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "BeMikL5ug4TX" }, "source": [ "### TensorFlow 2: Using Keras WideDeepModel" ] }, { "cell_type": "markdown", "metadata": { "id": "CYByxxBhg4TX" }, "source": [ "In TensorFlow 2, you can create an instance of the Keras `tf.compat.v1.keras.models.WideDeepModel` to substitute for one generated by `tf.estimator.DNNLinearCombinedEstimator`, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).\n", "\n", "This `WideDeepModel` is constructed on the basis of a constituent `LinearModel` and a custom DNN Model, both of which are discussed in the preceding two examples. A custom linear model can also be used in place of the built-in Keras `LinearModel` if desired.\n", "\n", "If you would like to build your own model instead of using a canned estimator, check out the [Keras Sequential model](https://www.tensorflow.org/guide/keras/sequential_model) guide. For more information on custom training and optimizers, check out the [Custom training: walkthrough](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough) guide." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:04.934074Z", "iopub.status.busy": "2023-09-16T01:22:04.933471Z", "iopub.status.idle": "2023-09-16T01:22:05.655301Z", "shell.execute_reply": "2023-09-16T01:22:05.654469Z" }, "id": "mIFM3e-_RLSX" }, "outputs": [], "source": [ "# Create LinearModel and DNN Model as in Examples 1 and 2\n", "optimizer = create_sample_optimizer('tf2')\n", "\n", "linear_model = tf.compat.v1.keras.experimental.LinearModel()\n", "linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])\n", "linear_model.fit(x_train, y_train, epochs=10, verbose=0)\n", "\n", "dnn_model = tf.keras.models.Sequential(\n", " [tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(1)])\n", "dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:05.659341Z", "iopub.status.busy": "2023-09-16T01:22:05.658679Z", "iopub.status.idle": "2023-09-16T01:22:07.077239Z", "shell.execute_reply": "2023-09-16T01:22:07.076569Z" }, "id": "mFmQz9kjmMSx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 11s - loss: 99.1739 - accuracy: 0.4688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 1s 3ms/step - loss: 682.2249 - accuracy: 0.6858\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2859 - accuracy: 0.6875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2909 - accuracy: 0.7193\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2286 - accuracy: 0.7188" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2140 - accuracy: 0.7400\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3173 - accuracy: 0.8125" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.2084 - accuracy: 0.7671\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.1924 - accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1809 - accuracy: 0.7719\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.1665 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1699 - accuracy: 0.7911\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2044 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1745 - accuracy: 0.7687\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.1953 - accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1665 - accuracy: 0.7974\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.2248 - accuracy: 0.7812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1693 - accuracy: 0.7911\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/20 [>.............................] - ETA: 0s - loss: 0.3363 - accuracy: 0.6875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/20 [==============================] - 0s 2ms/step - loss: 0.1748 - accuracy: 0.7847\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/9 [==>...........................] - ETA: 1s - loss: 0.2207 - accuracy: 0.7188" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "9/9 [==============================] - 0s 2ms/step - loss: 0.2050 - accuracy: 0.7159\n" ] }, { "data": { "text/plain": [ "{'loss': 0.2049505114555359, 'accuracy': 0.7159090638160706}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,\n", " dnn_model)\n", "combined_model.compile(\n", " optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])\n", "combined_model.fit([x_train, x_train], y_train, epochs=10)\n", "combined_model.evaluate(x_eval, y_eval, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "wP1DBRhpeOJn" }, "source": [ "## Example 4: Migrating from BoostedTreesEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "_3mCQVDSeOKD" }, "source": [ "### TensorFlow 1: Using BoostedTreesEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "oEWYHNt4eOKD" }, "source": [ "In TensorFlow 1, you could use `tf.estimator.BoostedTreesEstimator` to create a baseline to create a baseline Gradient Boosting model using an ensemble of decision trees for regression and classification problems. This functionality is no longer included in TensorFlow 2." ] }, { "cell_type": "markdown", "metadata": { "id": "wliVIER1jLnA" }, "source": [ "```\n", "bt_estimator = tf1.estimator.BoostedTreesEstimator(\n", " head=tf.estimator.BinaryClassHead(),\n", " n_batches_per_layer=1,\n", " max_depth=10,\n", " n_trees=1000,\n", " feature_columns=feature_columns)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "-K87uBrZjR0u" }, "source": [ "```\n", "bt_estimator.train(input_fn=_input_fn, steps=1000)\n", "bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "eNuLP6BeeOKF" }, "source": [ "### TensorFlow 2: Using TensorFlow Decision Forests" ] }, { "cell_type": "markdown", "metadata": { "id": "m3EVq388eOKF" }, "source": [ "In TensorFlow 2, `tf.estimator.BoostedTreesEstimator` is replaced by [tfdf.keras.GradientBoostedTreesModel](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel#attributes) from the [TensorFlow Decision Forests](https://www.tensorflow.org/decision_forests) package.\n", "\n", "TensorFlow Decision Forests provides various advantages over the `tf.estimator.BoostedTreesEstimator`, notably regarding quality, speed, ease of use and flexibility. To learn about TensorFlow Decision Forests, start with the [beginner colab](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab).\n", "\n", "The following example shows how to train a Gradient Boosted Trees model using TensorFlow 2:" ] }, { "cell_type": "markdown", "metadata": { "id": "UB90fXJdVWC5" }, "source": [ "Install TensorFlow Decision Forests." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:07.081296Z", "iopub.status.busy": "2023-09-16T01:22:07.081027Z", "iopub.status.idle": "2023-09-16T01:22:09.075158Z", "shell.execute_reply": "2023-09-16T01:22:09.074304Z" }, "id": "9097mTCIVVE9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow_decision_forests in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (1.5.0)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.24.3)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.1.0)\r\n", "Requirement already satisfied: tensorflow~=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.13.0)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.41.1)\r\n", "Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (3.0.3)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.1.21 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.5.26)\r\n", "Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.4.0)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.58.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.9.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: keras<2.14,>=2.13.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.1)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (16.0.6)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.1)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (68.2.2)\r\n", "Requirement already satisfied: tensorboard<2.14,>=2.13 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)\r\n", "Requirement already satisfied: tensorflow-estimator<2.14,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.0)\r\n", "Requirement already satisfied: typing-extensions<4.6.0,>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (4.5.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.14.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.34.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3.post1)\r\n", "Requirement already satisfied: tzdata>=2022.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.23.0)\r\n", "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.0.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4.4)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.31.0)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.7.1)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.7)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (5.3.1)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (4.9)\r\n", "Requirement already satisfied: urllib3<2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.26.16)\r\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.3.1)\r\n", "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (6.8.0)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.0)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2023.7.22)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.16.2)\r\n", "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.5.0)\r\n", "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.2)\r\n" ] } ], "source": [ "!pip install tensorflow_decision_forests" ] }, { "cell_type": "markdown", "metadata": { "id": "B1qTdAS-VpXk" }, "source": [ "Create a TensorFlow dataset. Note that Decision Forests natively support many types of features and do not need pre-processing." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:09.079684Z", "iopub.status.busy": "2023-09-16T01:22:09.079158Z", "iopub.status.idle": "2023-09-16T01:22:09.286939Z", "shell.execute_reply": "2023-09-16T01:22:09.286219Z" }, "id": "jkjFHmDTVswY" }, "outputs": [], "source": [ "train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n", "eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n", "\n", "# Convert the Pandas Dataframes into TensorFlow datasets.\n", "train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label=\"survived\")\n", "eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label=\"survived\")" ] }, { "cell_type": "markdown", "metadata": { "id": "7fPa-LfDWDzB" }, "source": [ "Train the model on the `train_dataset` dataset." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:09.290999Z", "iopub.status.busy": "2023-09-16T01:22:09.290295Z", "iopub.status.idle": "2023-09-16T01:22:14.816087Z", "shell.execute_reply": "2023-09-16T01:22:14.815482Z" }, "id": "JO0yCH9hWPvJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpinxrd9bl as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1818] \"goss_alpha\" set but \"sampling_method\" not equal to \"GOSS\".\n", "[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1829] \"goss_beta\" set but \"sampling_method\" not equal to \"GOSS\".\n", "[WARNING 23-09-16 01:22:09.3074 UTC gradient_boosted_trees.cc:1843] \"selective_gradient_boosting_ratio\" set but \"sampling_method\" not equal to \"SELGB\".\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:03.607059. Found 627 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.226305\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 23-09-16 01:22:13.1488 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpinxrd9bl/model/ with prefix b67347ee49794d62\n", "[INFO 23-09-16 01:22:13.1525 UTC abstract_model.cc:1311] Engine \"GradientBoostedTreesQuickScorerExtended\" built\n", "[INFO 23-09-16 01:22:13.1525 UTC kernel.cc:1075] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING: AutoGraph could not transform and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use the default hyper-parameters of the model.\n", "gbt_model = tfdf.keras.GradientBoostedTreesModel()\n", "gbt_model.fit(train_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "2Y5xm29AWGxt" }, "source": [ "Evaluate the quality of the model on the `eval_dataset` dataset." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:14.819749Z", "iopub.status.busy": "2023-09-16T01:22:14.819135Z", "iopub.status.idle": "2023-09-16T01:22:15.131572Z", "shell.execute_reply": "2023-09-16T01:22:15.130770Z" }, "id": "JLS_2vKKeOKF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.8295" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 296ms/step - loss: 0.0000e+00 - accuracy: 0.8295\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.0, 'accuracy': 0.8295454382896423}\n" ] } ], "source": [ "gbt_model.compile(metrics=['accuracy'])\n", "gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True)\n", "print(gbt_evaluation)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z22UJ5SUqToQ" }, "source": [ "Gradient Boosted Trees is just one of the many decision forest algorithms available in TensorFlow Decision Forests. For example, Random Forests (available as [tfdf.keras.GradientBoostedTreesModel](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/RandomForestModel) is very resistant to overfitting) while CART (available as [tfdf.keras.CartModel](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/CartModel)) is great for model interpretation.\n", "\n", "In the next example, train and plot a Random Forest model." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:15.135069Z", "iopub.status.busy": "2023-09-16T01:22:15.134552Z", "iopub.status.idle": "2023-09-16T01:22:15.866018Z", "shell.execute_reply": "2023-09-16T01:22:15.865328Z" }, "id": "W3slOhn4Zi9X" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpdvqhwuwo as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:00.187950. Found 627 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.191396\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 23-09-16 01:22:15.4265 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpdvqhwuwo/model/ with prefix 39c681f57c12496f\n", "[INFO 23-09-16 01:22:15.5264 UTC decision_forest.cc:660] Model loaded with 300 root(s), 34556 node(s), and 9 input feature(s).\n", "[INFO 23-09-16 01:22:15.5265 UTC kernel.cc:1075] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.8333" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 141ms/step - loss: 0.0000e+00 - accuracy: 0.8333\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.0, 'accuracy': 0.8333333134651184}\n" ] } ], "source": [ "# Train a Random Forest model\n", "rf_model = tfdf.keras.RandomForestModel()\n", "rf_model.fit(train_dataset)\n", "\n", "# Evaluate the Random Forest model\n", "rf_model.compile(metrics=['accuracy'])\n", "rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True)\n", "print(rf_evaluation)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z0QYolhoZb_k" }, "source": [ "In the final example, train and evaluate a CART model." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-09-16T01:22:15.869784Z", "iopub.status.busy": "2023-09-16T01:22:15.869172Z", "iopub.status.idle": "2023-09-16T01:22:16.252261Z", "shell.execute_reply": "2023-09-16T01:22:16.251571Z" }, "id": "027bGnCork_W" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpmgmlcvs6 as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:00.187860. Found 627 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.018191\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 23-09-16 01:22:16.0918 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpmgmlcvs6/model/ with prefix efef116800e041d8\n", "[INFO 23-09-16 01:22:16.0921 UTC decision_forest.cc:660] Model loaded with 1 root(s), 21 node(s), and 5 input feature(s).\n", "[INFO 23-09-16 01:22:16.0922 UTC kernel.cc:1075] Use fast generic engine\n" ] }, { "data": { "text/html": [ "\n", "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train a CART model\n", "cart_model = tfdf.keras.CartModel()\n", "cart_model.fit(train_dataset)\n", "\n", "# Plot the CART model\n", "tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "canned_estimators.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }