{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "7765UFHoyGx6" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2021-10-08T22:39:34.098155Z", "iopub.status.busy": "2021-10-08T22:39:34.097601Z", "iopub.status.idle": "2021-10-08T22:39:34.100153Z", "shell.execute_reply": "2021-10-08T22:39:34.099579Z" }, "id": "KsOkK8O69PyT" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "ZS8z-_KeywY9" }, "source": [ "# TF Lattice 自定义 Estimator" ] }, { "cell_type": "markdown", "metadata": { "id": "r61fkA2i9Y3_" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 中查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "Ur6yCw7YVvr8" }, "source": [ "## 概述\n", "\n", "您可以使用自定义 Estimator 通过 TFL 层创建任意单调模型。本指南概述了创建此类 Estimator 所需的步骤。" ] }, { "cell_type": "markdown", "metadata": { "id": "x769lI12IZXB" }, "source": [ "## 设置" ] }, { "cell_type": "markdown", "metadata": { "id": "fbBVAR6UeRN5" }, "source": [ "安装 TF Lattice 软件包:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-10-08T22:39:34.109444Z", "iopub.status.busy": "2021-10-08T22:39:34.106636Z", "iopub.status.idle": "2021-10-08T22:39:36.081903Z", "shell.execute_reply": "2021-10-08T22:39:36.082277Z" }, "id": "bpXjJKpSd3j4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-lattice\r\n", " Using cached tensorflow_lattice-2.0.9-py2.py3-none-any.whl (235 kB)\r\n", "Requirement already satisfied: pandas in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (1.3.3)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from tensorflow-lattice) (1.15.0)\r\n", "Requirement already satisfied: matplotlib in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (3.4.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting graphviz\r\n", " Using cached graphviz-0.17-py3-none-any.whl (18 kB)\r\n", "Requirement already satisfied: absl-py in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (0.12.0)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from tensorflow-lattice) (1.19.5)\r\n", "Requirement already satisfied: scikit-learn in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting dm-sonnet\r\n", " Using cached dm_sonnet-2.0.0-py3-none-any.whl (254 kB)\r\n", "Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from dm-sonnet->tensorflow-lattice) (1.12.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting dm-tree>=0.1.1\r\n", " Using cached dm_tree-0.1.6-cp37-cp37m-manylinux_2_24_x86_64.whl (93 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tabulate>=0.7.5\r\n", " Using cached tabulate-0.8.9-py3-none-any.whl (25 kB)\r\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (1.3.2)\r\n", "Requirement already satisfied: python-dateutil>=2.7 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (2.8.2)\r\n", "Requirement already satisfied: pillow>=6.2.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (8.3.2)\r\n", "Requirement already satisfied: pyparsing>=2.2.1 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (2.4.7)\r\n", "Requirement already satisfied: cycler>=0.10 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (0.10.0)\r\n", "Requirement already satisfied: pytz>=2017.3 in /home/kbuilder/.local/lib/python3.7/site-packages (from pandas->tensorflow-lattice) (2021.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (3.0.0)\r\n", "Requirement already satisfied: scipy>=1.1.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (1.7.1)\r\n", "Requirement already satisfied: joblib>=0.11 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (1.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: tabulate, dm-tree, graphviz, dm-sonnet, tensorflow-lattice\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed dm-sonnet-2.0.0 dm-tree-0.1.6 graphviz-0.17 tabulate-0.8.9 tensorflow-lattice-2.0.9\r\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "!pip install tensorflow-lattice" ] }, { "cell_type": "markdown", "metadata": { "id": "jSVl9SHTeSGX" }, "source": [ "导入所需的软件包:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-10-08T22:39:36.087864Z", "iopub.status.busy": "2021-10-08T22:39:36.087299Z", "iopub.status.idle": "2021-10-08T22:39:38.352148Z", "shell.execute_reply": "2021-10-08T22:39:38.351540Z" }, "id": "P9rMpg1-ASY3" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import logging\n", "import numpy as np\n", "import pandas as pd\n", "import sys\n", "import tensorflow_lattice as tfl\n", "from tensorflow import feature_column as fc\n", "\n", "from tensorflow_estimator.python.estimator.canned import optimizers\n", "from tensorflow_estimator.python.estimator.head import binary_class_head\n", "logging.disable(sys.maxsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "svPuM6QNxlrH" }, "source": [ "下载 UCI Statlog (Heart) 数据集:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-10-08T22:39:38.357907Z", "iopub.status.busy": "2021-10-08T22:39:38.357341Z", "iopub.status.idle": "2021-10-08T22:39:38.372608Z", "shell.execute_reply": "2021-10-08T22:39:38.372994Z" }, "id": "M0CmH1gPASZF" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathal
063111452331215002.330fixed
167141602860210811.523normal
267141202290212912.622reversible
337131302500018703.530normal
441021302040217201.410normal
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", "0 63 1 1 145 233 1 2 150 0 2.3 3 \n", "1 67 1 4 160 286 0 2 108 1 1.5 2 \n", "2 67 1 4 120 229 0 2 129 1 2.6 2 \n", "3 37 1 3 130 250 0 0 187 0 3.5 3 \n", "4 41 0 2 130 204 0 2 172 0 1.4 1 \n", "\n", " ca thal \n", "0 0 fixed \n", "1 3 normal \n", "2 2 reversible \n", "3 0 normal \n", "4 0 normal " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "csv_file = tf.keras.utils.get_file(\n", " 'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n", "df = pd.read_csv(csv_file)\n", "target = df.pop('target')\n", "train_size = int(len(df) * 0.8)\n", "train_x = df[:train_size]\n", "train_y = target[:train_size]\n", "test_x = df[train_size:]\n", "test_y = target[train_size:]\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "nKkAw12SxvGG" }, "source": [ "设置用于在本指南中进行训练的默认值:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2021-10-08T22:39:38.377145Z", "iopub.status.busy": "2021-10-08T22:39:38.376554Z", "iopub.status.idle": "2021-10-08T22:39:38.378495Z", "shell.execute_reply": "2021-10-08T22:39:38.378067Z" }, "id": "1T6GFI9F6mcG" }, "outputs": [], "source": [ "LEARNING_RATE = 0.1\n", "BATCH_SIZE = 128\n", "NUM_EPOCHS = 1000" ] }, { "cell_type": "markdown", "metadata": { "id": "0TGfzhPHzpix" }, "source": [ "## 特征列\n", "\n", "与任何其他 TF Estimator 一样,数据通常需要通过 input_fn 传递给 Estimator,并使用 [FeatureColumns](https://tensorflow.google.cn/guide/feature_columns) 进行解析。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-10-08T22:39:38.383125Z", "iopub.status.busy": "2021-10-08T22:39:38.382520Z", "iopub.status.idle": "2021-10-08T22:39:38.384707Z", "shell.execute_reply": "2021-10-08T22:39:38.384168Z" }, "id": "DCIUz8apzs0l" }, "outputs": [], "source": [ "# Feature columns.\n", "# - age\n", "# - sex\n", "# - ca number of major vessels (0-3) colored by flourosopy\n", "# - thal 3 = normal; 6 = fixed defect; 7 = reversable defect\n", "feature_columns = [\n", " fc.numeric_column('age', default_value=-1),\n", " fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n", " fc.numeric_column('ca'),\n", " fc.categorical_column_with_vocabulary_list(\n", " 'thal', ['normal', 'fixed', 'reversible']),\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "hEZstmtT2CA3" }, "source": [ "请注意,分类特征不需要用密集特征列包装,因为 `tfl.laysers.CategoricalCalibration` 层可以直接使用分类索引。" ] }, { "cell_type": "markdown", "metadata": { "id": "H_LoW_9m5OFL" }, "source": [ "## 创建 input_fn\n", "\n", "与任何其他 Estimator 一样,您可以使用 input_fn 将数据馈送给模型进行训练和评估。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2021-10-08T22:39:38.389638Z", "iopub.status.busy": "2021-10-08T22:39:38.389044Z", "iopub.status.idle": "2021-10-08T22:39:38.645699Z", "shell.execute_reply": "2021-10-08T22:39:38.645102Z" }, "id": "lFVy1Efy5NKD" }, "outputs": [], "source": [ "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n", " x=train_x,\n", " y=train_y,\n", " shuffle=True,\n", " batch_size=BATCH_SIZE,\n", " num_epochs=NUM_EPOCHS,\n", " num_threads=1)\n", "\n", "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n", " x=test_x,\n", " y=test_y,\n", " shuffle=False,\n", " batch_size=BATCH_SIZE,\n", " num_epochs=1,\n", " num_threads=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "kbrgSr9KaRg0" }, "source": [ "## 创建 model_fn\n", "\n", "您可以通过多种方式创建自定义 Estimator。在这里,我们将构造一个在已解析的输入张量上调用 Keras 模型的 `model_fn`。要解析输入特征,您可以使用 `tf.feature_column.input_layer`、`tf.keras.layers.DenseFeatures` 或 `tfl.estimators.transform_features`。如果使用后者,则不需要使用密集特征列包装分类特征,并且生成的张量不会串联,这样可以更轻松地在校准层中使用特征。\n", "\n", "要构造模型,您可以搭配使用 TFL 层或任何其他 Keras 层。在这里,我们从 TFL 层创建一个校准点阵 Keras 模型,并施加一些单调性约束。随后,我们使用 Keras 模型创建自定义 Estimator。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2021-10-08T22:39:38.656023Z", "iopub.status.busy": "2021-10-08T22:39:38.655425Z", "iopub.status.idle": "2021-10-08T22:39:38.656901Z", "shell.execute_reply": "2021-10-08T22:39:38.657225Z" }, "id": "n2Zrv6OPaQO2" }, "outputs": [], "source": [ "def model_fn(features, labels, mode, config):\n", " \"\"\"model_fn for the custom estimator.\"\"\"\n", " del config\n", " input_tensors = tfl.estimators.transform_features(features, feature_columns)\n", " inputs = {\n", " key: tf.keras.layers.Input(shape=(1,), name=key) for key in input_tensors\n", " }\n", "\n", " lattice_sizes = [3, 2, 2, 2]\n", " lattice_monotonicities = ['increasing', 'none', 'increasing', 'increasing']\n", " lattice_input = tf.keras.layers.Concatenate(axis=1)([\n", " tfl.layers.PWLCalibration(\n", " input_keypoints=np.linspace(10, 100, num=8, dtype=np.float32),\n", " # The output range of the calibrator should be the input range of\n", " # the following lattice dimension.\n", " output_min=0.0,\n", " output_max=lattice_sizes[0] - 1.0,\n", " monotonicity='increasing',\n", " )(inputs['age']),\n", " tfl.layers.CategoricalCalibration(\n", " # Number of categories including any missing/default category.\n", " num_buckets=2,\n", " output_min=0.0,\n", " output_max=lattice_sizes[1] - 1.0,\n", " )(inputs['sex']),\n", " tfl.layers.PWLCalibration(\n", " input_keypoints=[0.0, 1.0, 2.0, 3.0],\n", " output_min=0.0,\n", " output_max=lattice_sizes[0] - 1.0,\n", " # You can specify TFL regularizers as tuple\n", " # ('regularizer name', l1, l2).\n", " kernel_regularizer=('hessian', 0.0, 1e-4),\n", " monotonicity='increasing',\n", " )(inputs['ca']),\n", " tfl.layers.CategoricalCalibration(\n", " num_buckets=3,\n", " output_min=0.0,\n", " output_max=lattice_sizes[1] - 1.0,\n", " # Categorical monotonicity can be partial order.\n", " # (i, j) indicates that we must have output(i) <= output(j).\n", " # Make sure to set the lattice monotonicity to 'increasing' for this\n", " # dimension.\n", " monotonicities=[(0, 1), (0, 2)],\n", " )(inputs['thal']),\n", " ])\n", " output = tfl.layers.Lattice(\n", " lattice_sizes=lattice_sizes, monotonicities=lattice_monotonicities)(\n", " lattice_input)\n", "\n", " training = (mode == tf.estimator.ModeKeys.TRAIN)\n", " model = tf.keras.Model(inputs=inputs, outputs=output)\n", " logits = model(input_tensors, training=training)\n", "\n", " if training:\n", " optimizer = optimizers.get_optimizer_instance_v2('Adagrad', LEARNING_RATE)\n", " else:\n", " optimizer = None\n", "\n", " head = binary_class_head.BinaryClassHead()\n", " return head.create_estimator_spec(\n", " features=features,\n", " mode=mode,\n", " labels=labels,\n", " optimizer=optimizer,\n", " logits=logits,\n", " trainable_variables=model.trainable_variables,\n", " update_ops=model.updates)" ] }, { "cell_type": "markdown", "metadata": { "id": "mng-VtsSbVtQ" }, "source": [ "## 训练和 Estimator\n", "\n", "使用 `model_fn`,我们可以创建和训练 Estimator。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2021-10-08T22:39:38.661528Z", "iopub.status.busy": "2021-10-08T22:39:38.660924Z", "iopub.status.idle": "2021-10-08T22:40:04.554406Z", "shell.execute_reply": "2021-10-08T22:40:04.554845Z" }, "id": "j38GaEbKbZju" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC: 0.5701754689216614\n" ] } ], "source": [ "estimator = tf.estimator.Estimator(model_fn=model_fn)\n", "estimator.train(input_fn=train_input_fn)\n", "results = estimator.evaluate(input_fn=test_input_fn)\n", "print('AUC: {}'.format(results['auc']))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "custom_estimators.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 0 }