{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "d6p8EySq1zXZ" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "KsOkK8O69PyT" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "F1xIRPtY0E1w" }, "source": [ "# Create an Estimator from a Keras model" ] }, { "cell_type": "markdown", "metadata": { "id": "r61fkA2i9Y3_" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Dhcq8Ds4mCtm" }, "source": [ "> Warning: TensorFlow 2.15 included the final release of the `tf-estimator` package. Estimators will not be available in TensorFlow 2.16 or after. See the [migration guide](https://tensorflow.org/guide/migrate/migrating_estimator) for more information about how to convert off of Estimators." ] }, { "cell_type": "markdown", "metadata": { "id": "ZaGcclVLwqDS" }, "source": [ "## Overview\n", "\n", "TensorFlow Estimators are supported in TensorFlow, and can be created from new and existing `tf.keras` models. This tutorial contains a complete, minimal example of that process.\n", "\n", "Note: If you have a Keras model, you can use it directly with [`tf.distribute` strategies](https://tensorflow.org/guide/migrate/guide/distributed_training) without converting it to an estimator. As such, `model_to_estimator` is no longer recommended." ] }, { "cell_type": "markdown", "metadata": { "id": "epgfaZmO2vF0" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qmq4FzaztASN" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import numpy as np\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "9ZUATGJGtQIU" }, "source": [ "### Create a simple Keras model." ] }, { "cell_type": "markdown", "metadata": { "id": "rR-zPidHyzcb" }, "source": [ "In Keras, you assemble *layers* to build *models*. A model is (usually) a graph\n", "of layers. The most common type of model is a stack of layers: the\n", "`tf.keras.Sequential` model.\n", "\n", "To build a simple, fully-connected network (i.e. multi-layer perceptron):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p5NSx38itD1a" }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),\n", " tf.keras.layers.Dropout(0.2),\n", " tf.keras.layers.Dense(3)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "ABgo9-8BtYNs" }, "source": [ "Compile the model and get a summary." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nViACuBDtVEC" }, "outputs": [], "source": [ "model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer='adam')\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "pM3Cx5Fm_sHI" }, "source": [ "### Create an input function\n", "\n", "Use the [Datasets API](../../guide/data.md) to scale to large datasets\n", "or multi-device training.\n", "\n", "Estimators need control of when and how their input pipeline is built. To allow this, they require an \"Input function\" or `input_fn`. The `Estimator` will call this function with no arguments. The `input_fn` must return a `tf.data.Dataset`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H0DpLEop_x0o" }, "outputs": [], "source": [ "def input_fn():\n", " split = tfds.Split.TRAIN\n", " dataset = tfds.load('iris', split=split, as_supervised=True)\n", " dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))\n", " dataset = dataset.batch(32).repeat()\n", " return dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "UR1vRw1bBFjo" }, "source": [ "Test out your `input_fn`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WO94bGYKBKRv" }, "outputs": [], "source": [ "for features_batch, labels_batch in input_fn().take(1):\n", " print(features_batch)\n", " print(labels_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "svdhkQ4Otcv0" }, "source": [ "### Create an Estimator from the tf.keras model.\n", "\n", "A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the\n", "model to an `tf.estimator.Estimator` object with\n", "`tf.keras.estimator.model_to_estimator`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "roChngg8t7il" }, "outputs": [], "source": [ "import tempfile\n", "model_dir = tempfile.mkdtemp()\n", "keras_estimator = tf.keras.estimator.model_to_estimator(\n", " keras_model=model, model_dir=model_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "U-8ekW5It_2w" }, "source": [ "Train and evaluate the estimator." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ouIkVtp9uAg5" }, "outputs": [], "source": [ "keras_estimator.train(input_fn=input_fn, steps=500)\n", "eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)\n", "print('Eval result: {}'.format(eval_result))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "keras_model_to_estimator.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }