{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Pnn4rDWGqDZL" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T21:16:30.042138Z", "iopub.status.busy": "2022-12-14T21:16:30.041673Z", "iopub.status.idle": "2022-12-14T21:16:30.045825Z", "shell.execute_reply": "2022-12-14T21:16:30.045209Z" }, "id": "l534d35Gp68G" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "3TI3Q3XBesaS" }, "source": [ "# 체크포인트 훈련하기" ] }, { "cell_type": "markdown", "metadata": { "id": "yw_a0iGucY8z" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org에서 보기 Google Colab에서 실행GitHub에서 소스 보기노트북 다운로드
" ] }, { "cell_type": "markdown", "metadata": { "id": "LeDp7dovcbus" }, "source": [ "\"텐서플로 모델 저장하기\" 라는 문구는 보통 둘중 하나를 의미합니다:\n", "\n", "1. Checkpoints, 혹은\n", "2. SavedModel.\n", "\n", "Checkpoint는 모델이 사용한 모든 매개변수(`tf.Variable` 객체들)의 정확한 값을 캡처합니다. Chekcpoint는 모델에 의해 정의된 연산에 대한 설명을 포함하지 않으므로 일반적으로 저장된 매개변수 값을 사용할 소스 코드를 사용할 수 있을 때만 유용합니다.\n", "\n", "반면 SavedModel 형식은 매개변수 값(체크포인트) 외에 모델에 의해 정의된 연산에 대한 일련화된 설명을 포함합니다. 이 형식의 모델은 모델을 만든 소스 코드와 독립적입니다. 따라서 TensorFlow Serving, TensorFlow Lite, TensorFlow.js 또는 다른 프로그래밍 언어(C, C++, Java, Go, Rust, C# 등. TensorFlow APIs)로 배포하기에 적합합니다.\n", "\n", "이 가이드는 체크포인트 쓰기 및 읽기를 위한 API들을 다룹니다." ] }, { "cell_type": "markdown", "metadata": { "id": "U0nm8k-6xfh2" }, "source": [ "## 설치" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:30.049334Z", "iopub.status.busy": "2022-12-14T21:16:30.048901Z", "iopub.status.idle": "2022-12-14T21:16:31.998940Z", "shell.execute_reply": "2022-12-14T21:16:31.998235Z" }, "id": "VEvpMYAKsC4z" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 21:16:31.009549: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:16:31.009664: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 21:16:31.009675: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:32.003421Z", "iopub.status.busy": "2022-12-14T21:16:32.002536Z", "iopub.status.idle": "2022-12-14T21:16:32.007883Z", "shell.execute_reply": "2022-12-14T21:16:32.006946Z" }, "id": "OEQCseyeC4Ev" }, "outputs": [], "source": [ "class Net(tf.keras.Model):\n", " \"\"\"A simple linear model.\"\"\"\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.l1 = tf.keras.layers.Dense(5)\n", "\n", " def call(self, x):\n", " return self.l1(x)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:32.011225Z", "iopub.status.busy": "2022-12-14T21:16:32.010756Z", "iopub.status.idle": "2022-12-14T21:16:35.379368Z", "shell.execute_reply": "2022-12-14T21:16:35.378676Z" }, "id": "utqeoDADC5ZR" }, "outputs": [], "source": [ "net = Net()" ] }, { "cell_type": "markdown", "metadata": { "id": "5vsq3-pffo1I" }, "source": [ "## `tf.keras` 훈련 API들로부터 저장하기\n", "\n", "See the [`tf.keras` guide on saving and restoring](https://www.tensorflow.org/guide/keras/save_and_serialize) .\n", "\n", "`tf.keras.Model.save_weights` 가 텐서플로 CheckPoint를 저장합니다. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.382966Z", "iopub.status.busy": "2022-12-14T21:16:35.382720Z", "iopub.status.idle": "2022-12-14T21:16:35.398756Z", "shell.execute_reply": "2022-12-14T21:16:35.398157Z" }, "id": "SuhmrYPEl4D_" }, "outputs": [], "source": [ "net.save_weights('easy_checkpoint')" ] }, { "cell_type": "markdown", "metadata": { "id": "XseWX5jDg4lQ" }, "source": [ "## Checkpoints 작성하기\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1jpZPz76ZP3K" }, "source": [ "텐서플로 모델의 지속적인 상태는 `tf.Variable` 객체에 저장되어 있습니다. 이들은 직접으로 구성할 수 있지만, `tf.keras.layers` 혹은 `tf.keras.Model`와 같은 고수준 API들로 만들어 지기도 합니다.\n", "\n", "변수를 관리하는 가장 쉬운 방법은 Python 객체에 변수를 연결한 다음 해당 객체를 참조하는 것입니다.\n", "\n", "`tf.train.Checkpoint`, `tf.keras.layers.Layer`, and `tf.keras.Model`의 하위클래스들은 해당 속성에 할당된 변수를 자동 추적합니다. 다음 예시는 간단한 선형 model을 구성하고, 모든 model 변수의 값을 포합하는 checkpoint를 씁니다." ] }, { "cell_type": "markdown", "metadata": { "id": "x0vFBr_Im73_" }, "source": [ "`Model.save_weights`를 사용해 손쉽게 model-checkpoint를 저장할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": { "id": "FHTJ1JzxCi8a" }, "source": [ "### 직접 Checkpoint작성하기" ] }, { "cell_type": "markdown", "metadata": { "id": "6cF9fqYOCrEO" }, "source": [ "#### 설치" ] }, { "cell_type": "markdown", "metadata": { "id": "fNjf9KaLdIRP" }, "source": [ "`tf.train.Checkpoint`의 모든 특성을 입증하기 위해서 toy dataset과 optimization step을 정의해야 합니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.402663Z", "iopub.status.busy": "2022-12-14T21:16:35.402192Z", "iopub.status.idle": "2022-12-14T21:16:35.406397Z", "shell.execute_reply": "2022-12-14T21:16:35.405590Z" }, "id": "tSNyP4IJ9nkU" }, "outputs": [], "source": [ "def toy_dataset():\n", " inputs = tf.range(10.)[:, None]\n", " labels = inputs * 5. + tf.range(5.)[None, :]\n", " return tf.data.Dataset.from_tensor_slices(\n", " dict(x=inputs, y=labels)).repeat().batch(2)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.409366Z", "iopub.status.busy": "2022-12-14T21:16:35.408904Z", "iopub.status.idle": "2022-12-14T21:16:35.413084Z", "shell.execute_reply": "2022-12-14T21:16:35.412322Z" }, "id": "ICm1cufh_JH8" }, "outputs": [], "source": [ "def train_step(net, example, optimizer):\n", " \"\"\"Trains `net` on `example` using `optimizer`.\"\"\"\n", " with tf.GradientTape() as tape:\n", " output = net(example['x'])\n", " loss = tf.reduce_mean(tf.abs(output - example['y']))\n", " variables = net.trainable_variables\n", " gradients = tape.gradient(loss, variables)\n", " optimizer.apply_gradients(zip(gradients, variables))\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "vxzGpHRbOVO6" }, "source": [ "#### Checkpoint객체 생성\n", "\n", "인위적으로 checkpoint를 만드려면 `tf.train.Checkpoint` 객체가 필요합니다. Checkpoint하고 싶은 객체의 위치는 객체의 특성으로 설정이 되어 있습니다.\n", "\n", "`tf.train.CheckpointManager`도 다수의 checkpoint를 관리할때 도움이 됩니다" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.416665Z", "iopub.status.busy": "2022-12-14T21:16:35.416028Z", "iopub.status.idle": "2022-12-14T21:16:35.451396Z", "shell.execute_reply": "2022-12-14T21:16:35.450835Z" }, "id": "ou5qarOQOWYl" }, "outputs": [], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "8ZbYSD4uCy96" }, "source": [ "#### 훈련하고 model checkpoint작성하기" ] }, { "cell_type": "markdown", "metadata": { "id": "NP9IySmCeCkn" }, "source": [ "다음 훈련 루프는 model과 optimizer의 인스턴스를 만든 후 `tf.train.Checkpoint` 객체에 수집합니다. 이것은 각 데이터 배치에 있는 루프의 훈련 단계를 호출하고, 주기적으로 디스크에 checkpoint를 작성합니다." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.455109Z", "iopub.status.busy": "2022-12-14T21:16:35.454460Z", "iopub.status.idle": "2022-12-14T21:16:35.459263Z", "shell.execute_reply": "2022-12-14T21:16:35.458522Z" }, "id": "BbCS5A6K1VSH" }, "outputs": [], "source": [ "def train_and_checkpoint(net, manager):\n", " ckpt.restore(manager.latest_checkpoint)\n", " if manager.latest_checkpoint:\n", " print(\"Restored from {}\".format(manager.latest_checkpoint))\n", " else:\n", " print(\"Initializing from scratch.\")\n", "\n", " for _ in range(50):\n", " example = next(iterator)\n", " loss = train_step(net, example, opt)\n", " ckpt.step.assign_add(1)\n", " if int(ckpt.step) % 10 == 0:\n", " save_path = manager.save()\n", " print(\"Saved checkpoint for step {}: {}\".format(int(ckpt.step), save_path))\n", " print(\"loss {:1.2f}\".format(loss.numpy()))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:35.462455Z", "iopub.status.busy": "2022-12-14T21:16:35.461862Z", "iopub.status.idle": "2022-12-14T21:16:36.403551Z", "shell.execute_reply": "2022-12-14T21:16:36.402626Z" }, "id": "Ik3IBMTdPW41" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing from scratch.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 10: ./tf_ckpts/ckpt-1\n", "loss 26.76\n", "Saved checkpoint for step 20: ./tf_ckpts/ckpt-2\n", "loss 20.18\n", "Saved checkpoint for step 30: ./tf_ckpts/ckpt-3\n", "loss 13.62\n", "Saved checkpoint for step 40: ./tf_ckpts/ckpt-4\n", "loss 7.20\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 50: ./tf_ckpts/ckpt-5\n", "loss 2.25\n" ] } ], "source": [ "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "2wzcc1xYN-sH" }, "source": [ "#### 복구하고 훈련 계속하기" ] }, { "cell_type": "markdown", "metadata": { "id": "lw1QeyRBgsLE" }, "source": [ "첫 번째 과정 이후 새로운 model과 매니저를 전달할 수 있지만, 일을 마무리 한 정확한 지점에서 훈련을 가져와야 합니다:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:36.407614Z", "iopub.status.busy": "2022-12-14T21:16:36.406953Z", "iopub.status.idle": "2022-12-14T21:16:36.929633Z", "shell.execute_reply": "2022-12-14T21:16:36.928827Z" }, "id": "UjilkTOV2PBK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Restored from ./tf_ckpts/ckpt-5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 60: ./tf_ckpts/ckpt-6\n", "loss 0.72\n", "Saved checkpoint for step 70: ./tf_ckpts/ckpt-7\n", "loss 0.56\n", "Saved checkpoint for step 80: ./tf_ckpts/ckpt-8\n", "loss 0.45\n", "Saved checkpoint for step 90: ./tf_ckpts/ckpt-9\n", "loss 0.41\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint for step 100: ./tf_ckpts/ckpt-10\n", "loss 0.26\n" ] } ], "source": [ "opt = tf.keras.optimizers.Adam(0.1)\n", "net = Net()\n", "dataset = toy_dataset()\n", "iterator = iter(dataset)\n", "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)\n", "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)\n", "\n", "train_and_checkpoint(net, manager)" ] }, { "cell_type": "markdown", "metadata": { "id": "dxJT9vV-2PnZ" }, "source": [ "`tf.train.CheckpointManager` 객체가 이전 checkpoint들을 제거합니다. 위는 가장 최근의 3개 checkpoint만 유지하도록 구성되어 있습니다." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:36.933155Z", "iopub.status.busy": "2022-12-14T21:16:36.932653Z", "iopub.status.idle": "2022-12-14T21:16:36.936572Z", "shell.execute_reply": "2022-12-14T21:16:36.935818Z" }, "id": "3zmM0a-F5XqC" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']\n" ] } ], "source": [ "print(manager.checkpoints) # List the three remaining checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "qwlYDyjemY4P" }, "source": [ "예를 들어, `'./tf_ckpts/ckpt-10'`같은 경로들은 디스크에 있는 파일이 아닙니다. 대신에 이 경로들은 `index` 파일과 변수 값들을 담고있는 파일들의 전위 표기입니다. 이 전위 표기들은 `CheckpointManager`가 상태를 저장하는 하나의 `checkpoint` 파일 (`'./tf_ckpts/checkpoint'`)에 그룹으로 묶여있습니다." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:36.940204Z", "iopub.status.busy": "2022-12-14T21:16:36.939618Z", "iopub.status.idle": "2022-12-14T21:16:37.113827Z", "shell.execute_reply": "2022-12-14T21:16:37.112767Z" }, "id": "t1feej9JntV_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint\t\t ckpt-8.data-00000-of-00001 ckpt-9.index\r\n", "ckpt-10.data-00000-of-00001 ckpt-8.index\r\n", "ckpt-10.index\t\t ckpt-9.data-00000-of-00001\r\n" ] } ], "source": [ "!ls ./tf_ckpts" ] }, { "cell_type": "markdown", "metadata": { "id": "DR2wQc9x6b3X" }, "source": [ "\n", "\n", "## 작동 원리\n", "\n", "텐서플로는 로드되는 객체에서 시작하여 명명된 엣지가 있는 방향 그래프를 통과시켜 변수를 checkpoint된 값과 일치시킵니다. 엣지의 이름들은 특히 기여한 객체의 이름에서 따왔습니다. 예를들면, `self.l1 = tf.keras.layers.Dense(5)`안의 `\"l1\"`. `tf.train.Checkpoint` 이것의 키워드 전달인자 이름을 사용했습니다, 여기에서는 `\"step\"` in `tf.train.Checkpoint(step=...)`.\n", "\n", "위의 예에서 나온 종속성 그래프는 다음과 같습니다.:\n", "\n", "![훈련 반복 예시의 의존 그래프 시각화](http://tensorflow.org/images/guide/whole_checkpoint.svg)\n", "\n", "optimizer는 빨간색으로, regular 변수는 파란색으로, optimizer 슬롯 변수는 주황색으로 표시합니다. 다른 nodes는, 예를 들면 `tf.train.Checkpoint`, 이 검은색임을 나타냅니다.\n", "\n", "슬롯 변수는 옵티마이저 상태의 일부이지만 특정 변수에 대해 생성됩니다. 예를 들어 위의 `''` 엣지는 Adam 옵티마이저가 각 변수에 대해 추적하는 모멘텀에 해당합니다. 슬롯 변수는 변수와 옵티마이저가 모두 저장되어 점선으로 된 엣지인 경우에만 체크포인트에 저장됩니다." ] }, { "cell_type": "markdown", "metadata": { "id": "VpY5IuanUEQ0" }, "source": [ "`tf.train.Checkpoint`로 불러온 `restore()` 오브젝트 큐는그`Checkpoint` 개체에서 일치하는 방법이 있습니다. 변수 값 복원을 요청한 복원 작업 대기 행렬로 정리합니다. 예를 들어, 우리는 네트워크와 계층을 통해 그것에 대한 하나의 경로를 재구성함으로서 위에서 정의한 모델에서 커널만 로드할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.117957Z", "iopub.status.busy": "2022-12-14T21:16:37.117700Z", "iopub.status.idle": "2022-12-14T21:16:37.132351Z", "shell.execute_reply": "2022-12-14T21:16:37.131707Z" }, "id": "wmX2AuyH7TVt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 0. 0. 0. 0.]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2.800037 1.7855548 3.146778 3.5596383 3.8949137]\n" ] } ], "source": [ "to_restore = tf.Variable(tf.zeros([5]))\n", "print(to_restore.numpy()) # All zeros\n", "fake_layer = tf.train.Checkpoint(bias=to_restore)\n", "fake_net = tf.train.Checkpoint(l1=fake_layer)\n", "new_root = tf.train.Checkpoint(net=fake_net)\n", "status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))\n", "print(to_restore.numpy()) # This gets the restored value." ] }, { "cell_type": "markdown", "metadata": { "id": "GqEW-_pJDAnE" }, "source": [ "이 새로운 개체에 대한 의존도 그래프는 우리가 위에 적은 더 큰 checkpoint보다 작은 하위 그래프입니다. 이것은 오직 `tf.train.Checkpoint`에서 checkpoints 셀때 편향과 저장 카운터만 포함합니다.\n", "\n", "![편향 변수의 서브그래프 시각화](http://tensorflow.org/images/guide/partial_checkpoint.svg)\n", "\n", "`restore()` 함수는 선택적으로 확인을 거친 객체의 상태를 반환합니다. 새로 만든 `Checkpoint`에서 우리가 만든 모든 개체가 복원되어 `status.assert_existing_objects_matched`가 통과합니다." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.135372Z", "iopub.status.busy": "2022-12-14T21:16:37.135148Z", "iopub.status.idle": "2022-12-14T21:16:37.142499Z", "shell.execute_reply": "2022-12-14T21:16:37.141775Z" }, "id": "P9TQXl81Dq5r" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "status.assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "GoMwf8CFDu9r" }, "source": [ "checkpoint에는 계층의 커널과 optimizer의 변수를 포함하여 일치하지 않는 많은 개체가 있습니다. `status.assert_consumed`는 checkpoint와 프로그램이 정확히 일치할 경우에만 통과하고 여기에 예외를 둘 것입니다." ] }, { "cell_type": "markdown", "metadata": { "id": "KCcmJ-2j9RUP" }, "source": [ "### 지연된 복원\n", "\n", "텐서플로우의 `Layer` 객체는 입력 형상을 이용할 수 있을 때 변수 생성을 첫 번째 호출로 지연시킬 수 있습니다. 예를 들어, `Dense` 층의 커널의 모양은 계층의 입력과 출력 형태 모두에 따라 달라지기 때문에, 생성자 인수로 필요한 출력 형태는 그 자체로 변수를 만들기에 충분한 정보가 아닙니다. 예를 들어, `Layer` 층의 커널의 모양은 계층의 입력과 출력 형태 모두에 따라 달라지기 때문에, 생성자 인수로 필요한 출력 형태는 그 자체로 변수를 만들기에 충분한 정보가 아닙니다.\n", "\n", "이 관용구를 지원하기 위해 `tf.train.Checkpoint`는 아직 일치하는 변수가 없는 복원을 연기합니다." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.146485Z", "iopub.status.busy": "2022-12-14T21:16:37.145916Z", "iopub.status.idle": "2022-12-14T21:16:37.153046Z", "shell.execute_reply": "2022-12-14T21:16:37.152354Z" }, "id": "TXYUCO3v-I72" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 0.]]\n", "[[4.575237 4.8758802 4.761698 4.9702144 4.997744 ]]\n" ] } ], "source": [ "deferred_restore = tf.Variable(tf.zeros([1, 5]))\n", "print(deferred_restore.numpy()) # Not restored; still zeros\n", "fake_layer.kernel = deferred_restore\n", "print(deferred_restore.numpy()) # Restored" ] }, { "cell_type": "markdown", "metadata": { "id": "-DWhJ3glyobN" }, "source": [ "### checkpoints 수동 검사\n", "\n", "`tf.train.load_checkpoint`는 체크포인트 내용에 대한 낮은 수준의 액세스를 제공 `CheckpointReader`를 반환합니다. 여기에는 각 변수의 키에서 검사점의 각 변수에 대한 모양 및 dtype으로의 매핑이 포함됩니다. 변수의 키는 위에 표시된 그래프와 같이 객체 경로입니다.\n", "\n", "참고: 체크포인트에는 더 높은 수준의 구조가 없습니다. 변수의 경로와 값만 알고 있으며 `models`, `layers` 또는 이들이 연결되는 방식에 대한 개념이 없습니다." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.156496Z", "iopub.status.busy": "2022-12-14T21:16:37.155904Z", "iopub.status.idle": "2022-12-14T21:16:37.161712Z", "shell.execute_reply": "2022-12-14T21:16:37.160863Z" }, "id": "RlRsADTezoBD" }, "outputs": [ { "data": { "text/plain": [ "['_CHECKPOINTABLE_OBJECT_GRAPH',\n", " 'iterator/.ATTRIBUTES/ITERATOR_STATE',\n", " 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/1/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',\n", " 'step/.ATTRIBUTES/VARIABLE_VALUE']" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reader = tf.train.load_checkpoint('./tf_ckpts/')\n", "shape_from_key = reader.get_variable_to_shape_map()\n", "dtype_from_key = reader.get_variable_to_dtype_map()\n", "\n", "sorted(shape_from_key.keys())" ] }, { "cell_type": "markdown", "metadata": { "id": "AVrdvbNvgq5V" }, "source": [ "따라서 `net.l1.kernel`의 값에 관심이 있다면 다음 코드로 이 값을 얻을 수 있습니다." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.164902Z", "iopub.status.busy": "2022-12-14T21:16:37.164373Z", "iopub.status.idle": "2022-12-14T21:16:37.168646Z", "shell.execute_reply": "2022-12-14T21:16:37.167908Z" }, "id": "lYhX_XWCgl92" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape: [1, 5]\n", "Dtype: float32\n" ] } ], "source": [ "key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'\n", "\n", "print(\"Shape:\", shape_from_key[key])\n", "print(\"Dtype:\", dtype_from_key[key].name)" ] }, { "cell_type": "markdown", "metadata": { "id": "2Zk92jM5gRDW" }, "source": [ "변수 값을 검사할 수 있는 `get_tensor` 메서드도 제공합니다." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.171834Z", "iopub.status.busy": "2022-12-14T21:16:37.171277Z", "iopub.status.idle": "2022-12-14T21:16:37.176359Z", "shell.execute_reply": "2022-12-14T21:16:37.175606Z" }, "id": "cDJO3cgmecvi" }, "outputs": [ { "data": { "text/plain": [ "array([[4.575237 , 4.8758802, 4.761698 , 4.9702144, 4.997744 ]],\n", " dtype=float32)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reader.get_tensor(key)" ] }, { "cell_type": "markdown", "metadata": { "id": "5fxk_BnZ4W1b" }, "source": [ "### 목록 및 딕셔너리 추적\n", "\n", "체크포인트는 속성 중 하나에 설정된 모든 변수 또는 추적 가능한 개체를 '추적'하여 `tf.Variable` 개체의 값을 저장하고 복원합니다. 저장을 실행할 때 도달 가능한 모든 추적 개체로부터 변수가 재귀적으로 수집됩니다.\n", "\n", "`self.l1 = tf.keras.layer.Dense(5)`,와 같은 직접적인 속성 할당은 목록과 사전적 속성에 할당하면 내용이 추적됩니다." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.179539Z", "iopub.status.busy": "2022-12-14T21:16:37.179031Z", "iopub.status.idle": "2022-12-14T21:16:37.197357Z", "shell.execute_reply": "2022-12-14T21:16:37.196827Z" }, "id": "rfaIbDtDHAr_" }, "outputs": [], "source": [ "save = tf.train.Checkpoint()\n", "save.listed = [tf.Variable(1.)]\n", "save.listed.append(tf.Variable(2.))\n", "save.mapped = {'one': save.listed[0]}\n", "save.mapped['two'] = save.listed[1]\n", "save_path = save.save('./tf_list_example')\n", "\n", "restore = tf.train.Checkpoint()\n", "v2 = tf.Variable(0.)\n", "assert 0. == v2.numpy() # Not restored yet\n", "restore.mapped = {'two': v2}\n", "restore.restore(save_path)\n", "assert 2. == v2.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "UTKvbxHcI3T2" }, "source": [ "당신은 래퍼(wrapper) 객체를 목록과 사전에 있음을 알아차릴겁니다. 이러한 래퍼는 기본 데이터 구조의 checkpoint 가능한 버전입니다. 속성 기반 로딩과 마찬가지로, 이러한 래퍼들은 변수의 값이 용기에 추가되는 즉시 복원됩니다." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T21:16:37.200799Z", "iopub.status.busy": "2022-12-14T21:16:37.200330Z", "iopub.status.idle": "2022-12-14T21:16:37.206328Z", "shell.execute_reply": "2022-12-14T21:16:37.205691Z" }, "id": "s0Uq1Hv5JCmm" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ListWrapper([])\n" ] } ], "source": [ "restore.listed = []\n", "print(restore.listed) # ListWrapper([])\n", "v1 = tf.Variable(0.)\n", "restore.listed.append(v1) # Restores v1, from restore() in the previous cell\n", "assert 1. == v1.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "OxCIf2J6JyQ8" }, "source": [ "추적 가능한 개체는 `tf.train.Checkpoint`, `tf.Module` 및 해당 서브 클래스(예: `keras.layers.Layer` 및 `keras.Model`)와 인식된 Python 컨테이너를 포함합니다.\n", "\n", "- `dict`(및 `collections.OrderedDict`)\n", "- `list`\n", "- `tuple`(및 `collections.namedtuple`, `typing.NamedTuple`)\n", "\n", "다음을 포함하여 기타 컨테이너 유형은 **지원하지 않습니다**.\n", "\n", "- `collections.defaultdict`\n", "- `set`\n", "\n", "다음을 포함한 기타 Python 개체는 **무시됩니다**.\n", "\n", "- `int`\n", "- `string`\n", "- `float`\n" ] }, { "cell_type": "markdown", "metadata": { "id": "knyUFMrJg8y4" }, "source": [ "## 요약\n", "\n", "텐서프로우 객체는 사용하는 변수의 값을 저장하고 복원할 수 있는 쉬운 자동 메커니즘을 제공합니다.\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "checkpoint.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }