{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "7765UFHoyGx6"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2021-10-08T22:39:34.098155Z",
"iopub.status.busy": "2021-10-08T22:39:34.097601Z",
"iopub.status.idle": "2021-10-08T22:39:34.100153Z",
"shell.execute_reply": "2021-10-08T22:39:34.099579Z"
},
"id": "KsOkK8O69PyT"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZS8z-_KeywY9"
},
"source": [
"# TF Lattice 自定义 Estimator"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r61fkA2i9Y3_"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ur6yCw7YVvr8"
},
"source": [
"## 概述\n",
"\n",
"您可以使用自定义 Estimator 通过 TFL 层创建任意单调模型。本指南概述了创建此类 Estimator 所需的步骤。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x769lI12IZXB"
},
"source": [
"## 设置"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fbBVAR6UeRN5"
},
"source": [
"安装 TF Lattice 软件包:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2021-10-08T22:39:34.109444Z",
"iopub.status.busy": "2021-10-08T22:39:34.106636Z",
"iopub.status.idle": "2021-10-08T22:39:36.081903Z",
"shell.execute_reply": "2021-10-08T22:39:36.082277Z"
},
"id": "bpXjJKpSd3j4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tensorflow-lattice\r\n",
" Using cached tensorflow_lattice-2.0.9-py2.py3-none-any.whl (235 kB)\r\n",
"Requirement already satisfied: pandas in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (1.3.3)\r\n",
"Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from tensorflow-lattice) (1.15.0)\r\n",
"Requirement already satisfied: matplotlib in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (3.4.3)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting graphviz\r\n",
" Using cached graphviz-0.17-py3-none-any.whl (18 kB)\r\n",
"Requirement already satisfied: absl-py in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (0.12.0)\r\n",
"Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from tensorflow-lattice) (1.19.5)\r\n",
"Requirement already satisfied: scikit-learn in /home/kbuilder/.local/lib/python3.7/site-packages (from tensorflow-lattice) (1.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting dm-sonnet\r\n",
" Using cached dm_sonnet-2.0.0-py3-none-any.whl (254 kB)\r\n",
"Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.7/site-packages (from dm-sonnet->tensorflow-lattice) (1.12.1)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting dm-tree>=0.1.1\r\n",
" Using cached dm_tree-0.1.6-cp37-cp37m-manylinux_2_24_x86_64.whl (93 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tabulate>=0.7.5\r\n",
" Using cached tabulate-0.8.9-py3-none-any.whl (25 kB)\r\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (1.3.2)\r\n",
"Requirement already satisfied: python-dateutil>=2.7 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (2.8.2)\r\n",
"Requirement already satisfied: pillow>=6.2.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (8.3.2)\r\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (2.4.7)\r\n",
"Requirement already satisfied: cycler>=0.10 in /home/kbuilder/.local/lib/python3.7/site-packages (from matplotlib->tensorflow-lattice) (0.10.0)\r\n",
"Requirement already satisfied: pytz>=2017.3 in /home/kbuilder/.local/lib/python3.7/site-packages (from pandas->tensorflow-lattice) (2021.3)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: threadpoolctl>=2.0.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (3.0.0)\r\n",
"Requirement already satisfied: scipy>=1.1.0 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (1.7.1)\r\n",
"Requirement already satisfied: joblib>=0.11 in /home/kbuilder/.local/lib/python3.7/site-packages (from scikit-learn->tensorflow-lattice) (1.1.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Installing collected packages: tabulate, dm-tree, graphviz, dm-sonnet, tensorflow-lattice\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully installed dm-sonnet-2.0.0 dm-tree-0.1.6 graphviz-0.17 tabulate-0.8.9 tensorflow-lattice-2.0.9\r\n"
]
}
],
"source": [
"#@test {\"skip\": true}\n",
"!pip install tensorflow-lattice"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jSVl9SHTeSGX"
},
"source": [
"导入所需的软件包:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"cellView": "both",
"execution": {
"iopub.execute_input": "2021-10-08T22:39:36.087864Z",
"iopub.status.busy": "2021-10-08T22:39:36.087299Z",
"iopub.status.idle": "2021-10-08T22:39:38.352148Z",
"shell.execute_reply": "2021-10-08T22:39:38.351540Z"
},
"id": "P9rMpg1-ASY3"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"import logging\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sys\n",
"import tensorflow_lattice as tfl\n",
"from tensorflow import feature_column as fc\n",
"\n",
"from tensorflow_estimator.python.estimator.canned import optimizers\n",
"from tensorflow_estimator.python.estimator.head import binary_class_head\n",
"logging.disable(sys.maxsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "svPuM6QNxlrH"
},
"source": [
"下载 UCI Statlog (Heart) 数据集:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"cellView": "both",
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.357907Z",
"iopub.status.busy": "2021-10-08T22:39:38.357341Z",
"iopub.status.idle": "2021-10-08T22:39:38.372608Z",
"shell.execute_reply": "2021-10-08T22:39:38.372994Z"
},
"id": "M0CmH1gPASZF"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" sex | \n",
" cp | \n",
" trestbps | \n",
" chol | \n",
" fbs | \n",
" restecg | \n",
" thalach | \n",
" exang | \n",
" oldpeak | \n",
" slope | \n",
" ca | \n",
" thal | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 63 | \n",
" 1 | \n",
" 1 | \n",
" 145 | \n",
" 233 | \n",
" 1 | \n",
" 2 | \n",
" 150 | \n",
" 0 | \n",
" 2.3 | \n",
" 3 | \n",
" 0 | \n",
" fixed | \n",
"
\n",
" \n",
" 1 | \n",
" 67 | \n",
" 1 | \n",
" 4 | \n",
" 160 | \n",
" 286 | \n",
" 0 | \n",
" 2 | \n",
" 108 | \n",
" 1 | \n",
" 1.5 | \n",
" 2 | \n",
" 3 | \n",
" normal | \n",
"
\n",
" \n",
" 2 | \n",
" 67 | \n",
" 1 | \n",
" 4 | \n",
" 120 | \n",
" 229 | \n",
" 0 | \n",
" 2 | \n",
" 129 | \n",
" 1 | \n",
" 2.6 | \n",
" 2 | \n",
" 2 | \n",
" reversible | \n",
"
\n",
" \n",
" 3 | \n",
" 37 | \n",
" 1 | \n",
" 3 | \n",
" 130 | \n",
" 250 | \n",
" 0 | \n",
" 0 | \n",
" 187 | \n",
" 0 | \n",
" 3.5 | \n",
" 3 | \n",
" 0 | \n",
" normal | \n",
"
\n",
" \n",
" 4 | \n",
" 41 | \n",
" 0 | \n",
" 2 | \n",
" 130 | \n",
" 204 | \n",
" 0 | \n",
" 2 | \n",
" 172 | \n",
" 0 | \n",
" 1.4 | \n",
" 1 | \n",
" 0 | \n",
" normal | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n",
"0 63 1 1 145 233 1 2 150 0 2.3 3 \n",
"1 67 1 4 160 286 0 2 108 1 1.5 2 \n",
"2 67 1 4 120 229 0 2 129 1 2.6 2 \n",
"3 37 1 3 130 250 0 0 187 0 3.5 3 \n",
"4 41 0 2 130 204 0 2 172 0 1.4 1 \n",
"\n",
" ca thal \n",
"0 0 fixed \n",
"1 3 normal \n",
"2 2 reversible \n",
"3 0 normal \n",
"4 0 normal "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"csv_file = tf.keras.utils.get_file(\n",
" 'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')\n",
"df = pd.read_csv(csv_file)\n",
"target = df.pop('target')\n",
"train_size = int(len(df) * 0.8)\n",
"train_x = df[:train_size]\n",
"train_y = target[:train_size]\n",
"test_x = df[train_size:]\n",
"test_y = target[train_size:]\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nKkAw12SxvGG"
},
"source": [
"设置用于在本指南中进行训练的默认值:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"cellView": "both",
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.377145Z",
"iopub.status.busy": "2021-10-08T22:39:38.376554Z",
"iopub.status.idle": "2021-10-08T22:39:38.378495Z",
"shell.execute_reply": "2021-10-08T22:39:38.378067Z"
},
"id": "1T6GFI9F6mcG"
},
"outputs": [],
"source": [
"LEARNING_RATE = 0.1\n",
"BATCH_SIZE = 128\n",
"NUM_EPOCHS = 1000"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0TGfzhPHzpix"
},
"source": [
"## 特征列\n",
"\n",
"与任何其他 TF Estimator 一样,数据通常需要通过 input_fn 传递给 Estimator,并使用 [FeatureColumns](https://tensorflow.google.cn/guide/feature_columns) 进行解析。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.383125Z",
"iopub.status.busy": "2021-10-08T22:39:38.382520Z",
"iopub.status.idle": "2021-10-08T22:39:38.384707Z",
"shell.execute_reply": "2021-10-08T22:39:38.384168Z"
},
"id": "DCIUz8apzs0l"
},
"outputs": [],
"source": [
"# Feature columns.\n",
"# - age\n",
"# - sex\n",
"# - ca number of major vessels (0-3) colored by flourosopy\n",
"# - thal 3 = normal; 6 = fixed defect; 7 = reversable defect\n",
"feature_columns = [\n",
" fc.numeric_column('age', default_value=-1),\n",
" fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n",
" fc.numeric_column('ca'),\n",
" fc.categorical_column_with_vocabulary_list(\n",
" 'thal', ['normal', 'fixed', 'reversible']),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hEZstmtT2CA3"
},
"source": [
"请注意,分类特征不需要用密集特征列包装,因为 `tfl.laysers.CategoricalCalibration` 层可以直接使用分类索引。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H_LoW_9m5OFL"
},
"source": [
"## 创建 input_fn\n",
"\n",
"与任何其他 Estimator 一样,您可以使用 input_fn 将数据馈送给模型进行训练和评估。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.389638Z",
"iopub.status.busy": "2021-10-08T22:39:38.389044Z",
"iopub.status.idle": "2021-10-08T22:39:38.645699Z",
"shell.execute_reply": "2021-10-08T22:39:38.645102Z"
},
"id": "lFVy1Efy5NKD"
},
"outputs": [],
"source": [
"train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
" x=train_x,\n",
" y=train_y,\n",
" shuffle=True,\n",
" batch_size=BATCH_SIZE,\n",
" num_epochs=NUM_EPOCHS,\n",
" num_threads=1)\n",
"\n",
"test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
" x=test_x,\n",
" y=test_y,\n",
" shuffle=False,\n",
" batch_size=BATCH_SIZE,\n",
" num_epochs=1,\n",
" num_threads=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kbrgSr9KaRg0"
},
"source": [
"## 创建 model_fn\n",
"\n",
"您可以通过多种方式创建自定义 Estimator。在这里,我们将构造一个在已解析的输入张量上调用 Keras 模型的 `model_fn`。要解析输入特征,您可以使用 `tf.feature_column.input_layer`、`tf.keras.layers.DenseFeatures` 或 `tfl.estimators.transform_features`。如果使用后者,则不需要使用密集特征列包装分类特征,并且生成的张量不会串联,这样可以更轻松地在校准层中使用特征。\n",
"\n",
"要构造模型,您可以搭配使用 TFL 层或任何其他 Keras 层。在这里,我们从 TFL 层创建一个校准点阵 Keras 模型,并施加一些单调性约束。随后,我们使用 Keras 模型创建自定义 Estimator。\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.656023Z",
"iopub.status.busy": "2021-10-08T22:39:38.655425Z",
"iopub.status.idle": "2021-10-08T22:39:38.656901Z",
"shell.execute_reply": "2021-10-08T22:39:38.657225Z"
},
"id": "n2Zrv6OPaQO2"
},
"outputs": [],
"source": [
"def model_fn(features, labels, mode, config):\n",
" \"\"\"model_fn for the custom estimator.\"\"\"\n",
" del config\n",
" input_tensors = tfl.estimators.transform_features(features, feature_columns)\n",
" inputs = {\n",
" key: tf.keras.layers.Input(shape=(1,), name=key) for key in input_tensors\n",
" }\n",
"\n",
" lattice_sizes = [3, 2, 2, 2]\n",
" lattice_monotonicities = ['increasing', 'none', 'increasing', 'increasing']\n",
" lattice_input = tf.keras.layers.Concatenate(axis=1)([\n",
" tfl.layers.PWLCalibration(\n",
" input_keypoints=np.linspace(10, 100, num=8, dtype=np.float32),\n",
" # The output range of the calibrator should be the input range of\n",
" # the following lattice dimension.\n",
" output_min=0.0,\n",
" output_max=lattice_sizes[0] - 1.0,\n",
" monotonicity='increasing',\n",
" )(inputs['age']),\n",
" tfl.layers.CategoricalCalibration(\n",
" # Number of categories including any missing/default category.\n",
" num_buckets=2,\n",
" output_min=0.0,\n",
" output_max=lattice_sizes[1] - 1.0,\n",
" )(inputs['sex']),\n",
" tfl.layers.PWLCalibration(\n",
" input_keypoints=[0.0, 1.0, 2.0, 3.0],\n",
" output_min=0.0,\n",
" output_max=lattice_sizes[0] - 1.0,\n",
" # You can specify TFL regularizers as tuple\n",
" # ('regularizer name', l1, l2).\n",
" kernel_regularizer=('hessian', 0.0, 1e-4),\n",
" monotonicity='increasing',\n",
" )(inputs['ca']),\n",
" tfl.layers.CategoricalCalibration(\n",
" num_buckets=3,\n",
" output_min=0.0,\n",
" output_max=lattice_sizes[1] - 1.0,\n",
" # Categorical monotonicity can be partial order.\n",
" # (i, j) indicates that we must have output(i) <= output(j).\n",
" # Make sure to set the lattice monotonicity to 'increasing' for this\n",
" # dimension.\n",
" monotonicities=[(0, 1), (0, 2)],\n",
" )(inputs['thal']),\n",
" ])\n",
" output = tfl.layers.Lattice(\n",
" lattice_sizes=lattice_sizes, monotonicities=lattice_monotonicities)(\n",
" lattice_input)\n",
"\n",
" training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
" model = tf.keras.Model(inputs=inputs, outputs=output)\n",
" logits = model(input_tensors, training=training)\n",
"\n",
" if training:\n",
" optimizer = optimizers.get_optimizer_instance_v2('Adagrad', LEARNING_RATE)\n",
" else:\n",
" optimizer = None\n",
"\n",
" head = binary_class_head.BinaryClassHead()\n",
" return head.create_estimator_spec(\n",
" features=features,\n",
" mode=mode,\n",
" labels=labels,\n",
" optimizer=optimizer,\n",
" logits=logits,\n",
" trainable_variables=model.trainable_variables,\n",
" update_ops=model.updates)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mng-VtsSbVtQ"
},
"source": [
"## 训练和 Estimator\n",
"\n",
"使用 `model_fn`,我们可以创建和训练 Estimator。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2021-10-08T22:39:38.661528Z",
"iopub.status.busy": "2021-10-08T22:39:38.660924Z",
"iopub.status.idle": "2021-10-08T22:40:04.554406Z",
"shell.execute_reply": "2021-10-08T22:40:04.554845Z"
},
"id": "j38GaEbKbZju"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC: 0.5701754689216614\n"
]
}
],
"source": [
"estimator = tf.estimator.Estimator(model_fn=model_fn)\n",
"estimator.train(input_fn=train_input_fn)\n",
"results = estimator.evaluate(input_fn=test_input_fn)\n",
"print('AUC: {}'.format(results['auc']))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "custom_estimators.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 0
}