This is a PyTorch Tutorial to Transformers.
While we will apply the transformer to a specific task – machine translation – in this tutorial, this is still a tutorial on transformers and how they work. You've come to the right place, regardless of your intended task, application, or domain – natural language processing (NLP) or computer vision (CV).
This is the sixth in a series of tutorials I'm writing about implementing cool models on your own with the amazing PyTorch library.
Basic knowledge of PyTorch is assumed.
If you're new to PyTorch, first read Deep Learning with PyTorch: A 60 Minute Blitz and Learning PyTorch with Examples.
Questions, suggestions, or corrections can be posted as issues.
I'm using PyTorch 1.4
in Python 3.6
.
Contents
Objective
Broadly...
To build a transformer model.
More precisely...
To build a model that can generate an output sequence given an input sequence.
We choose an application that represents one of the most complex uses of a transformer – a sequence to sequence problem.
But once you understand the transformer, you can just as easily apply it to any task, application, or even domain (NLP or CV) of your choosing. After all, an image is also a sequence, but over two-dimensional space instead of time.
Even more precisely...
To build a model that can translate from one language to another.
Um ein Modell zu erstellen, das von einer Sprache in eine andere übersetzen kann.
We will be implementing the pioneering research paper 'Attention Is All You Need', which introduced the Transformer network to the world. A watershed moment for cutting-edge Natural Language Processing.
Wir werden das wegweisende Forschungspapier "Attention Is All You Need" umsetzen, das das Transformer-Netzwerk in die Welt eingeführt hat. Ein Wendepunkt für die hochmoderne Natural Language Processing.
Specifically, we are going to be translating from English to German. And yes, everything written here in German is straight from the horse's mouth! (The horse, of course, being the model.)
Konkret werden wir vom Englischen ins Deutsche übersetzen. Und ja, alles, was hier in deutscher Sprache geschrieben wird, ist direkt aus dem Mund des Pferdes! (Das Pferd ist natürlich das Modell.)
Concepts
-
Machine Translation. duh.
-
Transformer Network. We have all but retired recurrent neural networks (RNNs) in favour of transformers, a new type of sequence model that possesses an unparalleled ability for representation and abstraction – all while being simpler, more efficient, and significantly more parallelizable. Today, the application of transformers is near universal, as their resounding success in NLP has also led to increasing adoption in computer vision tasks.
-
Multi-Head Scaled Dot-Product Attention. At the heart of the transformer is the attention mechanism, specifically this flavour of attention. It allows the transformer to interpret and encode a sequence in a multitude of contexts and with an unprecedented level of nuance.
-
Encoder-Decoder Architecture. Similar to RNNs, transformer models for sequence transduction typically consist of an encoder that encodes an input sequence, and a decoder that decodes it, token by token, into the output sequence.
-
Positional Embeddings. Unlike RNNs, transformers do not innately account for the sequential nature of text – they instead view such a sequence as a bag of tokens, or pieces of text, that can be freely mixed and matched with tokens from the same or different bag. The coordinates of tokens in a sequence are therefore manually injected into the transformer as one-dimensional vectors or embeddings, allowing the transformer to incorporate their relative positions into its calculations.
-
Byte Pair Encoding. Language models are both enabled and constrained by their vocabularies. Machine translation, especially, is an open-vocabulary problem. Byte Pair Encoding is a way to construct a vocabulary of moderate size that is still able to represent nearly any word, whether it is known, seldom known, or unknown.
-
Beam Search. As an alternative to simply choosing the highest-scoring token at each step of the generative process, we consider multiple candidates, reserving judgement until we see what they give rise to in subsequent steps – before finally picking the best overall output sequence.
Overview
In this section, I will present an overview of the transformer. If you're already familiar with it, you can skip straight to the Implementation section or the commented code.
Transformers have completely changed the deep learning landscape. They've replaced recurrent neural networks (RNNs) as the workhorse of modern NLP. They have caused a seismic shift in our sequence modeling capabilities, not only with their amazing representational ability, but also with their capacity for transfer learning after self-supervised pre-training on large amounts of data. You have no doubt heard of these models in one form or another – BERT, GPT, etc.
Today, they are also increasingly being used in computer vision applications as an alternative to, or in combination with, convolutional neural networks (CNNs).
As in the original transformer paper, the context presented here is an NLP task – specifically the sequence transduction problem of machine translation. If you want to apply transformers to images, this tutorial is still a good place to learn about how they work.
Better than RNNS, but how?
Recurrent neural networks (RNNs) have long been a staple among NLP practitioners. They operate upon a sequence – you guessed it – sequentially. This isn't weird at all; it's even intuitive because that's how we do it too – we read text from one side to another.
On the other hand, our ability to train deep neural networks is somewhat predicated on our ability to perform efficient computation. The sequential nature of RNNs precludes parallelization – you cannot move to the next token in the sequence without processing the previous one. Training networks on large amounts of data can be significantly time consuming.
Transformers can process elements in a sequence in parallel.
The sequential processing of text in an RNN introduces another problem. The output of the RNN at a given position is conditioned directly on the output at the previous position, which in turn is conditioned on its previous position, and so on. However, logical dependencies in text can occur across longer distances. It is often the case that you need access to information from a dozen positions ago.
No doubt some of this information can persist across moderate distances, but a lot of it could have decayed in the daisy-chain of computation that defines the RNN. It's easy to see why – we're relying on the output at each position to encode not only the output at that position but also other information that may (or may not) be useful ten steps down the line, with the outputs of each intervening step also having to encode their own information and pass on this possibly relevant information.
There have been various modifications to the original RNN cell over the years to allievate this problem, the most notable of which is probably the Long Short-Term Memory (LSTM) cell, which introduces an additional pathway known as the "cell state" for the sequential flow of information across cells, thereby reducing the burden on the cell outputs to encode all of this information. While this allows for modeling longer dependencies, the fundamental problem still exists – an RNN can access other positions only through intervening positions and not directly.
Transformers allow direct access to other positions.
This means that each position can use information directly from other positions in a sequence, producing a highly context-aware and nuanced output. This, along with other design choices we will see later, makes way for transformers' unprecedented representational ability.
The "direct access" we are speaking of occurs through an attention mechanism, something not completely unlike attention mechanisms you may have encountered earlier – for example, between an RNN encoder and decoder. If the concept is completely unfamiliar to you, it doesn't matter as we will go over it in quite some detail very soon.
I would like to point out here that RNNs need not be unidirectional. Since important textual context can occur both before and after a certain position in the sequence, we often use bidrectional RNNs, where we operate two different RNN cells in opposite directions and combine their outputs. On the other hand, a transformer layer is inherently bidirectional – unless we choose to constrain access to a particular direction, as we will see later.
Also note that in RNNs, information about the positions of individual elements in the sequence are inherently encoded into the RNN by the fact that we process them in a specific order. If in a transformer, they are being processed all at once, we need another way of introducing this positional information. We will explore this later on.
Another thing to keep in mind is that, during inference, the part of a trained transformer network that deals with the generation of a new sequence still operates autoregressively, like in an RNN, where each element of the sequence is produced from previously generated elements. During training, everything is parallelized. Encoding a known sequence, such as the input sequence, is always parallelized.
Wait. A sequence of what?
Words? Not necessarily.
In fact, over the years, we've tried just about everything – characters, words, and... subwords.
Subwords can be characters, words, or anything in between. They are a nice trade-off between a compact character vocabulary with units that aren't meaningful by themselves that produce extremely long sequences, and a monstrously large vocabulary of full words that would still be completely tripped up by a new word. Subwords allow for encoding almost any word, even unseen words, with a relatively compressed vocabulary size.
To create an optimal vocabulary of subwords, we will be using a technique called Byte Pair Encoding, a form of subword tokenization. We will study how this works later. For now, I only wanted to give you a heads up in case you're wondering why the sequences in our examples are not always split into full words.
As an example, let's consider the following –
The most beautiful thing we can experience is the mysterious. It is the fundamental emotion which stands at the cradle of true art and true science.
This is a quotation from Albert Einstein, but only a translation because he originally said it in his native German –
Das Schönste, was wir erleben können, ist das Geheimnisvolle. Es ist das Grundgefühl, das an der Wiege von wahrer Kunst und Wissenschaft steht.
The model implemented in this tutorial back-translates this common English translation back to German, as –
Das Schönste, was wir erleben können, ist das Geheimnis: Es ist die grundlegende Emotion die an der Wiege der wahren Kunst und der wahren Wissenschaft steht.
The model does a fairly good job, as far as I can tell. It opts for the noun Geheimnis (instead of the adjective Geheimnisvolle) which is primarily used to mean something akin to "secret" rather than "mystery", but the context is still clear. Also, it presents the output as a single sentence which is not surprising because it was trained predominantly on single sentences.
The English quote above is tokenized as follows –
["_The", "_most", "_beautiful", "_thing", "_we", "_can", "_experience", "_is", "_the", "_myster", "ious", ".", "_It", "_is", "_the", "_fundamental", "_emotion", "_which", "_stands", "_at", "_the", "_c", "rad", "le", "_of", "_true", "_art", "_and", "_true", "_science", "."]
The German translation by the model is tokenized as –
["_Das", "_Schön", "ste,", "_was", "_wir", "_erleben", "_können,", "_ist", "_das", "_Geheim", "nis", ":", "_Es", "_ist", "_die", "_grundlegende", "_Em", "otion", "_die", "_an", "_der", "_Wie", "ge", "_der", "_wahren", "_Kunst", "_und", "_der", "_wahren", "_Wissenschaft", "_steht."]
We will use this English-German input-output pair as an example frequently through this tutorial, although there may be other examples as well.
You can see in the tokenization that common words are retained as full words, but others are split into multiple parts. The "_" (underscore) character signifies the beginning of a word in the original sentence.
From this point on, I will simply refer to the units in a sequence as tokens.
A familiar form
While a transformer is quite different in its inner working compared to an RNN, they do take on a structure that you may already be familiar with – an encoder and a decoder.
The goal of the encoder is to encode the input sequence into a deep, "learned" representation. Like in an RNN, this involves encoding each token in the sequence.
The goal of the decoder is to use this encoded representation of the input sequence to produce the output sequence. Like in an RNN, during training, we use something called teacher forcing – we use every word from the true output sequence to learn to generate the next true word in the sequence. But while an RNN still processes the input sequence and generates sequentially, a transformer does it in parallel.
During inference, the decoder operates autoregressively, which means that it generates one word at a time, each of which is used to generate the next word.
The first generation is prompted with a <BOS>
token, which implies "beginning of sequence – start generating".
In the next step, the second token is generated from the first generation.
And then, another.
You get the drift. The generative process is terminated upon the generation of an <EOS>
token which signifies "end of sequence – I'm done".
Depending on the task at hand, you need either only the encoder, only the decoder, or both –
-
A sequence classification or labeling task requires only the encoder. Popular transformer models like BERT are encoder-only.
-
A sequence-to-sequence task (such as machine translation in this tutorial) conventionally uses both the encoder and decoder in the set-up we just described. Popular transformer models like T5 and BART are encoder-decoder formulations.
-
Sequence generation can also be accomplished by a decoder-only model, where the input sequence or prompt can be used as the first bunch of tokens in an autoregressive decoder. The popular GPT family of transformer models are decoder-only.
In fact, after this tutorial, it will be easy for you to read and understand the research papers for these and other popular transformer models because they adopt, for the most part, the same transformer architecture with some modifications.
Now, without further delay, let's dive into what makes a transformer... a transformer.
Queries, Keys, and Values
Consider a search problem.
The goal is, given a query, to find a value that is most closely matched to it. This will be accomplished by comparing the query to a key associated with each candidate value.
But in real life, there is rarely a single relevant match – relevancy is a spectrum!
Would it then make sense to return a weighted average of the values instead as the result?
Yes, it would absolutely make sense, especially if the values are vectors or embeddings, where the different dimensions numerically encode various (and often abstract)