{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:23.769191Z", "iopub.status.busy": "2022-12-14T20:27:23.768763Z", "iopub.status.idle": "2022-12-14T20:27:23.773226Z", "shell.execute_reply": "2022-12-14T20:27:23.772579Z" }, "id": "8JSGdaDHc_f4" }, "outputs": [], "source": [ "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "z2_BHI6XdJ30" }, "source": [ "# Text-to-Video retrieval with S3D MIL-NCE" ] }, { "cell_type": "markdown", "metadata": { "id": "Rm0K9ZTgfISB" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
View 在 TensorFlow.org 上查看在 Google Colab 中运行 在 GitHub 上查看源代码下载笔记本 看到 TF Hub models\n", "
" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:23.776967Z", "iopub.status.busy": "2022-12-14T20:27:23.776386Z", "iopub.status.idle": "2022-12-14T20:27:29.207205Z", "shell.execute_reply": "2022-12-14T20:27:29.206447Z" }, "id": "bC_xJPpQd-LO" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 20:27:28.148852: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 20:27:28.148944: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 20:27:28.148953: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "!pip install -q opencv-python\n", "\n", "import os\n", "\n", "import tensorflow.compat.v2 as tf\n", "import tensorflow_hub as hub\n", "\n", "import numpy as np\n", "import cv2\n", "from IPython import display\n", "import math" ] }, { "cell_type": "markdown", "metadata": { "id": "ZxwaK-jf7qkW" }, "source": [ "## 导入 TF-Hub 模型\n", "\n", "本教程演示了如何使用 TensorFlow Hub 中的 [S3D MIL-NCE 模型](https://tfhub.dev/deepmind/mil-nce/s3d/1)执行**文本到视频检索**,以便找到与给定文本查询最相似的视频。\n", "\n", "该模型有 2 个签名,一个用于生成*视频嵌入向量*,另一个用于生成*文本嵌入向量*,我们利用这些嵌入向量来查找嵌入向量空间中的最近邻。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:29.211790Z", "iopub.status.busy": "2022-12-14T20:27:29.211044Z", "iopub.status.idle": "2022-12-14T20:27:38.309857Z", "shell.execute_reply": "2022-12-14T20:27:38.309093Z" }, "id": "nwv4ZQ4qmak5" }, "outputs": [], "source": [ "# Load the model once from TF-Hub.\n", "hub_handle = 'https://tfhub.dev/deepmind/mil-nce/s3d/1'\n", "hub_model = hub.load(hub_handle)\n", "\n", "def generate_embeddings(model, input_frames, input_words):\n", " \"\"\"Generate embeddings from the model from video frames and input words.\"\"\"\n", " # Input_frames must be normalized in [0, 1] and of the shape Batch x T x H x W x 3\n", " vision_output = model.signatures['video'](tf.constant(tf.cast(input_frames, dtype=tf.float32)))\n", " text_output = model.signatures['text'](tf.constant(input_words))\n", " return vision_output['video_embedding'], text_output['text_embedding']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:38.314255Z", "iopub.status.busy": "2022-12-14T20:27:38.313992Z", "iopub.status.idle": "2022-12-14T20:27:38.324819Z", "shell.execute_reply": "2022-12-14T20:27:38.324143Z" }, "id": "EOZzu9ddekEj" }, "outputs": [], "source": [ "# @title Define video loading and visualization functions { display-mode: \"form\" }\n", "\n", "# Utilities to open video files using CV2\n", "def crop_center_square(frame):\n", " y, x = frame.shape[0:2]\n", " min_dim = min(y, x)\n", " start_x = (x // 2) - (min_dim // 2)\n", " start_y = (y // 2) - (min_dim // 2)\n", " return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]\n", "\n", "\n", "def load_video(video_url, max_frames=32, resize=(224, 224)):\n", " path = tf.keras.utils.get_file(os.path.basename(video_url)[-128:], video_url)\n", " cap = cv2.VideoCapture(path)\n", " frames = []\n", " try:\n", " while True:\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = crop_center_square(frame)\n", " frame = cv2.resize(frame, resize)\n", " frame = frame[:, :, [2, 1, 0]]\n", " frames.append(frame)\n", "\n", " if len(frames) == max_frames:\n", " break\n", " finally:\n", " cap.release()\n", " frames = np.array(frames)\n", " if len(frames) < max_frames:\n", " n_repeat = int(math.ceil(max_frames / float(len(frames))))\n", " frames = frames.repeat(n_repeat, axis=0)\n", " frames = frames[:max_frames]\n", " return frames / 255.0\n", "\n", "def display_video(urls):\n", " html = ''\n", " html += ''\n", " for url in urls:\n", " html += ''\n", " html += '
Video 1Video 2Video 3
'\n", " html += ''.format(url)\n", " html += '
'\n", " return display.HTML(html)\n", "\n", "def display_query_and_results_video(query, urls, scores):\n", " \"\"\"Display a text query and the top result videos and scores.\"\"\"\n", " sorted_ix = np.argsort(-scores)\n", " html = ''\n", " html += '

Input query: {}

'.format(query)\n", " html += 'Results:
'\n", " html += ''\n", " html += ''.format(scores[sorted_ix[0]])\n", " html += ''.format(scores[sorted_ix[1]])\n", " html += ''.format(scores[sorted_ix[2]])\n", " for i, idx in enumerate(sorted_ix):\n", " url = urls[sorted_ix[i]];\n", " html += ''\n", " html += '
Rank #1, Score:{:.2f}Rank #2, Score:{:.2f}Rank #3, Score:{:.2f}
'\n", " html += ''.format(url)\n", " html += '
'\n", " return html\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:38.328365Z", "iopub.status.busy": "2022-12-14T20:27:38.327808Z", "iopub.status.idle": "2022-12-14T20:27:39.184150Z", "shell.execute_reply": "2022-12-14T20:27:39.183421Z" }, "id": "Ime5V4kDewh8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/1207385 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65536/1207385 [>.............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 327680/1207385 [=======>......................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1207385/1207385 [==============================] - 0s 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/1021622 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65536/1021622 [>.............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 344064/1021622 [=========>....................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1021622/1021622 [==============================] - 0s 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/1506603 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 65536/1506603 [>.............................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 327680/1506603 [=====>........................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1376256/1506603 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1506603/1506603 [==============================] - 0s 0us/step\n" ] }, { "data": { "text/html": [ "
Video 1Video 2Video 3
" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title Load example videos and define text queries { display-mode: \"form\" }\n", "\n", "video_1_url = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif' # @param {type:\"string\"}\n", "video_2_url = 'https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif' # @param {type:\"string\"}\n", "video_3_url = 'https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif' # @param {type:\"string\"}\n", "\n", "video_1 = load_video(video_1_url)\n", "video_2 = load_video(video_2_url)\n", "video_3 = load_video(video_3_url)\n", "all_videos = [video_1, video_2, video_3]\n", "\n", "query_1_video = 'waterfall' # @param {type:\"string\"}\n", "query_2_video = 'playing guitar' # @param {type:\"string\"}\n", "query_3_video = 'car drifting' # @param {type:\"string\"}\n", "all_queries_video = [query_1_video, query_2_video, query_3_video]\n", "all_videos_urls = [video_1_url, video_2_url, video_3_url]\n", "display_video(all_videos_urls)" ] }, { "cell_type": "markdown", "metadata": { "id": "NCLKv_L_8Anc" }, "source": [ "## 演示文本到视频检索\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:39.187524Z", "iopub.status.busy": "2022-12-14T20:27:39.187284Z", "iopub.status.idle": "2022-12-14T20:27:42.287808Z", "shell.execute_reply": "2022-12-14T20:27:42.286966Z" }, "id": "9oX8ItFUjybi" }, "outputs": [], "source": [ "# Prepare video inputs.\n", "videos_np = np.stack(all_videos, axis=0)\n", "\n", "# Prepare text input.\n", "words_np = np.array(all_queries_video)\n", "\n", "# Generate the video and text embeddings.\n", "video_embd, text_embd = generate_embeddings(hub_model, videos_np, words_np)\n", "\n", "# Scores between video and text is computed by dot products.\n", "all_scores = np.dot(text_embd, tf.transpose(video_embd))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T20:27:42.292782Z", "iopub.status.busy": "2022-12-14T20:27:42.292479Z", "iopub.status.idle": "2022-12-14T20:27:42.298150Z", "shell.execute_reply": "2022-12-14T20:27:42.297488Z" }, "id": "d4AwYmODmE9Y" }, "outputs": [ { "data": { "text/html": [ "

Input query: waterfall

Results:
Rank #1, Score:4.71Rank #2, Score:-1.63Rank #3, Score:-4.17

Input query: playing guitar

Results:
Rank #1, Score:6.50Rank #2, Score:-1.79Rank #3, Score:-2.67

Input query: car drifting

Results:
Rank #1, Score:8.78Rank #2, Score:-1.07Rank #3, Score:-2.17

" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Display results.\n", "html = ''\n", "for i, words in enumerate(words_np):\n", " html += display_query_and_results_video(words, all_videos_urls, all_scores[i, :])\n", " html += '
'\n", "display.HTML(html)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "text_to_video_retrieval_with_s3d_milnce.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }