{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DweYe9FcbMK_"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2023-11-07T23:47:12.996675Z",
"iopub.status.busy": "2023-11-07T23:47:12.996423Z",
"iopub.status.idle": "2023-11-07T23:47:13.000668Z",
"shell.execute_reply": "2023-11-07T23:47:13.000059Z"
},
"id": "AVV2e0XKbJeX"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sUtoed20cRJJ"
},
"source": [
"# 用 tf.data 加载 CSV 数据"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ap_W4aQcgNT"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C-3Xbt0FfGfs"
},
"source": [
"本教程提供了如何在 TensorFlow 中使用 CSV 数据的示例。\n",
"\n",
"其中包括两个主要部分:\n",
"\n",
"1. **Loading the data off disk**\n",
"2. **Pre-processing it into a form suitable for training.**\n",
"\n",
"本教程侧重于加载,并提供了一些关于预处理的快速示例。要了解有关预处理方面的更多信息,请查看[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 预处理层对结构化数据进行分类](../structured_data/preprocessing_layers.ipynb)教程。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fgZ9gjmPfSnK"
},
"source": [
"## 设置"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:13.004907Z",
"iopub.status.busy": "2023-11-07T23:47:13.004295Z",
"iopub.status.idle": "2023-11-07T23:47:15.613213Z",
"shell.execute_reply": "2023-11-07T23:47:15.612330Z"
},
"id": "baYFZMW_bJHh"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-07 23:47:13.690444: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2023-11-07 23:47:13.690500: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2023-11-07 23:47:13.692293: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Make numpy values easier to read.\n",
"np.set_printoptions(precision=3, suppress=True)\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ZhJYbJxHNGJ"
},
"source": [
"## 内存数据"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ny5TEgcmHjVx"
},
"source": [
"对于任何较小的 CSV 数据集,在其上训练 TensorFlow 模型的最简单方式是将其作为 Pandas Dataframe 或 NumPy 数组加载到内存中。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LgpBOuU8PGFf"
},
"source": [
"一个相对简单的示例是 [Abalone Dataset](https://archive.ics.uci.edu/ml/datasets/abalone)。\n",
"\n",
"- 数据集很小。\n",
"- 所有输入特征都是有限范围的浮点值。\n",
"\n",
"以下是将数据下载到 [Pandas `DataFrame`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) 的方式:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:15.618052Z",
"iopub.status.busy": "2023-11-07T23:47:15.617599Z",
"iopub.status.idle": "2023-11-07T23:47:15.736323Z",
"shell.execute_reply": "2023-11-07T23:47:15.735548Z"
},
"id": "IZVExo9DKoNz"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Length \n",
" Diameter \n",
" Height \n",
" Whole weight \n",
" Shucked weight \n",
" Viscera weight \n",
" Shell weight \n",
" Age \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.435 \n",
" 0.335 \n",
" 0.110 \n",
" 0.334 \n",
" 0.1355 \n",
" 0.0775 \n",
" 0.0965 \n",
" 7 \n",
" \n",
" \n",
" 1 \n",
" 0.585 \n",
" 0.450 \n",
" 0.125 \n",
" 0.874 \n",
" 0.3545 \n",
" 0.2075 \n",
" 0.2250 \n",
" 6 \n",
" \n",
" \n",
" 2 \n",
" 0.655 \n",
" 0.510 \n",
" 0.160 \n",
" 1.092 \n",
" 0.3960 \n",
" 0.2825 \n",
" 0.3700 \n",
" 14 \n",
" \n",
" \n",
" 3 \n",
" 0.545 \n",
" 0.425 \n",
" 0.125 \n",
" 0.768 \n",
" 0.2940 \n",
" 0.1495 \n",
" 0.2600 \n",
" 16 \n",
" \n",
" \n",
" 4 \n",
" 0.545 \n",
" 0.420 \n",
" 0.130 \n",
" 0.879 \n",
" 0.3740 \n",
" 0.1695 \n",
" 0.2300 \n",
" 13 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Length Diameter Height Whole weight Shucked weight Viscera weight \\\n",
"0 0.435 0.335 0.110 0.334 0.1355 0.0775 \n",
"1 0.585 0.450 0.125 0.874 0.3545 0.2075 \n",
"2 0.655 0.510 0.160 1.092 0.3960 0.2825 \n",
"3 0.545 0.425 0.125 0.768 0.2940 0.1495 \n",
"4 0.545 0.420 0.130 0.879 0.3740 0.1695 \n",
"\n",
" Shell weight Age \n",
"0 0.0965 7 \n",
"1 0.2250 6 \n",
"2 0.3700 14 \n",
"3 0.2600 16 \n",
"4 0.2300 13 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"abalone_train = pd.read_csv(\n",
" \"https://storage.googleapis.com/download.tensorflow.org/data/abalone_train.csv\",\n",
" names=[\"Length\", \"Diameter\", \"Height\", \"Whole weight\", \"Shucked weight\",\n",
" \"Viscera weight\", \"Shell weight\", \"Age\"])\n",
"\n",
"abalone_train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hP22mdyPQ1_t"
},
"source": [
"该数据集包含一组[鲍鱼](https://en.wikipedia.org/wiki/Abalone)(一种海螺)的测量值。\n",
"\n",
"\n",
"\n",
"[“鲍鱼壳”](https://www.flickr.com/photos/thenickster/16641048623/)(作者:[Nicki Dugan Pogue](https://www.flickr.com/photos/thenickster/),CC BY-SA 2.0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vlfGrk_9N-wf"
},
"source": [
"此数据集的名义任务是根据其他测量值预测年龄,因此要把特征和标签分开以进行训练:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:15.740480Z",
"iopub.status.busy": "2023-11-07T23:47:15.740131Z",
"iopub.status.idle": "2023-11-07T23:47:15.745382Z",
"shell.execute_reply": "2023-11-07T23:47:15.744661Z"
},
"id": "udOnDJOxNi7p"
},
"outputs": [],
"source": [
"abalone_features = abalone_train.copy()\n",
"abalone_labels = abalone_features.pop('Age')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "seK9n71-UBfT"
},
"source": [
"对于此数据集,将以相同的方式处理所有特征。将这些特征打包成单个 NumPy 数组:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:15.749264Z",
"iopub.status.busy": "2023-11-07T23:47:15.748982Z",
"iopub.status.idle": "2023-11-07T23:47:15.754757Z",
"shell.execute_reply": "2023-11-07T23:47:15.754067Z"
},
"id": "Dp3N5McbUMwb"
},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.435, 0.335, 0.11 , ..., 0.136, 0.077, 0.097],\n",
" [0.585, 0.45 , 0.125, ..., 0.354, 0.207, 0.225],\n",
" [0.655, 0.51 , 0.16 , ..., 0.396, 0.282, 0.37 ],\n",
" ...,\n",
" [0.53 , 0.42 , 0.13 , ..., 0.374, 0.167, 0.249],\n",
" [0.395, 0.315, 0.105, ..., 0.118, 0.091, 0.119],\n",
" [0.45 , 0.355, 0.12 , ..., 0.115, 0.067, 0.16 ]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"abalone_features = np.array(abalone_features)\n",
"abalone_features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1C1yFOxLOdxh"
},
"source": [
"接下来,制作一个回归模型来预测年龄。由于只有一个输入张量,这里使用 `tf.keras.Sequential` 模型就足够了。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:15.758924Z",
"iopub.status.busy": "2023-11-07T23:47:15.758208Z",
"iopub.status.idle": "2023-11-07T23:47:18.086950Z",
"shell.execute_reply": "2023-11-07T23:47:18.085876Z"
},
"id": "d8zzNrZqOmfB"
},
"outputs": [],
"source": [
"abalone_model = tf.keras.Sequential([\n",
" layers.Dense(64),\n",
" layers.Dense(1)\n",
"])\n",
"\n",
"abalone_model.compile(loss = tf.keras.losses.MeanSquaredError(),\n",
" optimizer = tf.keras.optimizers.Adam())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j6IWeP78O2wE"
},
"source": [
"要训练该模型,请将特征和标签传递给 `Model.fit`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:18.091625Z",
"iopub.status.busy": "2023-11-07T23:47:18.090868Z",
"iopub.status.idle": "2023-11-07T23:47:22.027212Z",
"shell.execute_reply": "2023-11-07T23:47:22.026485Z"
},
"id": "uZdpCD92SN3Z"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1699400839.261974 571141 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 2:21 - loss: 101.9437"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 22/104 [=====>........................] - ETA: 0s - loss: 95.2627 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 44/104 [===========>..................] - ETA: 0s - loss: 89.1451"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 65/104 [=================>............] - ETA: 0s - loss: 80.6143"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 86/104 [=======================>......] - ETA: 0s - loss: 71.3353"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 2s 2ms/step - loss: 64.6157\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 29.5109"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 22/104 [=====>........................] - ETA: 0s - loss: 18.4533"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 42/104 [===========>..................] - ETA: 0s - loss: 15.2012"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 63/104 [=================>............] - ETA: 0s - loss: 13.3386"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 85/104 [=======================>......] - ETA: 0s - loss: 12.3196"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 11.5237\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 10.9910"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 24/104 [=====>........................] - ETA: 0s - loss: 8.3957 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 47/104 [============>.................] - ETA: 0s - loss: 8.0881"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 70/104 [===================>..........] - ETA: 0s - loss: 8.3030"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 92/104 [=========================>....] - ETA: 0s - loss: 8.1226"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 8.1954\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 6.1413"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 24/104 [=====>........................] - ETA: 0s - loss: 8.2923"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 47/104 [============>.................] - ETA: 0s - loss: 8.5084"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 69/104 [==================>...........] - ETA: 0s - loss: 8.1870"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 92/104 [=========================>....] - ETA: 0s - loss: 7.6837"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 7.7570\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 10.4424"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 8.0771 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 45/104 [===========>..................] - ETA: 0s - loss: 7.8526"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 68/104 [==================>...........] - ETA: 0s - loss: 7.6984"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 90/104 [========================>.....] - ETA: 0s - loss: 7.5845"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 7.3885\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 7.1712"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 7.1029"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 45/104 [===========>..................] - ETA: 0s - loss: 6.5226"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 66/104 [==================>...........] - ETA: 0s - loss: 6.7688"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 88/104 [========================>.....] - ETA: 0s - loss: 7.0163"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 7.0926\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.6098"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 6.0588"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 46/104 [============>.................] - ETA: 0s - loss: 6.4743"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 68/104 [==================>...........] - ETA: 0s - loss: 6.5212"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 91/104 [=========================>....] - ETA: 0s - loss: 6.8465"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 6.8470\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.5962"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 7.2752"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 46/104 [============>.................] - ETA: 0s - loss: 6.8060"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 68/104 [==================>...........] - ETA: 0s - loss: 6.8497"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 90/104 [========================>.....] - ETA: 0s - loss: 6.7589"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 6.7037\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 4.3912"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 24/104 [=====>........................] - ETA: 0s - loss: 6.6435"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 47/104 [============>.................] - ETA: 0s - loss: 6.6816"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 70/104 [===================>..........] - ETA: 0s - loss: 6.3140"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 92/104 [=========================>....] - ETA: 0s - loss: 6.5410"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 6.5731\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.8961"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 22/104 [=====>........................] - ETA: 0s - loss: 6.1178"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 44/104 [===========>..................] - ETA: 0s - loss: 6.2430"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 66/104 [==================>...........] - ETA: 0s - loss: 6.1084"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 88/104 [========================>.....] - ETA: 0s - loss: 6.3872"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 6.4726\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"abalone_model.fit(abalone_features, abalone_labels, epochs=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GapLOj1OOTQH"
},
"source": [
"您刚刚看到了使用 CSV 数据训练模型的最基本方式。接下来,您将学习如何应用预处理来归一化数值列。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B87Rd1SOUv02"
},
"source": [
"## 基本预处理"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yCrB2Jd-U0Vt"
},
"source": [
"对模型的输入进行归一化是一种很好的做法。Keras 预处理层提供了一种便捷方式来将此归一化构建到您的模型。\n",
"\n",
"`tf.keras.layers.Normalization` 层会预先计算每列的均值和方差,并使用这些内容对数据进行归一化。\n",
"\n",
"首先,创建层:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:22.031010Z",
"iopub.status.busy": "2023-11-07T23:47:22.030725Z",
"iopub.status.idle": "2023-11-07T23:47:22.036180Z",
"shell.execute_reply": "2023-11-07T23:47:22.035416Z"
},
"id": "H2WQpDU5VRk7"
},
"outputs": [],
"source": [
"normalize = layers.Normalization()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hGgEZE-7Vpt6"
},
"source": [
"然后,使用 `Normalization.adapt()` 方法使归一化层适应您的数据。\n",
"\n",
"注:仅将您的训练数据用于 `PreprocessingLayer.adapt` 方法。不要使用您的验证数据或测试数据。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:22.039905Z",
"iopub.status.busy": "2023-11-07T23:47:22.039207Z",
"iopub.status.idle": "2023-11-07T23:47:22.421960Z",
"shell.execute_reply": "2023-11-07T23:47:22.421141Z"
},
"id": "2WgOPIiOVpLg"
},
"outputs": [],
"source": [
"normalize.adapt(abalone_features)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rE6vh0byV7cE"
},
"source": [
"然后,将归一化层用于您的模型:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:22.426294Z",
"iopub.status.busy": "2023-11-07T23:47:22.426017Z",
"iopub.status.idle": "2023-11-07T23:47:26.119049Z",
"shell.execute_reply": "2023-11-07T23:47:26.118251Z"
},
"id": "quPcZ9dTWA9A"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 1:52 - loss: 122.1690"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 22/104 [=====>........................] - ETA: 0s - loss: 104.8419 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 43/104 [===========>..................] - ETA: 0s - loss: 102.0931"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 64/104 [=================>............] - ETA: 0s - loss: 98.3096 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 86/104 [=======================>......] - ETA: 0s - loss: 94.3227"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 1s 2ms/step - loss: 92.2746\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 69.8699"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 70.6138"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 44/104 [===========>..................] - ETA: 0s - loss: 67.1762"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 64/104 [=================>............] - ETA: 0s - loss: 61.9383"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 83/104 [======================>.......] - ETA: 0s - loss: 57.5425"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - ETA: 0s - loss: 53.1503"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 3ms/step - loss: 53.1503\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 27.6175"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 22/104 [=====>........................] - ETA: 0s - loss: 27.3872"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 43/104 [===========>..................] - ETA: 0s - loss: 24.2987"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 65/104 [=================>............] - ETA: 0s - loss: 20.5882"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 87/104 [========================>.....] - ETA: 0s - loss: 18.1379"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 16.5479\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 7.6246"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 8.1219"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 45/104 [===========>..................] - ETA: 0s - loss: 7.1057"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 67/104 [==================>...........] - ETA: 0s - loss: 6.4601"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 89/104 [========================>.....] - ETA: 0s - loss: 6.0833"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 5.9912\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 1.5492"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 4.5862"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 45/104 [===========>..................] - ETA: 0s - loss: 4.5351"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 67/104 [==================>...........] - ETA: 0s - loss: 4.9657"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 89/104 [========================>.....] - ETA: 0s - loss: 5.1963"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 5.1468\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.9418"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 4.8608"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 46/104 [============>.................] - ETA: 0s - loss: 4.9782"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 68/104 [==================>...........] - ETA: 0s - loss: 4.8786"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 90/104 [========================>.....] - ETA: 0s - loss: 4.9959"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 5.0204\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.5926"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 5.3847"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 45/104 [===========>..................] - ETA: 0s - loss: 5.2681"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 66/104 [==================>...........] - ETA: 0s - loss: 5.1667"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 88/104 [========================>.....] - ETA: 0s - loss: 5.0252"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 4.9907\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 3.9144"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 24/104 [=====>........................] - ETA: 0s - loss: 5.1471"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 47/104 [============>.................] - ETA: 0s - loss: 5.2076"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 70/104 [===================>..........] - ETA: 0s - loss: 5.1551"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 93/104 [=========================>....] - ETA: 0s - loss: 5.1030"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 4.9757\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 3.4640"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 24/104 [=====>........................] - ETA: 0s - loss: 4.5150"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 47/104 [============>.................] - ETA: 0s - loss: 4.8335"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 69/104 [==================>...........] - ETA: 0s - loss: 4.9005"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 92/104 [=========================>....] - ETA: 0s - loss: 4.9880"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 4.9543\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/104 [..............................] - ETA: 0s - loss: 5.8363"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 23/104 [=====>........................] - ETA: 0s - loss: 5.1082"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 44/104 [===========>..................] - ETA: 0s - loss: 4.9832"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 66/104 [==================>...........] - ETA: 0s - loss: 4.7715"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
" 88/104 [========================>.....] - ETA: 0s - loss: 4.7601"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"104/104 [==============================] - 0s 2ms/step - loss: 4.9403\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"norm_abalone_model = tf.keras.Sequential([\n",
" normalize,\n",
" layers.Dense(64),\n",
" layers.Dense(1)\n",
"])\n",
"\n",
"norm_abalone_model.compile(loss = tf.keras.losses.MeanSquaredError(),\n",
" optimizer = tf.keras.optimizers.Adam())\n",
"\n",
"norm_abalone_model.fit(abalone_features, abalone_labels, epochs=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wuqj601Qw0Ml"
},
"source": [
"## 混合数据类型\n",
"\n",
"The \"Titanic\" dataset contains information about the passengers on the Titanic. The nominal task on this dataset is to predict who survived.\n",
"\n",
"\n",
"\n",
"Image [from Wikimedia](https://commons.wikimedia.org/wiki/File:RMS_Titanic_3.jpg)\n",
"\n",
"The raw data can easily be loaded as a Pandas `DataFrame`, but is not immediately usable as input to a TensorFlow model.\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.122776Z",
"iopub.status.busy": "2023-11-07T23:47:26.122502Z",
"iopub.status.idle": "2023-11-07T23:47:26.194388Z",
"shell.execute_reply": "2023-11-07T23:47:26.193676Z"
},
"id": "GS-dBMpuYMnz"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" survived \n",
" sex \n",
" age \n",
" n_siblings_spouses \n",
" parch \n",
" fare \n",
" class \n",
" deck \n",
" embark_town \n",
" alone \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0 \n",
" male \n",
" 22.0 \n",
" 1 \n",
" 0 \n",
" 7.2500 \n",
" Third \n",
" unknown \n",
" Southampton \n",
" n \n",
" \n",
" \n",
" 1 \n",
" 1 \n",
" female \n",
" 38.0 \n",
" 1 \n",
" 0 \n",
" 71.2833 \n",
" First \n",
" C \n",
" Cherbourg \n",
" n \n",
" \n",
" \n",
" 2 \n",
" 1 \n",
" female \n",
" 26.0 \n",
" 0 \n",
" 0 \n",
" 7.9250 \n",
" Third \n",
" unknown \n",
" Southampton \n",
" y \n",
" \n",
" \n",
" 3 \n",
" 1 \n",
" female \n",
" 35.0 \n",
" 1 \n",
" 0 \n",
" 53.1000 \n",
" First \n",
" C \n",
" Southampton \n",
" n \n",
" \n",
" \n",
" 4 \n",
" 0 \n",
" male \n",
" 28.0 \n",
" 0 \n",
" 0 \n",
" 8.4583 \n",
" Third \n",
" unknown \n",
" Queenstown \n",
" y \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" survived sex age n_siblings_spouses parch fare class deck \\\n",
"0 0 male 22.0 1 0 7.2500 Third unknown \n",
"1 1 female 38.0 1 0 71.2833 First C \n",
"2 1 female 26.0 0 0 7.9250 Third unknown \n",
"3 1 female 35.0 1 0 53.1000 First C \n",
"4 0 male 28.0 0 0 8.4583 Third unknown \n",
"\n",
" embark_town alone \n",
"0 Southampton n \n",
"1 Cherbourg n \n",
"2 Southampton y \n",
"3 Southampton n \n",
"4 Queenstown y "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titanic = pd.read_csv(\"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")\n",
"titanic.head()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.197707Z",
"iopub.status.busy": "2023-11-07T23:47:26.197439Z",
"iopub.status.idle": "2023-11-07T23:47:26.201631Z",
"shell.execute_reply": "2023-11-07T23:47:26.200958Z"
},
"id": "D8rCGIK1ZzKx"
},
"outputs": [],
"source": [
"titanic_features = titanic.copy()\n",
"titanic_labels = titanic_features.pop('survived')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "urHOwpCDYtcI"
},
"source": [
"由于数据类型和范围不同,您不能简单地将特征堆叠到 NumPy 数组中并将其传递给 `tf.keras.Sequential` 模型。每列都需要单独处理。\n",
"\n",
"作为一种选择,您可以(使用您喜欢的任何工具)离线预处理数据,将分类列转换为数值列,然后将处理后的输出传递给 TensorFlow 模型。这种方式的缺点是,如果保存并导出模型,预处理不会随之保存。Keras 预处理层能够避免这个问题,因为它们是模型的一部分。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bta4Sx0Zau5v"
},
"source": [
"在此示例中,您将构建一个使用 [Keras 函数式 API](https://tensorflow.google.cn/guide/keras/functional) 实现预处理逻辑的模型。您也可以通过[子类化](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)来实现。\n",
"\n",
"函数式 API 会对“符号”张量进行运算。正常的 \"eager\" 张量有一个值。相比之下,这些“符号”张量则没有值。相反,它们会跟踪在它们上面运行的运算,并构建可以稍后运行的计算的表示。以下是一个简单示例:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.204829Z",
"iopub.status.busy": "2023-11-07T23:47:26.204579Z",
"iopub.status.idle": "2023-11-07T23:47:26.220657Z",
"shell.execute_reply": "2023-11-07T23:47:26.219964Z"
},
"id": "730F16_97D-3"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a symbolic input\n",
"input = tf.keras.Input(shape=(), dtype=tf.float32)\n",
"\n",
"# Perform a calculation using the input\n",
"result = 2*input + 1\n",
"\n",
"# the result doesn't have a value\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.224282Z",
"iopub.status.busy": "2023-11-07T23:47:26.223613Z",
"iopub.status.idle": "2023-11-07T23:47:26.230109Z",
"shell.execute_reply": "2023-11-07T23:47:26.229451Z"
},
"id": "RtcNXWB18kMJ"
},
"outputs": [],
"source": [
"calc = tf.keras.Model(inputs=input, outputs=result)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.233104Z",
"iopub.status.busy": "2023-11-07T23:47:26.232855Z",
"iopub.status.idle": "2023-11-07T23:47:26.241318Z",
"shell.execute_reply": "2023-11-07T23:47:26.240650Z"
},
"id": "fUGQOUqZ8sa-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.0\n",
"5.0\n"
]
}
],
"source": [
"print(calc(1).numpy())\n",
"print(calc(2).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rNS9lT7f6_U2"
},
"source": [
"要构建预处理模型,首先要构建一组符号 `tf.keras.Input` 对象,匹配 CSV 列的名称和数据类型。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.244510Z",
"iopub.status.busy": "2023-11-07T23:47:26.244251Z",
"iopub.status.idle": "2023-11-07T23:47:26.260126Z",
"shell.execute_reply": "2023-11-07T23:47:26.259433Z"
},
"id": "5WODe_1da3yw"
},
"outputs": [
{
"data": {
"text/plain": [
"{'sex': ,\n",
" 'age': ,\n",
" 'n_siblings_spouses': ,\n",
" 'parch': ,\n",
" 'fare': ,\n",
" 'class': ,\n",
" 'deck': ,\n",
" 'embark_town': ,\n",
" 'alone': }"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = {}\n",
"\n",
"for name, column in titanic_features.items():\n",
" dtype = column.dtype\n",
" if dtype == object:\n",
" dtype = tf.string\n",
" else:\n",
" dtype = tf.float32\n",
"\n",
" inputs[name] = tf.keras.Input(shape=(1,), name=name, dtype=dtype)\n",
"\n",
"inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aaheJFmymq8l"
},
"source": [
"预处理逻辑的第一步是将数值输入串联在一起,并通过归一化层运行它们:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.263245Z",
"iopub.status.busy": "2023-11-07T23:47:26.262988Z",
"iopub.status.idle": "2023-11-07T23:47:26.542419Z",
"shell.execute_reply": "2023-11-07T23:47:26.541668Z"
},
"id": "wPRC_E6rkp8D"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numeric_inputs = {name:input for name,input in inputs.items()\n",
" if input.dtype==tf.float32}\n",
"\n",
"x = layers.Concatenate()(list(numeric_inputs.values()))\n",
"norm = layers.Normalization()\n",
"norm.adapt(np.array(titanic[numeric_inputs.keys()]))\n",
"all_numeric_inputs = norm(x)\n",
"\n",
"all_numeric_inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-JoR45Uj712l"
},
"source": [
"收集所有符号预处理结果,稍后将它们串联起来:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.546479Z",
"iopub.status.busy": "2023-11-07T23:47:26.545766Z",
"iopub.status.idle": "2023-11-07T23:47:26.549480Z",
"shell.execute_reply": "2023-11-07T23:47:26.548793Z"
},
"id": "M7jIJw5XntdN"
},
"outputs": [],
"source": [
"preprocessed_inputs = [all_numeric_inputs]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r0Hryylyosfm"
},
"source": [
"对于字符串输入,请使用 `tf.keras.layers.StringLookup` 函数将字符串映射到词汇表中的整数索引。接下来,使用 `tf.keras.layers.CategoryEncoding` 将索引转换为适合模型的 `float32` 数据。\n",
"\n",
"`tf.keras.layers.CategoryEncoding` 层的默认设置会为每个输入创建一个独热向量。也可以使用 `tf.keras.layers.Embedding`。请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 预处理层对结构化数据进行分类](../structured_data/preprocessing_layers.ipynb)教程,了解有关此主题的更多信息。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.552979Z",
"iopub.status.busy": "2023-11-07T23:47:26.552470Z",
"iopub.status.idle": "2023-11-07T23:47:26.686094Z",
"shell.execute_reply": "2023-11-07T23:47:26.685245Z"
},
"id": "79fi1Cgan2YV"
},
"outputs": [],
"source": [
"for name, input in inputs.items():\n",
" if input.dtype == tf.float32:\n",
" continue\n",
" \n",
" lookup = layers.StringLookup(vocabulary=np.unique(titanic_features[name]))\n",
" one_hot = layers.CategoryEncoding(num_tokens=lookup.vocabulary_size())\n",
"\n",
" x = lookup(input)\n",
" x = one_hot(x)\n",
" preprocessed_inputs.append(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wnhv0T7itnc7"
},
"source": [
"您可以使用 `inputs` 和 `processed_inputs` 的集合将所有预处理的输入串联在一起,并构建处理预处理的模型:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.690729Z",
"iopub.status.busy": "2023-11-07T23:47:26.690042Z",
"iopub.status.idle": "2023-11-07T23:47:26.913294Z",
"shell.execute_reply": "2023-11-07T23:47:26.912362Z"
},
"id": "XJRzUTe8ukXc"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessed_inputs_cat = layers.Concatenate()(preprocessed_inputs)\n",
"\n",
"titanic_preprocessing = tf.keras.Model(inputs, preprocessed_inputs_cat)\n",
"\n",
"tf.keras.utils.plot_model(model = titanic_preprocessing , rankdir=\"LR\", dpi=72, show_shapes=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PNHxrNW8vdda"
},
"source": [
"此 `model` 仅包含输入预处理。您可以运行它以查看其对您的数据进行了哪些操作。Keras 模型不会自动转换 Pandas DataFrames
,因为不清楚是应该将其转换为一个张量还是张量字典。因此,将其转换为张量字典:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.917898Z",
"iopub.status.busy": "2023-11-07T23:47:26.917590Z",
"iopub.status.idle": "2023-11-07T23:47:26.922779Z",
"shell.execute_reply": "2023-11-07T23:47:26.921855Z"
},
"id": "5YjdYyMEacwQ"
},
"outputs": [],
"source": [
"titanic_features_dict = {name: np.array(value) \n",
" for name, value in titanic_features.items()}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nKJYoPByada"
},
"source": [
"切出第一个训练样本并将其传递给此预处理模型,您会看到数字特征和字符串独热全部串联在一起:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:26.926556Z",
"iopub.status.busy": "2023-11-07T23:47:26.925996Z",
"iopub.status.idle": "2023-11-07T23:47:28.035502Z",
"shell.execute_reply": "2023-11-07T23:47:28.034729Z"
},
"id": "SjnmU8PSv8T3"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n",
"titanic_preprocessing(features_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkBf4LvmzMDp"
},
"source": [
"接下来,在此基础上构建模型:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:28.039012Z",
"iopub.status.busy": "2023-11-07T23:47:28.038735Z",
"iopub.status.idle": "2023-11-07T23:47:28.188278Z",
"shell.execute_reply": "2023-11-07T23:47:28.187418Z"
},
"id": "coIPtGaCzUV7"
},
"outputs": [],
"source": [
"def titanic_model(preprocessing_head, inputs):\n",
" body = tf.keras.Sequential([\n",
" layers.Dense(64),\n",
" layers.Dense(1)\n",
" ])\n",
"\n",
" preprocessed_inputs = preprocessing_head(inputs)\n",
" result = body(preprocessed_inputs)\n",
" model = tf.keras.Model(inputs, result)\n",
"\n",
" model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam())\n",
" return model\n",
"\n",
"titanic_model = titanic_model(titanic_preprocessing, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LK5uBQQF2KbZ"
},
"source": [
"训练模型时,将特征字典作为 `x` 传递,将标签作为 `y` 传递。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:28.192436Z",
"iopub.status.busy": "2023-11-07T23:47:28.192152Z",
"iopub.status.idle": "2023-11-07T23:47:30.741126Z",
"shell.execute_reply": "2023-11-07T23:47:30.740359Z"
},
"id": "D1gVfwJ61ejz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 30s - loss: 0.6159"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"13/20 [==================>...........] - ETA: 0s - loss: 0.5816 "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 2s 4ms/step - loss: 0.5669\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.4976"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.5234"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.5039\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.5891"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"15/20 [=====================>........] - ETA: 0s - loss: 0.4731"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4748\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.4022"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4678"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4562\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.3903"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4497"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4436\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.3655"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4231"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4357\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.3896"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4154"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4293\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.4782"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4362"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4253\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.4934"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4070"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4244\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.3007"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4141"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4222\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titanic_model.fit(x=titanic_features_dict, y=titanic_labels, epochs=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LxgJarZk3bfH"
},
"source": [
"由于预处理是模型的一部分,您可以保存模型并将其重新加载到其他地方并获得相同的结果:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:30.744806Z",
"iopub.status.busy": "2023-11-07T23:47:30.744520Z",
"iopub.status.idle": "2023-11-07T23:47:33.378169Z",
"shell.execute_reply": "2023-11-07T23:47:33.377099Z"
},
"id": "Ay-8ymNA2ZCh"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: test/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: test/assets\n"
]
}
],
"source": [
"titanic_model.save('test')\n",
"reloaded = tf.keras.models.load_model('test')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.383554Z",
"iopub.status.busy": "2023-11-07T23:47:33.382910Z",
"iopub.status.idle": "2023-11-07T23:47:33.446479Z",
"shell.execute_reply": "2023-11-07T23:47:33.445704Z"
},
"id": "Qm6jMTpD20lK"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([[-1.899]], shape=(1, 1), dtype=float32)\n",
"tf.Tensor([[-1.899]], shape=(1, 1), dtype=float32)\n"
]
}
],
"source": [
"features_dict = {name:values[:1] for name, values in titanic_features_dict.items()}\n",
"\n",
"before = titanic_model(features_dict)\n",
"after = reloaded(features_dict)\n",
"assert (before-after)<1e-3\n",
"print(before)\n",
"print(after)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7VsPlxIRZpXf"
},
"source": [
"## 使用 tf.data\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NyVDCwGzR5HW"
},
"source": [
"在前一部分中,您在训练模型时依赖了模型的内置数据乱序和批处理。\n",
"\n",
"如果您需要对输入数据流水线进行更多控制或需要使用不易放入内存的数据:请使用 `tf.data`。\n",
"\n",
"有关更多示例,请参阅 [`tf.data`:构建 TensorFlow 输入流水线](../../guide/data.ipynb)指南。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gP5Y1jM2Sor0"
},
"source": [
"### 有关内存数据\n",
"\n",
"作为将 `tf.data` 应用于 CSV 数据的第一个样本,请考虑使用以下代码手动切分上一个部分中的特征字典。对于每个索引,它会为每个特征获取该索引:\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.451057Z",
"iopub.status.busy": "2023-11-07T23:47:33.450223Z",
"iopub.status.idle": "2023-11-07T23:47:33.455011Z",
"shell.execute_reply": "2023-11-07T23:47:33.454309Z"
},
"id": "i8wE-MVuVu7_"
},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"def slices(features):\n",
" for i in itertools.count():\n",
" # For each feature take index `i`\n",
" example = {name:values[i] for name, values in features.items()}\n",
" yield example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cQ3RTbS9YEal"
},
"source": [
"运行此代码并打印第一个样本:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.458856Z",
"iopub.status.busy": "2023-11-07T23:47:33.458239Z",
"iopub.status.idle": "2023-11-07T23:47:33.462671Z",
"shell.execute_reply": "2023-11-07T23:47:33.461997Z"
},
"id": "Wwq8XK88WwFk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sex : male\n",
"age : 22.0\n",
"n_siblings_spouses : 1\n",
"parch : 0\n",
"fare : 7.25\n",
"class : Third\n",
"deck : unknown\n",
"embark_town : Southampton\n",
"alone : n\n"
]
}
],
"source": [
"for example in slices(titanic_features_dict):\n",
" for name, value in example.items():\n",
" print(f\"{name:19s}: {value}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vvp8Dct6YOIE"
},
"source": [
"内存数据加载程序中最基本的 `tf.data.Dataset` 是 `Dataset.from_tensor_slices` 构造函数。这会返回一个 `tf.data.Dataset`,它将在 TensorFlow 中实现上述 `slices` 函数的泛化版本。"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.466443Z",
"iopub.status.busy": "2023-11-07T23:47:33.465787Z",
"iopub.status.idle": "2023-11-07T23:47:33.474260Z",
"shell.execute_reply": "2023-11-07T23:47:33.473556Z"
},
"id": "2gEJthslYxeV"
},
"outputs": [],
"source": [
"features_ds = tf.data.Dataset.from_tensor_slices(titanic_features_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ZC0rTpMZMZK"
},
"source": [
"您可以像任何其他 Python 可迭代对象一样迭代 `tf.data.Dataset`:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.478347Z",
"iopub.status.busy": "2023-11-07T23:47:33.477679Z",
"iopub.status.idle": "2023-11-07T23:47:33.492288Z",
"shell.execute_reply": "2023-11-07T23:47:33.491504Z"
},
"id": "gOHbiefaY4ag"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sex : b'male'\n",
"age : 22.0\n",
"n_siblings_spouses : 1\n",
"parch : 0\n",
"fare : 7.25\n",
"class : b'Third'\n",
"deck : b'unknown'\n",
"embark_town : b'Southampton'\n",
"alone : b'n'\n"
]
}
],
"source": [
"for example in features_ds:\n",
" for name, value in example.items():\n",
" print(f\"{name:19s}: {value}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uwcFoVJWZY5F"
},
"source": [
"`from_tensor_slices` 函数可以处理嵌套字典或元组的任何结构。以下代码创建了一个 `(features_dict, labels)` 对的数据集:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.496289Z",
"iopub.status.busy": "2023-11-07T23:47:33.495524Z",
"iopub.status.idle": "2023-11-07T23:47:33.506626Z",
"shell.execute_reply": "2023-11-07T23:47:33.505927Z"
},
"id": "xIHGBy76Zcrx"
},
"outputs": [],
"source": [
"titanic_ds = tf.data.Dataset.from_tensor_slices((titanic_features_dict, titanic_labels))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gQwxitt8c2GK"
},
"source": [
"要使用此 `Dataset` 训练模型,您至少需要对数据进行 `shuffle` 和 `batch`。"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.510919Z",
"iopub.status.busy": "2023-11-07T23:47:33.510235Z",
"iopub.status.idle": "2023-11-07T23:47:33.521210Z",
"shell.execute_reply": "2023-11-07T23:47:33.520540Z"
},
"id": "SbJcbldhddeC"
},
"outputs": [],
"source": [
"titanic_batches = titanic_ds.shuffle(len(titanic_labels)).batch(32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4FRqhRFuoJx"
},
"source": [
"不是将 `features` 和 `labels` 传递给 `Model.fit`,而是传递数据集:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:33.524736Z",
"iopub.status.busy": "2023-11-07T23:47:33.524451Z",
"iopub.status.idle": "2023-11-07T23:47:34.393746Z",
"shell.execute_reply": "2023-11-07T23:47:34.393007Z"
},
"id": "8yXkNPumdBtB"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 7s - loss: 0.2701"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4223"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4209\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.2889"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"15/20 [=====================>........] - ETA: 0s - loss: 0.4129"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4204\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.4769"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"15/20 [=====================>........] - ETA: 0s - loss: 0.4060"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4205\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.5053"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4209"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4205\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1/20 [>.............................] - ETA: 0s - loss: 0.5134"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"14/20 [====================>.........] - ETA: 0s - loss: 0.4280"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"20/20 [==============================] - 0s 4ms/step - loss: 0.4193\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titanic_model.fit(titanic_batches, epochs=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qXuibiv9exT7"
},
"source": [
"### 从单个文件\n",
"\n",
"到目前为止,本教程已经使用了内存数据。`tf.data` 是用于构建数据流水线的高度可扩展的工具包,并提供了一些用于处理加载 CSV 文件的函数。 "
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:34.397649Z",
"iopub.status.busy": "2023-11-07T23:47:34.396990Z",
"iopub.status.idle": "2023-11-07T23:47:34.445928Z",
"shell.execute_reply": "2023-11-07T23:47:34.445172Z"
},
"id": "Ncf5t6tgL5ZI"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 8192/30874 [======>.......................] - ETA: 0s"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"30874/30874 [==============================] - 0s 0us/step\n"
]
}
],
"source": [
"titanic_file_path = tf.keras.utils.get_file(\"train.csv\", \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t4N-plO4tDXd"
},
"source": [
"现在,从文件中读取 CSV 数据并创建一个 `tf.data.Dataset`。\n",
"\n",
"(有关完整文档,请参阅 `tf.data.experimental.make_csv_dataset`)\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:34.449555Z",
"iopub.status.busy": "2023-11-07T23:47:34.449286Z",
"iopub.status.idle": "2023-11-07T23:47:34.514055Z",
"shell.execute_reply": "2023-11-07T23:47:34.513306Z"
},
"id": "yIbUscB9sqha"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/readers.py:573: ignore_errors (from tensorflow.python.data.experimental.ops.error_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.ignore_errors` instead.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/experimental/ops/readers.py:573: ignore_errors (from tensorflow.python.data.experimental.ops.error_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.ignore_errors` instead.\n"
]
}
],
"source": [
"titanic_csv_ds = tf.data.experimental.make_csv_dataset(\n",
" titanic_file_path,\n",
" batch_size=5, # Artificially small to make examples easier to show.\n",
" label_name='survived',\n",
" num_epochs=1,\n",
" ignore_errors=True,)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sf3v3BKgy4AG"
},
"source": [
"此函数包括许多方便的功能,因此很容易处理数据。这包括:\n",
"\n",
"- 使用列标题作为字典键。\n",
"- 自动确定每列的类型。\n",
"\n",
"小心:请确保在 `tf.data.experimental.make_csv_dataset` 中设置 `num_epochs` 参数,否则 `tf.data.Dataset` 的默认行为是无限循环。"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:34.517891Z",
"iopub.status.busy": "2023-11-07T23:47:34.517604Z",
"iopub.status.idle": "2023-11-07T23:47:34.584520Z",
"shell.execute_reply": "2023-11-07T23:47:34.583676Z"
},
"id": "v4oMO9MIxgTG"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sex : [b'female' b'female' b'male' b'male' b'male']\n",
"age : [50. 31. 28. 28. 28.]\n",
"n_siblings_spouses : [0 0 1 0 0]\n",
"parch : [0 0 2 0 0]\n",
"fare : [28.712 7.854 23.45 56.496 7.896]\n",
"class : [b'First' b'Third' b'Third' b'Third' b'Third']\n",
"deck : [b'C' b'unknown' b'unknown' b'unknown' b'unknown']\n",
"embark_town : [b'Cherbourg' b'Southampton' b'Southampton' b'Southampton' b'Southampton']\n",
"alone : [b'y' b'y' b'n' b'y' b'y']\n",
"\n",
"label : [0 0 0 0 0]\n"
]
}
],
"source": [
"for batch, label in titanic_csv_ds.take(1):\n",
" for key, value in batch.items():\n",
" print(f\"{key:20s}: {value}\")\n",
" print()\n",
" print(f\"{'label':20s}: {label}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k-TgA6o2Ja6U"
},
"source": [
"注:如果您运行两次上述代码单元,它将产生不同的结果。`tf.data.experimental.make_csv_dataset` 的默认设置包括 `shuffle_buffer_size=1000`,这对于这个小型数据集来说已经绰绰有余,但可能不适用于实际的数据集。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6uviU_KCCWD"
},
"source": [
"它还可以对数据进行即时解压。下面是一个用 gzip 压缩的 CSV 文件,其中包含 [Metro Interstate Traffic Dataset](https://archive.ics.uci.edu/ml/datasets/Metro+Interstate+Traffic+Volume)。\n",
"\n",
"\n",
"\n",
"图片[来自 Wikimedia](https://commons.wikimedia.org/wiki/File:Trafficjam.jpg)\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:34.588534Z",
"iopub.status.busy": "2023-11-07T23:47:34.588236Z",
"iopub.status.idle": "2023-11-07T23:47:35.121197Z",
"shell.execute_reply": "2023-11-07T23:47:35.120348Z"
},
"id": "kT7oZI2E46Q8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 8192/Unknown - 0s 0us/step"
]
}
],
"source": [
"traffic_volume_csv_gz = tf.keras.utils.get_file(\n",
" 'Metro_Interstate_Traffic_Volume.csv.gz', \n",
" \"https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz\",\n",
" cache_dir='.', cache_subdir='traffic')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F-IOsFHbCw0i"
},
"source": [
"将 `compression_type` 参数设置为直接从压缩文件中读取:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:35.125225Z",
"iopub.status.busy": "2023-11-07T23:47:35.124948Z",
"iopub.status.idle": "2023-11-07T23:47:35.389707Z",
"shell.execute_reply": "2023-11-07T23:47:35.388889Z"
},
"id": "ar0MPEVJ5NeA"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"holiday : [b'None' b'None' b'None' b'None' b'None']\n",
"temp : [256.34 296.9 268.41 289.85 266.65]\n",
"rain_1h : [0. 0. 0. 0. 0.]\n",
"snow_1h : [0. 0. 0. 0. 0.]\n",
"clouds_all : [90 80 1 0 90]\n",
"weather_main : [b'Clouds' b'Mist' b'Clear' b'Clear' b'Clouds']\n",
"weather_description : [b'overcast clouds' b'mist' b'sky is clear' b'Sky is Clear'\n",
" b'overcast clouds']\n",
"date_time : [b'2013-01-14 07:00:00' b'2013-08-30 08:00:00' b'2013-02-09 17:00:00'\n",
" b'2013-09-12 09:00:00' b'2012-12-17 19:00:00']\n",
"\n",
"label : [6579 6042 4374 5687 2953]\n"
]
}
],
"source": [
"traffic_volume_csv_gz_ds = tf.data.experimental.make_csv_dataset(\n",
" traffic_volume_csv_gz,\n",
" batch_size=256,\n",
" label_name='traffic_volume',\n",
" num_epochs=1,\n",
" compression_type=\"GZIP\")\n",
"\n",
"for batch, label in traffic_volume_csv_gz_ds.take(1):\n",
" for key, value in batch.items():\n",
" print(f\"{key:20s}: {value[:5]}\")\n",
" print()\n",
" print(f\"{'label':20s}: {label[:5]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p12Y6tGq8D6M"
},
"source": [
"注:如果需要在 `tf.data` 流水线中解析这些日期时间字符串,您可以使用 `tfa.text.parse_time`。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EtrAXzYGP3l0"
},
"source": [
"### 缓存"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fN2dL_LRP83r"
},
"source": [
"解析 CSV 数据有一些开销。对于小型模型,这可能是训练的瓶颈。\n",
"\n",
"根据您的用例,使用 `Dataset.cache` 或 `tf.data.Dataset.snapshot` 可能是个好主意,这样 CSV 数据仅会在第一个周期进行解析。\n",
"\n",
"`cache` 和 `snapshot` 方法的主要区别在于 `cache` 文件只能由创建它们的 TensorFlow 进程使用,而 `snapshot` 文件可以被其他进程读取。\n",
"\n",
"例如,在没有缓存的情况下迭代 `traffic_volume_csv_gz_ds` 20 次可能需要大约 15 秒,而使用缓存大约需要 2 秒。"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:35.393800Z",
"iopub.status.busy": "2023-11-07T23:47:35.393521Z",
"iopub.status.idle": "2023-11-07T23:47:46.711854Z",
"shell.execute_reply": "2023-11-07T23:47:46.711001Z"
},
"id": "Qk38Sw4MO4eh"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".\n",
"CPU times: user 15 s, sys: 2.52 s, total: 17.5 s\n",
"Wall time: 11.3 s\n"
]
}
],
"source": [
"%%time\n",
"for i, (batch, label) in enumerate(traffic_volume_csv_gz_ds.repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pN3HtDONh5TX"
},
"source": [
"注:`Dataset.cache` 会存储第一个周期的数据并按顺序回放。因此,使用 `cache` 方法会禁用流水线中较早的任何乱序内容。下面,在 `Dataset.cache` 之后重新添加了 `Dataset.shuffle`。"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:46.715367Z",
"iopub.status.busy": "2023-11-07T23:47:46.715074Z",
"iopub.status.idle": "2023-11-07T23:47:48.589272Z",
"shell.execute_reply": "2023-11-07T23:47:48.588498Z"
},
"id": "r5Jj72MrPbnh"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"..............\n",
"CPU times: user 1.9 s, sys: 285 ms, total: 2.19 s\n",
"Wall time: 1.87 s\n"
]
}
],
"source": [
"%%time\n",
"caching = traffic_volume_csv_gz_ds.cache().shuffle(1000)\n",
"\n",
"for i, (batch, label) in enumerate(caching.shuffle(1000).repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wN7uUBjmgNZ9"
},
"source": [
"注:`tf.data.Dataset.snapshot` 文件用于在使用时*临时*存储数据集。这*不是*长期存储的格式。文件格式被视为内部详细信息,无法在 TensorFlow 各版本之间保证。"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:48.592913Z",
"iopub.status.busy": "2023-11-07T23:47:48.592618Z",
"iopub.status.idle": "2023-11-07T23:47:50.715544Z",
"shell.execute_reply": "2023-11-07T23:47:50.714777Z"
},
"id": "PHGD1E8ktUvW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"............."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"............."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"................"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
".....\n",
"CPU times: user 2.68 s, sys: 738 ms, total: 3.42 s\n",
"Wall time: 2.12 s\n"
]
}
],
"source": [
"%%time\n",
"snapshotting = traffic_volume_csv_gz_ds.snapshot('titanic.tfsnap').shuffle(1000)\n",
"\n",
"for i, (batch, label) in enumerate(snapshotting.shuffle(1000).repeat(20)):\n",
" if i % 40 == 0:\n",
" print('.', end='')\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fUSSegnMCGRz"
},
"source": [
"如果加载 CSV 文件减慢了数据加载速度,并且 `Dataset.cache` 和 `tf.data.Dataset.snapshot` 不足以满足您的用例,请考虑将数据重新编码为更简化的格式。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M0iGXv9pC5kr"
},
"source": [
"### 多个文件"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9FFzHQrCDH4w"
},
"source": [
"到目前为止,本部分中的所有示例都可以在没有 `tf.data` 的情况下轻松完成。处理文件集合时,`tf.data` 可以真正简化事情。\n",
"\n",
"例如,将 [Character Font Images](https://archive.ics.uci.edu/ml/datasets/Character+Font+Images) 数据集作为 CSV 文件的集合分发,每种字体一个集合。\n",
"\n",
"\n",
"\n",
"图像作者:Willi Heidelbach ,来源:Pixabay \n",
"\n",
"下载数据集,并查看里面的文件:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:50.719800Z",
"iopub.status.busy": "2023-11-07T23:47:50.719157Z",
"iopub.status.idle": "2023-11-07T23:47:58.746732Z",
"shell.execute_reply": "2023-11-07T23:47:58.745897Z"
},
"id": "RmVknMdJh5ks"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://archive.ics.uci.edu/ml/machine-learning-databases/00417/fonts.zip\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 8192/Unknown - 0s 0us/step"
]
}
],
"source": [
"fonts_zip = tf.keras.utils.get_file(\n",
" 'fonts.zip', \"https://archive.ics.uci.edu/ml/machine-learning-databases/00417/fonts.zip\",\n",
" cache_dir='.', cache_subdir='fonts',\n",
" extract=True)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:58.751395Z",
"iopub.status.busy": "2023-11-07T23:47:58.750649Z",
"iopub.status.idle": "2023-11-07T23:47:58.757791Z",
"shell.execute_reply": "2023-11-07T23:47:58.757107Z"
},
"id": "xsDlMCnyi55e"
},
"outputs": [
{
"data": {
"text/plain": [
"['fonts/AGENCY.csv',\n",
" 'fonts/ARIAL.csv',\n",
" 'fonts/BAITI.csv',\n",
" 'fonts/BANKGOTHIC.csv',\n",
" 'fonts/BASKERVILLE.csv',\n",
" 'fonts/BAUHAUS.csv',\n",
" 'fonts/BELL.csv',\n",
" 'fonts/BERLIN.csv',\n",
" 'fonts/BERNARD.csv',\n",
" 'fonts/BITSTREAMVERA.csv']"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pathlib\n",
"font_csvs = sorted(str(p) for p in pathlib.Path('fonts').glob(\"*.csv\"))\n",
"\n",
"font_csvs[:10]"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:58.761096Z",
"iopub.status.busy": "2023-11-07T23:47:58.760619Z",
"iopub.status.idle": "2023-11-07T23:47:58.765317Z",
"shell.execute_reply": "2023-11-07T23:47:58.764651Z"
},
"id": "lRAEJx9ROAGl"
},
"outputs": [
{
"data": {
"text/plain": [
"153"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(font_csvs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "19Udrw9iG-FS"
},
"source": [
"在处理一堆文件时,可以将 glob 样式的 `file_pattern` 传递给 `tf.data.experimental.make_csv_dataset` 函数。每次迭代都会打乱文件的顺序。\n",
"\n",
"使用 `num_parallel_reads` 参数对并行读取多少文件并交错在一起进行设置。"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:58.768599Z",
"iopub.status.busy": "2023-11-07T23:47:58.768050Z",
"iopub.status.idle": "2023-11-07T23:47:59.645990Z",
"shell.execute_reply": "2023-11-07T23:47:59.644966Z"
},
"id": "6TSUNdT6iG58"
},
"outputs": [],
"source": [
"fonts_ds = tf.data.experimental.make_csv_dataset(\n",
" file_pattern = \"fonts/*.csv\",\n",
" batch_size=10, num_epochs=1,\n",
" num_parallel_reads=20,\n",
" shuffle_buffer_size=10000)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMoexinLHYFa"
},
"source": [
"这些 CSV 文件会将图像展平成一行。列名的格式为 `r{row}c{column}`。下面是第一个批次:"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:47:59.650772Z",
"iopub.status.busy": "2023-11-07T23:47:59.650047Z",
"iopub.status.idle": "2023-11-07T23:48:01.769941Z",
"shell.execute_reply": "2023-11-07T23:48:01.769003Z"
},
"id": "RmFvBWxxi3pq"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"font : [b'BASKERVILLE' b'RAGE' b'COMPLEX' b'MISTRAL' b'COUNTRYBLUEPRINT'\n",
" b'SWIS721' b'RICHARD' b'STYLUS' b'SWIS721' b'SWIS721']\n",
"fontVariant : [b'BASKERVILLE OLD FACE' b'RAGE ITALIC' b'COMPLEX' b'MISTRAL'\n",
" b'COUNTRYBLUEPRINT' b'SWIS721 LTEX BT' b'POOR RICHARD' b'STYLUS BT'\n",
" b'SWIS721 LTEX BT' b'SWIS721 LTEX BT']\n",
"m_label : [ 68 111 9578 383 8225 126 92 93 376 382]\n",
"strength : [0.4 0.4 0.4 0.4 0.4 0.4 0.4 0.4 0.4 0.4]\n",
"italic : [1 0 0 0 0 0 0 0 1 0]\n",
"orientation : [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"m_top : [34 55 28 29 38 60 29 38 23 34]\n",
"m_left : [22 20 25 25 24 26 21 20 37 22]\n",
"originalH : [45 21 74 49 54 9 60 54 61 50]\n",
"originalW : [55 21 38 14 25 44 39 12 47 33]\n",
"h : [20 20 20 20 20 20 20 20 20 20]\n",
"w : [20 20 20 20 20 20 20 20 20 20]\n",
"r0c0 : [ 1 1 1 1 255 1 168 161 1 1]\n",
"r0c1 : [ 1 1 1 1 255 1 255 161 1 1]\n",
"r0c2 : [ 1 1 1 7 255 1 57 161 1 1]\n",
"r0c3 : [ 1 1 1 47 255 1 1 161 1 137]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"...\n",
"[total: 412 features]\n"
]
}
],
"source": [
"for features in fonts_ds.take(1):\n",
" for i, (name, value) in enumerate(features.items()):\n",
" if i>15:\n",
" break\n",
" print(f\"{name:20s}: {value}\")\n",
"print('...')\n",
"print(f\"[total: {len(features)} features]\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xrC3sKdeOhb5"
},
"source": [
"#### 可选:打包字段\n",
"\n",
"您可能不想像这样在单独的列中处理每个像素。在尝试使用此数据集之前,请务必将像素打包到图像张量中。\n",
"\n",
"下面是解析列名,从而为每个示例构建图像的代码:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:01.774004Z",
"iopub.status.busy": "2023-11-07T23:48:01.773428Z",
"iopub.status.idle": "2023-11-07T23:48:01.779214Z",
"shell.execute_reply": "2023-11-07T23:48:01.778565Z"
},
"id": "hct5EMEWNyfH"
},
"outputs": [],
"source": [
"import re\n",
"\n",
"def make_images(features):\n",
" image = [None]*400\n",
" new_feats = {}\n",
"\n",
" for name, value in features.items():\n",
" match = re.match('r(\\d+)c(\\d+)', name)\n",
" if match:\n",
" image[int(match.group(1))*20+int(match.group(2))] = value\n",
" else:\n",
" new_feats[name] = value\n",
"\n",
" image = tf.stack(image, axis=0)\n",
" image = tf.reshape(image, [20, 20, -1])\n",
" new_feats['image'] = image\n",
"\n",
" return new_feats"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "61qy8utAwARP"
},
"source": [
"将该函数应用于数据集中的每个批次:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:01.782924Z",
"iopub.status.busy": "2023-11-07T23:48:01.782291Z",
"iopub.status.idle": "2023-11-07T23:48:04.414751Z",
"shell.execute_reply": "2023-11-07T23:48:04.413941Z"
},
"id": "DJnnfIW9baE4"
},
"outputs": [],
"source": [
"fonts_image_ds = fonts_ds.map(make_images)\n",
"\n",
"for features in fonts_image_ds.take(1):\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ThqrthGwHSm"
},
"source": [
"绘制生成的图像:"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:04.419327Z",
"iopub.status.busy": "2023-11-07T23:48:04.418772Z",
"iopub.status.idle": "2023-11-07T23:48:05.242178Z",
"shell.execute_reply": "2023-11-07T23:48:05.241404Z"
},
"id": "I5dcey31T_tk"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"plt.figure(figsize=(6,6), dpi=120)\n",
"\n",
"for n in range(9):\n",
" plt.subplot(3,3,n+1)\n",
" plt.imshow(features['image'][..., n])\n",
" plt.title(chr(features['m_label'][n]))\n",
" plt.axis('off')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7-nNR0Nncdd1"
},
"source": [
"## 低级函数"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3jiGZeUijJNd"
},
"source": [
"到目前为止,本教程重点介绍了用于读取 CSV 数据的最高级别实用程序。如果您的用例不符合基本模式,还有其他两个 API 可能对高级用户有所帮助。\n",
"\n",
"- `tf.io.decode_csv`:用于将文本行解析为 CSV 列张量列表的函数。\n",
"- `tf.data.experimental.CsvDataset`:较低级别的 CSV 数据集构造函数。\n",
"\n",
"本部分会重新创建 `tf.data.experimental.make_csv_dataset` 提供的功能,以演示如何使用此较低级别的功能。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LL_ixywomOHW"
},
"source": [
"### `tf.io.decode_csv`\n",
"\n",
"此函数会将字符串或字符串列表解码为列列表。\n",
"\n",
"与 `tf.data.experimental.make_csv_dataset` 不同,此函数不会尝试猜测列数据类型。您可以通过为每列提供包含正确类型值的记录 `record_defaults` 值列表来指定列类型。\n",
"\n",
"要使用 tf.io.decode_csv
将 Titanic 数据作为字符串 读取,您可以使用以下代码:"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.246494Z",
"iopub.status.busy": "2023-11-07T23:48:05.246013Z",
"iopub.status.idle": "2023-11-07T23:48:05.252473Z",
"shell.execute_reply": "2023-11-07T23:48:05.251757Z"
},
"id": "m1D2C-qdlqeW"
},
"outputs": [
{
"data": {
"text/plain": [
"['', '', '', '', '', '', '', '', '', '']"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = pathlib.Path(titanic_file_path).read_text()\n",
"lines = text.split('\\n')[1:-1]\n",
"\n",
"all_strings = [str()]*10\n",
"all_strings"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.255897Z",
"iopub.status.busy": "2023-11-07T23:48:05.255292Z",
"iopub.status.idle": "2023-11-07T23:48:05.263425Z",
"shell.execute_reply": "2023-11-07T23:48:05.262745Z"
},
"id": "9W4UeJYyHPx5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n"
]
}
],
"source": [
"features = tf.io.decode_csv(lines, record_defaults=all_strings) \n",
"\n",
"for f in features:\n",
" print(f\"type: {f.dtype.name}, shape: {f.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j8TaHSQFoQL4"
},
"source": [
"要使用它们的实际类型解析它们,请创建相应类型的 `record_defaults` 列表: "
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.267016Z",
"iopub.status.busy": "2023-11-07T23:48:05.266553Z",
"iopub.status.idle": "2023-11-07T23:48:05.270431Z",
"shell.execute_reply": "2023-11-07T23:48:05.269794Z"
},
"id": "rzUjR59yoUe1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0,male,22.0,1,0,7.25,Third,unknown,Southampton,n\n"
]
}
],
"source": [
"print(lines[0])"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.273780Z",
"iopub.status.busy": "2023-11-07T23:48:05.273262Z",
"iopub.status.idle": "2023-11-07T23:48:05.278362Z",
"shell.execute_reply": "2023-11-07T23:48:05.277706Z"
},
"id": "7sPTunxwoeWU"
},
"outputs": [
{
"data": {
"text/plain": [
"[0, '', 0.0, 0, 0, 0.0, '', '', '', '']"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"titanic_types = [int(), str(), float(), int(), int(), float(), str(), str(), str(), str()]\n",
"titanic_types"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.281628Z",
"iopub.status.busy": "2023-11-07T23:48:05.281066Z",
"iopub.status.idle": "2023-11-07T23:48:05.289059Z",
"shell.execute_reply": "2023-11-07T23:48:05.288430Z"
},
"id": "n3NlViCzoB7F"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"type: int32, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: float32, shape: (627,)\n",
"type: int32, shape: (627,)\n",
"type: int32, shape: (627,)\n",
"type: float32, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n",
"type: string, shape: (627,)\n"
]
}
],
"source": [
"features = tf.io.decode_csv(lines, record_defaults=titanic_types) \n",
"\n",
"for f in features:\n",
" print(f\"type: {f.dtype.name}, shape: {f.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m-LkTUTnpn2P"
},
"source": [
"注:在大批量行上调用 `tf.io.decode_csv` 比在单个 CSV 文本行上调用更有效。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yp1UItJmqGqw"
},
"source": [
"### `tf.data.experimental.CsvDataset`\n",
"\n",
"`tf.data.experimental.CsvDataset` 类提供了一个最小的 CSV `Dataset` 接口,没有 `tf.data.experimental.make_csv_dataset` 函数的便利功能:列标题解析、列类型推断、自动乱序、文件交错。\n",
"\n",
"此构造函数使用 `record_defaults` 的方式与 `tf.io.decode_csv` 相同:\n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.292639Z",
"iopub.status.busy": "2023-11-07T23:48:05.292130Z",
"iopub.status.idle": "2023-11-07T23:48:05.310279Z",
"shell.execute_reply": "2023-11-07T23:48:05.309620Z"
},
"id": "9OzZLp3krP-t"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']\n"
]
}
],
"source": [
"simple_titanic = tf.data.experimental.CsvDataset(titanic_file_path, record_defaults=titanic_types, header=True)\n",
"\n",
"for example in simple_titanic.take(1):\n",
" print([e.numpy() for e in example])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_HBmfI-Ks7dw"
},
"source": [
"上面的代码基本等价于:"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.313807Z",
"iopub.status.busy": "2023-11-07T23:48:05.313185Z",
"iopub.status.idle": "2023-11-07T23:48:05.404242Z",
"shell.execute_reply": "2023-11-07T23:48:05.403422Z"
},
"id": "E5O5d69Yq7gG"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']\n"
]
}
],
"source": [
"def decode_titanic_line(line):\n",
" return tf.io.decode_csv(line, titanic_types)\n",
"\n",
"manual_titanic = (\n",
" # Load the lines of text\n",
" tf.data.TextLineDataset(titanic_file_path)\n",
" # Skip the header row.\n",
" .skip(1)\n",
" # Decode the line.\n",
" .map(decode_titanic_line)\n",
")\n",
"\n",
"for example in manual_titanic.take(1):\n",
" print([e.numpy() for e in example])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5R3ralsnt2AC"
},
"source": [
"#### 多个文件\n",
"\n",
"要使用 `tf.data.experimental.CsvDataset` 解析字体数据集,您首先需要确定 `record_defaults` 的列类型。首先检查一个文件的第一行:"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.408185Z",
"iopub.status.busy": "2023-11-07T23:48:05.407428Z",
"iopub.status.idle": "2023-11-07T23:48:05.416359Z",
"shell.execute_reply": "2023-11-07T23:48:05.415523Z"
},
"id": "3tlFOTjCvAI5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AGENCY,AGENCY FB,64258,0.400000,0,0.000000,35,21,51,22,20,20,1,1,1,21,101,210,255,255,255,255,255,255,255,255,255,255,255,255,255,255,1,1,1,93,255,255,255,176,146,146,146,146,146,146,146,146,216,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,141,141,141,182,255,255,255,172,141,141,141,115,1,1,1,1,163,255,255,255,255,255,255,255,255,255,255,255,255,255,255,209,1,1,1,1,163,255,255,255,6,6,6,96,255,255,255,74,6,6,6,5,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255,1,1,1,93,255,255,255,70,1,1,1,1,1,1,1,1,163,255,255,255\n"
]
}
],
"source": [
"font_line = pathlib.Path(font_csvs[0]).read_text().splitlines()[1]\n",
"print(font_line)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "etyGu8K_ySRz"
},
"source": [
"只有前两个字段是字符串,其余的都是整数或浮点数,通过计算逗号的个数可以得到特征总数:"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.419593Z",
"iopub.status.busy": "2023-11-07T23:48:05.419294Z",
"iopub.status.idle": "2023-11-07T23:48:05.423358Z",
"shell.execute_reply": "2023-11-07T23:48:05.422648Z"
},
"id": "crgZZn0BzkSB"
},
"outputs": [],
"source": [
"num_font_features = font_line.count(',')+1\n",
"font_column_types = [str(), str()] + [float()]*(num_font_features-2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YeK2Pw540RNj"
},
"source": [
"`tf.data.experimental.CsvDataset` 构造函数可以获取输入文件列表,但会按顺序读取它们。CSV 列表中的第一个文件是 `AGENCY.csv`:"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.426885Z",
"iopub.status.busy": "2023-11-07T23:48:05.426305Z",
"iopub.status.idle": "2023-11-07T23:48:05.431260Z",
"shell.execute_reply": "2023-11-07T23:48:05.430574Z"
},
"id": "_SvL5Uvl0r0N"
},
"outputs": [
{
"data": {
"text/plain": [
"'fonts/AGENCY.csv'"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"font_csvs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EfAX3G8Xywy6"
},
"source": [
"因此,当您将文件列表传递给 `CsvDataset` 时,会首先读取 `AGENCY.csv` 中的记录:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.434747Z",
"iopub.status.busy": "2023-11-07T23:48:05.434112Z",
"iopub.status.idle": "2023-11-07T23:48:05.470038Z",
"shell.execute_reply": "2023-11-07T23:48:05.469359Z"
},
"id": "Gtr1E66VmBqj"
},
"outputs": [],
"source": [
"simple_font_ds = tf.data.experimental.CsvDataset(\n",
" font_csvs, \n",
" record_defaults=font_column_types, \n",
" header=True)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.473568Z",
"iopub.status.busy": "2023-11-07T23:48:05.473059Z",
"iopub.status.idle": "2023-11-07T23:48:05.549053Z",
"shell.execute_reply": "2023-11-07T23:48:05.548279Z"
},
"id": "k750Mgq4yt_o"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n",
"b'AGENCY'\n"
]
}
],
"source": [
"for row in simple_font_ds.take(10):\n",
" print(row[0].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NiqWKQV21FrE"
},
"source": [
"要交错多个文件,请使用 `Dataset.interleave`。\n",
"\n",
"这是一个包含 CSV 文件名的初始数据集: "
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.552726Z",
"iopub.status.busy": "2023-11-07T23:48:05.552171Z",
"iopub.status.idle": "2023-11-07T23:48:05.571282Z",
"shell.execute_reply": "2023-11-07T23:48:05.570550Z"
},
"id": "t9dS3SNb23W8"
},
"outputs": [],
"source": [
"font_files = tf.data.Dataset.list_files(\"fonts/*.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNiLHMXpzHy5"
},
"source": [
"这会在每个周期对文件名进行乱序:"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.574909Z",
"iopub.status.busy": "2023-11-07T23:48:05.574373Z",
"iopub.status.idle": "2023-11-07T23:48:05.617486Z",
"shell.execute_reply": "2023-11-07T23:48:05.616820Z"
},
"id": "zNd-TYyNzIgg"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1:\n",
" b'fonts/TEMPUS.csv'\n",
" b'fonts/PERPETUA.csv'\n",
" b'fonts/KRISTEN.csv'\n",
" b'fonts/MISTRAL.csv'\n",
" b'fonts/CENTURY.csv'\n",
" ...\n",
"\n",
"Epoch 2:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" b'fonts/OCRB.csv'\n",
" b'fonts/ARIAL.csv'\n",
" b'fonts/STYLUS.csv'\n",
" b'fonts/TIMES.csv'\n",
" b'fonts/BOOK.csv'\n",
" ...\n"
]
}
],
"source": [
"print('Epoch 1:')\n",
"for f in list(font_files)[:5]:\n",
" print(\" \", f.numpy())\n",
"print(' ...')\n",
"print()\n",
"\n",
"print('Epoch 2:')\n",
"for f in list(font_files)[:5]:\n",
" print(\" \", f.numpy())\n",
"print(' ...')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B0QB1PtU3WAN"
},
"source": [
"`interleave` 方法采用 `map_func`,它会为父 `Dataset`的每个元素创建一个子 `Dataset`。\n",
"\n",
"在这里,您要从文件数据集的每个元素创建一个 `tf.data.experimental.CsvDataset`:"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.621041Z",
"iopub.status.busy": "2023-11-07T23:48:05.620374Z",
"iopub.status.idle": "2023-11-07T23:48:05.624344Z",
"shell.execute_reply": "2023-11-07T23:48:05.623596Z"
},
"id": "QWp4rH0Q4uPh"
},
"outputs": [],
"source": [
"def make_font_csv_ds(path):\n",
" return tf.data.experimental.CsvDataset(\n",
" path, \n",
" record_defaults=font_column_types, \n",
" header=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VxRGdLMB5nRF"
},
"source": [
"交错返回的 `Dataset` 通过循环遍历多个子 `Dataset` 来返回元素。请注意,下面的数据集如何在 `cycle_length=3` 三个字体文件中循环:"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.627651Z",
"iopub.status.busy": "2023-11-07T23:48:05.627057Z",
"iopub.status.idle": "2023-11-07T23:48:05.845578Z",
"shell.execute_reply": "2023-11-07T23:48:05.844805Z"
},
"id": "OePMNF_x1_Cc"
},
"outputs": [],
"source": [
"font_rows = font_files.interleave(make_font_csv_ds,\n",
" cycle_length=3)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:05.849722Z",
"iopub.status.busy": "2023-11-07T23:48:05.849100Z",
"iopub.status.idle": "2023-11-07T23:48:06.007320Z",
"shell.execute_reply": "2023-11-07T23:48:06.006519Z"
},
"id": "UORIGWLy54-E"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/tmp/ipykernel_570970/998453860.py:5: DeprecationWarning: an integer is required (got type numpy.float32). Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.\n",
" fonts_dict['character'].append(chr(row[2].numpy()))\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" font_name \n",
" character \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" YI BAITI \n",
" ? \n",
" \n",
" \n",
" 1 \n",
" GUNPLAY \n",
" € \n",
" \n",
" \n",
" 2 \n",
" BAITI \n",
" ? \n",
" \n",
" \n",
" 3 \n",
" YI BAITI \n",
" ; \n",
" \n",
" \n",
" 4 \n",
" GUNPLAY \n",
" › \n",
" \n",
" \n",
" 5 \n",
" BAITI \n",
" ! \n",
" \n",
" \n",
" 6 \n",
" YI BAITI \n",
" : \n",
" \n",
" \n",
" 7 \n",
" GUNPLAY \n",
" ‹ \n",
" \n",
" \n",
" 8 \n",
" BAITI \n",
" ﹈ \n",
" \n",
" \n",
" 9 \n",
" YI BAITI \n",
" , \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" font_name character\n",
"0 YI BAITI ?\n",
"1 GUNPLAY €\n",
"2 BAITI ?\n",
"3 YI BAITI ;\n",
"4 GUNPLAY ›\n",
"5 BAITI !\n",
"6 YI BAITI :\n",
"7 GUNPLAY ‹\n",
"8 BAITI ﹈\n",
"9 YI BAITI ,"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fonts_dict = {'font_name':[], 'character':[]}\n",
"\n",
"for row in font_rows.take(10):\n",
" fonts_dict['font_name'].append(row[0].numpy().decode())\n",
" fonts_dict['character'].append(chr(row[2].numpy()))\n",
"\n",
"pd.DataFrame(fonts_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mkKZa_HX8zAm"
},
"source": [
"#### 性能\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8BtGHraUApdJ"
},
"source": [
"早些时候,有人注意到 `tf.io.decode_csv` 在一个批次字符串上运行时效率更高。\n",
"\n",
"当使用大批量时,可以利用这一事实来提高 CSV 加载性能(但请先尝试使用[缓存](#caching))。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d35zWMH7MDL1"
},
"source": [
"使用内置加载器 20,2048 个样本批次大约需要 17 秒。 "
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:06.011447Z",
"iopub.status.busy": "2023-11-07T23:48:06.011118Z",
"iopub.status.idle": "2023-11-07T23:48:06.914096Z",
"shell.execute_reply": "2023-11-07T23:48:06.913252Z"
},
"id": "ieUVAPryjpJS"
},
"outputs": [],
"source": [
"BATCH_SIZE=2048\n",
"fonts_ds = tf.data.experimental.make_csv_dataset(\n",
" file_pattern = \"fonts/*.csv\",\n",
" batch_size=BATCH_SIZE, num_epochs=1,\n",
" num_parallel_reads=100)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:06.918451Z",
"iopub.status.busy": "2023-11-07T23:48:06.918160Z",
"iopub.status.idle": "2023-11-07T23:48:30.179708Z",
"shell.execute_reply": "2023-11-07T23:48:30.178966Z"
},
"id": "MUC2KW4LkQIz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"CPU times: user 50.8 s, sys: 4.58 s, total: 55.4 s\n",
"Wall time: 23.3 s\n"
]
}
],
"source": [
"%%time\n",
"for i,batch in enumerate(fonts_ds.take(20)):\n",
" print('.',end='')\n",
"\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5lhnh6rZEDS2"
},
"source": [
"将**批量文本行**传递给 `decode_csv` 运行速度更快,大约需要 5 秒:"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:30.183571Z",
"iopub.status.busy": "2023-11-07T23:48:30.183274Z",
"iopub.status.idle": "2023-11-07T23:48:30.638344Z",
"shell.execute_reply": "2023-11-07T23:48:30.637466Z"
},
"id": "4XbPZV1okVF9"
},
"outputs": [],
"source": [
"fonts_files = tf.data.Dataset.list_files(\"fonts/*.csv\")\n",
"fonts_lines = fonts_files.interleave(\n",
" lambda fname:tf.data.TextLineDataset(fname).skip(1), \n",
" cycle_length=100).batch(BATCH_SIZE)\n",
"\n",
"fonts_fast = fonts_lines.map(lambda x: tf.io.decode_csv(x, record_defaults=font_column_types))"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"execution": {
"iopub.execute_input": "2023-11-07T23:48:30.642557Z",
"iopub.status.busy": "2023-11-07T23:48:30.642273Z",
"iopub.status.idle": "2023-11-07T23:48:31.576293Z",
"shell.execute_reply": "2023-11-07T23:48:31.575330Z"
},
"id": "te9C2km-qO8W"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".............."
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"......"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"CPU times: user 4.21 s, sys: 139 ms, total: 4.35 s\n",
"Wall time: 929 ms\n"
]
}
],
"source": [
"%%time\n",
"for i,batch in enumerate(fonts_fast.take(20)):\n",
" print('.',end='')\n",
"\n",
"print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aebC1plsMeOi"
},
"source": [
"有关通过使用大批量提高 CSV 性能的另一个示例,请参阅[过拟合和欠拟合教程](../keras/overfit_and_underfit.ipynb)。\n",
"\n",
"这种方式可能有效,但请考虑其他选项,例如 `Dataset.cache` 和 `tf.data.Dataset.snapshot`,或者将您的数据重新编码为更简化的格式。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "csv.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 0
}