{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ckM5wJMsNTYL" }, "source": [ "##### Copyright 2023 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "NKvERjPVNWxu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "bqePLdDjNhNk" }, "source": [ "# Import a JAX model using JAX2TF" ] }, { "cell_type": "markdown", "metadata": { "id": "gw3w46yhNiK_" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "IyrsY3uTOmPY" }, "source": [ "This notebook provides a complete, runnable example of creating a model using [JAX](https://jax.readthedocs.io/en/latest/) and bringing it into TensorFlow to continue training. This is made possible by [JAX2TF](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), a lightweight API that provides a pathway from the JAX ecosystem to the TensorFlow ecosystem. \n", "\n", "JAX is a high-performance array computing library. To create the model, this notebook uses [Flax](https://flax.readthedocs.io/en/latest/), a neural network library for JAX. To train it, it uses [Optax](https://optax.readthedocs.io), an optimization library for JAX.\n", "\n", "If you're a researcher using JAX, JAX2TF gives you a path to production using TensorFlow's proven tools.\n", "\n", "There are many ways this can be useful, here are just a few:\n", "\n", "* Inference: Taking a model written for JAX and deploying it either on a server using TF Serving, on-device using TFLite, or on the web using TensorFlow.js. \n", "\n", "* Fine-tuning: Taking a model that was trained using JAX, you can bring its components to TF using JAX2TF, and continue training it in TensorFlow with your existing training data and setup.\n", "\n", "* Fusion: Combining parts of models that were trained using JAX with those trained using TensorFlow, for maximum flexibility.\n", "\n", "The key to enabling this kind of interoperation between JAX and TensorFlow is `jax2tf.convert`, which takes in model components created on top of JAX (your loss function, prediction function, etc) and creates equivalent representations of them as TensorFlow functions, which can then be exported as a TensorFlow SavedModel." ] }, { "cell_type": "markdown", "metadata": { "id": "G6rtu96yOepm" }, "source": [ "## Setup\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9yqxfHzr0LPF" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "import flax\n", "import optax\n", "import os\n", "from matplotlib import pyplot as plt\n", "from jax.experimental import jax2tf\n", "from threading import Lock # Only used in the visualization utility.\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDnTaZO0r872" }, "outputs": [], "source": [ "# Needed for TensorFlow and JAX to coexist in GPU memory.\n", "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n", "gpus = tf.config.list_physical_devices('GPU')\n", "if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " except RuntimeError as e:\n", " # Memory growth must be set before GPUs have been initialized.\n", " print(e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "BXOjCNJxDLil" }, "outputs": [], "source": [ "#@title Visualization utilities\n", "\n", "plt.rcParams[\"figure.figsize\"] = (20,8)\n", "\n", "# The utility for displaying training and validation curves.\n", "def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):\n", "\n", " ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)\n", "\n", " # The losses.\n", " ax = plt.subplot(121)\n", " if loss is not None:\n", " x = np.arange(len(loss)) / steps_per_epochs #* epochs\n", " ax.plot(x, loss)\n", " ax.plot(range(1, epochs+1), avg_loss, \"-o\", linewidth=3)\n", " ax.plot(range(1, epochs+1), eval_loss, \"-o\", linewidth=3)\n", " ax.set_title('Loss')\n", " ax.set_ylabel('loss')\n", " ax.set_xlabel('epoch')\n", " if loss is not None:\n", " ax.set_ylim(0, np.max(loss[ignore_first_n:]))\n", " ax.legend(['train', 'avg train', 'eval'])\n", " else:\n", " ymin = np.min(avg_loss[ignore_first_n_epochs:])\n", " ymax = np.max(avg_loss[ignore_first_n_epochs:])\n", " ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n", " ax.legend(['avg train', 'eval'])\n", "\n", " # The accuracy.\n", " ax = plt.subplot(122)\n", " ax.set_title('Eval Accuracy')\n", " ax.set_ylabel('accuracy')\n", " ax.set_xlabel('epoch')\n", " ymin = np.min(eval_accuracy[ignore_first_n_epochs:])\n", " ymax = np.max(eval_accuracy[ignore_first_n_epochs:])\n", " ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n", " ax.plot(range(1, epochs+1), eval_accuracy, \"-o\", linewidth=3)\n", "\n", "class Progress:\n", " \"\"\"Text mode progress bar.\n", " Usage:\n", " p = Progress(30)\n", " p.step()\n", " p.step()\n", " p.step(reset=True) # to restart form 0%\n", " The progress bar displays a new header at each restart.\"\"\"\n", " def __init__(self, maxi, size=100, msg=\"\"):\n", " \"\"\"\n", " :param maxi: the number of steps required to reach 100%\n", " :param size: the number of characters taken on the screen by the progress bar\n", " :param msg: the message displayed in the header of the progress bar\n", " \"\"\"\n", " self.maxi = maxi\n", " self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator.\n", " self.header_printed = False\n", " self.msg = msg\n", " self.size = size\n", " self.lock = Lock()\n", "\n", " def step(self, reset=False):\n", " with self.lock:\n", " if reset:\n", " self.__init__(self.maxi, self.size, self.msg)\n", " if not self.header_printed:\n", " self.__print_header()\n", " next(self.p)\n", "\n", " def __print_header(self):\n", " print()\n", " format_string = \"0%{: ^\" + str(self.size - 6) + \"}100%\"\n", " print(format_string.format(self.msg))\n", " self.header_printed = True\n", "\n", " def __start_progress(self, maxi):\n", " def print_progress():\n", " # Bresenham's algorithm. Yields the number of dots printed.\n", " # This will always print 100 dots in max invocations.\n", " dx = maxi\n", " dy = self.size\n", " d = dy - dx\n", " for x in range(maxi):\n", " k = 0\n", " while d >= 0:\n", " print('=', end=\"\", flush=True)\n", " k += 1\n", " d -= dx\n", " d += dy\n", " yield k\n", " # Keep yielding the last result if there are too many steps.\n", " while True:\n", " yield k\n", "\n", " return print_progress" ] }, { "cell_type": "markdown", "metadata": { "id": "6xgS_8nDDIu8" }, "source": [ "## Download and prepare the MNIST dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nbN7rmuF0VFB" }, "outputs": [], "source": [ "(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))\n", "train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),\n", " tf.one_hot(y, depth=10)))\n", "\n", "BATCH_SIZE = 256\n", "train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)\n", "train_data = train_data.cache()\n", "train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)\n", "\n", "test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))\n", "test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),\n", " tf.one_hot(y, depth=10)))\n", "test_data = test_data.batch(10000)\n", "test_data = test_data.cache()\n", "\n", "(one_batch, one_batch_labels) = next(iter(train_data)) # just one batch\n", "(all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000" ] }, { "cell_type": "markdown", "metadata": { "id": "LuZTo7SM3W_n" }, "source": [ "## Configure training\n", "This notebook will create and train a simple model for demonstration purposes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3vbKB4yZ3aTL" }, "outputs": [], "source": [ "# Training hyperparameters.\n", "JAX_EPOCHS = 3\n", "TF_EPOCHS = 7\n", "STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE\n", "LEARNING_RATE = 0.01\n", "LEARNING_RATE_EXP_DECAY = 0.6\n", "\n", "# The learning rate schedule for JAX (with Optax).\n", "jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)\n", "\n", "# THe learning rate schedule for TensorFlow.\n", "tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Od3sMwQxtC34" }, "source": [ "## Create the model using Flax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-ybqQF2zd2QX" }, "outputs": [], "source": [ "class ConvModel(flax.linen.Module):\n", "\n", " @flax.linen.compact\n", " def __call__(self, x, train):\n", " x = flax.linen.Conv(features=12, kernel_size=(3,3), padding=\"SAME\", use_bias=False)(x)\n", " x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = flax.linen.Dense(features=200, use_bias=True)(x)\n", " x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)\n", " x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x)\n", " x = flax.linen.relu(x)\n", " x = flax.linen.Dense(features=10)(x)\n", " #x = flax.linen.log_softmax(x)\n", " return x\n", "\n", " # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).\n", " # `jax.grad` will differentiate it against the fist argument.\n", " # The user must split trainable and non-trainable variables into `params` and `other_state`.\n", " # Must pass a different RNG key each time for the dropout mask to be different.\n", " def loss(self, params, other_state, rng, data, labels, train):\n", " logits, batch_stats = self.apply({'params': params, **other_state},\n", " data,\n", " mutable=['batch_stats'],\n", " rngs={'dropout': rng},\n", " train=train)\n", " # The loss averaged across the batch dimension.\n", " loss = optax.softmax_cross_entropy(logits, labels).mean()\n", " return loss, batch_stats\n", "\n", " def predict(self, state, data):\n", " logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False)\n", " probabilities = flax.linen.log_softmax(logits)\n", " return probabilities\n", "\n", " def accuracy(self, state, data, labels):\n", " probabilities = self.predict(state, data)\n", " predictions = jnp.argmax(probabilities, axis=-1)\n", " dense_labels = jnp.argmax(labels, axis=-1)\n", " accuracy = jnp.equal(predictions, dense_labels).mean()\n", " return accuracy" ] }, { "cell_type": "markdown", "metadata": { "id": "7Cr0FRNFtHN4" }, "source": [ "## Write the training step function" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tmDwApcpgZzw" }, "outputs": [], "source": [ "# The training step.\n", "@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model\n", "def train_step(model, state, optimizer_state, rng, data, labels):\n", "\n", " other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables\n", " (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True)\n", "\n", " updates, optimizer_state = optimizer.update(grads, optimizer_state)\n", " params = optax.apply_updates(params, updates)\n", " new_state = state.copy(add_or_replace={**batch_stats, 'params': params})\n", "\n", " rng, _ = jax.random.split(rng)\n", "\n", " return new_state, optimizer_state, rng, loss" ] }, { "cell_type": "markdown", "metadata": { "id": "Zr16g6NzV4O9" }, "source": [ "## Write the training loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zbl5w-KUV7Qw" }, "outputs": [], "source": [ "def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies):\n", " p = Progress(STEPS_PER_EPOCH)\n", " rng = jax.random.PRNGKey(0)\n", " for epoch in range(epochs):\n", "\n", " # This is where the learning rate schedule state is stored in the optimizer state.\n", " optimizer_step = optimizer_state[1].count\n", "\n", " # Run an epoch of training.\n", " for step, (data, labels) in enumerate(train_data):\n", " p.step(reset=(step==0))\n", " state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())\n", " losses.append(loss)\n", " avg_loss = np.mean(losses[-step:])\n", " avg_losses.append(avg_loss)\n", "\n", " # Run one epoch of evals (10,000 test images in a single batch).\n", " other_state, params = state.pop('params')\n", " # Gotcha: must discard modified batch_stats here\n", " eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)\n", " eval_losses.append(eval_loss)\n", " eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy())\n", " eval_accuracies.append(eval_accuracy)\n", "\n", " print(\"\\nEpoch\", epoch, \"train loss:\", avg_loss, \"eval loss:\", eval_loss, \"eval accuracy\", eval_accuracy, \"lr:\", jlr_decay(optimizer_step))\n", "\n", " return state, optimizer_state" ] }, { "cell_type": "markdown", "metadata": { "id": "DGB3W5g0Wt1H" }, "source": [ "## Create the model and the optimizer (with Optax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mW5mkmCWtN8W" }, "outputs": [], "source": [ "# The model.\n", "model = ConvModel()\n", "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n", "\n", "# The optimizer.\n", "optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule\n", "optimizer_state = optimizer.init(state['params'])\n", "\n", "losses=[]\n", "avg_losses=[]\n", "eval_losses=[]\n", "eval_accuracies=[]" ] }, { "cell_type": "markdown", "metadata": { "id": "FJdsKghBNF" }, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nmcofTTBZSIb" }, "outputs": [], "source": [ "new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n_20vgvDXB5r" }, "outputs": [], "source": [ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)" ] }, { "cell_type": "markdown", "metadata": { "id": "0lT3cdENCBzL" }, "source": [ "## Partially train the model\n", "\n", "You will continue training the model in TensorFlow shortly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KT-xqj5N7C6L" }, "outputs": [], "source": [ "model = ConvModel()\n", "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n", "\n", "# The optimizer.\n", "optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule\n", "optimizer_state = optimizer.init(state['params'])\n", "\n", "losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oa362HMDbzDE" }, "outputs": [], "source": [ "state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0IyZtUPPCt0y" }, "outputs": [], "source": [ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)" ] }, { "cell_type": "markdown", "metadata": { "id": "uNtlSaOCCumB" }, "source": [ "## Save just enough for inference\n", "\n", "If your goal is to deploy your JAX model (so you can run inference using `model.predict()`), simply exporting it to [SavedModel](https://www.tensorflow.org/guide/saved_model) is sufficient. This section demonstrates how to accomplish that." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O653B3-5H8FL" }, "outputs": [], "source": [ "# Test data with a different batch size to test polymorphic shapes.\n", "x, y = next(iter(train_data.unbatch().batch(13)))\n", "\n", "m = tf.Module()\n", "# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.\n", "state_vars = tf.nest.map_structure(tf.Variable, state)\n", "# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).\n", "m.vars = tf.nest.flatten(state_vars)\n", "# Convert the desired JAX function (`model.predict`).\n", "predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n", "# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).\n", "@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n", "def predict(data):\n", " return predict_fn(state_vars, data)\n", "m.predict = predict\n", "tf.saved_model.save(m, \"./\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8HFx67zStgvo" }, "outputs": [], "source": [ "# Test the converted function.\n", "print(\"Converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n", "# Reload the model.\n", "reloaded_model = tf.saved_model.load(\"./\")\n", "# Test the reloaded converted function (the result should be the same).\n", "print(\"Reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))" ] }, { "cell_type": "markdown", "metadata": { "id": "eEk8wv4HJu94" }, "source": [ "## Save everything\n", "If your goal is a comprehensive export (useful if you're planning on brining the model into TensorFlow for fine-tuning, fusion, etc), this section demonstrates how to save the model so you can access methods including:\n", "\n", " - model.predict\n", " - model.accuracy\n", " - model.loss (including train=True/False bool, RNG for dropout and BatchNorm state updates)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9mty52pmvDDp" }, "outputs": [], "source": [ "from collections import abc\n", "\n", "def _fix_frozen(d):\n", " \"\"\"Changes any mappings (e.g. frozendict) back to dict.\"\"\"\n", " if isinstance(d, list):\n", " return [_fix_frozen(v) for v in d]\n", " elif isinstance(d, tuple):\n", " return tuple(_fix_frozen(v) for v in d)\n", " elif not isinstance(d, abc.Mapping):\n", " return d\n", " d = dict(d)\n", " for k, v in d.items():\n", " d[k] = _fix_frozen(v)\n", " return d" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3HEsKNXbCwXw" }, "outputs": [], "source": [ "class TFModel(tf.Module):\n", " def __init__(self, state, model):\n", " super().__init__()\n", "\n", " # Special care needed for the train=True/False parameter in the loss\n", " @jax.jit\n", " def loss_with_train_bool(state, rng, data, labels, train):\n", " other_state, params = state.pop('params')\n", " loss, batch_stats = jax.lax.cond(train,\n", " lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True),\n", " lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False),\n", " state, data, labels)\n", " # must use JAX to split the RNG, therefore, must do it in a @jax.jit function\n", " new_rng, _ = jax.random.split(rng)\n", " return loss, batch_stats, new_rng\n", "\n", " self.state_vars = tf.nest.map_structure(tf.Variable, state)\n", " self.vars = tf.nest.flatten(self.state_vars)\n", " self.jax_rng = tf.Variable(jax.random.PRNGKey(0))\n", "\n", " self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=[\"...\", \"...\", \"(b, 28, 28, 1)\", \"(b, 10)\", \"...\"])\n", " self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\", \"(b, 10)\"])\n", " self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n", "\n", " # Must specify TensorSpec manually for variable batch size to work\n", " @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n", " def predict(self, data):\n", " # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly\n", " # otherwise, all model weights would be embedded in the TF graph as constants.\n", " return self.predict_fn(self.state_vars, data)\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n", " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n", " autograph=False)\n", " def train_loss(self, data, labels):\n", " loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True)\n", " # update batch norm stats\n", " flat_vars = tf.nest.flatten(self.state_vars['batch_stats'])\n", " flat_values = tf.nest.flatten(batch_stats['batch_stats'])\n", " for var, val in zip(flat_vars, flat_values):\n", " var.assign(val)\n", " # update RNG\n", " self.jax_rng.assign(new_rng)\n", " return loss\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n", " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n", " autograph=False)\n", " def eval_loss(self, data, labels):\n", " loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False)\n", " return loss\n", "\n", " @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),\n", " tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n", " autograph=False)\n", " def accuracy(self, data, labels):\n", " return self.accuracy_fn(self.state_vars, data, labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "znJrAVpcxO9u" }, "outputs": [], "source": [ "# Instantiate the model.\n", "tf_model = TFModel(state, model)\n", "\n", "# Save the model.\n", "tf.saved_model.save(tf_model, \"./\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Y02DHEwTjNzV" }, "source": [ "## Reload the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i75yS3v2jPpM" }, "outputs": [], "source": [ "reloaded_model = tf.saved_model.load(\"./\")\n", "\n", "# Test if it works and that the batch size is indeed variable.\n", "x,y = next(iter(train_data.unbatch().batch(13)))\n", "print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n", "x,y = next(iter(train_data.unbatch().batch(20)))\n", "print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n", "\n", "print(reloaded_model.accuracy(one_batch, one_batch_labels))\n", "print(reloaded_model.accuracy(all_test_data, all_test_labels))" ] }, { "cell_type": "markdown", "metadata": { "id": "DiwEAwQmlx1x" }, "source": [ "## Continue training the converted JAX model in TensorFlow" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MubFcO_jl2vE" }, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)\n", "\n", "# Set the iteration step for the learning rate to resume from where it left off in JAX.\n", "optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)\n", "\n", "p = Progress(STEPS_PER_EPOCH)\n", "\n", "for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):\n", "\n", " # This is where the learning rate schedule state is stored in the optimizer state.\n", " optimizer_step = optimizer.iterations\n", "\n", " for step, (data, labels) in enumerate(train_data):\n", " p.step(reset=(step==0))\n", " with tf.GradientTape() as tape:\n", " #loss = reloaded_model.loss(data, labels, True)\n", " loss = reloaded_model.train_loss(data, labels)\n", " grads = tape.gradient(loss, reloaded_model.vars)\n", " optimizer.apply_gradients(zip(grads, reloaded_model.vars))\n", " losses.append(loss)\n", " avg_loss = np.mean(losses[-step:])\n", " avg_losses.append(avg_loss)\n", "\n", " eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy()\n", " eval_losses.append(eval_loss)\n", " eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy()\n", " eval_accuracies.append(eval_accuracy)\n", "\n", " print(\"\\nEpoch\", epoch, \"train loss:\", avg_loss, \"eval loss:\", eval_loss, \"eval accuracy\", eval_accuracy, \"lr:\", tflr_decay(optimizer.iterations).numpy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "50V1FSmI6UTk" }, "outputs": [], "source": [ "display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH)\n", "\n", "# The loss takes a hit when the training restarts, but does not go back to random levels.\n", "# This is likely caused by the optimizer momentum being reinitialized." ] }, { "cell_type": "markdown", "metadata": { "id": "L7lSziW0K0ny" }, "source": [ "## Next steps\n", "You can learn more about [JAX](https://jax.readthedocs.io/en/latest/index.html) and [Flax](https://flax.readthedocs.io/en/latest) on their documentation websites which contain detailed guides and examples. If you're new to JAX, be sure to explore the [JAX 101 tutorials](https://jax.readthedocs.io/en/latest/jax-101/index.html), and check out the [Flax quickstart](https://flax.readthedocs.io/en/latest/getting_started.html). To learn more about converting JAX models to TensorFlow format, check out the [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf) utility on GitHub. If you're interested in converting JAX models to run in the browser with TensorFlow.js, visit [JAX on the Web with TensorFlow.js](https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html). If you'd like to prepare JAX models to run in TensorFLow Lite, visit the [JAX Model Conversion For TFLite](https://www.tensorflow.org/lite/examples/jax_conversion/overview) guide." ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "jax2tf.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }