{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T22:54:41.696048Z", "iopub.status.busy": "2022-12-14T22:54:41.695591Z", "iopub.status.idle": "2022-12-14T22:54:41.699227Z", "shell.execute_reply": "2022-12-14T22:54:41.698688Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "VZ-KA8k5kybx" }, "source": [ "# テンソルスライスの基礎" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード
" ] }, { "cell_type": "markdown", "metadata": { "id": "AixIdVeRk3CO" }, "source": [ "オブジェクト検出や NLP などの機械学習アプリケーションでは、テンソルのサブセクション(スライス)を使用する必要がある場合があります。たとえば、モデルアーキテクチャにルーティングが含まれている場合、1 つのレイヤーが次のレイヤーにルーティングされるトレーニングサンプルを制御することがあります。この場合、テンソルスライス演算を使用して、テンソルを分割し、正しい順序に戻すことができます。\n", "\n", "NLP アプリケーションでは、トレーニング時にテンソルスライスを使用してワードマスキングを実行できます。たとえば、各文でマスクする単語インデックスを選択し、その単語をラベルとして取り出し、選択した単語をマスクトークンに置き換えることで、文のリストからトレーニングデータを生成できます。\n", "\n", "このガイドでは、TensorFlow API を使用して次を実行する方法を学習します。\n", "\n", "- テンソルからスライスを抽出する\n", "- テンソルの特定のインデックスにデータを挿入する\n", "\n", "このガイドは、テンソルのインデキシングに精通していることを前提としています。このガイドを開始する前に、[テンソル](https://www.tensorflow.org/guide/tensor#indexing)、および [TensorFlow NumPy](https://www.tensorflow.org/guide/tf_numpy#indexing) ガイドのインデキシングセクションをお読みください。" ] }, { "cell_type": "markdown", "metadata": { "id": "FcWhWYn7eXkF" }, "source": [ "## セットアップ\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:41.702546Z", "iopub.status.busy": "2022-12-14T22:54:41.702292Z", "iopub.status.idle": "2022-12-14T22:54:43.605973Z", "shell.execute_reply": "2022-12-14T22:54:43.605276Z" }, "id": "m6uvewqi0jso" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 22:54:42.642753: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:54:42.642872: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 22:54:42.642881: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": { "id": "K-muS4ej5zoN" }, "source": [ "## テンソルスライスを抽出する\n", "\n", "`tf.slice` を使用して、NumPy のようなテンソルスライスを実行します。\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:43.610447Z", "iopub.status.busy": "2022-12-14T22:54:43.610031Z", "iopub.status.idle": "2022-12-14T22:54:46.971934Z", "shell.execute_reply": "2022-12-14T22:54:46.971229Z" }, "id": "wZep0cjs0Oai" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n" ] } ], "source": [ "t1 = tf.constant([0, 1, 2, 3, 4, 5, 6, 7])\n", "\n", "print(tf.slice(t1,\n", " begin=[1],\n", " size=[3]))" ] }, { "cell_type": "markdown", "metadata": { "id": "Vh3xI3j0DRJ2" }, "source": [ "または、Python 構文を使用することもできます。テンソルスライスは、開始から終了までの範囲で等間隔に配置されていることに注意してください。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:46.975544Z", "iopub.status.busy": "2022-12-14T22:54:46.975058Z", "iopub.status.idle": "2022-12-14T22:54:46.980907Z", "shell.execute_reply": "2022-12-14T22:54:46.980351Z" }, "id": "P1MtEyKuWuDD" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n" ] } ], "source": [ "print(t1[1:4])" ] }, { "cell_type": "markdown", "metadata": { "id": "cjq1o8D2wKKs" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:46.984326Z", "iopub.status.busy": "2022-12-14T22:54:46.983884Z", "iopub.status.idle": "2022-12-14T22:54:46.988959Z", "shell.execute_reply": "2022-12-14T22:54:46.988387Z" }, "id": "UunuLTIuwDA-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([5 6 7], shape=(3,), dtype=int32)\n" ] } ], "source": [ "print(t1[-3:])" ] }, { "cell_type": "markdown", "metadata": { "id": "EHvRB-XTwRTd" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "SW1zFFTnUpCQ" }, "source": [ "2 次元テンソルの場合、次を使用できます。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:46.992473Z", "iopub.status.busy": "2022-12-14T22:54:46.991877Z", "iopub.status.idle": "2022-12-14T22:54:46.997952Z", "shell.execute_reply": "2022-12-14T22:54:46.997379Z" }, "id": "kThZhmpAVAQw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 1 2]\n", " [ 6 7]\n", " [11 12]], shape=(3, 2), dtype=int32)\n" ] } ], "source": [ "t2 = tf.constant([[0, 1, 2, 3, 4],\n", " [5, 6, 7, 8, 9],\n", " [10, 11, 12, 13, 14],\n", " [15, 16, 17, 18, 19]])\n", "\n", "print(t2[:-1, 1:3])" ] }, { "cell_type": "markdown", "metadata": { "id": "xA5Xt4OdVUui" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "iJPggqsH15fI" }, "source": [ "高次元テンソルでも `tf.slice` を使用できます。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.001234Z", "iopub.status.busy": "2022-12-14T22:54:47.000783Z", "iopub.status.idle": "2022-12-14T22:54:47.005791Z", "shell.execute_reply": "2022-12-14T22:54:47.005180Z" }, "id": "Re5eX1OXnKOZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([[[25 27]]], shape=(1, 1, 2), dtype=int32)\n" ] } ], "source": [ "t3 = tf.constant([[[1, 3, 5, 7],\n", " [9, 11, 13, 15]],\n", " [[17, 19, 21, 23],\n", " [25, 27, 29, 31]]\n", " ])\n", "\n", "print(tf.slice(t3,\n", " begin=[1, 1, 0],\n", " size=[1, 1, 2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "x-O5FNV9qOJK" }, "source": [ "また、`tf.strided_slice` を使用して、テンソルの次元をストライドすることでテンソルのスライスを抽出することもできます。" ] }, { "cell_type": "markdown", "metadata": { "id": "b9FhvrOnJsJb" }, "source": [ "`tf.gather` を使用して、テンソルの 1 つの軸から特定のインデックスを抽出します。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.009028Z", "iopub.status.busy": "2022-12-14T22:54:47.008544Z", "iopub.status.idle": "2022-12-14T22:54:47.021982Z", "shell.execute_reply": "2022-12-14T22:54:47.021314Z" }, "id": "TwviZrrIj2h7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([0 3 6], shape=(3,), dtype=int32)\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(tf.gather(t1,\n", " indices=[0, 3, 6]))\n", "\n", "# This is similar to doing\n", "\n", "t1[::3]" ] }, { "cell_type": "markdown", "metadata": { "id": "oKyjGi2zyzEC" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "obrjeKy1WfTN" }, "source": [ "`tf.gather` では、インデックスが等間隔である必要はありません。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.025301Z", "iopub.status.busy": "2022-12-14T22:54:47.024858Z", "iopub.status.idle": "2022-12-14T22:54:47.030524Z", "shell.execute_reply": "2022-12-14T22:54:47.029942Z" }, "id": "LjJcwcZ0druw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([b'c' b'a' b't' b's'], shape=(4,), dtype=string)\n" ] } ], "source": [ "alphabet = tf.constant(list('abcdefghijklmnopqrstuvwxyz'))\n", "\n", "print(tf.gather(alphabet,\n", " indices=[2, 0, 19, 18]))" ] }, { "cell_type": "markdown", "metadata": { "id": "mSHmUXIyeaJG" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "XsxMx49SOaVu" }, "source": [ "テンソルの複数の軸からスライスを抽出するには、`tf.gather_nd` を使用します。これは、行または列だけではなく、行列の要素を収集する場合に役立ちます。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.033862Z", "iopub.status.busy": "2022-12-14T22:54:47.033363Z", "iopub.status.idle": "2022-12-14T22:54:47.039078Z", "shell.execute_reply": "2022-12-14T22:54:47.038404Z" }, "id": "mT52NFWVdiTe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[2 7]\n", " [3 8]\n", " [0 5]], shape=(3, 2), dtype=int32)\n" ] } ], "source": [ "t4 = tf.constant([[0, 5],\n", " [1, 6],\n", " [2, 7],\n", " [3, 8],\n", " [4, 9]])\n", "\n", "print(tf.gather_nd(t4,\n", " indices=[[2], [3], [0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "87NN7YQhh2-a" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.042176Z", "iopub.status.busy": "2022-12-14T22:54:47.041608Z", "iopub.status.idle": "2022-12-14T22:54:47.048261Z", "shell.execute_reply": "2022-12-14T22:54:47.047644Z" }, "id": "_z6F2WcPJ9Rh" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([ 0 16], shape=(2,), dtype=int64)\n" ] } ], "source": [ "t5 = np.reshape(np.arange(18), [2, 3, 3])\n", "\n", "print(tf.gather_nd(t5,\n", " indices=[[0, 0, 0], [1, 2, 1]]))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.051228Z", "iopub.status.busy": "2022-12-14T22:54:47.050837Z", "iopub.status.idle": "2022-12-14T22:54:47.055294Z", "shell.execute_reply": "2022-12-14T22:54:47.054695Z" }, "id": "gyIjhm7cV2N0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[[ 0 1 2]\n", " [ 6 7 8]]\n", "\n", " [[ 9 10 11]\n", " [15 16 17]]], shape=(2, 2, 3), dtype=int64)\n" ] } ], "source": [ "# Return a list of two matrices\n", "\n", "print(tf.gather_nd(t5,\n", " indices=[[[0, 0], [0, 2]], [[1, 0], [1, 2]]]))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.058447Z", "iopub.status.busy": "2022-12-14T22:54:47.057968Z", "iopub.status.idle": "2022-12-14T22:54:47.062256Z", "shell.execute_reply": "2022-12-14T22:54:47.061637Z" }, "id": "368D4ciDWB3r" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0 1 2]\n", " [ 6 7 8]\n", " [ 9 10 11]\n", " [15 16 17]], shape=(4, 3), dtype=int64)\n" ] } ], "source": [ "# Return one matrix\n", "\n", "print(tf.gather_nd(t5,\n", " indices=[[0, 0], [0, 2], [1, 0], [1, 2]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "od51VzS2SSPS" }, "source": [ "## データをテンソルに挿入する\n", "\n", "`tf.scatter_nd` を使用して、テンソルの特定のスライス/インデックスにデータを挿入します。値を挿入するテンソルはゼロで初期化されることに注意してください。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.065371Z", "iopub.status.busy": "2022-12-14T22:54:47.064880Z", "iopub.status.idle": "2022-12-14T22:54:47.070526Z", "shell.execute_reply": "2022-12-14T22:54:47.069949Z" }, "id": "jlALYLWm1KhN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor([ 0 2 0 4 0 6 0 8 0 10], shape=(10,), dtype=int32)\n" ] } ], "source": [ "t6 = tf.constant([10])\n", "indices = tf.constant([[1], [3], [5], [7], [9]])\n", "data = tf.constant([2, 4, 6, 8, 10])\n", "\n", "print(tf.scatter_nd(indices=indices,\n", " updates=data,\n", " shape=t6))" ] }, { "cell_type": "markdown", "metadata": { "id": "CD5vd-kxksW7" }, "source": [ "ゼロで初期化されたテンソルを必要とする `tf.scatter_nd` のようなメソッドは、スパーステンソルのイニシャライザに似ています。`tf.gather_nd` と `tf.scatter_nd` を使用して、スパーステンソル演算の動作を模倣できます。\n", "\n", "これら 2 つのメソッドを組み合わせて使用してスパーステンソルを構築する例を考えてみましょう。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.073543Z", "iopub.status.busy": "2022-12-14T22:54:47.073079Z", "iopub.status.idle": "2022-12-14T22:54:47.076742Z", "shell.execute_reply": "2022-12-14T22:54:47.076191Z" }, "id": "xyK69QgRmrlW" }, "outputs": [], "source": [ "# Gather values from one tensor by specifying indices\n", "\n", "new_indices = tf.constant([[0, 2], [2, 1], [3, 3]])\n", "t7 = tf.gather_nd(t2, indices=new_indices)" ] }, { "cell_type": "markdown", "metadata": { "id": "_7V_Qfa4qkdn" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.080139Z", "iopub.status.busy": "2022-12-14T22:54:47.079600Z", "iopub.status.idle": "2022-12-14T22:54:47.083679Z", "shell.execute_reply": "2022-12-14T22:54:47.083118Z" }, "id": "QWT1E1Weqjx2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0 0 2 0 0]\n", " [ 0 0 0 0 0]\n", " [ 0 11 0 0 0]\n", " [ 0 0 0 18 0]], shape=(4, 5), dtype=int32)\n" ] } ], "source": [ "# Add these values into a new tensor\n", "\n", "t8 = tf.scatter_nd(indices=new_indices, updates=t7, shape=tf.constant([4, 5]))\n", "\n", "print(t8)" ] }, { "cell_type": "markdown", "metadata": { "id": "NUyYjnvCn_vu" }, "source": [ "これは次と同様です。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.086875Z", "iopub.status.busy": "2022-12-14T22:54:47.086345Z", "iopub.status.idle": "2022-12-14T22:54:47.091186Z", "shell.execute_reply": "2022-12-14T22:54:47.090641Z" }, "id": "LeqFwUgroE4j" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SparseTensor(indices=tf.Tensor(\n", "[[0 2]\n", " [2 1]\n", " [3 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([ 2 11 18], shape=(3,), dtype=int32), dense_shape=tf.Tensor([4 5], shape=(2,), dtype=int64))\n" ] } ], "source": [ "t9 = tf.SparseTensor(indices=[[0, 2], [2, 1], [3, 3]],\n", " values=[2, 11, 18],\n", " dense_shape=[4, 5])\n", "\n", "print(t9)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.094240Z", "iopub.status.busy": "2022-12-14T22:54:47.093725Z", "iopub.status.idle": "2022-12-14T22:54:47.100961Z", "shell.execute_reply": "2022-12-14T22:54:47.100396Z" }, "id": "5MaF6RlJot33" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[ 0 0 2 0 0]\n", " [ 0 0 0 0 0]\n", " [ 0 11 0 0 0]\n", " [ 0 0 0 18 0]], shape=(4, 5), dtype=int32)\n" ] } ], "source": [ "# Convert the sparse tensor into a dense tensor\n", "\n", "t10 = tf.sparse.to_dense(t9)\n", "\n", "print(t10)" ] }, { "cell_type": "markdown", "metadata": { "id": "4sf3F3Xk56Bt" }, "source": [ "既存の値を持つテンソルにデータを挿入するには、`tf.tensor_scatter_nd_add` を使用します。" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.104169Z", "iopub.status.busy": "2022-12-14T22:54:47.103731Z", "iopub.status.idle": "2022-12-14T22:54:47.109396Z", "shell.execute_reply": "2022-12-14T22:54:47.108845Z" }, "id": "mte2ifOb6sQO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[2 7 6]\n", " [9 5 1]\n", " [4 3 8]], shape=(3, 3), dtype=int32)\n" ] } ], "source": [ "t11 = tf.constant([[2, 7, 0],\n", " [9, 0, 1],\n", " [0, 3, 8]])\n", "\n", "# Convert the tensor into a magic square by inserting numbers at appropriate indices\n", "\n", "t12 = tf.tensor_scatter_nd_add(t11,\n", " indices=[[0, 2], [1, 1], [2, 0]],\n", " updates=[6, 5, 4])\n", "\n", "print(t12)" ] }, { "cell_type": "markdown", "metadata": { "id": "2dQYyROU09G6" }, "source": [ "同様に、`tf.tensor_scatter_nd_sub` を使用して、既存の値を持つテンソルから値を減算します。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.112775Z", "iopub.status.busy": "2022-12-14T22:54:47.112234Z", "iopub.status.idle": "2022-12-14T22:54:47.117926Z", "shell.execute_reply": "2022-12-14T22:54:47.117375Z" }, "id": "ac6_i6uK1EI6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[1 0 0]\n", " [0 1 0]\n", " [0 0 1]], shape=(3, 3), dtype=int32)\n" ] } ], "source": [ "# Convert the tensor into an identity matrix\n", "\n", "t13 = tf.tensor_scatter_nd_sub(t11,\n", " indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [2, 2]],\n", " updates=[1, 7, 9, -1, 1, 3, 7])\n", "\n", "print(t13)" ] }, { "cell_type": "markdown", "metadata": { "id": "B_2DuzRRwVc8" }, "source": [ "`tf.tensor_scatter_nd_min` を使用して、要素ごとの最小値をあるテンソルから別のテンソルにコピーします。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.121140Z", "iopub.status.busy": "2022-12-14T22:54:47.120640Z", "iopub.status.idle": "2022-12-14T22:54:47.126880Z", "shell.execute_reply": "2022-12-14T22:54:47.126319Z" }, "id": "T_4FrHrHlkHK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[-2 -7 -6]\n", " [-9 -5 1]\n", " [-4 -3 -8]], shape=(3, 3), dtype=int32)\n" ] } ], "source": [ "t14 = tf.constant([[-2, -7, 0],\n", " [-9, 0, 1],\n", " [0, -3, -8]])\n", "\n", "t15 = tf.tensor_scatter_nd_min(t14,\n", " indices=[[0, 2], [1, 1], [2, 0]],\n", " updates=[-6, -5, -4])\n", "\n", "print(t15)" ] }, { "cell_type": "markdown", "metadata": { "id": "PkaiKyrF0WtX" }, "source": [ "同様に、`tf.tensor_scatter_nd_max` を使用して、要素ごとの最大値をあるテンソルから別のテンソルにコピーします。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T22:54:47.130101Z", "iopub.status.busy": "2022-12-14T22:54:47.129569Z", "iopub.status.idle": "2022-12-14T22:54:47.135062Z", "shell.execute_reply": "2022-12-14T22:54:47.134444Z" }, "id": "izJu0nXi0GDq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[-2 -7 6]\n", " [-9 5 1]\n", " [ 4 -3 -8]], shape=(3, 3), dtype=int32)\n" ] } ], "source": [ "t16 = tf.tensor_scatter_nd_max(t14,\n", " indices=[[0, 2], [1, 1], [2, 0]],\n", " updates=[6, 5, 4])\n", "\n", "print(t16)" ] }, { "cell_type": "markdown", "metadata": { "id": "QAffUOa-85lF" }, "source": [ "## その他の資料とリソース\n", "\n", "このガイドでは、TensorFlow で利用可能なテンソルスライス演算を使用して、テンソル内の要素をより詳細に制御する方法を学びました。\n", "\n", "- `tf.experimental.numpy.take_along_axis` や `tf.experimental.numpy.take` など、TensorFlow NumPy で利用可能なスライス演算を確認してください。\n", "\n", "- また、[テンソルガイド](https://www.tensorflow.org/guide/tensor)と[変数ガイド](https://www.tensorflow.org/guide/variable)も参照してください。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "tensor_slicing.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }