{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "6XvCUmCEd4Dm" }, "source": [ "# TensorFlow Datasets\n", "\n", "TFDS 提供了一组现成的数据集,适合与 TensorFlow、Jax 和其他机器学习框架配合使用。\n", "\n", "它可以确定地处理下载和准备数据并构造 `tf.data.Dataset`(或 `np.array`)。\n", "\n", "注:不要将 [TFDS](https://tensorflow.google.cn/datasets)(此库)与 `tf.data`(用于构建高效数据流水线的 TensorFlow API)混淆。TFDS 是 `tf.data` 的高级封装容器。如果您不熟悉此 API,建议您先阅读[官方 tf.data 指南](https://tensorflow.google.cn/guide/data)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "J8y9ZkLXmAZc" }, "source": [ "版权所有 2018 TensorFlow 数据集作者,以 Apache License, Version 2.0 授权" ] }, { "cell_type": "markdown", "metadata": { "id": "OGw9EgE0tC0C" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "_7hshda5eaGL" }, "source": [ "## 安装\n", "\n", "TFDS 存在于两个软件包中:\n", "\n", "- `pip install tensorflow-datasets`:稳定版,数月发行一次。\n", "- `pip install tfds-nightly`:每天发行,包含最近版本的数据集。\n", "\n", "此 colab 使用 `tfds-nightly`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "id": "boeZp0sYbO41" }, "outputs": [], "source": [ "!pip install -q tfds-nightly tensorflow matplotlib" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TTBSvHcSLBzc" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "VZZyuO13fPvk" }, "source": [ "## 查找可用的数据集\n", "\n", "所有数据集构建工具都是 `tfds.core.DatasetBuilder` 的子类。要获取可用构建工具的列表,请使用 `tfds.list_builders()` 或查看我们的[目录](https://tensorflow.google.cn/datasets/catalog/overview)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FAvbSVzjLCIb" }, "outputs": [], "source": [ "tfds.list_builders()" ] }, { "cell_type": "markdown", "metadata": { "id": "VjI6VgOBf0v0" }, "source": [ "## 加载数据集\n", "\n", "### tfds.load\n", "\n", "加载数据集最简单的方法是 `tfds.load`。它将执行以下操作:\n", "\n", "1. 下载数据并将其存储为 [`tfrecord`](https://tensorflow.google.cn/tutorials/load_data/tfrecord) 文件。\n", "2. 加载 `tfrecord` 并创建 `tf.data.Dataset`。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dCou80mnLLPV" }, "outputs": [], "source": [ "ds = tfds.load('mnist', split='train', shuffle_files=True)\n", "assert isinstance(ds, tf.data.Dataset)\n", "print(ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "byOXYCEJS7S6" }, "source": [ "一些常见的参数:\n", "\n", "- `split=`:要读取的拆分(例如 `'train'`、`['train', 'test']`、`'train[80%:]'`…)。请参阅我们的[拆分 API 指南](https://tensorflow.google.cn/datasets/splits)。\n", "- `shuffle_files=`:控制是否重排每个周期间的文件顺序(TFDS 以多个较小的文件存储大数据集)\n", "- `data_dir=`:数据集存储的位置(默认为 `~/tensorflow_datasets/`)\n", "- `with_info=True`:返回包含数据集元数据的 `tfds.core.DatasetInfo`\n", "- `download=False`:停用下载\n" ] }, { "cell_type": "markdown", "metadata": { "id": "qeNmFx_1RXCb" }, "source": [ "### tfds.builder\n", "\n", "`tfds.load` 是 `tfds.core.DatasetBuilder` 的瘦封装容器。您可以使用 `tfds.core.DatasetBuilder` API 获得相同的输出:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2zN_jQ2ER40W" }, "outputs": [], "source": [ "builder = tfds.builder('mnist')\n", "# 1. Create the tfrecord files (no-op if already exists)\n", "builder.download_and_prepare()\n", "# 2. Load the `tf.data.Dataset`\n", "ds = builder.as_dataset(split='train', shuffle_files=True)\n", "print(ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "IwrjccfjoQCD" }, "source": [ "### `tfds build` CLI\n", "\n", "如果您希望生成一个特定的数据集,可以使用 [`tfds` 命令行](https://tensorflow.google.cn/datasets/cli)。例如:\n", "\n", "```sh\n", "tfds build mnist\n", "```\n", "\n", "请参阅[文档](https://tensorflow.google.cn/datasets/cli)查看可用标志。" ] }, { "cell_type": "markdown", "metadata": { "id": "aW132I-rbJXE" }, "source": [ "## 迭代数据集\n", "\n", "### 作为字典\n", "\n", "默认情况下,`tf.data.Dataset` 对象包含 `tf.Tensor` 的 `dict`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JAGjXdk_bIYQ" }, "outputs": [], "source": [ "ds = tfds.load('mnist', split='train')\n", "ds = ds.take(1) # Only take a single example\n", "\n", "for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`\n", " print(list(example.keys()))\n", " image = example[\"image\"]\n", " label = example[\"label\"]\n", " print(image.shape, label)" ] }, { "cell_type": "markdown", "metadata": { "id": "sIqX2bmhu-8d" }, "source": [ "要找出 `dict` 键名和结构,请查看[我们目录](https://tensorflow.google.cn/datasets/catalog/overview#all_datasets)中的数据集文档。例如:[mnist 文档](https://tensorflow.google.cn/datasets/catalog/mnist)。" ] }, { "cell_type": "markdown", "metadata": { "id": "umAtqBBqdkDG" }, "source": [ "### 作为元组(`as_supervised=True`)\n", "\n", "使用 `as_supervised=True`,您可以获取 `(features, label)` 元组作为替代的监督数据集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nJ4O0xy3djfV" }, "outputs": [], "source": [ "ds = tfds.load('mnist', split='train', as_supervised=True)\n", "ds = ds.take(1)\n", "\n", "for image, label in ds: # example is (image, label)\n", " print(image.shape, label)" ] }, { "cell_type": "markdown", "metadata": { "id": "u9palgyHfEwQ" }, "source": [ "### 作为 numpy(`tfds.as_numpy`)\n", "\n", "使用 `tfds.as_numpy` 进行以下转换:\n", "\n", "- `tf.Tensor` -> `np.array`\n", "- `tf.data.Dataset` -> `Iterator[Tree[np.array]]`(`Tree` 可能是任意嵌套的 `Dict`、`Tuple`)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tzQTCUkAfe9R" }, "outputs": [], "source": [ "ds = tfds.load('mnist', split='train', as_supervised=True)\n", "ds = ds.take(1)\n", "\n", "for image, label in tfds.as_numpy(ds):\n", " print(type(image), type(label), label)" ] }, { "cell_type": "markdown", "metadata": { "id": "XaRN-LdXUkl_" }, "source": [ "### 作为 batched tf.Tensor(`batch_size=-1`)\n", "\n", "使用 `batch_size=-1`,您可以在单个批次中加载完整的数据集。\n", "\n", "这可与 `as_supervised=True` 和 `tfds.as_numpy` 结合使用以获取 `(np.array, np.array)` 形式的数据:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gg8BNsv-UzFl" }, "outputs": [], "source": [ "image, label = tfds.as_numpy(tfds.load(\n", " 'mnist',\n", " split='test',\n", " batch_size=-1,\n", " as_supervised=True,\n", "))\n", "\n", "print(type(image), image.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "kRJrB3L6wgKI" }, "source": [ "请注意,您的数据集可以放入内存,并且所有样本都具有相同的形状。" ] }, { "cell_type": "markdown", "metadata": { "id": "heaKNg7-X4jN" }, "source": [ "## 对您的数据集进行基准分析\n", "\n", "对数据集进行基准分析是对任何可迭代对象(例如 `tf.data.Dataset`、`tfds.as_numpy`…)的简单 `tfds.benchmark` 调用。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZyQzZ98bX3dM" }, "outputs": [], "source": [ "ds = tfds.load('mnist', split='train')\n", "ds = ds.batch(32).prefetch(1)\n", "\n", "tfds.benchmark(ds, batch_size=32)\n", "tfds.benchmark(ds, batch_size=32) # Second epoch much faster due to auto-caching" ] }, { "cell_type": "markdown", "metadata": { "id": "MT0yEX_4kYnV" }, "source": [ "- 不要忘记使用 `batch_size=` kwarg 对每个批次大小的结果进行归一化。\n", "- 总之,第一个预热批次与其他预热批次分开以捕获 `tf.data.Dataset` 额外的设置时间(例如缓冲区初始化…)。\n", "- 请注意,由于 [TFDS 自动缓存功能](https://tensorflow.google.cn/datasets/performances#auto-caching),第二次迭代的速度要快得多。\n", "- `tfds.benchmark` 会返回 `tfds.core.BenchmarkResult` ,可以检查它以进行进一步分析。" ] }, { "cell_type": "markdown", "metadata": { "id": "o-cuwvVbeb43" }, "source": [ "### 构建端到端流水线\n", "\n", "要想深入一点,您可以查看:\n", "\n", "- 我们的[端到端 Keras 示例](https://tensorflow.google.cn/datasets/keras_example)来了解完整的训练流水线(包括批处理、重排…)。\n", "- 有助于提高流水线速度的[性能指南](https://tensorflow.google.cn/datasets/performances)(提示:使用 `tfds.benchmark(ds)` 对数据集进行基准分析)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "gTRTEQqscxAE" }, "source": [ "## 呈现\n", "\n", "### tfds.as_dataframe\n", "\n", "使用 `tfds.as_dataframe`,可以将 `tf.data.Dataset` 对象转换为 [`pandas.DataFrame`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) 以在 [Colab](https://colab.research.google.com) 上呈现。\n", "\n", "- 添加 `tfds.core.DatasetInfo` 作为 `tfds.as_dataframe` 的第二个参数以呈现图像、音频、文本、视频…\n", "- 使用 `ds.take(x)` 仅显示前 `x` 个样本。`pandas.DataFrame` 将在内存中加载完整数据集,并且显示开销可能非常高。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FKouwN_yVSGQ" }, "outputs": [], "source": [ "ds, info = tfds.load('mnist', split='train', with_info=True)\n", "\n", "tfds.as_dataframe(ds.take(4), info)" ] }, { "cell_type": "markdown", "metadata": { "id": "b-eDO_EXVGWC" }, "source": [ "### tfds.show_examples\n", "\n", "`tfds.show_examples` 返回 `matplotlib.figure.Figure`(现在只支持图像数据集):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DpE2FD56cSQR" }, "outputs": [], "source": [ "ds, info = tfds.load('mnist', split='train', with_info=True)\n", "\n", "fig = tfds.show_examples(ds, info)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y0iVVStvk0oI" }, "source": [ "## 访问数据集元数据\n", "\n", "所有构建工具都包括一个包含数据集元数据的 `tfds.core.DatasetInfo` 对象。\n", "\n", "可以通过以下方式访问:\n", "\n", "- `tfds.load` API:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UgLgtcd1ljzt" }, "outputs": [], "source": [ "ds, info = tfds.load('mnist', with_info=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "XodyqNXrlxTM" }, "source": [ "- `tfds.core.DatasetBuilder` API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nmq97QkilxeL" }, "outputs": [], "source": [ "builder = tfds.builder('mnist')\n", "info = builder.info" ] }, { "cell_type": "markdown", "metadata": { "id": "zMGOk_ZsmPeu" }, "source": [ "数据集信息包含有关数据集的附加信息(版本、引用、首页、描述…)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O-wLIKD-mZQT" }, "outputs": [], "source": [ "print(info)" ] }, { "cell_type": "markdown", "metadata": { "id": "1zvAfRtwnAFk" }, "source": [ "### 特征元数据(标签名称、图像形状…)\n", "\n", "访问 `tfds.features.FeatureDict`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RcyZXncqoFab" }, "outputs": [], "source": [ "info.features" ] }, { "cell_type": "markdown", "metadata": { "id": "KAm9AV7loyw5" }, "source": [ "类、标签名的数量:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HhfzBH6qowpz" }, "outputs": [], "source": [ "print(info.features[\"label\"].num_classes)\n", "print(info.features[\"label\"].names)\n", "print(info.features[\"label\"].int2str(7)) # Human readable version (8 -> 'cat')\n", "print(info.features[\"label\"].str2int('7'))" ] }, { "cell_type": "markdown", "metadata": { "id": "g5eWtk9ro_AK" }, "source": [ "形状、数据类型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SergV_wQowLY" }, "outputs": [], "source": [ "print(info.features.shape)\n", "print(info.features.dtype)\n", "print(info.features['image'].shape)\n", "print(info.features['image'].dtype)" ] }, { "cell_type": "markdown", "metadata": { "id": "thMOZ4IKm55N" }, "source": [ "### 拆分元数据(例如拆分名称、样本数量…)\n", "\n", "访问 `tfds.core.SplitDict`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FBbfwA8Sp4ax" }, "outputs": [], "source": [ "print(info.splits)" ] }, { "cell_type": "markdown", "metadata": { "id": "EVw1UVYa2HgN" }, "source": [ "可用拆分:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fRBieOOquDzX" }, "outputs": [], "source": [ "print(list(info.splits.keys()))" ] }, { "cell_type": "markdown", "metadata": { "id": "iHW0VfA0t3dO" }, "source": [ "获取有关个别拆分的信息:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-h_OSpRsqKpP" }, "outputs": [], "source": [ "print(info.splits['train'].num_examples)\n", "print(info.splits['train'].filenames)\n", "print(info.splits['train'].num_shards)" ] }, { "cell_type": "markdown", "metadata": { "id": "fWhSkHFNuLwW" }, "source": [ "它也适用于 subsplit API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HO5irBZ3uIzQ" }, "outputs": [], "source": [ "print(info.splits['train[15%:75%]'].num_examples)\n", "print(info.splits['train[15%:75%]'].file_instructions)" ] }, { "cell_type": "markdown", "metadata": { "id": "YZp2XJwQQrI0" }, "source": [ "## 问题排查\n", "\n", "### 手动下载(如果下载失败)\n", "\n", "如果由于某种原因下载失败(例如离线…),那么您始终可以自己手动下载数据并将其放置在 `manual_dir` 中(默认为 `~/tensorflow_datasets/download/manual/`)。\n", "\n", "要找到下载网址,请查看:\n", "\n", "- 对于新数据集(作为文件夹实现):[`tensorflow_datasets/`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/)`//checksums.tsv`。例如:[`tensorflow_datasets/datasets/bool_q/checksums.tsv`](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/bool_q/checksums.tsv)。\n", "\n", " 您可以在[我们的目录](https://tensorflow.google.cn/datasets/catalog/overview)中找到数据集的源位置。\n", "\n", "- 对于旧数据集:[`tensorflow_datasets/url_checksums/.txt`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/url_checksums)\n", "\n", "### 修正 `NonMatchingChecksumError`\n", "\n", "TFDS 通过验证下载网址的校验和来确保确定性。如果引发 `NonMatchingChecksumError`,则可能表示:\n", "\n", "- 网站可能宕机(如 `503 status code`)。请检查网址。\n", "- 对于 Google 云端硬盘网址,请稍后再试。当很多人访问同一网址时云端硬盘有时拒绝下载。请参阅[错误](https://github.com/tensorflow/datasets/issues/1482)\n", "- 原始数据集文件可能已更新。在这种情况下,应当更新 TFDS 数据集构建工具。请打开一个新的 Github 议题或拉取请求:\n", " - 使用 `tfds build --register_checksums` 注册新的校验和\n", " - 逐步更新数据集生成代码。\n", " - 更新数据集 `VERSION`\n", " - 更新数据集 `RELEASE_NOTES`:是什么导致校验和发生变化?一些样本发生了改变吗?\n", " - 确保数据集仍能够构建。\n", " - 向我们发送拉取请求\n", "\n", "注:您也可以检查 `~/tensorflow_datasets/download/` 中的下载文件。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GmeeOokMODg2" }, "source": [ "## 引用\n", "\n", "如果您在论文中使用 `tensorflow-datasets`,除了特定于所用数据集(可以在[数据集目录](https://tensorflow.google.cn/datasets/catalog/overview)中找到)的任何引用之外,请包含以下引用。\n", "\n", "```\n", "@misc{TFDS,\n", " title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},\n", " howpublished = {\\url{https://tensorflow.google.cn/datasets}},\n", "}\n", "```" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "overview.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }