{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2023 The TensorFlow Datasets Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "M7vSdG6sAIQn" }, "source": [ "# 适用于 Jax 和 PyTorch 的 TFDS" ] }, { "cell_type": "markdown", "metadata": { "id": "fwc5GKHBASdc" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", " 在 Google Colab 中运行\n", " 在 GitHub 上查看源代码\n", " 下载笔记本\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "9ee074e4" }, "source": [ "TFDS 一直都独立于框架。例如,您可以轻松地加载 [NumPy 格式](https://tensorflow.google.cn/datasets/api_docs/python/tfds/as_numpy)的数据集以在 Jax 和 PyTorch 中使用。\n", "\n", "TensorFlow 及其数据加载解决方案 ([`tf.data`](https://tensorflow.google.cn/guide/data)) 按照设计是我们 API 中的一等公民。\n", "\n", "我们扩展了 TFDS 以支持仅使用 NumPy 而无需 TensorFlow 的数据加载。这对于在 Jax 和 PyTorch 等机器学习框架中使用非常方便。事实上,对于后者的用户来说,TensorFlow:\n", "\n", "- 会保留 GPU/TPU 内存;\n", "- 会在 CI/CD 中增加构建时间;\n", "- 在运行时需要花费时间导入。\n", "\n", "TensorFlow 不再是读取数据集的依赖项。\n", "\n", "机器学习流水线需要一个数据加载器来加载样本,将其解码并呈现给模型。数据加载器使用“源/采样器/加载器”范式:\n", "\n", "```\n", " TFDS dataset ┌────────────────┐\n", " on disk │ │\n", " ┌──────────►│ Data │\n", "|..|... │ | │ source ├─┐\n", "├──┼────┴─────┤ │ │ │\n", "│12│image12 │ └────────────────┘ │ ┌────────────────┐\n", "├──┼──────────┤ │ │ │\n", "│13│image13 │ ├───►│ Data ├───► ML pipeline\n", "├──┼──────────┤ │ │ loader │\n", "│14│image14 │ ┌────────────────┐ │ │ │\n", "├──┼──────────┤ │ │ │ └────────────────┘\n", "|..|... | │ Index ├─┘\n", " │ sampler │\n", " │ │\n", " └────────────────┘\n", "```\n", "\n", "- 数据源负责实时访问和解码来自 TFDS 数据集的样本。\n", "- 索引采样器负责确定记录处理的顺序。在读取任何记录之前,实现全局转换(例如全局重排、分片、重复多个周期)非常重要。\n", "- 数据加载器通过利用数据源和索引采样器来编排加载。它可以实现性能优化(例如,预提取、多进程或多线程)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "UaWdLA3fQDK2" }, "source": [ "## 速览\n", "\n", "`tfds.data_source` 是一个用于创建数据源的 API:\n", "\n", "1. 用于纯 Python 流水线的快速原型设计;\n", "2. 用于大规模管理数据密集型机器学习流水线。" ] }, { "cell_type": "markdown", "metadata": { "id": "aLho3l_Vd0a5" }, "source": [ "## 安装\n", "\n", "让我们安装并导入所需依赖项:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c4COEsqIdvYH" }, "outputs": [], "source": [ "!pip install array_record\n", "!pip install tfds-nightly\n", "\n", "import os\n", "os.environ.pop('TFDS_DATA_DIR', None)\n", "\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "CjEJeF1Id_JM" }, "source": [ "## 数据源\n", "\n", "数据源基本上是 Python 序列。因此,它们需要实现以下协议:\n", "\n", "```python\n", "class RandomAccessDataSource(Protocol):\n", " \"\"\"Interface for datasources where storage supports efficient random access.\"\"\"\n", "\n", " def __len__(self) -> int:\n", " \"\"\"Number of records in the dataset.\"\"\"\n", "\n", " def __getitem__(self, record_key: int) -> Sequence[Any]:\n", " \"\"\"Retrieves records for the given record_keys.\"\"\"\n", "```\n", "\n", "**警告**:该 API 仍在积极开发中。特别是,`__getitem__` 目前在输入中必须支持 `int` 和 `list[int]`。将来,按照[标准](https://docs.python.org/3/reference/datamodel.html#object.__getitem__),它可能仅支持 `int`。\n", "\n", "底层文件格式需要支持高效的随机访问。目前,TFDS 依赖于 [`array_record`](https://github.com/google/array_record)。\n", "\n", "[`array_record`](https://github.com/google/array_record) 是一种衍生自 [Riegeli](https://github.com/google/riegeli) 的新文件格式,实现了 IO 效率的新前沿。特别是,ArrayRecord 支持按记录索引并行读取、写入和随机访问。ArrayRecord 建立在 Riegeli 之上,并支持相同的压缩算法。\n", "\n", "[`fashion_mnist`](https://tensorflow.google.cn/datasets/catalog/fashion_mnist) 是一个常见的计算机视觉数据集。要使用 TFDS 检索基于 ArrayRecord 的数据源,只需使用以下命令:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9Tslzx0_eEWx" }, "outputs": [], "source": [ "ds = tfds.data_source('fashion_mnist')" ] }, { "cell_type": "markdown", "metadata": { "id": "AlaRrD_SeHLY" }, "source": [ "`tfds.data_source` 是一个方便的包装器。它等同于:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "duHDKzXReIKB" }, "outputs": [], "source": [ "builder = tfds.builder('fashion_mnist', file_format='array_record')\n", "builder.download_and_prepare()\n", "ds = builder.as_data_source()" ] }, { "cell_type": "markdown", "metadata": { "id": "rlyIsd0ueKjQ" }, "source": [ "这将输出一个数据源字典:\n", "\n", "```\n", "{\n", " 'train': DataSource(name=fashion_mnist, split='train', decoders=None),\n", " 'test': DataSource(name=fashion_mnist, split='test', decoders=None),\n", "}\n", "```\n", "\n", "一旦 `download_and_prepare` 运行并在您生成记录文件后,我们就不再需要 TensorFlow 了。一切都将在 Python/NumPy 中完成!\n", "\n", "让我们通过卸载 TensorFlow 并在另一个子进程中重新加载数据源对此进行检查:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mTfSzvaQkSd9" }, "outputs": [], "source": [ "!pip uninstall -y tensorflow" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3sT5AN7neNT9" }, "outputs": [], "source": [ "%%writefile no_tensorflow.py\n", "import os\n", "os.environ.pop('TFDS_DATA_DIR', None)\n", "\n", "import tensorflow_datasets as tfds\n", "\n", "try:\n", " import tensorflow as tf\n", "except ImportError:\n", " print('No TensorFlow found...')\n", "\n", "ds = tfds.data_source('fashion_mnist')\n", "print('...but the data source could still be loaded...')\n", "ds['train'][0]\n", "print('...and the records can be decoded.')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FxohFdb3kSxh" }, "outputs": [], "source": [ "!python no_tensorflow.py" ] }, { "cell_type": "markdown", "metadata": { "id": "1o8n-BhhePYY" }, "source": [ "在未来的版本中,我们还将使数据集准备不再依赖 TensorFlow。\n", "\n", "数据源的长度为:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qtfl17SQeQ7F" }, "outputs": [], "source": [ "len(ds['train'])" ] }, { "cell_type": "markdown", "metadata": { "id": "a-UFBu8leSMp" }, "source": [ "访问数据集的第一个元素:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tFvT2Sx2eToh" }, "outputs": [], "source": [ "%%timeit\n", "ds['train'][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "VTgZskyZeU_D" }, "source": [ "…开销与访问任何其他元素一样低。下面是[随机访问](https://en.wikipedia.org/wiki/Random_access)的定义:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cPJFa6aIeWcY" }, "outputs": [], "source": [ "%%timeit\n", "ds['train'][1000]" ] }, { "cell_type": "markdown", "metadata": { "id": "fs3kafYheX6N" }, "source": [ "特征现在使用 NumPy DType(而不是 TensorFlow DType)。您可以使用以下命令检查特征:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q7x5AEEaeZja" }, "outputs": [], "source": [ "features = tfds.builder('fashion_mnist').info.features" ] }, { "cell_type": "markdown", "metadata": { "id": "VOnLqAZOeiBi" }, "source": [ "您可以在我们的[文档](https://tensorflow.google.cn/datasets/api_docs/python/tfds/features)中找到有关特征的更多信息。在这里,我们特别可以检索图像的形状和类别数量:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xk8Vc-y0edlb" }, "outputs": [], "source": [ "shape = features['image'].shape\n", "num_classes = features['label'].num_classes" ] }, { "cell_type": "markdown", "metadata": { "id": "eFh8pytVemsu" }, "source": [ "## 在纯 Python 中使用\n", "\n", "您可以通过迭代来使用 Python 中的数据源:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULjO-JDVefNf" }, "outputs": [], "source": [ "for example in ds['train']:\n", " print(example)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "gZRHZNOkenPb" }, "source": [ "如果您检查元素,还会注意到所有特征都已使用 NumPy 解码。在幕后,我们默认使用 [OpenCV](https://opencv.org),因为它很快。如果您没有安装 OpenCV,我们将默认使用 [Pillow](python-pillow.org) 来提供轻量级和快速的图像解码。\n", "\n", "```\n", "{\n", " 'image': array([[[0], [0], ..., [0]],\n", " [[0], [0], ..., [0]]], dtype=uint8),\n", " 'label': 2,\n", "}\n", "```\n", "\n", "**注**:目前,该功能仅适用于 `Tensor`、`Image` 和 `Scalar` 特征。`Audio` 和 `Video` 特征即将推出。敬请关注!" ] }, { "cell_type": "markdown", "metadata": { "id": "8kLyK5j1enhc" }, "source": [ "## 与 PyTorch 结合使用\n", "\n", "PyTorch 使用源/采样器/加载器范式。在 Torch 中,“数据源”称为“数据集”。[`torch.utils.data`](https://pytorch.org/docs/stable/data.html) 包含构建高效输入流水线所需的所有详细信息。\n", "\n", "TFDS 数据源可以像常规的[映射样式数据集](https://pytorch.org/docs/stable/data.html#map-style-datasets)一样使用。\n", "\n", "首先,我们安装并导入Torch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3aKol1fDeyoK" }, "outputs": [], "source": [ "!pip install torch\n", "\n", "from tqdm import tqdm\n", "import torch" ] }, { "cell_type": "markdown", "metadata": { "id": "HKdJvYywe0YC" }, "source": [ "我们已经为训练和测试分别定义了数据源(分别是 `ds['train']` 和 `ds['test']`)。现在,我们可以定义采样器和加载器:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_4P2JIrie23f" }, "outputs": [], "source": [ "batch_size = 128\n", "train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)\n", "train_loader = torch.utils.data.DataLoader(\n", " ds['train'],\n", " sampler=train_sampler,\n", " batch_size=batch_size,\n", ")\n", "test_loader = torch.utils.data.DataLoader(\n", " ds['test'],\n", " sampler=None,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "EVhofOm4e53O" }, "source": [ "使用 PyTorch,我们在第一个样本上进行训练,并评估简单的逻辑回归:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HcAmvMa-e42p" }, "outputs": [], "source": [ "class LinearClassifier(torch.nn.Module):\n", " def __init__(self, shape, num_classes):\n", " super(LinearClassifier, self).__init__()\n", " height, width, channels = shape\n", " self.classifier = torch.nn.Linear(height * width * channels, num_classes)\n", "\n", " def forward(self, image):\n", " image = image.view(image.size()[0], -1).to(torch.float32)\n", " return self.classifier(image)\n", "\n", "\n", "model = LinearClassifier(shape, num_classes)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "\n", "print('Training...')\n", "model.train()\n", "for example in tqdm(train_loader):\n", " image, label = example['image'], example['label']\n", " prediction = model(image)\n", " loss = loss_function(prediction, label)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", "print('Testing...')\n", "model.eval()\n", "num_examples = 0\n", "true_positives = 0\n", "for example in tqdm(test_loader):\n", " image, label = example['image'], example['label']\n", " prediction = model(image)\n", " num_examples += image.shape[0]\n", " predicted_label = prediction.argmax(dim=1)\n", " true_positives += (predicted_label == label).sum().item()\n", "print(f'\\nAccuracy: {true_positives/num_examples * 100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": { "id": "ewKJQpwZe6Ik" }, "source": [ "## 即将推出:与 JAX 结合使用\n", "\n", "我们正在与 [Grain](https://github.com/google/grain) 密切合作。Grain 是适用于 Python 的开源、快速和确定性数据加载器。敬请关注!" ] }, { "cell_type": "markdown", "metadata": { "id": "JvLEtCWRvvy8" }, "source": [ "## 阅读更多内容\n", "\n", "有关详情,请参阅 [`tfds.data_source`](https://tensorflow.google.cn/datasets/api_docs/python/tfds/data_source) API 文档。" ] } ], "metadata": { "colab": { "name": "tfless_tfds.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }