{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "30155835fc9f" }, "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-10-17T12:26:59.077356Z", "iopub.status.busy": "2023-10-17T12:26:59.076891Z", "iopub.status.idle": "2023-10-17T12:26:59.080622Z", "shell.execute_reply": "2023-10-17T12:26:59.080010Z" }, "id": "906e07f6e562" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "5hrbPTziJK15" }, "source": [ "# Load LM Checkpoints using Model Garden" ] }, { "cell_type": "markdown", "metadata": { "id": "-PYqCW1II75I" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "yyyk1KMlJdWd" }, "source": [ "This tutorial demonstrates how to load BERT, ALBERT and ELECTRA pretrained checkpoints and use them for downstream tasks.\n", "\n", "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development." ] }, { "cell_type": "markdown", "metadata": { "id": "uEG4RYHolQij" }, "source": [ "## Install TF Model Garden package" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:26:59.084350Z", "iopub.status.busy": "2023-10-17T12:26:59.083920Z", "iopub.status.idle": "2023-10-17T12:27:09.445433Z", "shell.execute_reply": "2023-10-17T12:27:09.444441Z" }, "id": "kPfC1NJZnJq1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\r\n", "\u001b[0m" ] } ], "source": [ "!pip install -U -q \"tf-models-official\"" ] }, { "cell_type": "markdown", "metadata": { "id": "Op9R3zy3lUk8" }, "source": [ "## Import necessary libraries" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:09.449901Z", "iopub.status.busy": "2023-10-17T12:27:09.449639Z", "iopub.status.idle": "2023-10-17T12:27:11.927098Z", "shell.execute_reply": "2023-10-17T12:27:11.926326Z" }, "id": "6_y4Rfq23wK-" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 12:27:09.738068: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-10-17 12:27:09.738115: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-10-17 12:27:09.738155: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import os\n", "import yaml\n", "import json\n", "\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:11.931599Z", "iopub.status.busy": "2023-10-17T12:27:11.930811Z", "iopub.status.idle": "2023-10-17T12:27:13.968982Z", "shell.execute_reply": "2023-10-17T12:27:13.967896Z" }, "id": "xjgv3gllzbYQ" }, "outputs": [], "source": [ "import tensorflow_models as tfm\n", "\n", "from official.core import exp_factory" ] }, { "cell_type": "markdown", "metadata": { "id": "J-t2mo6VQNfY" }, "source": [ "## Load BERT model pretrained checkpoints" ] }, { "cell_type": "markdown", "metadata": { "id": "hdBsFnI20LDE" }, "source": [ "### Select required BERT model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:13.973810Z", "iopub.status.busy": "2023-10-17T12:27:13.973290Z", "iopub.status.idle": "2023-10-17T12:27:23.597185Z", "shell.execute_reply": "2023-10-17T12:27:23.596119Z" }, "id": "apn3VgxUlr5G" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2023-10-17 12:27:14-- https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\r\n", "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.219.207, 209.85.146.207, 209.85.147.207, ...\r\n", "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.219.207|:443... connected.\r\n", "HTTP request sent, awaiting response... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "200 OK\r\n", "Length: 401886728 (383M) [application/octet-stream]\r\n", "Saving to: ‘cased_L-12_H-768_A-12.tar.gz’\r\n", "\r\n", "\r", " cased_L-1 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12 2%[ ] 8.01M 39.0MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_ 6%[> ] 24.01M 44.1MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H 11%[=> ] 43.82M 58.9MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H- 15%[==> ] 59.02M 62.5MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H-7 19%[==> ] 73.48M 64.2MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H-76 25%[====> ] 96.01M 69.1MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H-768 28%[====> ] 108.73M 68.4MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H-768_ 31%[=====> ] 120.01M 61.9MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " cased_L-12_H-768_A 35%[======> ] 136.01M 61.9MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "cased_L-12_H-768_A- 39%[======> ] 152.01M 63.2MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "ased_L-12_H-768_A-1 45%[========> ] 176.01M 64.9MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "sed_L-12_H-768_A-12 52%[=========> ] 200.01M 68.6MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "ed_L-12_H-768_A-12. 55%[==========> ] 214.15M 68.7MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "d_L-12_H-768_A-12.t 61%[===========> ] 235.60M 71.0MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "_L-12_H-768_A-12.ta 65%[============> ] 251.98M 71.6MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "L-12_H-768_A-12.tar 66%[============> ] 256.01M 70.4MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "-12_H-768_A-12.tar. 73%[=============> ] 280.01M 73.8MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "12_H-768_A-12.tar.g 75%[==============> ] 288.01M 69.8MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "2_H-768_A-12.tar.gz 78%[==============> ] 301.40M 69.0MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "_H-768_A-12.tar.gz 83%[===============> ] 321.52M 70.8MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "H-768_A-12.tar.gz 89%[================> ] 344.01M 70.1MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "-768_A-12.tar.gz 95%[==================> ] 367.90M 73.5MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "cased_L-12_H-768_A- 100%[===================>] 383.27M 79.4MB/s in 5.3s \r\n", "\r\n", "2023-10-17 12:27:19 (72.9 MB/s) - ‘cased_L-12_H-768_A-12.tar.gz’ saved [401886728/401886728]\r\n", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cased_L-12_H-768_A-12/\r\n", "cased_L-12_H-768_A-12/vocab.txt\r\n", "cased_L-12_H-768_A-12/bert_model.ckpt.index\r\n", "cased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cased_L-12_H-768_A-12/params.yaml\r\n", "cased_L-12_H-768_A-12/bert_config.json\r\n" ] } ], "source": [ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n", "model_display_name = 'BERT-base cased English' # @param ['BERT-base uncased English','BERT-base cased English','BERT-large uncased English', 'BERT-large cased English', 'BERT-large, Uncased (Whole Word Masking)', 'BERT-large, Cased (Whole Word Masking)', 'BERT-base MultiLingual','BERT-base Chinese']\n", "\n", "if model_display_name == 'BERT-base uncased English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\"\n", " !tar -xvf \"uncased_L-12_H-768_A-12.tar.gz\"\n", "elif model_display_name == 'BERT-base cased English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\"\n", " !tar -xvf \"cased_L-12_H-768_A-12.tar.gz\"\n", "elif model_display_name == \"BERT-large uncased English\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-24_H-1024_A-16.tar.gz\"\n", " !tar -xvf \"uncased_L-24_H-1024_A-16.tar.gz\"\n", "elif model_display_name == \"BERT-large cased English\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-24_H-1024_A-16.tar.gz\"\n", " !tar -xvf \"cased_L-24_H-1024_A-16.tar.gz\"\n", "elif model_display_name == \"BERT-large, Uncased (Whole Word Masking)\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n", " !tar -xvf \"wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n", "elif model_display_name == \"BERT-large, Cased (Whole Word Masking)\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_cased_L-24_H-1024_A-16.tar.gz\"\n", " !tar -xvf \"wwm_cased_L-24_H-1024_A-16.tar.gz\"\n", "elif model_display_name == \"BERT-base MultiLingual\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\"\n", " !tar -xvf \"multi_cased_L-12_H-768_A-12.tar.gz\"\n", "elif model_display_name == \"BERT-base Chinese\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\"\n", " !tar -xvf \"chinese_L-12_H-768_A-12.tar.gz\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.601469Z", "iopub.status.busy": "2023-10-17T12:27:23.600918Z", "iopub.status.idle": "2023-10-17T12:27:23.609425Z", "shell.execute_reply": "2023-10-17T12:27:23.608830Z" }, "id": "jzxyziRuaC95" }, "outputs": [ { "data": { "text/plain": [ "'cased_L-12_H-768_A-12'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Lookup table of the directory name corresponding to each model checkpoint\n", "folder_bert_dict = {\n", " 'BERT-base uncased English': 'uncased_L-12_H-768_A-12',\n", " 'BERT-base cased English': 'cased_L-12_H-768_A-12',\n", " 'BERT-large uncased English': 'uncased_L-24_H-1024_A-16',\n", " 'BERT-large cased English': 'cased_L-24_H-1024_A-16',\n", " 'BERT-large, Uncased (Whole Word Masking)': 'wwm_uncased_L-24_H-1024_A-16',\n", " 'BERT-large, Cased (Whole Word Masking)': 'wwm_cased_L-24_H-1024_A-16',\n", " 'BERT-base MultiLingual': 'multi_cased_L-12_H-768_A-1',\n", " 'BERT-base Chinese': 'chinese_L-12_H-768_A-12'\n", "}\n", "\n", "folder_bert = folder_bert_dict.get(model_display_name)\n", "folder_bert" ] }, { "cell_type": "markdown", "metadata": { "id": "q1WrYswpZPlc" }, "source": [ "### Construct BERT Model Using the New `params.yaml`\n", "\n", "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.612467Z", "iopub.status.busy": "2023-10-17T12:27:23.612154Z", "iopub.status.idle": "2023-10-17T12:27:23.619764Z", "shell.execute_reply": "2023-10-17T12:27:23.619198Z" }, "id": "quu1s8Hi2szo" }, "outputs": [ { "data": { "text/plain": [ "{'task': {'model': {'encoder': {'bert': {'attention_dropout_rate': 0.1,\n", " 'dropout_rate': 0.1,\n", " 'hidden_activation': 'gelu',\n", " 'hidden_size': 768,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 3072,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 12,\n", " 'num_layers': 12,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 28996},\n", " 'type': 'bert'}}}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config_file = os.path.join(folder_bert, \"params.yaml\")\n", "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n", "config_dict" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.622726Z", "iopub.status.busy": "2023-10-17T12:27:23.622491Z", "iopub.status.idle": "2023-10-17T12:27:23.636255Z", "shell.execute_reply": "2023-10-17T12:27:23.635705Z" }, "id": "3t8o0iG9v8ac" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 28996,\n", " 'hidden_size': 768,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 12,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 3072,\n", " 'dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02,\n", " 'embedding_size': None,\n", " 'output_range': None,\n", " 'return_all_encoder_outputs': False,\n", " 'return_attention_scores': False,\n", " 'norm_first': False}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Method 1: pass encoder config dict into EncoderConfig\n", "encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.639341Z", "iopub.status.busy": "2023-10-17T12:27:23.638854Z", "iopub.status.idle": "2023-10-17T12:27:23.652252Z", "shell.execute_reply": "2023-10-17T12:27:23.651682Z" }, "id": "2I5PetB6wPvb" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 28996,\n", " 'hidden_size': 768,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 12,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 3072,\n", " 'dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02,\n", " 'embedding_size': None,\n", " 'output_range': None,\n", " 'return_all_encoder_outputs': False,\n", " 'return_attention_scores': False,\n", " 'norm_first': False}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Method 2: use override_params_dict function to override default Encoder params\n", "encoder_config = tfm.nlp.encoders.EncoderConfig()\n", "tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "5yHiG_9oS3Uw" }, "source": [ "### Construct BERT Model Using the Old `bert_config.json`" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.655525Z", "iopub.status.busy": "2023-10-17T12:27:23.655104Z", "iopub.status.idle": "2023-10-17T12:27:23.659968Z", "shell.execute_reply": "2023-10-17T12:27:23.659407Z" }, "id": "WEyaqLcW3nne" }, "outputs": [ { "data": { "text/plain": [ "{'hidden_size': 768,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 3072,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 12,\n", " 'num_layers': 12,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 28996,\n", " 'hidden_activation': 'gelu',\n", " 'dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bert_config_file = os.path.join(folder_bert, \"bert_config.json\")\n", "config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n", "config_dict" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.663102Z", "iopub.status.busy": "2023-10-17T12:27:23.662622Z", "iopub.status.idle": "2023-10-17T12:27:23.675255Z", "shell.execute_reply": "2023-10-17T12:27:23.674666Z" }, "id": "xSIcaW9tdrl4" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 28996,\n", " 'hidden_size': 768,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 12,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 3072,\n", " 'dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02,\n", " 'embedding_size': None,\n", " 'output_range': None,\n", " 'return_all_encoder_outputs': False,\n", " 'return_attention_scores': False,\n", " 'norm_first': False}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder_config = tfm.nlp.encoders.EncoderConfig({\n", " 'type':'bert',\n", " 'bert': config_dict\n", "})\n", "\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "yZznAP--TDLe" }, "source": [ "### Construct a classifier with `encoder_config`\n", "\n", "Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:23.678489Z", "iopub.status.busy": "2023-10-17T12:27:23.677896Z", "iopub.status.idle": "2023-10-17T12:27:27.415328Z", "shell.execute_reply": "2023-10-17T12:27:27.414369Z" }, "id": "Ny962I8nqs4n" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-17 12:27:24.243086: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n" ] }, { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n", "bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)\n", "\n", "tf.keras.utils.plot_model(bert_classifier)" ] }, { "cell_type": "markdown", "metadata": { "id": "IStKfxXkTJMu" }, "source": [ "### Load Pretrained Weights into the BERT Classifier\n", "\n", "The provided pretrained checkpoint only contains weights for the BERT Encoder within the BERT Classifier. Weights for the Classification Head is still randomly initialized." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:27.419655Z", "iopub.status.busy": "2023-10-17T12:27:27.419007Z", "iopub.status.idle": "2023-10-17T12:27:27.843139Z", "shell.execute_reply": "2023-10-17T12:27:27.842304Z" }, "id": "G9_XCBpOEo4y" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n", "checkpoint.read(\n", " os.path.join(folder_bert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "E6Hu1FFgQWUU" }, "source": [ "## Load ALBERT model pretrained checkpoints" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:27.847077Z", "iopub.status.busy": "2023-10-17T12:27:27.846567Z", "iopub.status.idle": "2023-10-17T12:27:42.948033Z", "shell.execute_reply": "2023-10-17T12:27:42.946986Z" }, "id": "TWUtFeWxQn0V" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2023-10-17 12:27:27-- https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz\r\n", "Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.114.207, 172.217.214.207, 142.251.6.207, ...\r\n", "Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.114.207|:443... connected.\r\n", "HTTP request sent, awaiting response... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "200 OK\r\n", "Length: 826059238 (788M) [application/octet-stream]\r\n", "Saving to: ‘albert_xxlarge.tar.gz’\r\n", "\r\n", "\r", "albert_xxlarge.tar. 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 1%[ ] 8.01M 36.0MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 5%[> ] 40.01M 87.1MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 9%[> ] 72.01M 102MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 13%[=> ] 104.01M 113MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 17%[==> ] 136.01M 122MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 20%[===> ] 160.01M 120MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 23%[===> ] 189.05M 123MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 26%[====> ] 208.35M 120MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 29%[====> ] 235.73M 122MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 33%[=====> ] 267.62M 125MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 36%[======> ] 291.15M 125MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 40%[=======> ] 315.29M 124MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 41%[=======> ] 329.50M 121MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 46%[========> ] 363.59M 124MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 49%[========> ] 392.01M 123MB/s eta 3s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 53%[=========> ] 424.01M 124MB/s eta 3s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 57%[==========> ] 456.01M 123MB/s eta 3s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 61%[===========> ] 486.80M 130MB/s eta 3s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 64%[===========> ] 509.65M 129MB/s eta 3s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 65%[============> ] 517.79M 124MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 70%[=============> ] 552.01M 121MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 74%[=============> ] 584.01M 119MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 77%[==============> ] 610.95M 121MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 80%[===============> ] 632.01M 119MB/s eta 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 84%[===============> ] 664.01M 121MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 88%[================> ] 696.01M 118MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 93%[=================> ] 736.01M 122MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 97%[==================> ] 768.01M 118MB/s eta 1s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "albert_xxlarge.tar. 100%[===================>] 787.79M 117MB/s in 6.5s \r\n", "\r\n", "2023-10-17 12:27:34 (122 MB/s) - ‘albert_xxlarge.tar.gz’ saved [826059238/826059238]\r\n", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "albert_xxlarge/\r\n", "albert_xxlarge/bert_model.ckpt.index\r\n", "albert_xxlarge/30k-clean.model\r\n", "albert_xxlarge/30k-clean.vocab\r\n", "albert_xxlarge/bert_model.ckpt.data-00000-of-00001\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "albert_xxlarge/params.yaml\r\n", "albert_xxlarge/albert_config.json\r\n" ] } ], "source": [ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n", "albert_model_display_name = 'ALBERT-xxlarge English' # @param ['ALBERT-base English', 'ALBERT-large English', 'ALBERT-xlarge English', 'ALBERT-xxlarge English']\n", "\n", "if albert_model_display_name == 'ALBERT-base English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz\"\n", " !tar -xvf \"albert_base.tar.gz\"\n", "elif albert_model_display_name == 'ALBERT-large English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_large.tar.gz\"\n", " !tar -xvf \"albert_large.tar.gz\"\n", "elif albert_model_display_name == \"ALBERT-xlarge English\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xlarge.tar.gz\"\n", " !tar -xvf \"albert_xlarge.tar.gz\"\n", "elif albert_model_display_name == \"ALBERT-xxlarge English\":\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz\"\n", " !tar -xvf \"albert_xxlarge.tar.gz\"" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:42.952880Z", "iopub.status.busy": "2023-10-17T12:27:42.952150Z", "iopub.status.idle": "2023-10-17T12:27:42.958377Z", "shell.execute_reply": "2023-10-17T12:27:42.957763Z" }, "id": "5lZDWD7zUAAO" }, "outputs": [ { "data": { "text/plain": [ "'albert_xxlarge'" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Lookup table of the directory name corresponding to each model checkpoint\n", "folder_albert_dict = {\n", " 'ALBERT-base English': 'albert_base',\n", " 'ALBERT-large English': 'albert_large',\n", " 'ALBERT-xlarge English': 'albert_xlarge',\n", " 'ALBERT-xxlarge English': 'albert_xxlarge'\n", "}\n", "\n", "folder_albert = folder_albert_dict.get(albert_model_display_name)\n", "folder_albert" ] }, { "cell_type": "markdown", "metadata": { "id": "ftXwmObdU2fS" }, "source": [ "### Construct ALBERT Model Using the New `params.yaml`\n", "\n", "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:42.962052Z", "iopub.status.busy": "2023-10-17T12:27:42.961488Z", "iopub.status.idle": "2023-10-17T12:27:42.968925Z", "shell.execute_reply": "2023-10-17T12:27:42.968354Z" }, "id": "VXn20q2oU1UJ" }, "outputs": [ { "data": { "text/plain": [ "{'task': {'model': {'encoder': {'albert': {'attention_dropout_rate': 0.0,\n", " 'dropout_rate': 0.0,\n", " 'embedding_width': 128,\n", " 'hidden_activation': 'gelu',\n", " 'hidden_size': 4096,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 16384,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 64,\n", " 'num_layers': 12,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 30000},\n", " 'type': 'albert'}}}}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config_file = os.path.join(folder_albert, \"params.yaml\")\n", "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n", "config_dict" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:42.972118Z", "iopub.status.busy": "2023-10-17T12:27:42.971752Z", "iopub.status.idle": "2023-10-17T12:27:42.984904Z", "shell.execute_reply": "2023-10-17T12:27:42.984323Z" }, "id": "Uo_TSMSvWOX_" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 30000,\n", " 'embedding_width': 128,\n", " 'hidden_size': 4096,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 64,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 16384,\n", " 'dropout_rate': 0.0,\n", " 'attention_dropout_rate': 0.0,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Method 1: pass encoder config dict into EncoderConfig\n", "encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:42.988191Z", "iopub.status.busy": "2023-10-17T12:27:42.987558Z", "iopub.status.idle": "2023-10-17T12:27:43.001705Z", "shell.execute_reply": "2023-10-17T12:27:43.001007Z" }, "id": "u7oJe93uWcy0" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 30000,\n", " 'embedding_width': 128,\n", " 'hidden_size': 4096,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 64,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 16384,\n", " 'dropout_rate': 0.0,\n", " 'attention_dropout_rate': 0.0,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Method 2: use override_params_dict function to override default Encoder params\n", "encoder_config = tfm.nlp.encoders.EncoderConfig()\n", "tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "abpQFw80Wx6c" }, "source": [ "### Construct ALBERT Model Using the Old `albert_config.json`" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:43.004660Z", "iopub.status.busy": "2023-10-17T12:27:43.004433Z", "iopub.status.idle": "2023-10-17T12:27:43.009490Z", "shell.execute_reply": "2023-10-17T12:27:43.008876Z" }, "id": "Xb99qms6WuPa" }, "outputs": [ { "data": { "text/plain": [ "{'hidden_size': 4096,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 16384,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 64,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 30000,\n", " 'embedding_width': 128,\n", " 'attention_dropout_rate': 0.0,\n", " 'dropout_rate': 0.0,\n", " 'num_layers': 12,\n", " 'hidden_activation': 'gelu'}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "albert_config_file = os.path.join(folder_albert, \"albert_config.json\")\n", "config_dict = json.loads(tf.io.gfile.GFile(albert_config_file).read())\n", "config_dict" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:43.012497Z", "iopub.status.busy": "2023-10-17T12:27:43.012171Z", "iopub.status.idle": "2023-10-17T12:27:43.024585Z", "shell.execute_reply": "2023-10-17T12:27:43.024014Z" }, "id": "mCW0RJHcEtVV" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 30000,\n", " 'embedding_width': 128,\n", " 'hidden_size': 4096,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 64,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 16384,\n", " 'dropout_rate': 0.0,\n", " 'attention_dropout_rate': 0.0,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder_config = tfm.nlp.encoders.EncoderConfig({\n", " 'type':'albert',\n", " 'albert': config_dict\n", "})\n", "\n", "encoder_config.get().as_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "EIAMaOxdZw5u" }, "source": [ "### Construct a Classifier with `encoder_config`\n", "\n", "Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:43.027646Z", "iopub.status.busy": "2023-10-17T12:27:43.027237Z", "iopub.status.idle": "2023-10-17T12:27:45.257212Z", "shell.execute_reply": "2023-10-17T12:27:45.256122Z" }, "id": "xTkUisEEFEey" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "albert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n", "albert_classifier = tfm.nlp.models.BertClassifier(network=albert_encoder, num_classes=2)\n", "\n", "tf.keras.utils.plot_model(albert_classifier)" ] }, { "cell_type": "markdown", "metadata": { "id": "m6EG_7CaZ2rI" }, "source": [ "### Load Pretrained Weights into the Classifier\n", "\n", "The provided pretrained checkpoint only contains weights for the ALBERT Encoder within the ALBERT Classifier. Weights for the Classification Head is still randomly initialized." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:45.261529Z", "iopub.status.busy": "2023-10-17T12:27:45.260851Z", "iopub.status.idle": "2023-10-17T12:27:45.624803Z", "shell.execute_reply": "2023-10-17T12:27:45.624073Z" }, "id": "7dOG3agXZ9Dx" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint = tf.train.Checkpoint(encoder=albert_encoder)\n", "checkpoint.read(\n", " os.path.join(folder_albert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()" ] }, { "cell_type": "markdown", "metadata": { "id": "6xsbeS-EcCqu" }, "source": [ "## Load ELECTRA model pretrained checkpoints" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:45.628328Z", "iopub.status.busy": "2023-10-17T12:27:45.628068Z", "iopub.status.idle": "2023-10-17T12:27:48.685718Z", "shell.execute_reply": "2023-10-17T12:27:48.684634Z" }, "id": "VpwIrAR4cIBF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2023-10-17 12:27:45-- https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz\r\n", "Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.114.207, 172.217.214.207, 142.251.6.207, ...\r\n", "Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.114.207|:443... connected.\r\n", "HTTP request sent, awaiting response... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "200 OK\r\n", "Length: 157951922 (151M) [application/octet-stream]\r\n", "Saving to: ‘small.tar.gz’\r\n", "\r\n", "\r", "small.tar.gz 0%[ ] 0 --.-KB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "small.tar.gz 10%[=> ] 16.01M 79.9MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "small.tar.gz 39%[======> ] 59.93M 150MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "small.tar.gz 63%[===========> ] 96.01M 153MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "small.tar.gz 92%[=================> ] 138.75M 167MB/s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "small.tar.gz 100%[===================>] 150.63M 173MB/s in 0.9s \r\n", "\r\n", "2023-10-17 12:27:46 (173 MB/s) - ‘small.tar.gz’ saved [157951922/157951922]\r\n", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "small/\r\n", "small/ckpt-1000000.data-00000-of-00001\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "small/params.yaml\r\n", "small/checkpoint\r\n", "small/ckpt-1000000.index\r\n" ] } ], "source": [ "# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n", "electra_model_display_name = 'ELECTRA-small English' # @param ['ELECTRA-small English', 'ELECTRA-base English']\n", "\n", "if electra_model_display_name == 'ELECTRA-small English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz\"\n", " !tar -xvf \"small.tar.gz\"\n", "elif electra_model_display_name == 'ELECTRA-base English':\n", " !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/base.tar.gz\"\n", " !tar -xvf \"base.tar.gz\"" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:48.689731Z", "iopub.status.busy": "2023-10-17T12:27:48.689450Z", "iopub.status.idle": "2023-10-17T12:27:48.695251Z", "shell.execute_reply": "2023-10-17T12:27:48.694629Z" }, "id": "fy4FmsNOhlNa" }, "outputs": [ { "data": { "text/plain": [ "'small'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Lookup table of the directory name corresponding to each model checkpoint\n", "folder_electra_dict = {\n", " 'ELECTRA-small English': 'small',\n", " 'ELECTRA-base English': 'base'\n", "}\n", "\n", "folder_electra = folder_electra_dict.get(electra_model_display_name)\n", "folder_electra" ] }, { "cell_type": "markdown", "metadata": { "id": "rgAcf-Fl3RTG" }, "source": [ "### Construct BERT Model Using the `params.yaml`\n", "\n", "params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:48.698883Z", "iopub.status.busy": "2023-10-17T12:27:48.698324Z", "iopub.status.idle": "2023-10-17T12:27:48.708951Z", "shell.execute_reply": "2023-10-17T12:27:48.708368Z" }, "id": "ZNBg5xzqh0Gr" }, "outputs": [ { "data": { "text/plain": [ "{'model': {'cls_heads': [{'activation': 'tanh',\n", " 'cls_token_idx': 0,\n", " 'dropout_rate': 0.1,\n", " 'inner_dim': 64,\n", " 'name': 'next_sentence',\n", " 'num_classes': 2}],\n", " 'disallow_correct': False,\n", " 'discriminator_encoder': {'type': 'bert',\n", " 'bert': {'attention_dropout_rate': 0.1,\n", " 'dropout_rate': 0.1,\n", " 'embedding_size': 128,\n", " 'hidden_activation': 'gelu',\n", " 'hidden_size': 256,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 1024,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 4,\n", " 'num_layers': 12,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 30522}},\n", " 'discriminator_loss_weight': 50.0,\n", " 'generator_encoder': {'type': 'bert',\n", " 'bert': {'attention_dropout_rate': 0.1,\n", " 'dropout_rate': 0.1,\n", " 'embedding_size': 128,\n", " 'hidden_activation': 'gelu',\n", " 'hidden_size': 64,\n", " 'initializer_range': 0.02,\n", " 'intermediate_size': 256,\n", " 'max_position_embeddings': 512,\n", " 'num_attention_heads': 1,\n", " 'num_layers': 12,\n", " 'type_vocab_size': 2,\n", " 'vocab_size': 30522}},\n", " 'num_classes': 2,\n", " 'num_masked_tokens': 76,\n", " 'sequence_length': 512,\n", " 'tie_embeddings': True}}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config_file = os.path.join(folder_electra, \"params.yaml\")\n", "config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n", "config_dict" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:48.712338Z", "iopub.status.busy": "2023-10-17T12:27:48.711773Z", "iopub.status.idle": "2023-10-17T12:27:48.724602Z", "shell.execute_reply": "2023-10-17T12:27:48.724016Z" }, "id": "i-yX-KgJyduv" }, "outputs": [ { "data": { "text/plain": [ "{'vocab_size': 30522,\n", " 'hidden_size': 256,\n", " 'num_layers': 12,\n", " 'num_attention_heads': 4,\n", " 'hidden_activation': 'gelu',\n", " 'intermediate_size': 1024,\n", " 'dropout_rate': 0.1,\n", " 'attention_dropout_rate': 0.1,\n", " 'max_position_embeddings': 512,\n", " 'type_vocab_size': 2,\n", " 'initializer_range': 0.02,\n", " 'embedding_size': 128,\n", " 'output_range': None,\n", " 'return_all_encoder_outputs': False,\n", " 'return_attention_scores': False,\n", " 'norm_first': False}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "disc_encoder_config = tfm.nlp.encoders.EncoderConfig(\n", " config_dict['model']['discriminator_encoder']\n", ")\n", "\n", "disc_encoder_config.get().as_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "1AdrMkH73VYz" }, "source": [ "### Construct a Classifier with `encoder_config`\n", "\n", "Here, we construct a Classifier with 2 classes and plot its model architecture. A Classifier consists of a ELECTRA discriminator encoder using the selected encoder config, a Dropout layer and a MLP classification head.\n", "\n", "**Note**: The generator is discarded and the discriminator is used for downstream tasks" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:48.727907Z", "iopub.status.busy": "2023-10-17T12:27:48.727540Z", "iopub.status.idle": "2023-10-17T12:27:51.168411Z", "shell.execute_reply": "2023-10-17T12:27:51.167342Z" }, "id": "98Pt-SxszAvN" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "disc_encoder = tfm.nlp.encoders.build_encoder(disc_encoder_config)\n", "elctra_dic_classifier = tfm.nlp.models.BertClassifier(network=disc_encoder, num_classes=2)\n", "tf.keras.utils.plot_model(elctra_dic_classifier)" ] }, { "cell_type": "markdown", "metadata": { "id": "aWQ2FKj64X5U" }, "source": [ "### Load Pretrained Weights into the Classifier\n", "\n", "The provided pretrained checkpoint contains weights for the entire ELECTRA model. We are only loading its discriminator (conveninently named as `encoder`) wights within the Classifier. Weights for the Classification Head is still randomly initialized." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "execution": { "iopub.execute_input": "2023-10-17T12:27:51.172362Z", "iopub.status.busy": "2023-10-17T12:27:51.172065Z", "iopub.status.idle": "2023-10-17T12:27:51.379102Z", "shell.execute_reply": "2023-10-17T12:27:51.378428Z" }, "id": "99pznFJszQfV" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint = tf.train.Checkpoint(encoder=disc_encoder)\n", "checkpoint.read(\n", " tf.train.latest_checkpoint(os.path.join(folder_electra))\n", " ).expect_partial().assert_existing_objects_matched()" ] } ], "metadata": { "colab": { "name": "load_lm_ckpts.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }