Decoder-Only Transformer LM
Description
A decoder-only Transformer LM refers to a language model architecture that only employs the decoder component of a Transformer model. In the original Transformer architecture, used in tasks like machine translation, the model consists of an encoder and a decoder. The encoder processes the input sequence, creating a contextualized representation, while the decoder generates the output sequence based on that representation.
In a decoder-only Transformer LM, the encoder part of the model is omitted, and only the decoder is used. This setup is typically employed in autoregressive language models where the task is to generate the next token in a sequence based on the previous tokens. In such models, the decoder predicts the next token given the previously generated tokens, often by attending to them through self-attention mechanisms.
Decoder-only Transformer LMs are commonly used in tasks like language generation, text completion, and dialogue generation, where the model needs to generate coherent and contextually appropriate sequences of text. These models are trained on large corpora of text data using techniques like maximum likelihood estimation or variants of it, such as teacher forcing or self-attention mechanisms.
Implementation
The following example uses the PeerMR fork of TensorFlow.js. This fork implements the All-Reduce pattern so that you can train models with large amounts of data across multiple web browsers. It also provides a good example of how to port Tensorflow Python code to Tensorflow.js. The original Tensorflow Python version is available here.
// adapted from: https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm
// use PMRFS for storage and use 8 workers
const execution = new JobExecution('pmrfs', workerCount = 8);
// require the PeerMR fork of Tensorflow.js
execution.scripts = ["https://storage.googleapis.com/peermr.com/js/tf.es2017-5.js"];
// require a WebGPU browser
execution.gpu = true;
// not necessary for this example, but here we specify some WebGPU requirements
// to demonstrate the capability.
// workers that do not meet these requirements will not be selected to run this job.
execution.gpuRequirements = {
info: {
'vendor': 'apple',
},
features: {
'texture-compression-astc': true
},
limits: {
'maxBufferSize': 4294967296
}
}
// coordinator browser will fetch the input and distribute in pieces to all workers
execution.setInputFn(async function inputs() {
async function getDatasets() {
const trainDataUrl = 'https://storage.googleapis.com/peermr/ptb/ptb.train.txt';
context.log(`fetching train data ${trainDataUrl}`);
const trainDataResponse = await fetch(trainDataUrl);
const trainDataText = await trainDataResponse.text();
// get all tokens
context.log(`parsing vocab`);
const eos = '<eos>';
const vocabTokens = {};
vocabTokens[eos] = true;
const lines = trainDataText.split('\\n');
for (const line of lines) {
const tokens = line.trim().split(' ');
for (const token of tokens) {
const t = token.trim();
if (t.length) {
vocabTokens[t] = true;
}
}
}
const tokenCount = Object.keys(vocabTokens).length;
context.log(`${tokenCount} vocab tokens`);
if (tokenCount !== 10_000) {
throw new Error(`invalid token count ${tokenCount}`);
}
// group tokens by unique index
context.log(`group tokens by index`);
let index = 0;
const vocab = {};
for (const token in vocabTokens) {
vocab[token] = index;
index += 1;
}
context.log(`${Object.keys(vocab).length} vocab size`);
function* getData(text) {
const lines = text.split('\\n');
for (const line of lines) {
const tokens = line.trim().split(' ');
for (const token of tokens) {
const t = token.trim();
if (t.length) {
yield vocab[t];
}
}
yield vocab[eos];
}
}
const trainData = Array.from(getData(trainDataText));
context.log(`${trainData.length} train data size`);
return [vocab, trainData];
}
const [vocab, trainData] = await getDatasets();
const workerCount = context.getWorkerCount();
const samplesPerWorker = trainData.length / workerCount;
const mapStageInputs = [];
for (let i = 0; i < workerCount; i++) {
const data = trainData.slice(i * samplesPerWorker, (i + 1) * samplesPerWorker);
mapStageInputs.push({
'vocab': vocab,
'data': data,
});
}
return mapStageInputs;
});
const stage = new MapStage(async function map(mapInputs) {
if (mapInputs.length !== 1) {
throw new Error('invalid mapper input length: ' + mapInputs.length);
}
const contextSize = 1024; // context size in tokens of the model
const numBlocks = 12; // number of transformer blocks
const numHeads = 16; // number of heads used for multi-head attention
const batchSize = 2; // minibatch size
const iterations = 1000; // train iterations
const learningRate = 1e-3; // SGD learning rate
const modelDims = 1024; // dimensionality of embeddings and hidden layers
function createAdditiveCausalMask(N) {
return tf.tidy(() => {
const indices = tf.range(0, N);
const mask = tf.less(tf.reshape(indices, [-1, 1]), tf.reshape(indices, [1, -1]));
return mask.cast('float32').mul(tf.scalar(-1e9));
});
}
class SelfAttention extends tf.layers.Layer {
constructor(name, numHeads, modelDims, contextSize) {
super({name: name, inputShape: [modelDims, modelDims]});
this.Wq = tf.layers.dense({units: modelDims, useBias: false});
this.Wk = tf.layers.dense({units: modelDims, useBias: false});
this.Wv = tf.layers.dense({units: modelDims, useBias: false});
this.Wo = tf.layers.dense({units: modelDims, useBias: false});
this.causalMask = createAdditiveCausalMask(contextSize);
this.numHeads = numHeads;
this.headDim = modelDims / numHeads;
this.scale = tf.scalar(1.0 / Math.sqrt(this.headDim));
}
build(inputShape) {
this.Wq.build(inputShape);
this.Wk.build(inputShape);
this.Wv.build(inputShape);
this.Wo.build(inputShape);
}
call(inputs, kwargs) {
return tf.tidy(() => {
const x = inputs[0];
let queries = this.Wq.apply(x);
let keys = this.Wk.apply(x);
let values = this.Wv.apply(x);
const [B, L, D] = x.shape;
queries = tf.transpose(tf.reshape(queries, [B, L, this.numHeads, -1]), [0, 2, 1, 3]);
keys = tf.transpose(tf.reshape(keys, [B, L, this.numHeads, -1]), [0, 2, 1, 3]);
values = tf.transpose(tf.reshape(values, [B, L, this.numHeads, -1]), [0, 2, 1, 3]);
let scores = this.scale.mul(queries.matMul(tf.transpose(keys, [0, 1, 3, 2])));
scores = tf.softmax(scores.add(this.causalMask), -1);
values = scores.matMul(values);
let valuesHat = tf.reshape(tf.transpose(values, [0, 2, 1, 3]), [B, L, -1]);
return this.Wo.apply(valuesHat);
});
}
}
SelfAttention.className = 'SelfAttention';
tf.serialization.registerClass(SelfAttention);
class EncoderLayer extends tf.layers.Layer {
constructor(name, numHeads, modelDims, contextSize) {
super({name: name, inputShape: [modelDims, modelDims]});
// if these are not wrapped in sequential then the layers need to have build called with this.build is called
this.selfAttn = tf.sequential();
this.selfAttn.add(tf.layers.layerNormalization({epsilon: 1e-5, inputShape: [modelDims, modelDims]}));
this.selfAttn.add(new SelfAttention(`${name}-self-attention`, numHeads, modelDims, contextSize));
this.mlp = tf.sequential();
this.mlp.add(tf.layers.layerNormalization({epsilon: 1e-5, inputShape: [modelDims, modelDims]}));
this.mlp.add(tf.layers.dense({units: 4 * modelDims, activation: 'relu', inputShape: [modelDims, modelDims]}));
this.mlp.add(tf.layers.dense({units: modelDims}));
}
call(inputs, kwargs) {
return tf.tidy(() => {
const x = this.selfAttn.apply(inputs);
const x1 = tf.add(inputs, x);
const x2 = this.mlp.apply(x1);
return tf.add(x1, x2);
});
}
}
EncoderLayer.className = 'EncoderLayer';
tf.serialization.registerClass(EncoderLayer);
class TransformerLM extends tf.LayersModel {
constructor(vocabSize, numLayers, numHeads, modelDims, contextSize) {
const embedding = tf.layers.embedding({
name: 'embedding',
inputDim: vocabSize,
outputDim: modelDims,
inputLength: modelDims,
});
const trnsfrmr = tf.sequential();
for (let i = 0; i < numLayers; i++) {
trnsfrmr.add(new EncoderLayer(
`transformer-encoder-layer-${i}`, numHeads, modelDims, contextSize));
}
const projection = tf.layers.dense({
name: 'projection',
units: vocabSize,
});
const input = tf.input({shape: [modelDims]});
const x1 = embedding.apply(input);
const x2 = trnsfrmr.apply(x1);
const x3 = projection.apply(x2);
super({
inputs: input,
outputs: x3,
name: 'transformer_lm'
});
}
}
TransformerLM.className = 'TransformerLM';
tf.serialization.registerClass(TransformerLM);
function toSamples(contextSize, dataset, batchSize, offset) {
return tf.tidy(() => {
const windowSize = contextSize + 1; // include target
const d = [];
const upperBound = batchSize + offset;
for (let i = offset; i < upperBound; i++) {
d.push(dataset.slice(i, i + windowSize));
}
let X = tf.tensor(d);
const rows = X.shape[0];
const cols = X.shape[1];
const inputs = X.slice([0, 0], [rows, cols - 1]); // all but last column
const targets = X.slice([0, 1], [rows, cols - 1]); // all but first column
return [inputs, targets];
});
}
const [k, mapData] = mapInputs[0];
const trainData = mapData['data'];
const vocab = mapData['vocab'];
const vocabSize = Object.keys(vocab).length;
context.log(`vocab size: ${vocabSize}`);
const transformer = new TransformerLM(
vocabSize, numBlocks, numHeads, modelDims, contextSize
);
transformer.compile({
optimizer: tf.train.sgd(learningRate),
loss: 'sparseCategoricalCrossentropy',
});
transformer.build([batchSize, contextSize]);
transformer.summary();
const paramCount = transformer.countParams();
context.log(`training a transformer with ${paramCount} parameters`);
const tokenCount = trainData.length;
const windowSize = contextSize + 1; // include target
const sampleSize = tokenCount - windowSize + 1;
for (let i = 0; i < iterations; i++) {
for (let j = 0; j < sampleSize; j += batchSize) {
const [inputs, targets] = toSamples(contextSize, trainData, batchSize, j);
const loss = await transformer.trainOnBatch(inputs, targets, context);
inputs.dispose();
targets.dispose();
if (j % 10 === 0) {
context.log(`iteration ${j}, loss: ${loss}`);
context.log(tf.memory());
}
}
}
await context.onComplete();
});
stage.timeout = 8 * 60 * 60;
execution.addStage(stage);
execution.start(jobRunner);