GPT-2 From Scratch in MLX
Train.py is ~200 lines of python code that define and train GPT-2 from scratch using mlx and numpy as the only dependencies. This readme will detail writing train.py from scratch. The model is trained on ~1 million characters of Shakespeare contained in input.txt, and it can be trained in around 10 minutes on a macbook to produce coherent Shakespeare-like text.
This tutorial follows Karpathy's GPT from scratch but in MLX.
Table of Contents
- Preparing the data
- Creating the vocabulary
- Coding GPT-2
- Positional Embeddings
- Self-Attention
- Multi-Head Attention
- MLP
- Block
- Layernorms and Skip Connections
- Forward Pass
- Sampling
- Initialization
- Training Loop
- References
Preparing the data
Install mlx and run the following imports.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as utils
import numpy as np
import math
The first step to training an LLM is collecting a large corpus of text data and then tokenizing it. Tokenization is the process of mapping text to integers, which can be fed into the LLM. Our training corpus for this model will be the works of Shakespeare concatenated into one file. This is roughly 1 million characters and looks like this:
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
...
First, we read the file as a single long string into the text variable. Then we use the set() function to get all the unique characters in the text which will be our vocabulary. By printing vocab you can see all the characters in our vocabulary as one string, and we have a total of 65 characters which till be our tokens.
Creating the vocabulary
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
print(''.join(vocab))
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
print(vocab_size)
# 65
Production models will use tokenization algorithms like byte-pair encoding to generate a larger vocabulary of sub-word chunks. Since our focus today is on the architecture, we will continue with character-level tokenization. Next, we will map our vocabulary to integers known as token IDs. Then we can encode our text into tokens and decode them back to a string.
# Create mapping from vocab to integers
itos = {i:c for i,c in enumerate(vocab)} # int to string
stoi = {c:i for i,c in enumerate(vocab)} # string to int
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])
print(encode("hello world"))
# [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
print(decode(encode("hello world")))
# hello world
We use the enumerate()
function to iterate over all characters and their index in the vocabulary and create a dictionary itos which maps integers to characters and stoi which maps strings to integers. Then we use these mappings to create our encode and decode functions. Now we can encode the entire text and split training and validation data.
data = encode(text)
split = int(0.9 * len(data))
train_data = data[:split]
val_data = data[split:]
Currently, our training data is just a very long string of tokens. However, we are trying to train our model to predict the next token some given previous tokens. Therefore our dataset should be comprised of examples where the input is some string of tokens and the label is the correct next token. We need to define a model parameter called context length which is the maximum number of tokens used to predict the next token. Our training examples will be the length of our context length. Let's look at the first ctx_len+1 tokens.
ctx_len = 8
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58, 1, 15, 47, 58]
# x: [18, 47, 56, 57, 58, 1, 15, 47] | y: 58
This is one training example where the input is "18, 47, 56, 57, 58, 1, 15, 47" and the desired output is "58". This is 8 tokens of context. However, we also want to train the model to predict the next token given only 7, 6, 5 … 0 tokens as context which is needed during generation. Therefore we also consider the 8 sub examples packed into this example:
ctx_len = 8
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58, 1, 15, 47, 58]
# 8 sub examples
# [18] --> 47
# [18, 47] --> 56
# [18, 47, 56] --> 57
# [18, 47, 56, 57] --> 58
# [18, 47, 56, 57, 58] --> 1
# [18, 47, 56, 57, 58, 1] --> 15
# [18, 47, 56, 57, 58, 1, 15] --> 47
# [18, 47, 56, 57, 58, 1, 15, 47] --> 58
Notice that the labels are simply the inputs shifted left.
print("inputs: ", train_data[:ctx_len])
print("labels: ", train_data[1:ctx_len+1]) # labels = inputs indexed 1 higher
# inputs: [18, 47, 56, 57, 58, 1, 15, 47]
# labels: [47, 56, 57, 58, 1, 15, 47, 58]
At index 0 the input is 18 and the label is 47. At index 1 the input is everything before and including index 1 which is [18, 47] and the label is 56, etc. Now that we understand that the labels are simply the input sequence indexed one higher we can build our datasets.
# Creating training and validation datasets
ctx_len = 8
X_train = mx.array([train_data[i:i+ctx_len] for i in range(0, len(train_data) - ctx_len, ctx_len)])
y_train = mx.array([train_data[i+1:i+ctx_len+1] for i in range(0, len(train_data) - ctx_len, ctx_len)])
X_val = mx.array([val_data[i:i+ctx_len] for i in range(0, len(val_data) - ctx_len, ctx_len)])
y_val = mx.array([val_data[i+1:i+ctx_len+1] for i in range(0, len(val_data) - ctx_len, ctx_len)])
We loop through the data and take chunks of size ctx_len
as the inputs (X) and then take the same chunks but at 1 higher index as the labels (y). Then we take these Python lists and create mlx array objects from them. The model internals will be written with mlx so we want our inputs to be mlx arrays.
One more thing. During training we don't want to feed the model one example at a time, we want to feed it multiple examples in parallel for efficiency. This group of examples is called our batch, and the number of examples in a group is our batch size. Thus we define a function to generate batches for training.
def get_batches(X, y, b_size, shuffle=True):
if shuffle:
ix = np.arange(X.shape[0])
np.random.shuffle(ix)
ix = mx.array(ix)
X = X[ix]
y = y[ix]
for i in range(0, X.shape[0], b_size):
input = X[i:i+b_size]
label = y[i:i+b_size]
yield input, label
If shuffle=True, we shuffle the data by indexing it with a randomly shuffled index. Then we loop through our dataset and return batch-size chunks from input and label datasets. These chunks are known as mini-batches and are just stacked examples that we process in parallel. These mini-batches will be our input to the model during training.
Here's an example of a minibatch of 4 examples with context length 8. This minibatch packs 32 next-token prediction problems. The model will predict the next token for each token in the input and the labels will be used to calculate the loss. Notice that the labels contain the next token for each index of the inputs.
You'll want to keep this picture in your mind because the shapes of these tensors will get hairy. For now, just remember that we will input a tensor of shape (batch_size, ctx_len) to the model.
Coding GPT-2
Let's look at the GPT-2 architecture to get an overview of what we are trying to implement. Don't worry if this looks confusing. We will implement it step by step from bottom to top. Let's start by implementing the input embeddings.
Input Embeddings
The purpose of the input embedding layer is to map token IDs to vectors. Each token will be mapped to a vector which will be its representation as it is forwarded through the model. The vectors for each token will accumulate and exchange information as they pass through the model and eventually be used to predict the next token. These vectors are called embeddings.
The simplest way to map token IDs to vectors is through a lookup table. We create a matrix of size (vocab_size, n_emb) where each row is the embedding vector for the corresponding token. This matrix is known as the embedding weights.
The diagram shows an example embedding layer of size (65, 6). This means there are 65 tokens in the vocabulary and each one will be represented by a length 6 embedding vector. The inputted sequence will be used to index the embedding weights to get the vector corresponding to each token. Remember the minibatches we input into the model? Originally the minibatch is size (batch_size, ctx_len). After passing through the embedding layer it is size (batch_size, ctx_len, n_emb). Instead of each token being a single integer, each token is now a vector of length n_emb.
Let's define the embedding layer in code now.
n_emb = 6 # You can add these hyperparams at the top of your file
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb)
We will define a class to organize our implementation. We subclass nn.Module to take advantage of mlx's features. Then in the init function, we call the superclass constructor and initialize our token embedding layer called wte
.
Positional Embeddings
Next up is the positional embeddings. The purpose of positional embeddings is to encode information about the position of each token in the sequence. This can be added to our input embeddings to get a complete representation of each token that contains information about the token's position in the sequence.
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb) # token embeddings
self.wpe = nn.Embedding(ctx_len, n_emb) # position embeddings
The position embeddings work the same as token embeddings, except instead of having a row for each token we have a row for each possible position index. This means our embedding weights will be of shape (ctx_len, n_emb). Now we implement the call function in our GPT class. This function will contain the forward pass of the model.
# Tensor shapes commented
def __call__(self, x):
B, T = x.shape # (B = batch_size, T = ctx_len)
tok_emb = self.wte(x) # (B, T, n_emb)
pos_emb = self.wpe(mx.arange(T)) # (T, n_emb)
x = tok_emb + pos_emb # (B, T, n_emb)
First, we break out the dimensions of our input into variables B and T for easy handling. In sequence modeling contexts B and T are usually used as shorthand for "batch" and "time" dimensions. In this case, the "time" dimension of our sequence is the context length.
Next, we calculate token and position embeddings. Notice that for the position embeddings, our input is mx.arange(T)
. This will output an array of consecutive integers from 0 to T-1 which is exactly what we want because those are the positions we want to embed. After passing that through the embedding layer we will have a tensor of shape (T, n_emb) because the embedding layer plucks out the n_emb length vector for each of the T positions. Note that even though pos_emb is not the same shape as tok_emb we can add the two because mlx will broadcast, or replicate pos_emb across the batch dimension to allow elementwise addition. Finally, we perform the addition to get the new representations of the tokens with positional information.
Self-Attention
So far the representation vectors for each token have been calculated independently. They have not had the opportunity to exchange any information. This is intuitively bad in language modeling because the meaning and usage of words depend on the surrounding context. Self-attention is how we incorporate information from previous tokens into a given token. First, let's consider a naive approach. What if we simply represented each token as the average of its representation vector and the vectors of all the tokens before it? This achieves our goal of packing information from previous tokens into the representation for a given token. Here's what it would look like.
But self-attention doesn't involve writing a for-loop. The key insight is we can achieve this previous token averaging with matrix multiplication!
By multiplying our input sequence on the left by a special matrix we get the desired result. This matrix is known as the attention weights. Notice that each row of the attention weight matrix specificies "how much" of each other token goes into the representation for any given token. For example in row two, we have [0.5, 0.5, 0, 0]. This means that row two of the result will be 0.5*token1 + 0.5*token2 + 0*token3 + 0*token4
, or the average of token1 and token2. Note that the attention weights are a lower-triangular matrix (zeros in upper right entries). This ensures that future tokens will not be included in the representation of a given token. This ensures that tokens can only communicate with the previous tokens because during generation the model will only have access to previous tokens.
Let's look at how we can construct the attention weight matrix.
Notice that if we create an array of zeros with -inf in the upper right entries and then perform row-wise softmax we get the desired attention weights. A good exercise is to step through the softmax calculation for a row to see how this works. The takeaway is that we can take some array of size (ctx_len, ctx_len) and softmax each row to get attention weights that sum to one.
Now we can leave the realm of naive self-attention. Instead of simply averaging previous tokens, we use arbitrary weighted sums over previous tokens. Notice what happens when we do row-wise softmax of an arbitrary matrix.
We still get weights that sum to one on each row. During training, we can learn the numbers in the matrix on the left which will specify how much each token goes into the representation for another token. This is how tokens pay "attention" to each other. But we still haven't understood where this matrix on the left came from. These pre-softmax attention weights are calculated from the tokens themselves, but indirectly through three linear projections.
Keys, Queries, and Values
Each token in our sequence emits 3 new vectors. These vectors are called keys, queries, and values. We use the dot product of the query vector of one token and the key vector of another token to quantify the "affinity" those two tokens have. We want to calculate the pairwise affinities of each token with every other token, therefore we multiply the query vector (4x3) with the key vector transposed (3x4) to get the raw attention weights (4x4). Due to the way matrix multiplication works the (i,j) entry in the raw attention weights will be the query of token i dot the key of token j or the "affinity" between the two. Thus we have calculated interactions between every token. However, we don't want past tokens interacting with future tokens so we apply a mask of -inf to the upper right entries to ensure they will zero out after softmax. Then we perform row-wise softmax to get the final attention weights. Instead of multiplying these weights directly with the input, we multiply them with the value projection. This results in the new representations.
Now that we understand attention conceptually, let's implement it.
class Attention(nn.Module):
def