{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:07.545842Z", "iopub.status.busy": "2024-03-09T14:50:07.545218Z", "iopub.status.idle": "2024-03-09T14:50:07.549506Z", "shell.execute_reply": "2024-03-09T14:50:07.548746Z" }, "id": "8JSGdaDHc_f4" }, "outputs": [], "source": [ "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "z2_BHI6XdJ30" }, "source": [ "# Text-to-Video retrieval with S3D MIL-NCE" ] }, { "cell_type": "markdown", "metadata": { "id": "Rm0K9ZTgfISB" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub model\n", "
" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:07.553910Z", "iopub.status.busy": "2024-03-09T14:50:07.553267Z", "iopub.status.idle": "2024-03-09T14:50:14.337361Z", "shell.execute_reply": "2024-03-09T14:50:14.336483Z" }, "id": "bC_xJPpQd-LO" }, "outputs": [], "source": [ "!pip install -q opencv-python\n", "\n", "import os\n", "\n", "import tensorflow.compat.v2 as tf\n", "import tensorflow_hub as hub\n", "\n", "import numpy as np\n", "import cv2\n", "from IPython import display\n", "import math" ] }, { "cell_type": "markdown", "metadata": { "id": "ZxwaK-jf7qkW" }, "source": [ "## Import TF-Hub model\n", "\n", "This tutorial demonstrates how to use the [S3D MIL-NCE model](https://tfhub.dev/deepmind/mil-nce/s3d/1) from TensorFlow Hub to do **text-to-video retrieval** to find the most similar videos for a given text query.\n", "\n", "The model has 2 signatures, one for generating *video embeddings* and one for generating *text embeddings*. We will use these embedding to find the nearest neighbors in the embedding space." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:14.342553Z", "iopub.status.busy": "2024-03-09T14:50:14.341680Z", "iopub.status.idle": "2024-03-09T14:50:21.618557Z", "shell.execute_reply": "2024-03-09T14:50:21.617598Z" }, "id": "nwv4ZQ4qmak5" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-03-09 14:50:17.063759: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] } ], "source": [ "# Load the model once from TF-Hub.\n", "hub_handle = 'https://tfhub.dev/deepmind/mil-nce/s3d/1'\n", "hub_model = hub.load(hub_handle)\n", "\n", "def generate_embeddings(model, input_frames, input_words):\n", " \"\"\"Generate embeddings from the model from video frames and input words.\"\"\"\n", " # Input_frames must be normalized in [0, 1] and of the shape Batch x T x H x W x 3\n", " vision_output = model.signatures['video'](tf.constant(tf.cast(input_frames, dtype=tf.float32)))\n", " text_output = model.signatures['text'](tf.constant(input_words))\n", " return vision_output['video_embedding'], text_output['text_embedding']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:21.623077Z", "iopub.status.busy": "2024-03-09T14:50:21.622757Z", "iopub.status.idle": "2024-03-09T14:50:21.634256Z", "shell.execute_reply": "2024-03-09T14:50:21.633567Z" }, "id": "EOZzu9ddekEj" }, "outputs": [], "source": [ "# @title Define video loading and visualization functions { display-mode: \"form\" }\n", "\n", "# Utilities to open video files using CV2\n", "def crop_center_square(frame):\n", " y, x = frame.shape[0:2]\n", " min_dim = min(y, x)\n", " start_x = (x // 2) - (min_dim // 2)\n", " start_y = (y // 2) - (min_dim // 2)\n", " return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]\n", "\n", "\n", "def load_video(video_url, max_frames=32, resize=(224, 224)):\n", " path = tf.keras.utils.get_file(os.path.basename(video_url)[-128:], video_url)\n", " cap = cv2.VideoCapture(path)\n", " frames = []\n", " try:\n", " while True:\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = crop_center_square(frame)\n", " frame = cv2.resize(frame, resize)\n", " frame = frame[:, :, [2, 1, 0]]\n", " frames.append(frame)\n", "\n", " if len(frames) == max_frames:\n", " break\n", " finally:\n", " cap.release()\n", " frames = np.array(frames)\n", " if len(frames) < max_frames:\n", " n_repeat = int(math.ceil(max_frames / float(len(frames))))\n", " frames = frames.repeat(n_repeat, axis=0)\n", " frames = frames[:max_frames]\n", " return frames / 255.0\n", "\n", "def display_video(urls):\n", " html = ''\n", " html += ''\n", " for url in urls:\n", " html += ''\n", " html += '
Video 1Video 2Video 3
'\n", " html += ''.format(url)\n", " html += '
'\n", " return display.HTML(html)\n", "\n", "def display_query_and_results_video(query, urls, scores):\n", " \"\"\"Display a text query and the top result videos and scores.\"\"\"\n", " sorted_ix = np.argsort(-scores)\n", " html = ''\n", " html += '

Input query: {}

'.format(query)\n", " html += 'Results:
'\n", " html += ''\n", " html += ''.format(scores[sorted_ix[0]])\n", " html += ''.format(scores[sorted_ix[1]])\n", " html += ''.format(scores[sorted_ix[2]])\n", " for i, idx in enumerate(sorted_ix):\n", " url = urls[sorted_ix[i]];\n", " html += ''\n", " html += '
Rank #1, Score:{:.2f}Rank #2, Score:{:.2f}Rank #3, Score:{:.2f}
'\n", " html += ''.format(url)\n", " html += '
'\n", " return html\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:21.637718Z", "iopub.status.busy": "2024-03-09T14:50:21.637117Z", "iopub.status.idle": "2024-03-09T14:50:22.338995Z", "shell.execute_reply": "2024-03-09T14:50:22.338217Z" }, "id": "Ime5V4kDewh8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 0/1207385\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m 376832/1207385\u001b[0m \u001b[32m━━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0us/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1207385/1207385\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 0/1021622\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m 344064/1021622\u001b[0m \u001b[32m━━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0us/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1021622/1021622\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 0/1506603\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m 344064/1506603\u001b[0m \u001b[32m━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0us/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1506603/1506603\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" ] }, { "data": { "text/html": [ "
Video 1Video 2Video 3
" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title Load example videos and define text queries { display-mode: \"form\" }\n", "\n", "video_1_url = 'https://upload.wikimedia.org/wikipedia/commons/b/b0/YosriAirTerjun.gif' # @param {type:\"string\"}\n", "video_2_url = 'https://upload.wikimedia.org/wikipedia/commons/e/e6/Guitar_solo_gif.gif' # @param {type:\"string\"}\n", "video_3_url = 'https://upload.wikimedia.org/wikipedia/commons/3/30/2009-08-16-autodrift-by-RalfR-gif-by-wau.gif' # @param {type:\"string\"}\n", "\n", "video_1 = load_video(video_1_url)\n", "video_2 = load_video(video_2_url)\n", "video_3 = load_video(video_3_url)\n", "all_videos = [video_1, video_2, video_3]\n", "\n", "query_1_video = 'waterfall' # @param {type:\"string\"}\n", "query_2_video = 'playing guitar' # @param {type:\"string\"}\n", "query_3_video = 'car drifting' # @param {type:\"string\"}\n", "all_queries_video = [query_1_video, query_2_video, query_3_video]\n", "all_videos_urls = [video_1_url, video_2_url, video_3_url]\n", "display_video(all_videos_urls)" ] }, { "cell_type": "markdown", "metadata": { "id": "NCLKv_L_8Anc" }, "source": [ "## Demonstrate text to video retrieval\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:22.342522Z", "iopub.status.busy": "2024-03-09T14:50:22.342234Z", "iopub.status.idle": "2024-03-09T14:50:23.808121Z", "shell.execute_reply": "2024-03-09T14:50:23.807284Z" }, "id": "9oX8ItFUjybi" }, "outputs": [], "source": [ "# Prepare video inputs.\n", "videos_np = np.stack(all_videos, axis=0)\n", "\n", "# Prepare text input.\n", "words_np = np.array(all_queries_video)\n", "\n", "# Generate the video and text embeddings.\n", "video_embd, text_embd = generate_embeddings(hub_model, videos_np, words_np)\n", "\n", "# Scores between video and text is computed by dot products.\n", "all_scores = np.dot(text_embd, tf.transpose(video_embd))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-09T14:50:23.812516Z", "iopub.status.busy": "2024-03-09T14:50:23.811826Z", "iopub.status.idle": "2024-03-09T14:50:23.817933Z", "shell.execute_reply": "2024-03-09T14:50:23.817175Z" }, "id": "d4AwYmODmE9Y" }, "outputs": [ { "data": { "text/html": [ "

Input query: waterfall

Results:
Rank #1, Score:4.71Rank #2, Score:-1.63Rank #3, Score:-4.17

Input query: playing guitar

Results:
Rank #1, Score:6.50Rank #2, Score:-1.79Rank #3, Score:-2.67

Input query: car drifting

Results:
Rank #1, Score:8.78Rank #2, Score:-1.07Rank #3, Score:-2.17

" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Display results.\n", "html = ''\n", "for i, words in enumerate(words_np):\n", " html += display_query_and_results_video(words, all_videos_urls, all_scores[i, :])\n", " html += '
'\n", "display.HTML(html)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "text_to_video_retrieval_with_s3d_milnce.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }