1 - Attention is all you need - Training

note, we will adopt a row first approach as is the nature of pytorch, even though column vectors as features agree with linear algebra convention more.
Now, there is a lot of oversimplification as to how a transformer actually works. Let us start with a very simple idea.
let us say we have an input sequence s1,s2,s3 and a target sequence t1,t2,t3,t4.
During training(for autoregressive models): s1=t1,s2=t2,s3=t3 and we want to predict T^=t1^,t^2,t^3,t^4
Actually what happens is: we pass S=[s1s2s3] as a matrix to the "encoder block" and we right shift the target T=[<s>t1t2t3] into the "decoder block".
So given S,T we need to produce a T^.
And the loss is then T^ against t1,t2,t3,t4.

So to be very clear, in order to predict a 4 length sequence from a 3 length sequence, we need to pass the 3 length sequence into the "encoder block", Pass a right shifted target sequence (which is the same as the input sequence if its an autoregressive task, or its different for translation type stuff). And then using both, we predict the entire output sequence and evaluate it against the full target sequence (both of length 4).
Support/Figures/Pasted image 20250519172942.png

So, let S=s1,s2,...sT and T=t1,t2,t3..tT+1 both as matrix-one-hots.

then, we pass S into the "encoder" and Tshift=<s>,t1,t2,tT into the "decoder". In theory, the prediction is auto-regressive, in the sense that only <s>,s1 is used to predict t1, (where t1=s1 in auto-regressive non translation type stuff).
and so on. In practise, S,Tshift are passed at once, and we use attention and masking to parallel emulate auto-regression.

we predict T^=t^1,t^T+1 and then take Loss(T,T^).

Let us see what's happening at a high level here:
first ES=emb(S) we embed the one hot vectors into some vector space.

Then we introduce some positional encoding X=Pos(ES). Then we feed it through Nlayer encoder blocks H=encNl(X) notice that each encoder layer is different, and power is used to denote applying them sequentially because I like brevity.

So H is the encoding. Now there are Nlayer decoder blocks as well. And each decoder block recives H in its forward pass! Moreover the first decoder block receives Y=pos(ETshifted) where ETshifted=decoderemb(Tshifted) and pos is as usual a positional encoding.
G=decNlayer(H,Y)
then T^=softmax(linear(G)), and loss is crossENT(T^,T).

Self-Attention: Scaled-Dotproduct:

so, given an embedding matrix X of shape (T,dmodel) we make three linear projections of this model: K=XWk Q=XWQ V=XWV.
both WK,WQRdmodel×dk or (dmodel,dk) dk is the key/query dimension. and WV has shape (dmodel,dv)

Support/Figures/Pasted image 20250519182016.png

Support/Figures/Pasted image 20250519182042.png

So we compute the dot product between the Q and K matrices, and then scale it by the key dimension. we also optionally mask the upper-triangle with -inf to zero out the attention of future time-steps with current ones. and then finally matmul by V.

it is pretty clear that the output of a single attention head has shape (T,dv). as the QKT has shape (T,T) computing the attention pattern. Support/Figures/Pasted image 20250519182620.png

For multi head attention, we use the same embedding matrix and pass it through different heads. we choose dv to be dmodel/h so that after concatenating the outputs of multiple heads and matmul by WO we get a matrix of shape Hout=(T,dmodel) again. Once we understand attention, there is not much else to the architecture:

Encoder block:

does multi head attention and a residual connection adds the original embedding to the output of the multi head attention. Then we normalize, then pass through an MLP and then do another residual connection and normalize again

L=norm(MHattn(X)+X)enc=norm(MLP(L)+L)

so applying multiple encoder layers in chain gives us the final encoding H. For an encoder block which isn't the first one, we use the output of the previous encoder block in place of X.

Decoder block:

once encoding is done, let Y be the position-ally encoded, embedded version of Tshifted.
Then, for the first decoder block,

A=norm(MaskedMultiHeadattn(Y)+Y)B=norm(MultiHeadattn(H,A)+A)C=norm(MLP(B)+B)

and for future decoder blocks, the previous block's output is used in place of Y and the encoded matrix H is passed to EACH DECODER LAYER.

finally we apply multiple decoder layers and then a linear followed by softmax to get T^, and do grad descent on cross entropy between T^ and T.

A small caveat,in actuality,the encoder-decoder attention, the one that we wrote B=norm(MultiHeadattn(H,A)+A) actually uses the final stored K,V matrices from the final encoder head (from the output encoder head) and the Q comes from the previous decoder layer.
For all other attention heads, it uses its own K,Q,V as usual.

Why attention?

Support/Figures/Pasted image 20250519185144.png Support/Figures/Pasted image 20250519185224.png
so here, (in self attention) per layer, we only do 1 sequential operation, the complexity per layer is fully parallelised. Moreover, the most important thing is that the number of forward props we need to do to process a full sequence is just one. (upto big O). In a recurrent model, if the seq length is n we need to do n forward props and then BPTT. but transformers process sequences in one batched go. so the forward prop path length is constant, independent of sequence length in a transformer.

Superposition:

In theory we would love for each semantic "idea" to be embedded into perpendicular dimensions. That is if d_model = N, then we can encode N different semantic ideas as perpendicular directions, and represent a combination of ideas as component sums of each direction.
However, if one relaxes and allows an ϵ free perpendicular encoding of semantics, that is each semantic is at most ϵ angle away from being perpendicular to any other semantic, then the number of directions grows as e(ϵN)!!!! So in this case, we have exponential growth in the number of possible directions for a "semantic feature", however this also means that most semantic features are not a single direction or a single neuron, which is one of the theories on why LLM interpret-ability is hard.

We will reserve implementation and training details to another time, this was just the core architecture.

Sources: 3b1b deep learning series, Attention is all you need paper, my brain.