{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "l-23gBrt4x2B" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-17T02:27:01.875880Z", "iopub.status.busy": "2024-01-17T02:27:01.875646Z", "iopub.status.idle": "2024-01-17T02:27:01.879293Z", "shell.execute_reply": "2024-01-17T02:27:01.878760Z" }, "id": "HMUDt0CiUJk9" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# Migrate `tf.feature_column`s to Keras preprocessing layers\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "-5jGPDA2PDPI" }, "source": [ "Training a model usually comes with some amount of feature preprocessing, particularly when dealing with structured data. When training a `tf.estimator.Estimator` in TensorFlow 1, you usually perform feature preprocessing with the `tf.feature_column` API. In TensorFlow 2, you can do this directly with Keras preprocessing layers.\n", "\n", "This migration guide demonstrates common feature transformations using both feature columns and preprocessing layers, followed by training a complete model with both APIs.\n", "\n", "First, start with a couple of necessary imports:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:01.882605Z", "iopub.status.busy": "2024-01-17T02:27:01.882348Z", "iopub.status.idle": "2024-01-17T02:27:04.250967Z", "shell.execute_reply": "2024-01-17T02:27:04.250220Z" }, "id": "iE0vSfMXumKI" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-17 02:27:02.309609: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-17 02:27:02.309651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-17 02:27:02.311142: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1\n", "import math" ] }, { "cell_type": "markdown", "metadata": { "id": "NVPYTQAWtDwH" }, "source": [ "Now, add a utility function for calling a feature column for demonstration:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:04.255121Z", "iopub.status.busy": "2024-01-17T02:27:04.254689Z", "iopub.status.idle": "2024-01-17T02:27:04.258374Z", "shell.execute_reply": "2024-01-17T02:27:04.257802Z" }, "id": "LAaifuuytJjM" }, "outputs": [], "source": [ "def call_feature_columns(feature_columns, inputs):\n", " # This is a convenient way to call a `feature_column` outside of an estimator\n", " # to display its output.\n", " feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)\n", " return feature_layer(inputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZJnw07hYDGYt" }, "source": [ "## Input handling\n", "\n", "To use feature columns with an estimator, model inputs are always expected to be a dictionary of tensors:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:04.261774Z", "iopub.status.busy": "2024-01-17T02:27:04.261156Z", "iopub.status.idle": "2024-01-17T02:27:06.525904Z", "shell.execute_reply": "2024-01-17T02:27:06.525047Z" }, "id": "y0WUpQxsKEzf" }, "outputs": [], "source": [ "input_dict = {\n", " 'foo': tf.constant([1]),\n", " 'bar': tf.constant([0]),\n", " 'baz': tf.constant([-1])\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "xYsC6H_BJ8l3" }, "source": [ "Each feature column needs to be created with a key to index into the source data. The output of all feature columns is concatenated and used by the estimator model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:06.529342Z", "iopub.status.busy": "2024-01-17T02:27:06.529072Z", "iopub.status.idle": "2024-01-17T02:27:06.575826Z", "shell.execute_reply": "2024-01-17T02:27:06.575225Z" }, "id": "3fvIe3V8Ffjt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3124623333.py:2: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "columns = [\n", " tf1.feature_column.numeric_column('foo'),\n", " tf1.feature_column.numeric_column('bar'),\n", " tf1.feature_column.numeric_column('baz'),\n", "]\n", "call_feature_columns(columns, input_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "hvPfCK2XGTyl" }, "source": [ "In Keras, model input is much more flexible. A `tf.keras.Model` can handle a single tensor input, a list of tensor features, or a dictionary of tensor features. You can handle dictionary input by passing a dictionary of `tf.keras.Input` on model creation. Inputs will not be concatenated automatically, which allows them to be used in much more flexible ways. They can be concatenated with `tf.keras.layers.Concatenate`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:06.578884Z", "iopub.status.busy": "2024-01-17T02:27:06.578661Z", "iopub.status.idle": "2024-01-17T02:27:06.613416Z", "shell.execute_reply": "2024-01-17T02:27:06.612779Z" }, "id": "5sYWENkgLWJ2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = {\n", " 'foo': tf.keras.Input(shape=()),\n", " 'bar': tf.keras.Input(shape=()),\n", " 'baz': tf.keras.Input(shape=()),\n", "}\n", "# Inputs are typically transformed by preprocessing layers before concatenation.\n", "outputs = tf.keras.layers.Concatenate()(inputs.values())\n", "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", "model(input_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "GXkmiuwXTS-B" }, "source": [ "## One-hot encoding integer IDs\n", "\n", "A common feature transformation is one-hot encoding integer inputs of a known range. Here is an example using feature columns:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:06.616575Z", "iopub.status.busy": "2024-01-17T02:27:06.616329Z", "iopub.status.idle": "2024-01-17T02:27:06.660374Z", "shell.execute_reply": "2024-01-17T02:27:06.659802Z" }, "id": "XasXzOgatgRF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1369923821.py:1: categorical_column_with_identity (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1369923821.py:3: indicator_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_col = tf1.feature_column.categorical_column_with_identity(\n", " 'type', num_buckets=3)\n", "indicator_col = tf1.feature_column.indicator_column(categorical_col)\n", "call_feature_columns(indicator_col, {'type': [0, 1, 2]})" ] }, { "cell_type": "markdown", "metadata": { "id": "iSCkJEQ6U-ru" }, "source": [ "Using Keras preprocessing layers, these columns can be replaced by a single `tf.keras.layers.CategoryEncoding` layer with `output_mode` set to `'one_hot'`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:06.663879Z", "iopub.status.busy": "2024-01-17T02:27:06.663319Z", "iopub.status.idle": "2024-01-17T02:27:07.021599Z", "shell.execute_reply": "2024-01-17T02:27:07.020883Z" }, "id": "799lbMNNuAVz" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "one_hot_layer = tf.keras.layers.CategoryEncoding(\n", " num_tokens=3, output_mode='one_hot')\n", "one_hot_layer([0, 1, 2])" ] }, { "cell_type": "markdown", "metadata": { "id": "kNzRtESU7tga" }, "source": [ "Note: For large one-hot encodings, it is much more efficient to use a sparse representation of the output. If you pass `sparse=True` to the `CategoryEncoding` layer, the output of the layer will be a `tf.sparse.SparseTensor`, which can be efficiently handled as input to a `tf.keras.layers.Dense` layer." ] }, { "cell_type": "markdown", "metadata": { "id": "Zf7kjhTiAErK" }, "source": [ "## Normalizing numeric features\n", "\n", "When handling continuous, floating-point features with feature columns, you need to use a `tf.feature_column.numeric_column`. In the case where the input is already normalized, converting this to Keras is trivial. You can simply use a `tf.keras.Input` directly into your model, as shown above.\n", "\n", "A `numeric_column` can also be used to normalize input:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:07.025387Z", "iopub.status.busy": "2024-01-17T02:27:07.024809Z", "iopub.status.idle": "2024-01-17T02:27:07.037159Z", "shell.execute_reply": "2024-01-17T02:27:07.036552Z" }, "id": "HbTMGB9XctGx" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def normalize(x):\n", " mean, variance = (2.0, 1.0)\n", " return (x - mean) / math.sqrt(variance)\n", "numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)\n", "call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})" ] }, { "cell_type": "markdown", "metadata": { "id": "M9cyhPR_drOz" }, "source": [ "In contrast, with Keras, this normalization can be done with `tf.keras.layers.Normalization`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:07.040195Z", "iopub.status.busy": "2024-01-17T02:27:07.039953Z", "iopub.status.idle": "2024-01-17T02:27:07.462275Z", "shell.execute_reply": "2024-01-17T02:27:07.461557Z" }, "id": "8bcgG-yOdqUH" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)\n", "normalization_layer(tf.constant([[0.], [1.], [2.]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "d1InD_4QLKU-" }, "source": [ "## Bucketizing and one-hot encoding numeric features" ] }, { "cell_type": "markdown", "metadata": { "id": "k5e0b8iOLRzd" }, "source": [ "Another common transformation of continuous, floating point inputs is to bucketize then to integers of a fixed range.\n", "\n", "In feature columns, this can be achieved with a `tf.feature_column.bucketized_column`:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:07.465930Z", "iopub.status.busy": "2024-01-17T02:27:07.465668Z", "iopub.status.idle": "2024-01-17T02:27:07.478832Z", "shell.execute_reply": "2024-01-17T02:27:07.478262Z" }, "id": "_rbx6qQ-LQx7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3043215186.py:2: bucketized_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numeric_col = tf1.feature_column.numeric_column('col')\n", "bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])\n", "call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})\n" ] }, { "cell_type": "markdown", "metadata": { "id": "PCYu-XtwXahx" }, "source": [ "In Keras, this can be replaced by `tf.keras.layers.Discretization`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:07.482361Z", "iopub.status.busy": "2024-01-17T02:27:07.481781Z", "iopub.status.idle": "2024-01-17T02:27:08.152302Z", "shell.execute_reply": "2024-01-17T02:27:08.151545Z" }, "id": "QK1WOG2uVVsL" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])\n", "one_hot_layer = tf.keras.layers.CategoryEncoding(\n", " num_tokens=4, output_mode='one_hot')\n", "one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))" ] }, { "cell_type": "markdown", "metadata": { "id": "5bm9tJZAgpt4" }, "source": [ "## One-hot encoding string data with a vocabulary\n", "\n", "Handling string features often requires a vocabulary lookup to translate strings into indices. Here is an example using feature columns to lookup strings and then one-hot encode the indices:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.156109Z", "iopub.status.busy": "2024-01-17T02:27:08.155857Z", "iopub.status.idle": "2024-01-17T02:27:08.177359Z", "shell.execute_reply": "2024-01-17T02:27:08.176789Z" }, "id": "3fG_igjhukCO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/2845961037.py:1: categorical_column_with_vocabulary_list (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(\n", " 'sizes',\n", " vocabulary_list=['small', 'medium', 'large'],\n", " num_oov_buckets=0)\n", "indicator_col = tf1.feature_column.indicator_column(vocab_col)\n", "call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})" ] }, { "cell_type": "markdown", "metadata": { "id": "8rBgllRtY738" }, "source": [ "Using Keras preprocessing layers, use the `tf.keras.layers.StringLookup` layer with `output_mode` set to `'one_hot'`:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.180868Z", "iopub.status.busy": "2024-01-17T02:27:08.180371Z", "iopub.status.idle": "2024-01-17T02:27:08.198639Z", "shell.execute_reply": "2024-01-17T02:27:08.198021Z" }, "id": "arnPlSrWvDMe" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string_lookup_layer = tf.keras.layers.StringLookup(\n", " vocabulary=['small', 'medium', 'large'],\n", " num_oov_indices=0,\n", " output_mode='one_hot')\n", "string_lookup_layer(['small', 'medium', 'large'])" ] }, { "cell_type": "markdown", "metadata": { "id": "f76MVVYO8LB5" }, "source": [ "Note: For large one-hot encodings, it is much more efficient to use a sparse representation of the output. If you pass `sparse=True` to the `StringLookup` layer, the output of the layer will be a `tf.sparse.SparseTensor`, which can be efficiently handled as input to a `tf.keras.layers.Dense` layer." ] }, { "cell_type": "markdown", "metadata": { "id": "c1CmfSXQZHE5" }, "source": [ "## Embedding string data with a vocabulary\n", "\n", "For larger vocabularies, an embedding is often needed for good performance. Here is an example embedding a string feature using feature columns:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.202553Z", "iopub.status.busy": "2024-01-17T02:27:08.201882Z", "iopub.status.idle": "2024-01-17T02:27:08.401915Z", "shell.execute_reply": "2024-01-17T02:27:08.401166Z" }, "id": "C3RK4HFazxlU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/999372599.py:5: embedding_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(\n", " 'col',\n", " vocabulary_list=['small', 'medium', 'large'],\n", " num_oov_buckets=0)\n", "embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)\n", "call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})" ] }, { "cell_type": "markdown", "metadata": { "id": "3aTRVJ6qZZH0" }, "source": [ "Using Keras preprocessing layers, this can be achieved by combining a `tf.keras.layers.StringLookup` layer and an `tf.keras.layers.Embedding` layer. The default output for the `StringLookup` will be integer indices which can be fed directly into an embedding.\n", "\n", "Note: The `Embedding` layer contains trainable parameters. While the `StringLookup` layer can be applied to data inside or outside of a model, the `Embedding` must always be part of a trainable Keras model to train correctly." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.405821Z", "iopub.status.busy": "2024-01-17T02:27:08.405261Z", "iopub.status.idle": "2024-01-17T02:27:08.425895Z", "shell.execute_reply": "2024-01-17T02:27:08.425206Z" }, "id": "8resGZPo0Fho" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "string_lookup_layer = tf.keras.layers.StringLookup(\n", " vocabulary=['small', 'medium', 'large'], num_oov_indices=0)\n", "embedding = tf.keras.layers.Embedding(3, 4)\n", "embedding(string_lookup_layer(['small', 'medium', 'large']))" ] }, { "cell_type": "markdown", "metadata": { "id": "UwqvADV6HRdC" }, "source": [ "## Summing weighted categorical data\n", "\n", "In some cases, you need to deal with categorical data where each occurance of a category comes with an associated weight. In feature columns, this is handled with `tf.feature_column.weighted_categorical_column`. When paired with an `indicator_column`, this has the effect of summing weights per category." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.429345Z", "iopub.status.busy": "2024-01-17T02:27:08.428814Z", "iopub.status.idle": "2024-01-17T02:27:08.494783Z", "shell.execute_reply": "2024-01-17T02:27:08.494167Z" }, "id": "02HqjPLMRxWn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/3529191023.py:6: weighted_categorical_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4033: sparse_merge (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "No similar op available at this time.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids = tf.constant([[5, 11, 5, 17, 17]])\n", "weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])\n", "\n", "categorical_col = tf1.feature_column.categorical_column_with_identity(\n", " 'ids', num_buckets=20)\n", "weighted_categorical_col = tf1.feature_column.weighted_categorical_column(\n", " categorical_col, 'weights')\n", "indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)\n", "call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})" ] }, { "cell_type": "markdown", "metadata": { "id": "98jaq7Q3S9aG" }, "source": [ "In Keras, this can be done by passing a `count_weights` input to `tf.keras.layers.CategoryEncoding` with `output_mode='count'`." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.498304Z", "iopub.status.busy": "2024-01-17T02:27:08.497600Z", "iopub.status.idle": "2024-01-17T02:27:08.514747Z", "shell.execute_reply": "2024-01-17T02:27:08.514143Z" }, "id": "JsoYUUgRS7hu" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids = tf.constant([[5, 11, 5, 17, 17]])\n", "weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])\n", "\n", "# Using sparse output is more efficient when `num_tokens` is large.\n", "count_layer = tf.keras.layers.CategoryEncoding(\n", " num_tokens=20, output_mode='count', sparse=True)\n", "tf.sparse.to_dense(count_layer(ids, count_weights=weights))" ] }, { "cell_type": "markdown", "metadata": { "id": "gBJxb6y2GasI" }, "source": [ "## Embedding weighted categorical data\n", "\n", "You might alternately want to embed weighted categorical inputs. In feature columns, the `embedding_column` contains a `combiner` argument. If any sample\n", "contains multiple entries for a category, they will be combined according to the argument setting (by default `'mean'`)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.517806Z", "iopub.status.busy": "2024-01-17T02:27:08.517576Z", "iopub.status.idle": "2024-01-17T02:27:08.589968Z", "shell.execute_reply": "2024-01-17T02:27:08.589387Z" }, "id": "AjOt1wgmT5mM" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids = tf.constant([[5, 11, 5, 17, 17]])\n", "weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])\n", "\n", "categorical_col = tf1.feature_column.categorical_column_with_identity(\n", " 'ids', num_buckets=20)\n", "weighted_categorical_col = tf1.feature_column.weighted_categorical_column(\n", " categorical_col, 'weights')\n", "embedding_col = tf1.feature_column.embedding_column(\n", " weighted_categorical_col, 4, combiner='mean')\n", "call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})" ] }, { "cell_type": "markdown", "metadata": { "id": "fd6eluARXndC" }, "source": [ "In Keras, there is no `combiner` option to `tf.keras.layers.Embedding`, but you can achieve the same effect with `tf.keras.layers.Dense`. The `embedding_column` above is simply linearly combining embedding vectors according to category weight. Though not obvious at first, it is exactly equivalent to representing your categorical inputs as a sparse weight vector of size `(num_tokens)`, and multiplying them by a `Dense` kernel of shape `(embedding_size, num_tokens)`." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.593373Z", "iopub.status.busy": "2024-01-17T02:27:08.592780Z", "iopub.status.idle": "2024-01-17T02:27:08.613326Z", "shell.execute_reply": "2024-01-17T02:27:08.612727Z" }, "id": "Y-vZvPyiYilE" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids = tf.constant([[5, 11, 5, 17, 17]])\n", "weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])\n", "\n", "# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line\n", "# would be equivalent to an `embedding_column` with `combiner='sum'`.\n", "weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)\n", "\n", "count_layer = tf.keras.layers.CategoryEncoding(\n", " num_tokens=20, output_mode='count', sparse=True)\n", "embedding_layer = tf.keras.layers.Dense(4, use_bias=False)\n", "embedding_layer(count_layer(ids, count_weights=weights))" ] }, { "cell_type": "markdown", "metadata": { "id": "3I5loEx80MVm" }, "source": [ "## Complete training example\n", "\n", "To show a complete training workflow, first prepare some data with three features of different types:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.616464Z", "iopub.status.busy": "2024-01-17T02:27:08.615930Z", "iopub.status.idle": "2024-01-17T02:27:08.619465Z", "shell.execute_reply": "2024-01-17T02:27:08.618920Z" }, "id": "D_7nyBee0ZBV" }, "outputs": [], "source": [ "features = {\n", " 'type': [0, 1, 1],\n", " 'size': ['small', 'small', 'medium'],\n", " 'weight': [2.7, 1.8, 1.6],\n", "}\n", "labels = [1, 1, 0]\n", "predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}" ] }, { "cell_type": "markdown", "metadata": { "id": "e_4Xx2c37lqD" }, "source": [ "Define some common constants for both TensorFlow 1 and TensorFlow 2 workflows:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.622689Z", "iopub.status.busy": "2024-01-17T02:27:08.622081Z", "iopub.status.idle": "2024-01-17T02:27:08.625362Z", "shell.execute_reply": "2024-01-17T02:27:08.624768Z" }, "id": "3cyfQZ7z8jZh" }, "outputs": [], "source": [ "vocab = ['small', 'medium', 'large']\n", "one_hot_dims = 3\n", "embedding_dims = 4\n", "weight_mean = 2.0\n", "weight_variance = 1.0" ] }, { "cell_type": "markdown", "metadata": { "id": "ywCgU7CMIfTH" }, "source": [ "### With feature columns\n", "\n", "Feature columns must be passed as a list to the estimator on creation, and will be called implicitly during training." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:08.628510Z", "iopub.status.busy": "2024-01-17T02:27:08.628017Z", "iopub.status.idle": "2024-01-17T02:27:11.187434Z", "shell.execute_reply": "2024-01-17T02:27:11.186821Z" }, "id": "Wsdhlm-uipr1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_19805/1997355744.py:17: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpm02fidwe\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpm02fidwe', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Create CheckpointSaverHook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-01-17 02:27:09.322220: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT64\n", " }\n", " }\n", "}\n", " is neither a subtype nor a supertype of the combined inputs preceding it:\n", "type_id: TFT_OPTIONAL\n", "args {\n", " type_id: TFT_PRODUCT\n", " args {\n", " type_id: TFT_TENSOR\n", " args {\n", " type_id: TFT_INT32\n", " }\n", " }\n", "}\n", "\n", "\tfor Tuple type infernce function 0\n", "\twhile inferring type of node 'dnn/zero_fraction/cond/output/_18'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpm02fidwe/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 0.6289518, step = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpm02fidwe/model.ckpt.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 0.81661654.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_col = tf1.feature_column.categorical_column_with_identity(\n", " 'type', num_buckets=one_hot_dims)\n", "# Convert index to one-hot; e.g., [2] -> [0,0,1].\n", "indicator_col = tf1.feature_column.indicator_column(categorical_col)\n", "\n", "# Convert strings to indices; e.g., ['small'] -> [1].\n", "vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(\n", " 'size', vocabulary_list=vocab, num_oov_buckets=1)\n", "# Embed the indices.\n", "embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)\n", "\n", "normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)\n", "# Normalize the numeric inputs; e.g., [2.0] -> [0.0].\n", "numeric_col = tf1.feature_column.numeric_column(\n", " 'weight', normalizer_fn=normalizer_fn)\n", "\n", "estimator = tf1.estimator.DNNClassifier(\n", " feature_columns=[indicator_col, embedding_col, numeric_col],\n", " hidden_units=[1])\n", "\n", "def _input_fn():\n", " return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n", "\n", "estimator.train(_input_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "qPIeG_YtfNV1" }, "source": [ "The feature columns will also be used to transform input data when running inference on the model." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:11.190723Z", "iopub.status.busy": "2024-01-17T02:27:11.190426Z", "iopub.status.idle": "2024-01-17T02:27:12.489244Z", "shell.execute_reply": "2024-01-17T02:27:12.488471Z" }, "id": "K-AIIB8CfSqt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:596: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1307: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1309: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpm02fidwe/model.ckpt-3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "data": { "text/plain": [ "{'logits': array([0.57964283], dtype=float32),\n", " 'logistic': array([0.6409852], dtype=float32),\n", " 'probabilities': array([0.35901475, 0.6409852 ], dtype=float32),\n", " 'class_ids': array([1]),\n", " 'classes': array([b'1'], dtype=object),\n", " 'all_class_ids': array([0, 1], dtype=int32),\n", " 'all_classes': array([b'0', b'1'], dtype=object)}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _predict_fn():\n", " return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)\n", "\n", "next(estimator.predict(_predict_fn))" ] }, { "cell_type": "markdown", "metadata": { "id": "baMA01cBIivo" }, "source": [ "### With Keras preprocessing layers\n", "\n", "Keras preprocessing layers are more flexible in where they can be called. A layer can be applied directly to tensors, used inside a `tf.data` input pipeline, or built directly into a trainable Keras model.\n", "\n", "In this example, you will apply preprocessing layers inside a `tf.data` input pipeline. To do this, you can define a separate `tf.keras.Model` to preprocess your input features. This model is not trainable, but is a convenient way to group preprocessing layers." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:12.493018Z", "iopub.status.busy": "2024-01-17T02:27:12.492459Z", "iopub.status.idle": "2024-01-17T02:27:12.535791Z", "shell.execute_reply": "2024-01-17T02:27:12.535150Z" }, "id": "NMz8RfMQdCZf" }, "outputs": [], "source": [ "inputs = {\n", " 'type': tf.keras.Input(shape=(), dtype='int64'),\n", " 'size': tf.keras.Input(shape=(), dtype='string'),\n", " 'weight': tf.keras.Input(shape=(), dtype='float32'),\n", "}\n", "# Convert index to one-hot; e.g., [2] -> [0,0,1].\n", "type_output = tf.keras.layers.CategoryEncoding(\n", " one_hot_dims, output_mode='one_hot')(inputs['type'])\n", "# Convert size strings to indices; e.g., ['small'] -> [1].\n", "size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])\n", "# Normalize the numeric inputs; e.g., [2.0] -> [0.0].\n", "weight_output = tf.keras.layers.Normalization(\n", " axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])\n", "outputs = {\n", " 'type': type_output,\n", " 'size': size_output,\n", " 'weight': weight_output,\n", "}\n", "preprocessing_model = tf.keras.Model(inputs, outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "NRfISnj3NGlW" }, "source": [ "Note: As an alternative to supplying a vocabulary and normalization statistics on layer creation, many preprocessing layers provide an `adapt()` method for learning layer state directly from the input data. See the [preprocessing guide](https://www.tensorflow.org/guide/keras/preprocessing_layers#the_adapt_method) for more details.\n", "\n", "You can now apply this model inside a call to `tf.data.Dataset.map`. Please note that the function passed to `map` will automatically be converted into\n", "a `tf.function`, and usual caveats for writing `tf.function` code apply (no side effects)." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:12.539126Z", "iopub.status.busy": "2024-01-17T02:27:12.538887Z", "iopub.status.idle": "2024-01-17T02:27:12.635792Z", "shell.execute_reply": "2024-01-17T02:27:12.635132Z" }, "id": "c_6xAUnbNREh" }, "outputs": [ { "data": { "text/plain": [ "({'type': array([[1., 0., 0.]], dtype=float32),\n", " 'size': array([1]),\n", " 'weight': array([0.70000005], dtype=float32)},\n", " array([1], dtype=int32))" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Apply the preprocessing in tf.data.Dataset.map.\n", "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n", "dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),\n", " num_parallel_calls=tf.data.AUTOTUNE)\n", "# Display a preprocessed input sample.\n", "next(dataset.take(1).as_numpy_iterator())" ] }, { "cell_type": "markdown", "metadata": { "id": "8_4u3J4NdJ8R" }, "source": [ "Next, you can define a separate `Model` containing the trainable layers. Note how the inputs to this model now reflect the preprocessed feature types and shapes." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:12.639420Z", "iopub.status.busy": "2024-01-17T02:27:12.638794Z", "iopub.status.idle": "2024-01-17T02:27:12.673045Z", "shell.execute_reply": "2024-01-17T02:27:12.672404Z" }, "id": "kC9OZO5ldmP-" }, "outputs": [], "source": [ "inputs = {\n", " 'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),\n", " 'size': tf.keras.Input(shape=(), dtype='int64'),\n", " 'weight': tf.keras.Input(shape=(), dtype='float32'),\n", "}\n", "# Since the embedding is trainable, it needs to be part of the training model.\n", "embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)\n", "outputs = tf.keras.layers.Concatenate()([\n", " inputs['type'],\n", " embedding(inputs['size']),\n", " tf.expand_dims(inputs['weight'], -1),\n", "])\n", "outputs = tf.keras.layers.Dense(1)(outputs)\n", "training_model = tf.keras.Model(inputs, outputs)" ] }, { "cell_type": "markdown", "metadata": { "id": "ir-cn2H_d5R7" }, "source": [ "You can now train the `training_model` with `tf.keras.Model.fit`." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:12.676567Z", "iopub.status.busy": "2024-01-17T02:27:12.676060Z", "iopub.status.idle": "2024-01-17T02:27:13.797432Z", "shell.execute_reply": "2024-01-17T02:27:13.796715Z" }, "id": "6TS3YJ2vnvlW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/3 [=========>....................] - ETA: 2s - loss: 1.0855" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/3 [==============================] - 1s 5ms/step - loss: 0.8194\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1705458433.603835 19973 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train on the preprocessed data.\n", "training_model.compile(\n", " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))\n", "training_model.fit(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "pSaEbOE4ecsy" }, "source": [ "Finally, at inference time, it can be useful to combine these separate stages into a single model that handles raw feature inputs." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:13.801287Z", "iopub.status.busy": "2024-01-17T02:27:13.800670Z", "iopub.status.idle": "2024-01-17T02:27:13.958467Z", "shell.execute_reply": "2024-01-17T02:27:13.957809Z" }, "id": "QHjbIZYneboO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 102ms/step\n" ] }, { "data": { "text/plain": [ "array([[-0.95717776]], dtype=float32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = preprocessing_model.input\n", "outputs = training_model(preprocessing_model(inputs))\n", "inference_model = tf.keras.Model(inputs, outputs)\n", "\n", "predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)\n", "inference_model.predict(predict_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "O01VQIxCWBxU" }, "source": [ "This composed model can be saved as a `.keras` file for later use." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-01-17T02:27:13.962170Z", "iopub.status.busy": "2024-01-17T02:27:13.961591Z", "iopub.status.idle": "2024-01-17T02:27:14.233836Z", "shell.execute_reply": "2024-01-17T02:27:14.233163Z" }, "id": "6tsyVZgh7Pve" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 79ms/step\n" ] }, { "data": { "text/plain": [ "array([[-0.95717776]], dtype=float32)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inference_model.save('model.keras')\n", "restored_model = tf.keras.models.load_model('model.keras')\n", "restored_model.predict(predict_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "IXMBwzggwUjI" }, "source": [ "Note: Preprocessing layers are not trainable, which allows you to apply them *asynchronously* using `tf.data`. This has performance benefits, as you can both prefetch preprocessed batches, and free up any accelerators to focus on the differentiable parts of a model (learn more in the _Prefetching_ section of the [Better performance with the `tf.data` API](../data_performance.ipynb) guide). As this guide shows, separating preprocessing during training and composing it during inference is a flexible way to leverage these performance gains. However, if your model is small or preprocessing time is negligible, it may be simpler to build preprocessing into a complete model from the start. To do this you can build a single model starting with `tf.keras.Input`, followed by preprocessing layers, followed by trainable layers." ] }, { "cell_type": "markdown", "metadata": { "id": "2pjp7Z18gRCQ" }, "source": [ "## Feature column equivalence table\n", "\n", "For reference, here is an approximate correspondence between feature columns and\n", "Keras preprocessing layers:\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature columnKeras layer
`tf.feature_column.bucketized_column``tf.keras.layers.Discretization`
`tf.feature_column.categorical_column_with_hash_bucket``tf.keras.layers.Hashing`
`tf.feature_column.categorical_column_with_identity``tf.keras.layers.CategoryEncoding`
`tf.feature_column.categorical_column_with_vocabulary_file``tf.keras.layers.StringLookup` or `tf.keras.layers.IntegerLookup`
`tf.feature_column.categorical_column_with_vocabulary_list``tf.keras.layers.StringLookup` or `tf.keras.layers.IntegerLookup`
`tf.feature_column.crossed_column``tf.keras.layers.experimental.preprocessing.HashedCrossing`
`tf.feature_column.embedding_column``tf.keras.layers.Embedding`
`tf.feature_column.indicator_column``output_mode='one_hot'` or `output_mode='multi_hot'`*
`tf.feature_column.numeric_column``tf.keras.layers.Normalization`
`tf.feature_column.sequence_categorical_column_with_hash_bucket``tf.keras.layers.Hashing`
`tf.feature_column.sequence_categorical_column_with_identity``tf.keras.layers.CategoryEncoding`
`tf.feature_column.sequence_categorical_column_with_vocabulary_file``tf.keras.layers.StringLookup`, `tf.keras.layers.IntegerLookup`, or `tf.keras.layer.TextVectorization`†
`tf.feature_column.sequence_categorical_column_with_vocabulary_list``tf.keras.layers.StringLookup`, `tf.keras.layers.IntegerLookup`, or `tf.keras.layer.TextVectorization`†
`tf.feature_column.sequence_numeric_column``tf.keras.layers.Normalization`
`tf.feature_column.weighted_categorical_column``tf.keras.layers.CategoryEncoding`
\n", "\n", "\\* The `output_mode` can be passed to `tf.keras.layers.CategoryEncoding`, `tf.keras.layers.StringLookup`, `tf.keras.layers.IntegerLookup`, and `tf.keras.layers.TextVectorization`.\n", "\n", "† `tf.keras.layers.TextVectorization` can handle freeform text input directly (for example, entire sentences or paragraphs). This is not one-to-one replacement for categorical sequence handling in TensorFlow 1, but may offer a convenient replacement for ad-hoc text preprocessing.\n", "\n", "Note: Linear estimators, such as `tf.estimator.LinearClassifier`, can handle direct categorical input (integer indices) without an `embedding_column` or `indicator_column`. However, integer indices cannot be passed directly to `tf.keras.layers.Dense` or `tf.keras.experimental.LinearModel`. These inputs should be first encoded with `tf.layers.CategoryEncoding` with `output_mode='count'` (and `sparse=True` if the category sizes are large) before calling into `Dense` or `LinearModel`." ] }, { "cell_type": "markdown", "metadata": { "id": "AQCJ6lM3YDq_" }, "source": [ "## Next steps\n", "\n", " - For more information on Keras preprocessing layers, go to the [Working with preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) guide.\n", " - For a more in-depth example of applying preprocessing layers to structured data, refer to the [Classify structured data using Keras preprocessing layers](../../tutorials/structured_data/preprocessing_layers.ipynb) tutorial." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "migrating_feature_columns.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }