{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "toCy3v03Dwx7" }, "source": [ "##### Copyright 2021 The TensorFlow Hub Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:15.015968Z", "iopub.status.busy": "2024-01-11T19:50:15.015354Z", "iopub.status.idle": "2024-01-11T19:50:15.019245Z", "shell.execute_reply": "2024-01-11T19:50:15.018648Z" }, "id": "QKe-ubNcDvgv" }, "outputs": [], "source": [ "# Copyright 2021 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "KqtQzBCpIJ7Y" }, "source": [ "# MoveNet: 超高速で高性能な姿勢検出モデル" ] }, { "cell_type": "markdown", "metadata": { "id": "MCmFOosnSkCd" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行\n", "GitHub でソースを表示ノートブックをダウンロード TF Hub モデルを参照\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "6x99e0aEY_d6" }, "source": [ "**[MoveNet](https://t.co/QpfnVL0YYI?amp=1)** は、身体の 17 のキーポイントを検出する超高速で高精度なモデルです。[TF Hub](https://tfhub.dev/s?q=movenet) で提供され、ライトニングとサンダーとして知られる 2 つのバリアントがあります。ライトニングはレイテンシクリティカルなアプリケーションに、サンダーは高精度を必要とするアプリケーションに適しています。両モデルは、ほとんどの現代のデスクトップ、ラップトップ、スマートフォンでリアルタイム (30+ FPS) よりも高速に実行され、これはライブフィットネス、ヘルス、およびウェルネスアプリケーションにとって不可欠です。\n", "\n", "\n", "\"drawing\"\n", "\n", "*Pexels (https://www.pexels.com/) よりダウンロードされた画像\n", "\n", "この Colab では、MoveNet を読み込み、入力画像とビデオで推論を実行する方法について詳しく説明します。\n", "\n", "注: モデルの仕組みについては[ライブデモ](https://storage.googleapis.com/tfjs-models/demos/pose-detection/index.html?model=movenet)をご確認ください!" ] }, { "cell_type": "markdown", "metadata": { "id": "10_zkgbZBkIE" }, "source": [ "# MoveNet での姿勢推定" ] }, { "cell_type": "markdown", "metadata": { "id": "9u_VGR6_BmbZ" }, "source": [ "## 視覚化ライブラリとインポート" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:15.022793Z", "iopub.status.busy": "2024-01-11T19:50:15.022564Z", "iopub.status.idle": "2024-01-11T19:50:25.874954Z", "shell.execute_reply": "2024-01-11T19:50:25.873857Z" }, "id": "TtcwSIcgbIVN" }, "outputs": [], "source": [ "!pip install -q imageio\n", "!pip install -q opencv-python\n", "!pip install -q git+https://github.com/tensorflow/docs" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:25.879374Z", "iopub.status.busy": "2024-01-11T19:50:25.879094Z", "iopub.status.idle": "2024-01-11T19:50:28.722503Z", "shell.execute_reply": "2024-01-11T19:50:28.721810Z" }, "id": "9BLeJv-pCCld" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-01-11 19:50:26.313001: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-01-11 19:50:26.313045: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-01-11 19:50:26.314632: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "from tensorflow_docs.vis import embed\n", "import numpy as np\n", "import cv2\n", "\n", "# Import matplotlib libraries\n", "from matplotlib import pyplot as plt\n", "from matplotlib.collections import LineCollection\n", "import matplotlib.patches as patches\n", "\n", "# Some modules to display an animation using imageio.\n", "import imageio\n", "from IPython.display import HTML, display" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T19:50:28.726762Z", "iopub.status.busy": "2024-01-11T19:50:28.726381Z", "iopub.status.idle": "2024-01-11T19:50:28.745476Z", "shell.execute_reply": "2024-01-11T19:50:28.744815Z" }, "id": "bEJBMeRb3YUy" }, "outputs": [], "source": [ "#@title Helper functions for visualization\n", "\n", "# Dictionary that maps from joint names to keypoint indices.\n", "KEYPOINT_DICT = {\n", " 'nose': 0,\n", " 'left_eye': 1,\n", " 'right_eye': 2,\n", " 'left_ear': 3,\n", " 'right_ear': 4,\n", " 'left_shoulder': 5,\n", " 'right_shoulder': 6,\n", " 'left_elbow': 7,\n", " 'right_elbow': 8,\n", " 'left_wrist': 9,\n", " 'right_wrist': 10,\n", " 'left_hip': 11,\n", " 'right_hip': 12,\n", " 'left_knee': 13,\n", " 'right_knee': 14,\n", " 'left_ankle': 15,\n", " 'right_ankle': 16\n", "}\n", "\n", "# Maps bones to a matplotlib color name.\n", "KEYPOINT_EDGE_INDS_TO_COLOR = {\n", " (0, 1): 'm',\n", " (0, 2): 'c',\n", " (1, 3): 'm',\n", " (2, 4): 'c',\n", " (0, 5): 'm',\n", " (0, 6): 'c',\n", " (5, 7): 'm',\n", " (7, 9): 'm',\n", " (6, 8): 'c',\n", " (8, 10): 'c',\n", " (5, 6): 'y',\n", " (5, 11): 'm',\n", " (6, 12): 'c',\n", " (11, 12): 'y',\n", " (11, 13): 'm',\n", " (13, 15): 'm',\n", " (12, 14): 'c',\n", " (14, 16): 'c'\n", "}\n", "\n", "def _keypoints_and_edges_for_display(keypoints_with_scores,\n", " height,\n", " width,\n", " keypoint_threshold=0.11):\n", " \"\"\"Returns high confidence keypoints and edges for visualization.\n", "\n", " Args:\n", " keypoints_with_scores: A numpy array with shape [1, 1, 17, 3] representing\n", " the keypoint coordinates and scores returned from the MoveNet model.\n", " height: height of the image in pixels.\n", " width: width of the image in pixels.\n", " keypoint_threshold: minimum confidence score for a keypoint to be\n", " visualized.\n", "\n", " Returns:\n", " A (keypoints_xy, edges_xy, edge_colors) containing:\n", " * the coordinates of all keypoints of all detected entities;\n", " * the coordinates of all skeleton edges of all detected entities;\n", " * the colors in which the edges should be plotted.\n", " \"\"\"\n", " keypoints_all = []\n", " keypoint_edges_all = []\n", " edge_colors = []\n", " num_instances, _, _, _ = keypoints_with_scores.shape\n", " for idx in range(num_instances):\n", " kpts_x = keypoints_with_scores[0, idx, :, 1]\n", " kpts_y = keypoints_with_scores[0, idx, :, 0]\n", " kpts_scores = keypoints_with_scores[0, idx, :, 2]\n", " kpts_absolute_xy = np.stack(\n", " [width * np.array(kpts_x), height * np.array(kpts_y)], axis=-1)\n", " kpts_above_thresh_absolute = kpts_absolute_xy[\n", " kpts_scores > keypoint_threshold, :]\n", " keypoints_all.append(kpts_above_thresh_absolute)\n", "\n", " for edge_pair, color in KEYPOINT_EDGE_INDS_TO_COLOR.items():\n", " if (kpts_scores[edge_pair[0]] > keypoint_threshold and\n", " kpts_scores[edge_pair[1]] > keypoint_threshold):\n", " x_start = kpts_absolute_xy[edge_pair[0], 0]\n", " y_start = kpts_absolute_xy[edge_pair[0], 1]\n", " x_end = kpts_absolute_xy[edge_pair[1], 0]\n", " y_end = kpts_absolute_xy[edge_pair[1], 1]\n", " line_seg = np.array([[x_start, y_start], [x_end, y_end]])\n", " keypoint_edges_all.append(line_seg)\n", " edge_colors.append(color)\n", " if keypoints_all:\n", " keypoints_xy = np.concatenate(keypoints_all, axis=0)\n", " else:\n", " keypoints_xy = np.zeros((0, 17, 2))\n", "\n", " if keypoint_edges_all:\n", " edges_xy = np.stack(keypoint_edges_all, axis=0)\n", " else:\n", " edges_xy = np.zeros((0, 2, 2))\n", " return keypoints_xy, edges_xy, edge_colors\n", "\n", "\n", "def draw_prediction_on_image(\n", " image, keypoints_with_scores, crop_region=None, close_figure=False,\n", " output_image_height=None):\n", " \"\"\"Draws the keypoint predictions on image.\n", "\n", " Args:\n", " image: A numpy array with shape [height, width, channel] representing the\n", " pixel values of the input image.\n", " keypoints_with_scores: A numpy array with shape [1, 1, 17, 3] representing\n", " the keypoint coordinates and scores returned from the MoveNet model.\n", " crop_region: A dictionary that defines the coordinates of the bounding box\n", " of the crop region in normalized coordinates (see the init_crop_region\n", " function below for more detail). If provided, this function will also\n", " draw the bounding box on the image.\n", " output_image_height: An integer indicating the height of the output image.\n", " Note that the image aspect ratio will be the same as the input image.\n", "\n", " Returns:\n", " A numpy array with shape [out_height, out_width, channel] representing the\n", " image overlaid with keypoint predictions.\n", " \"\"\"\n", " height, width, channel = image.shape\n", " aspect_ratio = float(width) / height\n", " fig, ax = plt.subplots(figsize=(12 * aspect_ratio, 12))\n", " # To remove the huge white borders\n", " fig.tight_layout(pad=0)\n", " ax.margins(0)\n", " ax.set_yticklabels([])\n", " ax.set_xticklabels([])\n", " plt.axis('off')\n", "\n", " im = ax.imshow(image)\n", " line_segments = LineCollection([], linewidths=(4), linestyle='solid')\n", " ax.add_collection(line_segments)\n", " # Turn off tick labels\n", " scat = ax.scatter([], [], s=60, color='#FF1493', zorder=3)\n", "\n", " (keypoint_locs, keypoint_edges,\n", " edge_colors) = _keypoints_and_edges_for_display(\n", " keypoints_with_scores, height, width)\n", "\n", " line_segments.set_segments(keypoint_edges)\n", " line_segments.set_color(edge_colors)\n", " if keypoint_edges.shape[0]:\n", " line_segments.set_segments(keypoint_edges)\n", " line_segments.set_color(edge_colors)\n", " if keypoint_locs.shape[0]:\n", " scat.set_offsets(keypoint_locs)\n", "\n", " if crop_region is not None:\n", " xmin = max(crop_region['x_min'] * width, 0.0)\n", " ymin = max(crop_region['y_min'] * height, 0.0)\n", " rec_width = min(crop_region['x_max'], 0.99) * width - xmin\n", " rec_height = min(crop_region['y_max'], 0.99) * height - ymin\n", " rect = patches.Rectangle(\n", " (xmin,ymin),rec_width,rec_height,\n", " linewidth=1,edgecolor='b',facecolor='none')\n", " ax.add_patch(rect)\n", "\n", " fig.canvas.draw()\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n", " image_from_plot = image_from_plot.reshape(\n", " fig.canvas.get_width_height()[::-1] + (3,))\n", " plt.close(fig)\n", " if output_image_height is not None:\n", " output_image_width = int(output_image_height / height * width)\n", " image_from_plot = cv2.resize(\n", " image_from_plot, dsize=(output_image_width, output_image_height),\n", " interpolation=cv2.INTER_CUBIC)\n", " return image_from_plot\n", "\n", "def to_gif(images, duration):\n", " \"\"\"Converts image sequence (4D numpy array) to gif.\"\"\"\n", " imageio.mimsave('./animation.gif', images, duration=duration)\n", " return embed.embed_file('./animation.gif')\n", "\n", "def progress(value, max=100):\n", " return HTML(\"\"\"\n", " \n", " {value}\n", " \n", " \"\"\".format(value=value, max=max))" ] }, { "cell_type": "markdown", "metadata": { "id": "UvrN0iQiOxhR" }, "source": [ "## TF hub からモデルを読み込む" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:28.748878Z", "iopub.status.busy": "2024-01-11T19:50:28.748652Z", "iopub.status.idle": "2024-01-11T19:50:42.703326Z", "shell.execute_reply": "2024-01-11T19:50:42.702559Z" }, "id": "zeGHgANcT7a1" }, "outputs": [], "source": [ "model_name = \"movenet_lightning\" #@param [\"movenet_lightning\", \"movenet_thunder\", \"movenet_lightning_f16.tflite\", \"movenet_thunder_f16.tflite\", \"movenet_lightning_int8.tflite\", \"movenet_thunder_int8.tflite\"]\n", "\n", "if \"tflite\" in model_name:\n", " if \"movenet_lightning_f16\" in model_name:\n", " !wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite\n", " input_size = 192\n", " elif \"movenet_thunder_f16\" in model_name:\n", " !wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite\n", " input_size = 256\n", " elif \"movenet_lightning_int8\" in model_name:\n", " !wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite\n", " input_size = 192\n", " elif \"movenet_thunder_int8\" in model_name:\n", " !wget -q -O model.tflite https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/int8/4?lite-format=tflite\n", " input_size = 256\n", " else:\n", " raise ValueError(\"Unsupported model name: %s\" % model_name)\n", "\n", " # Initialize the TFLite interpreter\n", " interpreter = tf.lite.Interpreter(model_path=\"model.tflite\")\n", " interpreter.allocate_tensors()\n", "\n", " def movenet(input_image):\n", " \"\"\"Runs detection on an input image.\n", "\n", " Args:\n", " input_image: A [1, height, width, 3] tensor represents the input image\n", " pixels. Note that the height/width should already be resized and match the\n", " expected input resolution of the model before passing into this function.\n", "\n", " Returns:\n", " A [1, 1, 17, 3] float numpy array representing the predicted keypoint\n", " coordinates and scores.\n", " \"\"\"\n", " # TF Lite format expects tensor type of uint8.\n", " input_image = tf.cast(input_image, dtype=tf.uint8)\n", " input_details = interpreter.get_input_details()\n", " output_details = interpreter.get_output_details()\n", " interpreter.set_tensor(input_details[0]['index'], input_image.numpy())\n", " # Invoke inference.\n", " interpreter.invoke()\n", " # Get the model prediction.\n", " keypoints_with_scores = interpreter.get_tensor(output_details[0]['index'])\n", " return keypoints_with_scores\n", "\n", "else:\n", " if \"movenet_lightning\" in model_name:\n", " module = hub.load(\"https://tfhub.dev/google/movenet/singlepose/lightning/4\")\n", " input_size = 192\n", " elif \"movenet_thunder\" in model_name:\n", " module = hub.load(\"https://tfhub.dev/google/movenet/singlepose/thunder/4\")\n", " input_size = 256\n", " else:\n", " raise ValueError(\"Unsupported model name: %s\" % model_name)\n", "\n", " def movenet(input_image):\n", " \"\"\"Runs detection on an input image.\n", "\n", " Args:\n", " input_image: A [1, height, width, 3] tensor represents the input image\n", " pixels. Note that the height/width should already be resized and match the\n", " expected input resolution of the model before passing into this function.\n", "\n", " Returns:\n", " A [1, 1, 17, 3] float numpy array representing the predicted keypoint\n", " coordinates and scores.\n", " \"\"\"\n", " model = module.signatures['serving_default']\n", "\n", " # SavedModel format expects tensor type of int32.\n", " input_image = tf.cast(input_image, dtype=tf.int32)\n", " # Run model inference.\n", " outputs = model(input_image)\n", " # Output is a [1, 1, 17, 3] tensor.\n", " keypoints_with_scores = outputs['output_0'].numpy()\n", " return keypoints_with_scores" ] }, { "cell_type": "markdown", "metadata": { "id": "-h1qHYaqD9ap" }, "source": [ "## 単一画像の例" ] }, { "cell_type": "markdown", "metadata": { "id": "ymTVR2I9x22I" }, "source": [ "このセッションでは、17 の人間のキーポイントを予測するために**単一画像**でモデルを実行する最小限の作業例を説明します。" ] }, { "cell_type": "markdown", "metadata": { "id": "5I3xBq80E3N_" }, "source": [ "### 入力画像を読み込む" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:42.707822Z", "iopub.status.busy": "2024-01-11T19:50:42.707521Z", "iopub.status.idle": "2024-01-11T19:50:43.037204Z", "shell.execute_reply": "2024-01-11T19:50:43.036021Z" }, "id": "GMO4B-wx5psP" }, "outputs": [], "source": [ "!curl -o input_image.jpeg https://images.pexels.com/photos/4384679/pexels-photo-4384679.jpeg --silent" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:43.042008Z", "iopub.status.busy": "2024-01-11T19:50:43.041362Z", "iopub.status.idle": "2024-01-11T19:50:43.187814Z", "shell.execute_reply": "2024-01-11T19:50:43.186986Z" }, "id": "lJZYQ8KYFQ6x" }, "outputs": [], "source": [ "# 入力画像を読み込む。\n", "image_path = 'input_image.jpeg'\n", "image = tf.io.read_file(image_path)\n", "image = tf.image.decode_jpeg(image)" ] }, { "cell_type": "markdown", "metadata": { "id": "S_UWRdQxE6WN" }, "source": [ "### 推論を実行する" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:43.192019Z", "iopub.status.busy": "2024-01-11T19:50:43.191471Z", "iopub.status.idle": "2024-01-11T19:50:46.895359Z", "shell.execute_reply": "2024-01-11T19:50:46.894684Z" }, "id": "VHmTwACwFW-v" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Resize and pad the image to keep the aspect ratio and fit the expected size.\n", "input_image = tf.expand_dims(image, axis=0)\n", "input_image = tf.image.resize_with_pad(input_image, input_size, input_size)\n", "\n", "# Run model inference.\n", "keypoints_with_scores = movenet(input_image)\n", "\n", "# Visualize the predictions with image.\n", "display_image = tf.expand_dims(image, axis=0)\n", "display_image = tf.cast(tf.image.resize_with_pad(\n", " display_image, 1280, 1280), dtype=tf.int32)\n", "output_overlay = draw_prediction_on_image(\n", " np.squeeze(display_image.numpy(), axis=0), keypoints_with_scores)\n", "\n", "plt.figure(figsize=(5, 5))\n", "plt.imshow(output_overlay)\n", "_ = plt.axis('off')" ] }, { "cell_type": "markdown", "metadata": { "id": "rKm-B0eMYeg8" }, "source": [ "## ビデオ (画像シーケンス) 例" ] }, { "cell_type": "markdown", "metadata": { "id": "gdPFXabLyiKv" }, "source": [ "このセクションでは、入力が一連のフレームである場合に、前のフレームからの検出に基づいてインテリジェントなクロッピングを適用する方法について説明します。これにより、モデルは注意とリソースをメインの対象に集中させることができ、速度を損なうことなくより優れた品質の予測が可能になります。\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-01-11T19:50:46.900978Z", "iopub.status.busy": "2024-01-11T19:50:46.900707Z", "iopub.status.idle": "2024-01-11T19:50:46.918413Z", "shell.execute_reply": "2024-01-11T19:50:46.917781Z" }, "id": "SYFdK-JHYhrv" }, "outputs": [], "source": [ "#@title Cropping Algorithm\n", "\n", "# Confidence score to determine whether a keypoint prediction is reliable.\n", "MIN_CROP_KEYPOINT_SCORE = 0.2\n", "\n", "def init_crop_region(image_height, image_width):\n", " \"\"\"Defines the default crop region.\n", "\n", " The function provides the initial crop region (pads the full image from both\n", " sides to make it a square image) when the algorithm cannot reliably determine\n", " the crop region from the previous frame.\n", " \"\"\"\n", " if image_width > image_height:\n", " box_height = image_width / image_height\n", " box_width = 1.0\n", " y_min = (image_height / 2 - image_width / 2) / image_height\n", " x_min = 0.0\n", " else:\n", " box_height = 1.0\n", " box_width = image_height / image_width\n", " y_min = 0.0\n", " x_min = (image_width / 2 - image_height / 2) / image_width\n", "\n", " return {\n", " 'y_min': y_min,\n", " 'x_min': x_min,\n", " 'y_max': y_min + box_height,\n", " 'x_max': x_min + box_width,\n", " 'height': box_height,\n", " 'width': box_width\n", " }\n", "\n", "def torso_visible(keypoints):\n", " \"\"\"Checks whether there are enough torso keypoints.\n", "\n", " This function checks whether the model is confident at predicting one of the\n", " shoulders/hips which is required to determine a good crop region.\n", " \"\"\"\n", " return ((keypoints[0, 0, KEYPOINT_DICT['left_hip'], 2] >\n", " MIN_CROP_KEYPOINT_SCORE or\n", " keypoints[0, 0, KEYPOINT_DICT['right_hip'], 2] >\n", " MIN_CROP_KEYPOINT_SCORE) and\n", " (keypoints[0, 0, KEYPOINT_DICT['left_shoulder'], 2] >\n", " MIN_CROP_KEYPOINT_SCORE or\n", " keypoints[0, 0, KEYPOINT_DICT['right_shoulder'], 2] >\n", " MIN_CROP_KEYPOINT_SCORE))\n", "\n", "def determine_torso_and_body_range(\n", " keypoints, target_keypoints, center_y, center_x):\n", " \"\"\"Calculates the maximum distance from each keypoints to the center location.\n", "\n", " The function returns the maximum distances from the two sets of keypoints:\n", " full 17 keypoints and 4 torso keypoints. The returned information will be\n", " used to determine the crop size. See determineCropRegion for more detail.\n", " \"\"\"\n", " torso_joints = ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip']\n", " max_torso_yrange = 0.0\n", " max_torso_xrange = 0.0\n", " for joint in torso_joints:\n", " dist_y = abs(center_y - target_keypoints[joint][0])\n", " dist_x = abs(center_x - target_keypoints[joint][1])\n", " if dist_y > max_torso_yrange:\n", " max_torso_yrange = dist_y\n", " if dist_x > max_torso_xrange:\n", " max_torso_xrange = dist_x\n", "\n", " max_body_yrange = 0.0\n", " max_body_xrange = 0.0\n", " for joint in KEYPOINT_DICT.keys():\n", " if keypoints[0, 0, KEYPOINT_DICT[joint], 2] < MIN_CROP_KEYPOINT_SCORE:\n", " continue\n", " dist_y = abs(center_y - target_keypoints[joint][0]);\n", " dist_x = abs(center_x - target_keypoints[joint][1]);\n", " if dist_y > max_body_yrange:\n", " max_body_yrange = dist_y\n", "\n", " if dist_x > max_body_xrange:\n", " max_body_xrange = dist_x\n", "\n", " return [max_torso_yrange, max_torso_xrange, max_body_yrange, max_body_xrange]\n", "\n", "def determine_crop_region(\n", " keypoints, image_height,\n", " image_width):\n", " \"\"\"Determines the region to crop the image for the model to run inference on.\n", "\n", " The algorithm uses the detected joints from the previous frame to estimate\n", " the square region that encloses the full body of the target person and\n", " centers at the midpoint of two hip joints. The crop size is determined by\n", " the distances between each joints and the center point.\n", " When the model is not confident with the four torso joint predictions, the\n", " function returns a default crop which is the full image padded to square.\n", " \"\"\"\n", " target_keypoints = {}\n", " for joint in KEYPOINT_DICT.keys():\n", " target_keypoints[joint] = [\n", " keypoints[0, 0, KEYPOINT_DICT[joint], 0] * image_height,\n", " keypoints[0, 0, KEYPOINT_DICT[joint], 1] * image_width\n", " ]\n", "\n", " if torso_visible(keypoints):\n", " center_y = (target_keypoints['left_hip'][0] +\n", " target_keypoints['right_hip'][0]) / 2;\n", " center_x = (target_keypoints['left_hip'][1] +\n", " target_keypoints['right_hip'][1]) / 2;\n", "\n", " (max_torso_yrange, max_torso_xrange,\n", " max_body_yrange, max_body_xrange) = determine_torso_and_body_range(\n", " keypoints, target_keypoints, center_y, center_x)\n", "\n", " crop_length_half = np.amax(\n", " [max_torso_xrange * 1.9, max_torso_yrange * 1.9,\n", " max_body_yrange * 1.2, max_body_xrange * 1.2])\n", "\n", " tmp = np.array(\n", " [center_x, image_width - center_x, center_y, image_height - center_y])\n", " crop_length_half = np.amin(\n", " [crop_length_half, np.amax(tmp)]);\n", "\n", " crop_corner = [center_y - crop_length_half, center_x - crop_length_half];\n", "\n", " if crop_length_half > max(image_width, image_height) / 2:\n", " return init_crop_region(image_height, image_width)\n", " else:\n", " crop_length = crop_length_half * 2;\n", " return {\n", " 'y_min': crop_corner[0] / image_height,\n", " 'x_min': crop_corner[1] / image_width,\n", " 'y_max': (crop_corner[0] + crop_length) / image_height,\n", " 'x_max': (crop_corner[1] + crop_length) / image_width,\n", " 'height': (crop_corner[0] + crop_length) / image_height -\n", " crop_corner[0] / image_height,\n", " 'width': (crop_corner[1] + crop_length) / image_width -\n", " crop_corner[1] / image_width\n", " }\n", " else:\n", " return init_crop_region(image_height, image_width)\n", "\n", "def crop_and_resize(image, crop_region, crop_size):\n", " \"\"\"Crops and resize the image to prepare for the model input.\"\"\"\n", " boxes=[[crop_region['y_min'], crop_region['x_min'],\n", " crop_region['y_max'], crop_region['x_max']]]\n", " output_image = tf.image.crop_and_resize(\n", " image, box_indices=[0], boxes=boxes, crop_size=crop_size)\n", " return output_image\n", "\n", "def run_inference(movenet, image, crop_region, crop_size):\n", " \"\"\"Runs model inferece on the cropped region.\n", "\n", " The function runs the model inference on the cropped region and updates the\n", " model output to the original image coordinate system.\n", " \"\"\"\n", " image_height, image_width, _ = image.shape\n", " input_image = crop_and_resize(\n", " tf.expand_dims(image, axis=0), crop_region, crop_size=crop_size)\n", " # Run model inference.\n", " keypoints_with_scores = movenet(input_image)\n", " # Update the coordinates.\n", " for idx in range(17):\n", " keypoints_with_scores[0, 0, idx, 0] = (\n", " crop_region['y_min'] * image_height +\n", " crop_region['height'] * image_height *\n", " keypoints_with_scores[0, 0, idx, 0]) / image_height\n", " keypoints_with_scores[0, 0, idx, 1] = (\n", " crop_region['x_min'] * image_width +\n", " crop_region['width'] * image_width *\n", " keypoints_with_scores[0, 0, idx, 1]) / image_width\n", " return keypoints_with_scores" ] }, { "cell_type": "markdown", "metadata": { "id": "L2JmA1xAEntQ" }, "source": [ "### 入力画像シーケンスを読み込む" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:46.921818Z", "iopub.status.busy": "2024-01-11T19:50:46.921153Z", "iopub.status.idle": "2024-01-11T19:50:48.303023Z", "shell.execute_reply": "2024-01-11T19:50:48.301989Z" }, "id": "CzJxbxDckWl2" }, "outputs": [], "source": [ "!wget -q -O dance.gif https://github.com/tensorflow/tfjs-models/raw/master/pose-detection/assets/dance_input.gif" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:48.307314Z", "iopub.status.busy": "2024-01-11T19:50:48.306997Z", "iopub.status.idle": "2024-01-11T19:50:48.463413Z", "shell.execute_reply": "2024-01-11T19:50:48.462746Z" }, "id": "IxbMFZJUkd6W" }, "outputs": [], "source": [ "# 入力画像を読み込む。\n", "image_path = 'dance.gif'\n", "image = tf.io.read_file(image_path)\n", "image = tf.image.decode_gif(image)" ] }, { "cell_type": "markdown", "metadata": { "id": "CJKeQ4siEtU9" }, "source": [ "### クロップアルゴリズムで推論を実行する" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-01-11T19:50:48.467804Z", "iopub.status.busy": "2024-01-11T19:50:48.467189Z", "iopub.status.idle": "2024-01-11T19:50:57.250126Z", "shell.execute_reply": "2024-01-11T19:50:57.249153Z" }, "id": "9B57XS0NZPIy" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " 41\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_128592/2693263076.py:162: MatplotlibDeprecationWarning: The tostring_rgb function was deprecated in Matplotlib 3.8 and will be removed two minor releases later. Use buffer_rgba instead.\n", " image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the input image.\n", "num_frames, image_height, image_width, _ = image.shape\n", "crop_region = init_crop_region(image_height, image_width)\n", "\n", "output_images = []\n", "bar = display(progress(0, num_frames-1), display_id=True)\n", "for frame_idx in range(num_frames):\n", " keypoints_with_scores = run_inference(\n", " movenet, image[frame_idx, :, :, :], crop_region,\n", " crop_size=[input_size, input_size])\n", " output_images.append(draw_prediction_on_image(\n", " image[frame_idx, :, :, :].numpy().astype(np.int32),\n", " keypoints_with_scores, crop_region=None,\n", " close_figure=True, output_image_height=300))\n", " crop_region = determine_crop_region(\n", " keypoints_with_scores, image_height, image_width)\n", " bar.update(progress(frame_idx, num_frames-1))\n", "\n", "# Prepare gif visualization.\n", "output = np.stack(output_images, axis=0)\n", "to_gif(output, duration=100)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "9u_VGR6_BmbZ", "5I3xBq80E3N_", "L2JmA1xAEntQ" ], "name": "movenet.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }