{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "wJcYs_ERTnnI"
},
"source": [
"##### Copyright 2021 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2022-12-14T22:13:31.657927Z",
"iopub.status.busy": "2022-12-14T22:13:31.657323Z",
"iopub.status.idle": "2022-12-14T22:13:31.660959Z",
"shell.execute_reply": "2022-12-14T22:13:31.660426Z"
},
"id": "HMUDt0CiUJk9"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "77z2OchJTk0l"
},
"source": [
"# 指標とオプティマイザを移行する\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "meUTrR4I6m1C"
},
"source": [
"TF1 では、`tf.metrics` はすべての指標関数の API 名前空間です。各指標は、`label` と `prediction` を入力パラメータとして取り、対応する指標テンソルを結果として返す関数です。TF2 では、`tf.keras.metrics` にすべての指標関数とオブジェクトが含まれています。`Metric` オブジェクトを `tf.keras.Model` および `tf.keras.layers.layer` で使用して、指標値を計算できます。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YdZSoIXEbhg-"
},
"source": [
"## セットアップ\n",
"\n",
"いくつかの必要な TensorFlow インポートから始めましょう。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:31.664578Z",
"iopub.status.busy": "2022-12-14T22:13:31.664047Z",
"iopub.status.idle": "2022-12-14T22:13:33.592148Z",
"shell.execute_reply": "2022-12-14T22:13:33.591473Z"
},
"id": "iE0vSfMXumKI"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-12-14 22:13:32.624361: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
"2022-12-14 22:13:32.624482: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
"2022-12-14 22:13:32.624493: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import tensorflow.compat.v1 as tf1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jsm9Rxx7s1OZ"
},
"source": [
"デモ用にいくつかの簡単なデータを準備します。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:33.596819Z",
"iopub.status.busy": "2022-12-14T22:13:33.596006Z",
"iopub.status.idle": "2022-12-14T22:13:33.600184Z",
"shell.execute_reply": "2022-12-14T22:13:33.599585Z"
},
"id": "m7rnGxsXtDkV"
},
"outputs": [],
"source": [
"features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
"labels = [0, 0, 1]\n",
"eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]\n",
"eval_labels = [0, 1, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xswk0d4xrFaQ"
},
"source": [
"## TF1: Estimator を使用した tf.compat.v1.metrics\n",
"\n",
"TF1 では、指標は `eval_metric_ops` として `EstimatorSpec` に追加でき、演算は `tf.metrics` で定義されたすべての指標関数を介して生成されます。例に従って、`tf.metrics.accuracy` の使用方法を確認できます。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:33.603352Z",
"iopub.status.busy": "2022-12-14T22:13:33.602901Z",
"iopub.status.idle": "2022-12-14T22:13:37.942142Z",
"shell.execute_reply": "2022-12-14T22:13:37.941487Z"
},
"id": "lqe9obf7suIj"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp23fn1_x6\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp23fn1_x6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Call initializer instance with the dtype argument instead of passing it to the constructor\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Create CheckpointSaverHook.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Graph was finalized.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp23fn1_x6/model.ckpt.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:loss = 2.6736233, step = 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp23fn1_x6/model.ckpt.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Loss for final step: 0.008751018.\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def _input_fn():\n",
" return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
"\n",
"def _eval_input_fn():\n",
" return tf1.data.Dataset.from_tensor_slices(\n",
" (eval_features, eval_labels)).batch(1)\n",
"\n",
"def _model_fn(features, labels, mode):\n",
" logits = tf1.layers.Dense(2)(features)\n",
" predictions = tf.math.argmax(input=logits, axis=1)\n",
" loss = tf1.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
" optimizer = tf1.train.AdagradOptimizer(0.05)\n",
" train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
" accuracy = tf1.metrics.accuracy(labels=labels, predictions=predictions)\n",
" return tf1.estimator.EstimatorSpec(mode, \n",
" predictions=predictions,\n",
" loss=loss, \n",
" train_op=train_op,\n",
" eval_metric_ops={'accuracy': accuracy})\n",
"\n",
"estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
"estimator.train(_input_fn)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:37.945765Z",
"iopub.status.busy": "2022-12-14T22:13:37.945019Z",
"iopub.status.idle": "2022-12-14T22:13:38.316471Z",
"shell.execute_reply": "2022-12-14T22:13:38.315817Z"
},
"id": "HsOpjW5plH9Q"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Starting evaluation at 2022-12-14T22:13:38\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Graph was finalized.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp23fn1_x6/model.ckpt-3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Inference Time : 0.25721s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Finished evaluation at 2022-12-14-22:13:38\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving dict for global step 3: accuracy = 0.6666667, global_step = 3, loss = 2.0419586\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmp23fn1_x6/model.ckpt-3\n"
]
},
{
"data": {
"text/plain": [
"{'accuracy': 0.6666667, 'loss': 2.0419586, 'global_step': 3}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"estimator.evaluate(_eval_input_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wk4C6qA_OaQx"
},
"source": [
"また、指標は `tf.estimator.add_metrics()` を介してエスティメータに直接追加できます。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:38.320076Z",
"iopub.status.busy": "2022-12-14T22:13:38.319560Z",
"iopub.status.idle": "2022-12-14T22:13:38.592375Z",
"shell.execute_reply": "2022-12-14T22:13:38.591775Z"
},
"id": "B2lpLOh9Owma"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp23fn1_x6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done calling model_fn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Starting evaluation at 2022-12-14T22:13:38\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Graph was finalized.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp23fn1_x6/model.ckpt-3\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Done running local_init_op.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Inference Time : 0.16179s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Finished evaluation at 2022-12-14-22:13:38\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving dict for global step 3: accuracy = 0.6666667, global_step = 3, loss = 2.0419586, mean_squared_error = 0.33333334\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmp23fn1_x6/model.ckpt-3\n"
]
},
{
"data": {
"text/plain": [
"{'accuracy': 0.6666667,\n",
" 'loss': 2.0419586,\n",
" 'mean_squared_error': 0.33333334,\n",
" 'global_step': 3}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def mean_squared_error(labels, predictions):\n",
" labels = tf.cast(labels, predictions.dtype)\n",
" return {\"mean_squared_error\": \n",
" tf1.metrics.mean_squared_error(labels=labels, predictions=predictions)}\n",
"\n",
"estimator = tf1.estimator.add_metrics(estimator, mean_squared_error)\n",
"estimator.evaluate(_eval_input_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KEmzBjfnsxwT"
},
"source": [
"## TF2: tf.keras.Model を使用した Keras メトリクス API\n",
"\n",
"TF2 では、`tf.keras.metrics` にすべての指標クラスと関数が含まれています。これらは OOP スタイルで設計されており、他の `tf.keras` API と密接に統合されています。すべての指標は `tf.keras.metrics` 名前空間で見つけることができ、通常は `tf.compat.v1.metrics` と `tf.keras.metrics` の間に直接マッピングがあります。\n",
"\n",
"次の例では、指標が `model.compile()` メソッドに追加されています。ユーザーは、ラベルと予測テンソルを指定せずに、指標インスタンスを作成するだけで済みます。Keras モデルは、モデルの出力とラベルを指標オブジェクトにルーティングします。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:38.596142Z",
"iopub.status.busy": "2022-12-14T22:13:38.595500Z",
"iopub.status.idle": "2022-12-14T22:13:38.651963Z",
"shell.execute_reply": "2022-12-14T22:13:38.651351Z"
},
"id": "atVciNgPs0fw"
},
"outputs": [],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
"eval_dataset = tf.data.Dataset.from_tensor_slices(\n",
" (eval_features, eval_labels)).batch(1)\n",
"\n",
"inputs = tf.keras.Input((2,))\n",
"logits = tf.keras.layers.Dense(2)(inputs)\n",
"predictions = tf.math.argmax(input=logits, axis=1)\n",
"model = tf.keras.models.Model(inputs, predictions)\n",
"optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
"\n",
"model.compile(optimizer, loss='mse', metrics=[tf.keras.metrics.Accuracy()])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:38.655633Z",
"iopub.status.busy": "2022-12-14T22:13:38.655114Z",
"iopub.status.idle": "2022-12-14T22:13:38.784404Z",
"shell.execute_reply": "2022-12-14T22:13:38.783809Z"
},
"id": "Kip65sYBlKiu"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"1/3 [=========>....................] - ETA: 0s - loss: 1.0000 - accuracy: 0.0000e+00"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"3/3 [==============================] - 0s 4ms/step - loss: 0.3333 - accuracy: 0.6667\n"
]
},
{
"data": {
"text/plain": [
"{'loss': 0.3333333432674408, 'accuracy': 0.6666666865348816}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(eval_dataset, return_dict=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_mcGoCm_X1V0"
},
"source": [
"Eager execution を有効にすると、`tf.keras.metrics.Metric` インスタンスを直接使用して、numpy データまたは Eager テンソルを評価できます。`tf.keras.metrics.Metric` オブジェクトはステートフルコンテナーです。指標値は `metric.update_state(y_true, y_pred)` で更新でき、結果は `metrics.result() `で取得できます。\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:38.787561Z",
"iopub.status.busy": "2022-12-14T22:13:38.787316Z",
"iopub.status.idle": "2022-12-14T22:13:38.802645Z",
"shell.execute_reply": "2022-12-14T22:13:38.801945Z"
},
"id": "TVGn5_IhYhtG"
},
"outputs": [
{
"data": {
"text/plain": [
"0.75"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = tf.keras.metrics.Accuracy()\n",
"\n",
"accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 1])\n",
"accuracy.result().numpy()\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T22:13:38.805760Z",
"iopub.status.busy": "2022-12-14T22:13:38.805202Z",
"iopub.status.idle": "2022-12-14T22:13:38.814960Z",
"shell.execute_reply": "2022-12-14T22:13:38.814314Z"
},
"id": "wQEV2hHtY_su"
},
"outputs": [
{
"data": {
"text/plain": [
"0.41666666"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[0, 0, 0, 0])\n",
"accuracy.update_state(y_true=[0, 0, 1, 1], y_pred=[1, 1, 0, 0])\n",
"accuracy.result().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3F3ElcyadW-"
},
"source": [
"`tf.keras.metrics.Metric` の詳細については、`tf.keras.metrics.Metric` の API ドキュメントと[移行ガイド](https://www.tensorflow.org/guide/effective_tf2#new-style_metrics_and_losses)を参照してください。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eXKY9HEulxQC"
},
"source": [
"## TF1.x オプティマイザの Keras オプティマイザへの移行\n",
"\n",
"Adam オプティマイザや[勾配降下オプティマイザ](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/AdamOptimizer)などの tf.compat.v1.train
内のオプティマイザは、`tf.keras.optimizers` 内に同等のものをもちます。\n",
"\n",
"以下の表は、これらのレガシーオプティマイザを Keras の同等のものに変換する方法をまとめたものです。追加の手順([デフォルトの学習率の更新](../../guide/effective_tf2.ipynb#optimizer_defaults)など)が必要でない限り、TF1.x バージョンを TF2 バージョンに直接置き換えることができます。\n",
"\n",
"オプティマイザを変換すると、[古いチェックポイントの互換性が失われる可能性があること](./migrating_checkpoints.ipynb)に注意してください。\n",
"\n",
"\n",
" \n",
" TF1.x | \n",
" TF2 | \n",
" 追加の手順 | \n",
"
\n",
" \n",
" `tf.v1.train.GradientDescentOptimizer` | \n",
" `tf.keras.optimizers.SGD` | \n",
" なし | \n",
"
\n",
" \n",
" `tf.v1.train.MomentumOptimizer` | \n",
" `tf.keras.optimizers.SGD` | \n",
" `momentum` 引数を含む | \n",
"
\n",
" \n",
" `tf.v1.train.AdamOptimizer` | \n",
" `tf.keras.optimizers.Adam` | \n",
" `beta1` および `beta2` 引数の名前を `beta_1` および `beta_2` に変更する | \n",
"
\n",
" \n",
" `tf.v1.train.RMSPropOptimizer` | \n",
" `tf.keras.optimizers.RMSprop` | \n",
" `decay` 引数の名前を `rho` に変更する | \n",
"
\n",
" \n",
" `tf.v1.train.AdadeltaOptimizer` | \n",
" `tf.keras.optimizers.Adadelta` | \n",
" なし | \n",
"
\n",
" \n",
" `tf.v1.train.AdagradOptimizer` | \n",
" `tf.keras.optimizers.Adagrad` | \n",
" なし | \n",
"
\n",
" \n",
" `tf.v1.train.FtrlOptimizer` | \n",
" `tf.keras.optimizers.Ftrl` | \n",
" `accum_name` および `linear_name` 引数を削除する | \n",
"
\n",
" \n",
" `tf.contrib.AdamaxOptimizer` | \n",
" `tf.keras.optimizers.Adamax` | \n",
" `beta1` および `beta2` 引数の名前を `beta_1` および `beta_2` に変更する | \n",
"
\n",
" \n",
" `tf.contrib.Nadam` | \n",
" `tf.keras.optimizers.Nadam` | \n",
" `beta1` および `beta2` 引数の名前を `beta_1` および `beta_2` に変更する | \n",
"
\n",
"
\n",
"\n",
"注意: TF2 では、すべてのイプシロン(数値安定定数)のデフォルトが `1e-8` ではなく `1e-7` になりました。ほとんどの場合、この違いは無視できます。"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "metrics_optimizers.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 0
}