{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "EOgDUDMAG6mn" }, "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T17:28:45.546703Z", "iopub.status.busy": "2023-11-07T17:28:45.546277Z", "iopub.status.idle": "2023-11-07T17:28:45.550316Z", "shell.execute_reply": "2023-11-07T17:28:45.549694Z" }, "id": "B3PsBDmGG_W8" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "ifkGYxdCHIof" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", " 在 Google Colab 中运行\n", " 在 Github 上查看源代码\n", " 下载笔记本\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "sWxDDkRwLVMC" }, "source": [ "# 使用 MoViNet 进行视频分类的迁移学习\n", "\n", "MoViNets(移动视频网络)提供了一系列高效的视频分类模型,支持对流式视频进行推断。在本教程中,您将使用预训练的 MoViNet 模型对来自 [UCF101 数据集](https://www.crcv.ucf.edu/data/UCF101.php)的视频进行分类,特别是针对动作识别任务。预训练模型是一个先前在更大数据集上训练过的已保存网络。可以在 Kondratyuk, D. 等人 2021 年撰写的 [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) 论文中找到有关 MoViNets 的更多详细信息。在本教程中,您将完成以下任务:\n", "\n", "- 了解如何下载预训练的 MoViNet 模型\n", "- 通过冻结 MoViNet 模型的卷积基,使用带有新分类器的预训练模型创建新模型\n", "- 将分类器头替换为新数据集的标签数\n", "- 在 [UCF101 数据集](https://www.crcv.ucf.edu/data/UCF101.php)上执行迁移学习\n", "\n", "本教程下载的模型来自 [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet)。此仓库包含 TF Hub 在 TensorFlow 2 SavedModel 格式中使用的 MoViNet 模型集合。\n", "\n", "本视频加载和预处理教程是 TensorFlow 视频教程系列的第一部分。下面是其他三个教程:\n", "\n", "- [加载视频数据](https://tensorflow.google.cn/tutorials/load_data/video):本教程解释了本文档中使用的大部分代码;特别是,更详细地解释了如何通过 `FrameGenerator` 类预处理和加载数据。\n", "- [构建用于视频分类的 3D CNN 模型](https://tensorflow.google.cn/tutorials/video/video_classification)。请注意,本教程使用分解 3D 数据的空间和时间方面的 (2+1)D CNN;如果使用 MRI 扫描等体数据,请考虑使用 3D CNN 而不是 (2+1)D CNN。\n", "- [用于流式动作识别的 MoViNet](https://tensorflow.google.cn/hub/tutorials/movinet):熟悉 TF Hub 上提供的 MoViNet 模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "GidiisyXwK--" }, "source": [ "## 安装\n", "\n", "首先,安装并导入一些必要的库,包括:用于检查 ZIP 文件内容的 [remotezip](https://github.com/gtsystem/python-remotezip),用于使用进度条的 [tqdm](https://github.com/tqdm/tqdm),用于处理视频文件的 [OpenCV](https://opencv.org/)(确保 `opencv-python` 和 `opencv-python-headless` 是同一版本),以及用于下载预训练 MoViNet 模型的 TensorFlow 模型 ([`tf-models- official`](https://github.com/tensorflow/models/tree/master/official))。TensorFlow 模型软件包是一组使用 TensorFlow 高级 API 的模型。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:28:45.553919Z", "iopub.status.busy": "2023-11-07T17:28:45.553413Z", "iopub.status.idle": "2023-11-07T17:29:24.815353Z", "shell.execute_reply": "2023-11-07T17:29:24.814315Z" }, "id": "nubWhqYdwEXD" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting remotezip\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading remotezip-0.12.1.tar.gz (7.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25hRequirement already satisfied: tqdm in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (4.66.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting opencv-python==4.5.2.52\r\n", " Downloading opencv_python-4.5.2.52-cp39-cp39-manylinux2014_x86_64.whl (51.0 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting opencv-python-headless==4.5.2.52\r\n", " Downloading opencv_python_headless-4.5.2.52-cp39-cp39-manylinux2014_x86_64.whl (38.2 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-models-official\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tf_models_official-2.14.2-py2.py3-none-any.whl.metadata (1.4 kB)\r\n", "Requirement already satisfied: numpy>=1.19.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from opencv-python==4.5.2.52) (1.26.1)\r\n", "Requirement already satisfied: requests in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from remotezip) (2.31.0)\r\n", "Collecting tabulate (from remotezip)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting Cython (from tf-models-official)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading Cython-3.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)\r\n", "Requirement already satisfied: Pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (10.0.1)\r\n", "Collecting gin-config (from tf-models-official)\r\n", " Downloading gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting google-api-python-client>=1.6.7 (from tf-models-official)\r\n", " Downloading google_api_python_client-2.107.0-py2.py3-none-any.whl.metadata (6.6 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting immutabledict (from tf-models-official)\r\n", " Downloading immutabledict-3.0.0-py3-none-any.whl.metadata (3.1 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting kaggle>=1.3.9 (from tf-models-official)\r\n", " Downloading kaggle-1.5.16.tar.gz (83 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25hRequirement already satisfied: matplotlib in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (3.8.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting oauth2client (from tf-models-official)\r\n", " Downloading oauth2client-4.1.3-py2.py3-none-any.whl (98 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pandas>=0.22.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (2.1.2)\r\n", "Requirement already satisfied: psutil>=5.4.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (5.9.6)\r\n", "Collecting py-cpuinfo>=3.3.0 (from tf-models-official)\r\n", " Downloading py_cpuinfo-9.0.0-py3-none-any.whl (22 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting pycocotools (from tf-models-official)\r\n", " Downloading pycocotools-2.0.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyyaml>=6.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (6.0.1)\r\n", "Collecting sacrebleu (from tf-models-official)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading sacrebleu-2.3.2-py3-none-any.whl.metadata (57 kB)\r\n", "Requirement already satisfied: scipy>=0.19.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (1.11.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting sentencepiece (from tf-models-official)\r\n", " Downloading sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting seqeval (from tf-models-official)\r\n", " Downloading seqeval-1.2.2.tar.gz (43 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25hRequirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (1.16.0)\r\n", "Requirement already satisfied: tensorflow-datasets in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (4.9.3)\r\n", "Requirement already satisfied: tensorflow-hub>=0.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-models-official) (0.15.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-model-optimization>=0.4.1 (from tf-models-official)\r\n", " Downloading tensorflow_model_optimization-0.7.5-py2.py3-none-any.whl.metadata (914 bytes)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-text~=2.14.0 (from tf-models-official)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tensorflow_text-2.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow~=2.14.0 (from tf-models-official)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tensorflow-2.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-slim>=1.1.0 (from tf-models-official)\r\n", " Downloading tf_slim-1.1.0-py2.py3-none-any.whl (352 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting httplib2<1.dev0,>=0.15.0 (from google-api-python-client>=1.6.7->tf-models-official)\r\n", " Downloading httplib2-0.22.0-py3-none-any.whl (96 kB)\r\n", "Requirement already satisfied: google-auth<3.0.0.dev0,>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-api-python-client>=1.6.7->tf-models-official) (2.23.4)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting google-auth-httplib2>=0.1.0 (from google-api-python-client>=1.6.7->tf-models-official)\r\n", " Downloading google_auth_httplib2-0.1.1-py2.py3-none-any.whl.metadata (2.1 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0.dev0,>=1.31.5 (from google-api-python-client>=1.6.7->tf-models-official)\r\n", " Downloading google_api_core-2.12.0-py3-none-any.whl.metadata (2.7 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting uritemplate<5,>=3.0.1 (from google-api-python-client>=1.6.7->tf-models-official)\r\n", " Downloading uritemplate-4.1.1-py2.py3-none-any.whl (10 kB)\r\n", "Requirement already satisfied: certifi in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from kaggle>=1.3.9->tf-models-official) (2023.7.22)\r\n", "Requirement already satisfied: python-dateutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from kaggle>=1.3.9->tf-models-official) (2.8.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting python-slugify (from kaggle>=1.3.9->tf-models-official)\r\n", " Downloading python_slugify-8.0.1-py2.py3-none-any.whl (9.7 kB)\r\n", "Requirement already satisfied: urllib3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from kaggle>=1.3.9->tf-models-official) (2.0.7)\r\n", "Requirement already satisfied: bleach in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from kaggle>=1.3.9->tf-models-official) (6.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas>=0.22.0->tf-models-official) (2023.3.post1)\r\n", "Requirement already satisfied: tzdata>=2022.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas>=0.22.0->tf-models-official) (2023.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (23.5.26)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (0.2.0)\r\n", "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (16.0.6)\r\n", "Requirement already satisfied: ml-dtypes==0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (0.2.0)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (3.20.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (68.2.2)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (2.3.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (4.8.0)\r\n", "Requirement already satisfied: wrapt<1.15,>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (1.14.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (0.34.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.14.0->tf-models-official) (1.59.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorboard<2.15,>=2.14 (from tensorflow~=2.14.0->tf-models-official)\r\n", " Downloading tensorboard-2.14.1-py3-none-any.whl.metadata (1.7 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-estimator<2.15,>=2.14.0 (from tensorflow~=2.14.0->tf-models-official)\r\n", " Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl.metadata (1.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting keras<2.15,>=2.14.0 (from tensorflow~=2.14.0->tf-models-official)\r\n", " Downloading keras-2.14.0-py3-none-any.whl.metadata (2.4 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: dm-tree~=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-model-optimization>=0.4.1->tf-models-official) (0.1.8)\r\n", "Requirement already satisfied: contourpy>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (1.2.0)\r\n", "Requirement already satisfied: cycler>=0.10 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (0.12.1)\r\n", "Requirement already satisfied: fonttools>=4.22.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (4.44.0)\r\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (1.4.5)\r\n", "Requirement already satisfied: pyparsing>=2.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (3.1.1)\r\n", "Requirement already satisfied: importlib-resources>=3.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tf-models-official) (6.1.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pyasn1>=0.1.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from oauth2client->tf-models-official) (0.5.0)\r\n", "Requirement already satisfied: pyasn1-modules>=0.0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from oauth2client->tf-models-official) (0.3.0)\r\n", "Requirement already satisfied: rsa>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from oauth2client->tf-models-official) (4.9)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests->remotezip) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests->remotezip) (3.4)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting portalocker (from sacrebleu->tf-models-official)\r\n", " Downloading portalocker-2.8.2-py3-none-any.whl.metadata (8.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting regex (from sacrebleu->tf-models-official)\r\n", " Downloading regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting colorama (from sacrebleu->tf-models-official)\r\n", " Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting lxml (from sacrebleu->tf-models-official)\r\n", " Downloading lxml-4.9.3-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (3.8 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scikit-learn>=0.21.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from seqeval->tf-models-official) (1.3.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: array-record in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets->tf-models-official) (0.5.0)\r\n", "Requirement already satisfied: click in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets->tf-models-official) (8.1.7)\r\n", "Requirement already satisfied: etils>=0.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets->tf-models-official) (1.5.2)\r\n", "Requirement already satisfied: promise in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets->tf-models-official) (2.3)\r\n", "Requirement already satisfied: tensorflow-metadata in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets->tf-models-official) (1.14.0)\r\n", "Requirement already satisfied: toml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets->tf-models-official) (0.10.2)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow~=2.14.0->tf-models-official) (0.41.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets->tf-models-official) (2023.10.0)\r\n", "Requirement already satisfied: zipp in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets->tf-models-official) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: googleapis-common-protos<2.0.dev0,>=1.56.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0.dev0,>=1.31.5->google-api-python-client>=1.6.7->tf-models-official) (1.61.0)\r\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3.0.0.dev0,>=1.19.0->google-api-python-client>=1.6.7->tf-models-official) (5.3.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: joblib>=1.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from scikit-learn>=0.21.3->seqeval->tf-models-official) (1.3.2)\r\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from scikit-learn>=0.21.3->seqeval->tf-models-official) (3.2.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting google-auth-oauthlib<1.1,>=0.5 (from tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official)\r\n", " Downloading google_auth_oauthlib-1.0.0-py2.py3-none-any.whl (18 kB)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (3.5.1)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: webencodings in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from bleach->kaggle>=1.3.9->tf-models-official) (0.5.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting text-unidecode>=1.3 (from python-slugify->kaggle>=1.3.9->tf-models-official)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading text_unidecode-1.3-py2.py3-none-any.whl (78 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (1.3.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (6.8.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.15,>=2.14->tensorflow~=2.14.0->tf-models-official) (3.2.2)\r\n", "Downloading tf_models_official-2.14.2-py2.py3-none-any.whl (2.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading google_api_python_client-2.107.0-py2.py3-none-any.whl (12.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tensorflow-2.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (489.8 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tensorflow_model_optimization-0.7.5-py2.py3-none-any.whl (241 kB)\r\n", "Downloading tensorflow_text-2.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading Cython-3.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading immutabledict-3.0.0-py3-none-any.whl (4.0 kB)\r\n", "Downloading pycocotools-2.0.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)\r\n", "Downloading sacrebleu-2.3.2-py3-none-any.whl (119 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading google_api_core-2.12.0-py3-none-any.whl (121 kB)\r\n", "Downloading google_auth_httplib2-0.1.1-py2.py3-none-any.whl (9.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading keras-2.14.0-py3-none-any.whl (1.7 MB)\r\n", "Downloading tensorboard-2.14.1-py3-none-any.whl (5.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tensorflow_estimator-2.14.0-py2.py3-none-any.whl (440 kB)\r\n", "Downloading lxml-4.9.3-cp39-cp39-manylinux_2_28_x86_64.whl (8.0 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)\r\n", "Downloading regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (773 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Building wheels for collected packages: remotezip, kaggle, seqeval\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Building wheel for remotezip (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25h Created wheel for remotezip: filename=remotezip-0.12.1-py3-none-any.whl size=7934 sha256=71040c701024ca50e6820b29587d9487c50f07b5dabe4c335f9759b36f1543a0\r\n", " Stored in directory: /home/kbuilder/.cache/pip/wheels/60/74/6c/b12b4c8fb4b7ab08f495ce17e88f1e98835268af7a8ad5588f\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Building wheel for kaggle (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25h Created wheel for kaggle: filename=kaggle-1.5.16-py3-none-any.whl size=110683 sha256=6d41fc76e591c7d0525ee18ae40a68e35671c21f903befa4f8fd0d0887ff0ca3\r\n", " Stored in directory: /home/kbuilder/.cache/pip/wheels/d2/ed/a5/da3a0cfb13373d1ace41cafa4f2467d858c55c52473ba72799\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Building wheel for seqeval (setup.py) ... \u001b[?25l-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \b\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b \bdone\r\n", "\u001b[?25h Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=6a82fd451c05eee48f90f2958e8725be600286a1c36757efe66d90b8f57d6615\r\n", " Stored in directory: /home/kbuilder/.cache/pip/wheels/e2/a5/92/2c80d1928733611c2747a9820e1324a6835524d9411510c142\r\n", "Successfully built remotezip kaggle seqeval\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: text-unidecode, sentencepiece, py-cpuinfo, gin-config, uritemplate, tf-slim, tensorflow-model-optimization, tensorflow-estimator, tabulate, regex, python-slugify, portalocker, opencv-python-headless, opencv-python, lxml, keras, immutabledict, httplib2, Cython, colorama, sacrebleu, remotezip, oauth2client, kaggle, seqeval, pycocotools, google-auth-oauthlib, google-auth-httplib2, google-api-core, tensorboard, google-api-python-client, tensorflow, tensorflow-text, tf-models-official\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: tensorflow-estimator\r\n", " Found existing installation: tensorflow-estimator 2.15.0\r\n", " Uninstalling tensorflow-estimator-2.15.0:\r\n", " Successfully uninstalled tensorflow-estimator-2.15.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: keras\r\n", " Found existing installation: keras 2.15.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling keras-2.15.0:\r\n", " Successfully uninstalled keras-2.15.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: google-auth-oauthlib\r\n", " Found existing installation: google-auth-oauthlib 1.1.0\r\n", " Uninstalling google-auth-oauthlib-1.1.0:\r\n", " Successfully uninstalled google-auth-oauthlib-1.1.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: tensorboard\r\n", " Found existing installation: tensorboard 2.15.1\r\n", " Uninstalling tensorboard-2.15.1:\r\n", " Successfully uninstalled tensorboard-2.15.1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: tensorflow\r\n", " Found existing installation: tensorflow 2.15.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Uninstalling tensorflow-2.15.0rc1:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled tensorflow-2.15.0rc1\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed Cython-3.0.5 colorama-0.4.6 gin-config-0.5.0 google-api-core-2.12.0 google-api-python-client-2.107.0 google-auth-httplib2-0.1.1 google-auth-oauthlib-1.0.0 httplib2-0.22.0 immutabledict-3.0.0 kaggle-1.5.16 keras-2.14.0 lxml-4.9.3 oauth2client-4.1.3 opencv-python-4.5.2.52 opencv-python-headless-4.5.2.52 portalocker-2.8.2 py-cpuinfo-9.0.0 pycocotools-2.0.7 python-slugify-8.0.1 regex-2023.10.3 remotezip-0.12.1 sacrebleu-2.3.2 sentencepiece-0.1.99 seqeval-1.2.2 tabulate-0.9.0 tensorboard-2.14.1 tensorflow-2.14.0 tensorflow-estimator-2.14.0 tensorflow-model-optimization-0.7.5 tensorflow-text-2.14.0 text-unidecode-1.3 tf-models-official-2.14.2 tf-slim-1.1.0 uritemplate-4.1.1\r\n" ] } ], "source": [ "!pip install remotezip tqdm opencv-python==4.5.2.52 opencv-python-headless==4.5.2.52 tf-models-official" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:29:24.819696Z", "iopub.status.busy": "2023-11-07T17:29:24.819421Z", "iopub.status.idle": "2023-11-07T17:29:28.936954Z", "shell.execute_reply": "2023-11-07T17:29:28.936207Z" }, "id": "QImPsudoK9JI" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 17:29:26.513969: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 17:29:26.514019: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 17:29:26.514050: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tqdm\n", "import random\n", "import pathlib\n", "import itertools\n", "import collections\n", "\n", "import cv2\n", "import numpy as np\n", "import remotezip as rz\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "import keras\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n", "\n", "# Import the MoViNet model from TensorFlow Models (tf-models-official) for the MoViNet model\n", "from official.projects.movinet.modeling import movinet\n", "from official.projects.movinet.modeling import movinet_model" ] }, { "cell_type": "markdown", "metadata": { "id": "2w3H4dfOPfnm" }, "source": [ "## 加载数据\n", "\n", "下面的隐藏单元定义了从 UCF-101 数据集下载数据切片并将其加载到 `tf.data.Dataset` 中的函数。[加载视频数据教程](https://tensorflow.google.cn/tutorials/load_data/video)详细地介绍了此代码。\n", "\n", "隐藏块末尾的 `FrameGenerator` 类是这里最重要的实用工具。它会创建一个可以将数据馈送到 TensorFlow 数据流水线中的可迭代对象。具体来说,此类包含一个可加载视频帧及其编码标签的 Python 生成器。生成器 (`__call__`) 函数可产生由 `frames_from_video_file` 生成的帧数组以及与帧集关联的标签的独热编码向量。\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T17:29:28.941750Z", "iopub.status.busy": "2023-11-07T17:29:28.941348Z", "iopub.status.idle": "2023-11-07T17:29:28.959725Z", "shell.execute_reply": "2023-11-07T17:29:28.959155Z" }, "id": "fwEhJ13_PSy6" }, "outputs": [], "source": [ "#@title \n", "\n", "def list_files_per_class(zip_url):\n", " \"\"\"\n", " List the files in each class of the dataset given the zip URL.\n", "\n", " Args:\n", " zip_url: URL from which the files can be unzipped. \n", "\n", " Return:\n", " files: List of files in each of the classes.\n", " \"\"\"\n", " files = []\n", " with rz.RemoteZip(URL) as zip:\n", " for zip_info in zip.infolist():\n", " files.append(zip_info.filename)\n", " return files\n", "\n", "def get_class(fname):\n", " \"\"\"\n", " Retrieve the name of the class given a filename.\n", "\n", " Args:\n", " fname: Name of the file in the UCF101 dataset.\n", "\n", " Return:\n", " Class that the file belongs to.\n", " \"\"\"\n", " return fname.split('_')[-3]\n", "\n", "def get_files_per_class(files):\n", " \"\"\"\n", " Retrieve the files that belong to each class. \n", "\n", " Args:\n", " files: List of files in the dataset.\n", "\n", " Return:\n", " Dictionary of class names (key) and files (values).\n", " \"\"\"\n", " files_for_class = collections.defaultdict(list)\n", " for fname in files:\n", " class_name = get_class(fname)\n", " files_for_class[class_name].append(fname)\n", " return files_for_class\n", "\n", "def download_from_zip(zip_url, to_dir, file_names):\n", " \"\"\"\n", " Download the contents of the zip file from the zip URL.\n", "\n", " Args:\n", " zip_url: Zip URL containing data.\n", " to_dir: Directory to download data to.\n", " file_names: Names of files to download.\n", " \"\"\"\n", " with rz.RemoteZip(zip_url) as zip:\n", " for fn in tqdm.tqdm(file_names):\n", " class_name = get_class(fn)\n", " zip.extract(fn, str(to_dir / class_name))\n", " unzipped_file = to_dir / class_name / fn\n", "\n", " fn = pathlib.Path(fn).parts[-1]\n", " output_file = to_dir / class_name / fn\n", " unzipped_file.rename(output_file,)\n", "\n", "def split_class_lists(files_for_class, count):\n", " \"\"\"\n", " Returns the list of files belonging to a subset of data as well as the remainder of\n", " files that need to be downloaded.\n", "\n", " Args:\n", " files_for_class: Files belonging to a particular class of data.\n", " count: Number of files to download.\n", "\n", " Return:\n", " split_files: Files belonging to the subset of data.\n", " remainder: Dictionary of the remainder of files that need to be downloaded.\n", " \"\"\"\n", " split_files = []\n", " remainder = {}\n", " for cls in files_for_class:\n", " split_files.extend(files_for_class[cls][:count])\n", " remainder[cls] = files_for_class[cls][count:]\n", " return split_files, remainder\n", "\n", "def download_ufc_101_subset(zip_url, num_classes, splits, download_dir):\n", " \"\"\"\n", " Download a subset of the UFC101 dataset and split them into various parts, such as\n", " training, validation, and test. \n", "\n", " Args:\n", " zip_url: Zip URL containing data.\n", " num_classes: Number of labels.\n", " splits: Dictionary specifying the training, validation, test, etc. (key) division of data \n", " (value is number of files per split).\n", " download_dir: Directory to download data to.\n", "\n", " Return:\n", " dir: Posix path of the resulting directories containing the splits of data.\n", " \"\"\"\n", " files = list_files_per_class(zip_url)\n", " for f in files:\n", " tokens = f.split('/')\n", " if len(tokens) <= 2:\n", " files.remove(f) # Remove that item from the list if it does not have a filename\n", "\n", " files_for_class = get_files_per_class(files)\n", "\n", " classes = list(files_for_class.keys())[:num_classes]\n", "\n", " for cls in classes:\n", " new_files_for_class = files_for_class[cls]\n", " random.shuffle(new_files_for_class)\n", " files_for_class[cls] = new_files_for_class\n", "\n", " # Only use the number of classes you want in the dictionary\n", " files_for_class = {x: files_for_class[x] for x in list(files_for_class)[:num_classes]}\n", "\n", " dirs = {}\n", " for split_name, split_count in splits.items():\n", " print(split_name, \":\")\n", " split_dir = download_dir / split_name\n", " split_files, files_for_class = split_class_lists(files_for_class, split_count)\n", " download_from_zip(zip_url, split_dir, split_files)\n", " dirs[split_name] = split_dir\n", "\n", " return dirs\n", "\n", "def format_frames(frame, output_size):\n", " \"\"\"\n", " Pad and resize an image from a video.\n", "\n", " Args:\n", " frame: Image that needs to resized and padded. \n", " output_size: Pixel size of the output frame image.\n", "\n", " Return:\n", " Formatted frame with padding of specified output size.\n", " \"\"\"\n", " frame = tf.image.convert_image_dtype(frame, tf.float32)\n", " frame = tf.image.resize_with_pad(frame, *output_size)\n", " return frame\n", "\n", "def frames_from_video_file(video_path, n_frames, output_size = (224,224), frame_step = 15):\n", " \"\"\"\n", " Creates frames from each video file present for each category.\n", "\n", " Args:\n", " video_path: File path to the video.\n", " n_frames: Number of frames to be created per video file.\n", " output_size: Pixel size of the output frame image.\n", "\n", " Return:\n", " An NumPy array of frames in the shape of (n_frames, height, width, channels).\n", " \"\"\"\n", " # Read each video frame by frame\n", " result = []\n", " src = cv2.VideoCapture(str(video_path)) \n", "\n", " video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)\n", "\n", " need_length = 1 + (n_frames - 1) * frame_step\n", "\n", " if need_length > video_length:\n", " start = 0\n", " else:\n", " max_start = video_length - need_length\n", " start = random.randint(0, max_start + 1)\n", "\n", " src.set(cv2.CAP_PROP_POS_FRAMES, start)\n", " # ret is a boolean indicating whether read was successful, frame is the image itself\n", " ret, frame = src.read()\n", " result.append(format_frames(frame, output_size))\n", "\n", " for _ in range(n_frames - 1):\n", " for _ in range(frame_step):\n", " ret, frame = src.read()\n", " if ret:\n", " frame = format_frames(frame, output_size)\n", " result.append(frame)\n", " else:\n", " result.append(np.zeros_like(result[0]))\n", " src.release()\n", " result = np.array(result)[..., [2, 1, 0]]\n", "\n", " return result\n", "\n", "class FrameGenerator:\n", " def __init__(self, path, n_frames, training = False):\n", " \"\"\" Returns a set of frames with their associated label. \n", "\n", " Args:\n", " path: Video file paths.\n", " n_frames: Number of frames. \n", " training: Boolean to determine if training dataset is being created.\n", " \"\"\"\n", " self.path = path\n", " self.n_frames = n_frames\n", " self.training = training\n", " self.class_names = sorted(set(p.name for p in self.path.iterdir() if p.is_dir()))\n", " self.class_ids_for_name = dict((name, idx) for idx, name in enumerate(self.class_names))\n", "\n", " def get_files_and_class_names(self):\n", " video_paths = list(self.path.glob('*/*.avi'))\n", " classes = [p.parent.name for p in video_paths] \n", " return video_paths, classes\n", "\n", " def __call__(self):\n", " video_paths, classes = self.get_files_and_class_names()\n", "\n", " pairs = list(zip(video_paths, classes))\n", "\n", " if self.training:\n", " random.shuffle(pairs)\n", "\n", " for path, name in pairs:\n", " video_frames = frames_from_video_file(path, self.n_frames) \n", " label = self.class_ids_for_name[name] # Encode labels\n", " yield video_frames, label" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:29:28.963072Z", "iopub.status.busy": "2023-11-07T17:29:28.962497Z", "iopub.status.idle": "2023-11-07T17:30:22.525001Z", "shell.execute_reply": "2023-11-07T17:30:22.524225Z" }, "id": "vDHrNLZkPSR9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train :\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/300 [00:00" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_id = 'a0'\n", "resolution = 224\n", "\n", "tf.keras.backend.clear_session()\n", "\n", "backbone = movinet.Movinet(model_id=model_id)\n", "backbone.trainable = False\n", "\n", "# Set num_classes=600 to load the pre-trained weights from the original model\n", "model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)\n", "model.build([None, None, None, None, 3])\n", "\n", "# Load pre-trained weights\n", "!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz -O movinet_a0_base.tar.gz -q\n", "!tar -xvf movinet_a0_base.tar.gz\n", "\n", "checkpoint_dir = f'movinet_{model_id}_base'\n", "checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)\n", "checkpoint = tf.train.Checkpoint(model=model)\n", "status = checkpoint.restore(checkpoint_path)\n", "status.assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "BW23HVNtCXff" }, "source": [ "要构建分类器,请创建一个采用主干和数据集中的类数的函数。`build_classifier` 函数将采用主干和数据集中的类数来构建分类器。在这种情况下,新分类器将采用 `num_classes` 个输出(UCF101 的此子集有 10 个类)。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:30:36.249610Z", "iopub.status.busy": "2023-11-07T17:30:36.249312Z", "iopub.status.idle": "2023-11-07T17:30:36.253653Z", "shell.execute_reply": "2023-11-07T17:30:36.252982Z" }, "id": "6cfAelbU5Gi3" }, "outputs": [], "source": [ "def build_classifier(batch_size, num_frames, resolution, backbone, num_classes):\n", " \"\"\"Builds a classifier on top of a backbone model.\"\"\"\n", " model = movinet_model.MovinetClassifier(\n", " backbone=backbone,\n", " num_classes=num_classes)\n", " model.build([batch_size, num_frames, resolution, resolution, 3])\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:30:36.256915Z", "iopub.status.busy": "2023-11-07T17:30:36.256452Z", "iopub.status.idle": "2023-11-07T17:30:37.999139Z", "shell.execute_reply": "2023-11-07T17:30:37.998329Z" }, "id": "9HWSk-u7oPUZ" }, "outputs": [], "source": [ "model = build_classifier(batch_size, num_frames, resolution, backbone, 10)" ] }, { "cell_type": "markdown", "metadata": { "id": "JhbX7qdTN8lc" }, "source": [ "对于本教程,选择 `tf.keras.optimizers.Adam` 优化器和 `tf.keras.losses.SparseCategoricalCrossentropy` 损失函数。使用指标参数查看每个步骤中模型性能的准确率。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:30:38.003340Z", "iopub.status.busy": "2023-11-07T17:30:38.002834Z", "iopub.status.idle": "2023-11-07T17:30:38.022158Z", "shell.execute_reply": "2023-11-07T17:30:38.021570Z" }, "id": "dVqBLrn1tBsd" }, "outputs": [], "source": [ "num_epochs = 2\n", "\n", "loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "\n", "optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)\n", "\n", "model.compile(loss=loss_obj, optimizer=optimizer, metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "VflEr_t6CuQu" }, "source": [ "训练模型。两个周期后,观察训练集和测试集的低损失和高准确率。 " ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:30:38.025944Z", "iopub.status.busy": "2023-11-07T17:30:38.025264Z", "iopub.status.idle": "2023-11-07T17:56:26.980044Z", "shell.execute_reply": "2023-11-07T17:56:26.979216Z" }, "id": "9ZeiYzI0tqQG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 17:30:49.251128: E ./tensorflow/compiler/xla/stream_executor/stream_executor_internal.h:124] SetPriority unimplemented for this stream.\n", "2023-11-07 17:30:49.387365: E ./tensorflow/compiler/xla/stream_executor/stream_executor_internal.h:124] SetPriority unimplemented for this stream.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/Unknown - 24s 24s/step - loss: 2.2871 - accuracy: 0.1250" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/Unknown - 37s 12s/step - loss: 2.2236 - accuracy: 0.3125" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/Unknown - 49s 12s/step - loss: 2.1787 - accuracy: 0.3750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/Unknown - 62s 12s/step - loss: 2.1110 - accuracy: 0.4062" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/Unknown - 74s 12s/step - loss: 2.0345 - accuracy: 0.4000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/Unknown - 86s 12s/step - loss: 1.9710 - accuracy: 0.4375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/Unknown - 99s 12s/step - loss: 1.9378 - accuracy: 0.4643" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/Unknown - 111s 12s/step - loss: 1.9194 - accuracy: 0.4688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/Unknown - 123s 12s/step - loss: 1.8709 - accuracy: 0.5000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 10/Unknown - 136s 12s/step - loss: 1.8010 - accuracy: 0.5375" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 11/Unknown - 148s 12s/step - loss: 1.7230 - accuracy: 0.5568" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 12/Unknown - 160s 12s/step - loss: 1.6488 - accuracy: 0.5833" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/Unknown - 173s 12s/step - loss: 1.5995 - accuracy: 0.6154" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 14/Unknown - 185s 12s/step - loss: 1.5441 - accuracy: 0.6429" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 15/Unknown - 197s 12s/step - loss: 1.5073 - accuracy: 0.6583" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 16/Unknown - 210s 12s/step - loss: 1.4792 - accuracy: 0.6719" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/Unknown - 222s 12s/step - loss: 1.4502 - accuracy: 0.6838" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 18/Unknown - 234s 12s/step - loss: 1.4147 - accuracy: 0.6944" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/Unknown - 247s 12s/step - loss: 1.3872 - accuracy: 0.6974" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 20/Unknown - 259s 12s/step - loss: 1.3603 - accuracy: 0.6938" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 21/Unknown - 271s 12s/step - loss: 1.3240 - accuracy: 0.7024" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 22/Unknown - 284s 12s/step - loss: 1.2887 - accuracy: 0.7159" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 23/Unknown - 296s 12s/step - loss: 1.2520 - accuracy: 0.7283" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 24/Unknown - 309s 12s/step - loss: 1.2208 - accuracy: 0.7344" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/Unknown - 321s 12s/step - loss: 1.1851 - accuracy: 0.7400" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 26/Unknown - 334s 12s/step - loss: 1.1598 - accuracy: 0.7404" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 27/Unknown - 346s 12s/step - loss: 1.1294 - accuracy: 0.7500" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 28/Unknown - 359s 12s/step - loss: 1.1045 - accuracy: 0.7545" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 29/Unknown - 371s 12s/step - loss: 1.0767 - accuracy: 0.7629" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 30/Unknown - 383s 12s/step - loss: 1.0498 - accuracy: 0.7708" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 31/Unknown - 396s 12s/step - loss: 1.0223 - accuracy: 0.7782" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 32/Unknown - 408s 12s/step - loss: 0.9927 - accuracy: 0.7852" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 33/Unknown - 420s 12s/step - loss: 0.9854 - accuracy: 0.7803" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 34/Unknown - 433s 12s/step - loss: 0.9604 - accuracy: 0.7868" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 35/Unknown - 445s 12s/step - loss: 0.9376 - accuracy: 0.7929" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 36/Unknown - 457s 12s/step - loss: 0.9195 - accuracy: 0.7951" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 37/Unknown - 470s 12s/step - loss: 0.8998 - accuracy: 0.8007" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 38/Unknown - 479s 12s/step - loss: 0.8940 - accuracy: 0.8000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "38/38 [==============================] - 785s 21s/step - loss: 0.8940 - accuracy: 0.8000 - val_loss: 0.1508 - val_accuracy: 0.9900\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/38 [..............................] - ETA: 8:07 - loss: 0.0802 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/38 [>.............................] - ETA: 7:24 - loss: 0.0697 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/38 [=>............................] - ETA: 7:11 - loss: 0.0536 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/38 [==>...........................] - ETA: 6:59 - loss: 0.1014 - accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/38 [==>...........................] - ETA: 6:45 - loss: 0.1108 - accuracy: 0.9750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/38 [===>..........................] - ETA: 6:33 - loss: 0.0981 - accuracy: 0.9792" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/38 [====>.........................] - ETA: 6:20 - loss: 0.1141 - accuracy: 0.9643" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/38 [=====>........................] - ETA: 6:08 - loss: 0.1098 - accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/38 [======>.......................] - ETA: 5:55 - loss: 0.1111 - accuracy: 0.9722" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/38 [======>.......................] - ETA: 5:43 - loss: 0.1072 - accuracy: 0.9750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "11/38 [=======>......................] - ETA: 5:31 - loss: 0.1037 - accuracy: 0.9773" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12/38 [========>.....................] - ETA: 5:19 - loss: 0.1087 - accuracy: 0.9792" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/38 [=========>....................] - ETA: 5:07 - loss: 0.1055 - accuracy: 0.9808" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "14/38 [==========>...................] - ETA: 4:54 - loss: 0.1035 - accuracy: 0.9821" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "15/38 [==========>...................] - ETA: 4:42 - loss: 0.1061 - accuracy: 0.9833" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16/38 [===========>..................] - ETA: 4:30 - loss: 0.1110 - accuracy: 0.9766" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "17/38 [============>.................] - ETA: 4:17 - loss: 0.1326 - accuracy: 0.9706" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "18/38 [=============>................] - ETA: 4:05 - loss: 0.1337 - accuracy: 0.9653" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "19/38 [==============>...............] - ETA: 3:53 - loss: 0.1407 - accuracy: 0.9605" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20/38 [==============>...............] - ETA: 3:40 - loss: 0.1366 - accuracy: 0.9625" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "21/38 [===============>..............] - ETA: 3:28 - loss: 0.1308 - accuracy: 0.9643" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "22/38 [================>.............] - ETA: 3:16 - loss: 0.1312 - accuracy: 0.9602" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23/38 [=================>............] - ETA: 3:04 - loss: 0.1285 - accuracy: 0.9620" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "24/38 [=================>............] - ETA: 2:51 - loss: 0.1241 - accuracy: 0.9635" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/38 [==================>...........] - ETA: 2:39 - loss: 0.1220 - accuracy: 0.9650" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "26/38 [===================>..........] - ETA: 2:27 - loss: 0.1228 - accuracy: 0.9615" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27/38 [====================>.........] - ETA: 2:14 - loss: 0.1196 - accuracy: 0.9630" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "28/38 [=====================>........] - ETA: 2:02 - loss: 0.1165 - accuracy: 0.9643" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "29/38 [=====================>........] - ETA: 1:50 - loss: 0.1161 - accuracy: 0.9655" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "30/38 [======================>.......] - ETA: 1:38 - loss: 0.1141 - accuracy: 0.9667" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "31/38 [=======================>......] - ETA: 1:25 - loss: 0.1117 - accuracy: 0.9677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32/38 [========================>.....] - ETA: 1:13 - loss: 0.1122 - accuracy: 0.9688" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "33/38 [=========================>....] - ETA: 1:01 - loss: 0.1102 - accuracy: 0.9697" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "34/38 [=========================>....] - ETA: 49s - loss: 0.1079 - accuracy: 0.9706 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "35/38 [==========================>...] - ETA: 36s - loss: 0.1059 - accuracy: 0.9714" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "36/38 [===========================>..] - ETA: 24s - loss: 0.1036 - accuracy: 0.9722" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "37/38 [============================>.] - ETA: 12s - loss: 0.1024 - accuracy: 0.9730" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "38/38 [==============================] - ETA: 0s - loss: 0.1018 - accuracy: 0.9733 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "38/38 [==============================] - 764s 20s/step - loss: 0.1018 - accuracy: 0.9733 - val_loss: 0.0713 - val_accuracy: 0.9850\n" ] } ], "source": [ "results = model.fit(train_ds,\n", " validation_data=test_ds,\n", " epochs=num_epochs,\n", " validation_freq=1,\n", " verbose=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "KkLl2zF8G9W0" }, "source": [ "## 评估模型\n", "\n", "该模型在训练数据集上取得了很高的准确率。接下来,使用 Keras `Model.evaluate` 在测试集上对其进行评估。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T17:56:26.983712Z", "iopub.status.busy": "2023-11-07T17:56:26.983454Z", "iopub.status.idle": "2023-11-07T18:01:28.033252Z", "shell.execute_reply": "2023-11-07T18:01:28.032504Z" }, "id": "NqgbzOiKuxxT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/Unknown - 13s 13s/step - loss: 0.0174 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/Unknown - 25s 12s/step - loss: 0.0258 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/Unknown - 37s 12s/step - loss: 0.0198 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/Unknown - 49s 12s/step - loss: 0.0155 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/Unknown - 61s 12s/step - loss: 0.0136 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/Unknown - 73s 12s/step - loss: 0.0118 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/Unknown - 85s 12s/step - loss: 0.0111 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/Unknown - 97s 12s/step - loss: 0.0211 - accuracy: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/Unknown - 109s 12s/step - loss: 0.0578 - accuracy: 0.9861" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 10/Unknown - 121s 12s/step - loss: 0.0711 - accuracy: 0.9875" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 11/Unknown - 133s 12s/step - loss: 0.0697 - accuracy: 0.9886" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 12/Unknown - 145s 12s/step - loss: 0.0887 - accuracy: 0.9792" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/Unknown - 157s 12s/step - loss: 0.0961 - accuracy: 0.9808" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 14/Unknown - 169s 12s/step - loss: 0.0960 - accuracy: 0.9821" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 15/Unknown - 181s 12s/step - loss: 0.0946 - accuracy: 0.9833" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 16/Unknown - 193s 12s/step - loss: 0.0909 - accuracy: 0.9844" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/Unknown - 205s 12s/step - loss: 0.0877 - accuracy: 0.9853" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 18/Unknown - 217s 12s/step - loss: 0.0843 - accuracy: 0.9861" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/Unknown - 229s 12s/step - loss: 0.0889 - accuracy: 0.9803" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 20/Unknown - 241s 12s/step - loss: 0.0861 - accuracy: 0.9812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 21/Unknown - 253s 12s/step - loss: 0.0865 - accuracy: 0.9821" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 22/Unknown - 265s 12s/step - loss: 0.0834 - accuracy: 0.9830" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 23/Unknown - 277s 12s/step - loss: 0.0863 - accuracy: 0.9783" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 24/Unknown - 289s 12s/step - loss: 0.0831 - accuracy: 0.9792" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/Unknown - 301s 12s/step - loss: 0.0820 - accuracy: 0.9800" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/25 [==============================] - 301s 12s/step - loss: 0.0820 - accuracy: 0.9800\n" ] }, { "data": { "text/plain": [ "{'loss': 0.08196067810058594, 'accuracy': 0.9800000190734863}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(test_ds, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "OkFst2gsHBwD" }, "source": [ "要进一步呈现模型性能,请使用[混淆矩阵](https://tensorflow.google.cn/api_docs/python/tf/math/confusion_matrix)。混淆矩阵允许评估分类模型的性能,而不仅仅是准确率。为了构建此多类分类问题的混淆矩阵,需要获得测试集中的实际值和预测值。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T18:01:28.037079Z", "iopub.status.busy": "2023-11-07T18:01:28.036459Z", "iopub.status.idle": "2023-11-07T18:01:28.041530Z", "shell.execute_reply": "2023-11-07T18:01:28.040792Z" }, "id": "hssSdW9XHF_j" }, "outputs": [], "source": [ "def get_actual_predicted_labels(dataset):\n", " \"\"\"\n", " Create a list of actual ground truth values and the predictions from the model.\n", "\n", " Args:\n", " dataset: An iterable data structure, such as a TensorFlow Dataset, with features and labels.\n", "\n", " Return:\n", " Ground truth and predicted values for a particular dataset.\n", " \"\"\"\n", " actual = [labels for _, labels in dataset.unbatch()]\n", " predicted = model.predict(dataset)\n", "\n", " actual = tf.stack(actual, axis=0)\n", " predicted = tf.concat(predicted, axis=0)\n", " predicted = tf.argmax(predicted, axis=1)\n", "\n", " return actual, predicted" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T18:01:28.044979Z", "iopub.status.busy": "2023-11-07T18:01:28.044402Z", "iopub.status.idle": "2023-11-07T18:01:28.049296Z", "shell.execute_reply": "2023-11-07T18:01:28.048691Z" }, "id": "2TmTue6THGWO" }, "outputs": [], "source": [ "def plot_confusion_matrix(actual, predicted, labels, ds_type):\n", " cm = tf.math.confusion_matrix(actual, predicted)\n", " ax = sns.heatmap(cm, annot=True, fmt='g')\n", " sns.set(rc={'figure.figsize':(12, 12)})\n", " sns.set(font_scale=1.4)\n", " ax.set_title('Confusion matrix of action recognition for ' + ds_type)\n", " ax.set_xlabel('Predicted Action')\n", " ax.set_ylabel('Actual Action')\n", " plt.xticks(rotation=90)\n", " plt.yticks(rotation=0)\n", " ax.xaxis.set_ticklabels(labels)\n", " ax.yaxis.set_ticklabels(labels)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T18:01:28.052527Z", "iopub.status.busy": "2023-11-07T18:01:28.051929Z", "iopub.status.idle": "2023-11-07T18:01:28.055688Z", "shell.execute_reply": "2023-11-07T18:01:28.055069Z" }, "id": "4RK1A1C1HH6V" }, "outputs": [], "source": [ "fg = FrameGenerator(subset_paths['train'], num_frames, training = True)\n", "label_names = list(fg.class_ids_for_name.keys())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T18:01:28.058905Z", "iopub.status.busy": "2023-11-07T18:01:28.058311Z", "iopub.status.idle": "2023-11-07T18:06:48.880105Z", "shell.execute_reply": "2023-11-07T18:06:48.879332Z" }, "id": "r4AFi2e5HKEO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/Unknown - 16s 16s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/Unknown - 28s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/Unknown - 40s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/Unknown - 52s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/Unknown - 64s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/Unknown - 76s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/Unknown - 88s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/Unknown - 100s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/Unknown - 112s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 10/Unknown - 124s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 11/Unknown - 136s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 12/Unknown - 148s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/Unknown - 160s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 14/Unknown - 172s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 15/Unknown - 184s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 16/Unknown - 196s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/Unknown - 208s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 18/Unknown - 220s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/Unknown - 233s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 20/Unknown - 245s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 21/Unknown - 257s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 22/Unknown - 269s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 23/Unknown - 280s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 24/Unknown - 293s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 25/Unknown - 305s 12s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25/25 [==============================] - 305s 12s/step\n" ] } ], "source": [ "actual, predicted = get_actual_predicted_labels(test_ds)\n", "plot_confusion_matrix(actual, predicted, label_names, 'test')" ] }, { "cell_type": "markdown", "metadata": { "id": "ddQG9sYxa1Ib" }, "source": [ "## 后续步骤\n", "\n", "现在,您已经对 MoViNet 模型以及如何利用各种 TensorFlow API(例如,用于迁移学习)有了一定的了解,请尝试将本教程中的代码用于您自己的数据集。数据不必限于视频数据。MRI 扫描等体数据也可与 3D CNN 一起使用。[用于精神分裂症和控制分类的基于脑 MRI 的 3D 卷积神经网络](https://arxiv.org/pdf/2003.08818.pdf)中提到的 NUSDAT 和 IMH 数据集可能是 MRI 数据的两个此类来源。\n", "\n", "特别是,使用本教程和其他视频数据与分类教程中使用的 `FrameGenerator` 类可以帮助您将数据加载到模型中。\n", "\n", "要详细了解如何在 TensorFlow 中处理视频数据,请查看以下教程:\n", "\n", "- [加载视频数据](https://tensorflow.google.cn/tutorials/load_data/video)\n", "- [构建用于视频分类的 3D CNN 模型](https://tensorflow.google.cn/tutorials/video/video_classification)\n", "- [用于流式动作识别的 MoViNet](https://tensorflow.google.cn/hub/tutorials/movinet)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "transfer_learning_with_movinet.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }