simons blog

Chunkwise Gated Delta Rule

Chunkwise Gated Delta Rule is important when performing operations such as prefilling or training with the recently popular Gated Delta Attention formulation.

In essence, in the chunkwise algorithm, we want to transform the naive recurrent equations for the transition from one timestep to the next into a form that is more GPU-friendly. This works by splitting the sequence length into chunks and then deriving formulas to perform the transition from one chunk to the next via matrix multiplication. For Gated Delta Net, this is not completely straightforward, which is why I will derive the transition formulas step by step in this blog post. For educational purposes, we start with Linear Attention, then move to the Delta Rule, and finally to the Gated Delta Rule.

Chunkwise Linear Attention

S[t]r=S[t]+i=1rv[t]i(k[t]i)TRdv×dk

where v[t]iRdv×1 denotes the i'th row relative to the beginning of the chunk VtC+1:(t+1)CRC×dv (we start to count from 1) of the matrix VRL×dv and similar for K. S[t] is simply StC and S[t]r is this state shifted by r timesteps.

For the output we have

o[t]r=S[t]rq[t]r=S[t]q[t]r+i=1rv[t]i((k[t]i)Tq[t]r)Rdv×1

Note we can write this in matrix form as

S[t+1]=S[t]+V[t]TK[t]Rdv×dk

and

O[t]=Q[t]S[t]T+(Q[t]K[t]TM)V[t]RC×dv

where MRC×C is the causal mask. Why we need causal mask can be understood in a picture:

Causal Mask

But we see that the sum at timestep r relative to beginning of chunk should just contain r summands. This is achieved by causal masking which will eliminate all upper right entries from the matrix above and thus gives us r summands for rth row in second matrix.

Chunkwise Delta Rule

For Delta Rule we have state update as

St=St1(IβtktktT)+βtvtktTRdv×dk

As above we can expand the state by applying recursion to obtain the equation for an offset r from the start of a chunk:

S[t]r=S[t](i=1r(Iβ[t]ik[t]i(k[t]i)T)):=P[t]r+i=1r(β[t]iv[t]i(k[t]i)Tj=i+1r(Iβ[t]jk[t]j(k[t]j)T)):=H[t]rRdv×dk.

We will prove useful identity for these terms now.

P[t]r

Let's prove that

P[t]r=Ii=1rw[t]i(k[t]i)T

By induction over rN>0.

r=1:

P[t]1=i=11(Iβ[t]ik[t]i(k[t]i)T)=Iβ[t]1k[t]1(k[t]1)T

choose w[t]1=β[t]1k[t]1 and the equation is fulfilled.

rr+1:

For the induction step, we continue from

P[t]r+1=P[t]r(Iβ[t]r+1k[t]r+1(k[t]r+1)T)

and use the induction hypothesis

P[t]r=Ii=1rw[t]i(k[t]i)T.

Thus,

P[t]r+1=(Ii=1rw[t]i(k[t]i)T)(Iβ[t]r+1k[t]r+1(k[t]r+1)T).

Expanding the product yields

P[t]r+1=Iβ[t]r+1k[t]r+1(k[t]r+1)Ti=1rw[t]i(k[t]i)T+i=1rw[t]i(k[t]i)Tβ[t]r+1k[t]r+1(k[t]r+1)T.

Since ((k[t]i)Tβ[t]r+1k[t]r+1) is a scalar, the last term can be rewritten as

i=1rβ[t]r+1((k[t]i)Tk[t]r+1)w[t]i(k[t]r+1)T.

Therefore,

P[t]r+1=Ii=1rw[t]i(k[t]i)Tβ[t]r+1(k[t]r+1i=1r((k[t]i)Tk[t]r+1)w[t]i)(k[t]r+1)T.

Now define

w[t]r+1:=β[t]r+1(k[t]r+1i=1r((k[t]i)Tk[t]r+1)w[t]i).

Then we obtain

P[t]r+1=Ii=1r+1w[t]i(k[t]i)T.

This is exactly the desired form, so the induction step is proved.

Moreover, the proof gives a recursive way to compute the vectors w[t]r: start with

w[t]1=β[t]1k[t]1,

and for r>1 compute

w[t]r=β[t]r(k[t]ri=1r1((k[t]i)Tk[t]r)w[t]i).

H[t]r

Let's prove that

H[t]r=i=1ru[t]i(k[t]i)T

By induction over rN>0.

r=1:

H[t]1=i=11(β[t]iv[t]i(k[t]i)Tj=i+11(Iβ[t]jk[t]j(k[t]j)T))=β[t]1v[t]1(k[t]1)T

Choose u[t]1=β[t]1v[t]1 and the equation is fulfilled.

rr+1:

H[t]r+1=i=1r+1(β[t]iv[t]i(k[t]i)Tj=i+1r+1(Iβ[t]jk[t]j(k[t]j)T))

Let's use the following identities to bring this into a form where we can use the induction hypothesis:

j=1r+1xj=(j=1rxj)+xr+1j=1r+1xj=(j=1rxj)xr+1

Split up such that we first expand the last sum, and then the last product term for each product, and factor that out:

H[t]r+1=i=1r+1(β[t]iv[t]i(k[t]i)Tj=i+1r+1(Iβ[t]jk[t]j(k[t]j)T))=i=1r(β[t]iv[t]i(k[t]i)Tj=i+1r+1(Iβ[t]jk[t]j(k[t]j)T))+β[t]r+1v[t]r+1(k[t]r+1)T=i=1r(β[t]iv[t]i(k[t]i)Tj=i+1r(Iβ[t]jk[t]j(k[t]j)T)(Iβ[t]r+1k[t]r+1(k[t]r+1)T))+β[t]r+1v[t]r+1(k[t]r+1)T=(i=1rβ[t]iv[t]i(k[t]i)Tj=i+1r(Iβ[t]jk[t]j(k[t]j)T))(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T

This gives us

H[t]r+1=H[t]r(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T

Plug in the induction hypothesis:

H[t]r+1=(i=1ru[t]i(k[t]i)T)(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T

Now distribute:

H[t]r+1=i=1ru[t]i(k[t]i)Ti=1ru[t]i(k[t]i)Tβ[t]r+1k[t]r+1(k[t]r+1)T+β[t]r+1v[t]r+1(k[t]r+1)T

Since (k[t]i)Tk[t]r+1 is a scalar, this becomes

H[t]r+1=i=1ru[t]i(k[t]i)Ti=1rβ[t]r+1((k[t]i)Tk[t]r+1)u[t]i(k[t]r+1)T+β[t]r+1v[t]r+1(k[t]r+1)T

Group the last two terms together:

H[t]r+1=i=1ru[t]i(k[t]i)T+(β[t]r+1v[t]r+1i=1rβ[t]r+1((k[t]i)Tk[t]r+1)u[t]i)(k[t]r+1)T

Factor out β[t]r+1:

H[t]r+1=i=1ru[t]i(k[t]i)T+β[t]r+1(v[t]r+1i=1r((k[t]i)Tk[t]r+1)u[t]i)(k[t]r+1)T

Define

u[t]r+1:=β[t]r+1(v[t]r+1i=1r((k[t]i)Tk[t]r+1)u[t]i)

Then we obtain

H[t]r+1=i=1r+1u[t]i(k[t]i)T

which is exactly the desired form.

Moreover, the proof gives a recursive way to compute the vectors u[t]r: start with

u[t]1=β[t]1v[t]1,

and for r>1 compute

u[t]r=β[t]r(v[t]ri=1r1((k[t]i)Tk[t]r)u[t]i).

Hence, for every r1, we can write

H[t]r=i=1ru[t]i(k[t]i)T.

Simple vectorised expression

We have now an elegant form for translation of r timesteps for chunk t within this chunk:

S[t]r=S[t]P[t]r+H[t]rRdv×dk

where the terms on the right side were derived above.

Matrix equation

We can rewrite for the transition from one chunk to the next in a matrix notation:

P[t]=IW[t]TK[t]Rdk×dk

Note that can be understood by writing

W[t]T=[w[t]1,...,w[t]C]Rdk×CK[t]=[(k[t]1)T,...,(k[t]C)T]TRC×dk

The matrix multiplication will that give us

(W[t]TK[t]):,:=i=1C(W[t]):,iT(K[t])i,:=i=1Cw[t]i(k[t]i)T

Which corresponds to the transition from one chunk to the next we'd obtain from the vectorised formulation.

In similar way we can write

H[t]=U[t]TK[t]Rdv×dk

Let's derive closed forms for W[t] and U[t].

Write

W[t]=[(w[t]1)T,,(w[t]C)T]TRC×dk

and

K[t]=[(k[t]1)T,,(k[t]C)T]TRC×dk.

Also define

B[t]:=Diag(β[t]1,,β[t]C)RC×C

and

G[t]:=K[t]K[t]TRC×C.

The entries of G[t] are

(G[t])r,i=(K[t])r,:(K[t])i,:T=(k[t]r)Tk[t]i=(k[t]i)Tk[t]r.

Now define tril(A,1) for a matrix ARC×C as the matrix which keeps only the entries strictly below the main diagonal and sets all other entries to zero. In other words,

(tril(A,1))r,i={Ar,i,i<r,0,ir.

Using this, define

L[t]:=tril(B[t]K[t]K[t]T,1)RC×C.

Its entries are therefore

(L[t])r,i={β[t]r(k[t]i)Tk[t]r,i<r,0,ir.

Recall the recurrence

w[t]r=β[t]r(k[t]ri=1r1((k[t]i)Tk[t]r)w[t]i).

Since

W[t]=[(w[t]1)T,,(w[t]C)T]T,

the r-th row of W[t] is

(W[t])r,:=(w[t]r)T.

Similarly,

(K[t])r,:=(k[t]r)T.

Therefore the recurrence is equivalent, row by row, to

(W[t])r,:=β[t]r(K[t])r,:β[t]ri=1r1((k[t]i)Tk[t]r)(W[t])i,:.

Using the definition of L[t], we can rewrite the sum as

β[t]ri=1r1((k[t]i)Tk[t]r)(W[t])i,:=i=1C(L[t])r,i(W[t])i,:=(L[t]W[t])r,:.

Hence

(W[t])r,:+(L[t]W[t])r,:=(B[t]K[t])r,:.

Therefore

((I+L[t])W[t])r,:=(B[t]K[t])r,:

for every r=1,,C. Thus

(I+L[t])W[t]=B[t]K[t].

Therefore

W[t]=(I+L[t])1B[t]K[t].

Substituting the definition of L[t], we obtain

W[t]=(I+tril(B[t]K[t]K[t]T,1))1B[t]K[t].

Now define

T[t]:=(I+tril(B[t]K[t]K[t]T,1))1B[t]RC×C.

Then

W[t]=T[t]K[t].

In the same way, define

U[t]=[(u[t]1)T,,(u[t]C)T]TRC×dv

and

V[t]=[(v[t]1)T,,(v[t]C)T]TRC×dv.

Recall the recurrence

u[t]r=β[t]r(v[t]ri=1r1((k[t]i)Tk[t]r)u[t]i).

Since

U[t]=[(u[t]1)T,,(u[t]C)T]T,

the r-th row of U[t] is

(U[t])r,:=(u[t]r)T.

Similarly,

(V[t])r,:=(v[t]r)T.

Therefore the recurrence is equivalent, row by row, to

(U[t])r,:=β[t]r(V[t])r,:β[t]ri=1r1((k[t]i)Tk[t]r)(U[t])i,:.

Again using the definition of L[t], we obtain

β[t]ri=1r1((k[t]i)Tk[t]r)(U[t])i,:=i=1C(L[t])r,i(U[t])i,:=(L[t]U[t])r,:.

Hence

(U[t])r,:+(L[t]U[t])r,:=(B[t]V[t])r,:.

Therefore

((I+L[t])U[t])r,:=(B[t]V[t])r,:

for every r=1,,C. Thus

(I+L[t])U[t]=B[t]V[t].

Therefore

U[t]=(I+L[t])1B[t]V[t].

Substituting the definition of L[t], we obtain

U[t]=(I+tril(B[t]K[t]K[t]T,1))1B[t]V[t].

Using T[t], this can be written as

U[t]=T[t]V[t].

Hence we have the closed forms

W[t]=T[t]K[t],U[t]=T[t]V[t].

Plugging this into the expressions above yields

P[t]=IW[t]TK[t]=IK[t]TT[t]TK[t]

and

H[t]=U[t]TK[t]=V[t]TT[t]TK[t].

Matrix State and Output Form

We have

S[t+1]=S[t]P[t]+H[t]=S[t](IW[t]TK[t])+U[t]TK[t]

We can expand the brackets and factor out K[t] to obtain

S[t+1]=S[t]+(U[t]TS[t]W[t]T)K[t]=S[t]+(U[t]W[t]S[t]T)TK[t]

For the output

O[t]=Q[t]S[t]T+(Q[t]K[t]TM)(U[t]W[t]S[t]T)RC×dv

Compare the similarity to linear attention:

S[t+1]=S[t]+V[t]TK[t]Rdv×dk

and

O[t]=Q[t]S[t]T+(Q[t]K[t]TM)V[t]RC×dv

Where we see that conceptually

V[t]U[t]W[t]S[t]T

Gated Delta Net

Gated Delta Net has update rule for the state

St=St1(αt(IβtktktT))+βtvtktT𝐑dv×dk.

As above, we expand the state by applying recursion to obtain the equation for an offset r from the start of a chunk:

S[t]r=S[t](i=1rα[t]i(Iβ[t]ik[t]i(k[t]i)T)):=F[t]r+i=1r(β[t]iv[t]i(k[t]i)Tj=i+1rα[t]j(Iβ[t]jk[t]j(k[t]j)T)):=G[t]r𝐑dv×dk.

Thus,

S[t]r=S[t]F[t]r+G[t]r.

Cumulative gates

Define the cumulative gate

γ[t]r=j=1rα[t]j

and for 1ir

Γ[t]r,i=γ[t]rγ[t]i=j=i+1rα[t]j.

By convention, Γ[t]r,r=1, and we have

γ[t]r=Γ[t]r,iγ[t]i.

F[t]r

Since the gate factors are scalars, we can factor them out of the product and obtain

F[t]r=γ[t]ri=1r(Iβ[t]ik[t]i(k[t]i)T).

The remaining product is exactly the Delta Rule term P[t]r, so

F[t]r=γ[t]rP[t]r.

Using the result from the Delta Rule section,

P[t]r=Ii=1rw[t]i(k[t]i)T,

where

w[t]1=β[t]1k[t]1,

and for r>1

w[t]r=β[t]r(k[t]ri=1r1((k[t]i)Tk[t]r)w[t]i).

Therefore

F[t]r=γ[t]r(Ii=1rw[t]i(k[t]i)T).

G[t]r

Let's prove that

G[t]r=i=1rΓ[t]r,iu~[t]i(k[t]i)T

by induction over r𝐍>0.

r=1:

G[t]1=i=11(β[t]iv[t]i(k[t]i)Tj=i+11α[t]j(Iβ[t]jk[t]j(k[t]j)T))=β[t]1v[t]1(k[t]1)T.

Choose

u~[t]1=β[t]1v[t]1

and the equation is fulfilled.

rr+1:

We first derive a recurrence for G[t]r+1:

G[t]r+1=i=1r+1(β[t]iv[t]i(k[t]i)Tj=i+1r+1α[t]j(Iβ[t]jk[t]j(k[t]j)T)).

Split off the last term:

G[t]r+1=i=1r(β[t]iv[t]i(k[t]i)Tj=i+1r+1α[t]j(Iβ[t]jk[t]j(k[t]j)T))+β[t]r+1v[t]r+1(k[t]r+1)T.

Factor out the last product term:

G[t]r+1=(i=1rβ[t]iv[t]i(k[t]i)Tj=i+1rα[t]j(Iβ[t]jk[t]j(k[t]j)T))α[t]r+1(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T.

Hence

G[t]r+1=G[t]rα[t]r+1(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T.

Now plug in the induction hypothesis:

G[t]r+1=(i=1rΓ[t]r,iu~[t]i(k[t]i)T)α[t]r+1(Iβ[t]r+1k[t]r+1(k[t]r+1)T)+β[t]r+1v[t]r+1(k[t]r+1)T.

Distribute:

G[t]r+1=i=1rα[t]r+1Γ[t]r,iu~[t]i(k[t]i)Ti=1rα[t]r+1Γ[t]r,iu~[t]i(k[t]i)Tβ[t]r+1k[t]r+1(k[t]r+1)T+β[t]r+1v[t]r+1(k[t]r+1)T.

Since α[t]r+1Γ[t]r,i=Γ[t]r+1,i and (k[t]i)Tk[t]r+1 is a scalar, this becomes

G[t]r+1=i=1rΓ[t]r+1,iu~[t]i(k[t]i)Ti=1rβ[t]r+1Γ[t]r+1,i((k[t]i)Tk[t]r+1)u~[t]i(k[t]r+1)T+β[t]r+1v[t]r+1(k[t]r+1)T.

Group the last two terms together:

G[t]r+1=i=1rΓ[t]r+1,iu~[t]i(k[t]i)T+β[t]r+1(v[t]r+1i=1rΓ[t]r+1,i((k[t]i)Tk[t]r+1)u~[t]i)(k[t]r+1)T.

Define

u~[t]r+1:=β[t]r+1(v[t]r+1i=1rΓ[t]r+1,i((k[t]i)Tk[t]r+1)u~[t]i).

Then we obtain

G[t]r+1=i=1r+1Γ[t]r+1,iu~[t]i(k[t]i)T.

This is exactly the desired form, so the induction step is proved.

Moreover, the proof gives a recursive way to compute the vectors u~[t]r: start with

u~[t]1=β[t]1v[t]1,

and for r>1 compute

u~[t]r=β[t]r(v[t]ri=1r1Γ[t]r,i((k[t]i)Tk[t]r)u~[t]i)𝐑dv.

Hence, for every r1, we can write

G[t]r=i=1rΓ[t]r,iu~[t]i(k[t]i)T𝐑dv×dk.

Simple vectorised expression

Substituting the expressions for F[t]r and G[t]r gives

S[t]r=γ[t]rS[t]+i=1rΓ[t]r,i(u~[t]iγ[t]iS[t]w[t]i)(k[t]i)T.

Matrix equation

As in the Delta Rule section, define

W[t]=[(w[t]1)T,,(w[t]C)T]T𝐑C×dk,K[t]=[(k[t]1)T,,(k[t]C)T]T𝐑C×dk,V[t]=[(v[t]1)T,,(v[t]C)T]T𝐑C×dv,

and

B[t]:=Diag(β[t]1,,β[t]C)𝐑C×C.

The matrix W[t] is unchanged from the Delta Rule case, so

W[t]=(I+tril(B[t]K[t]K[t]T,1))1B[t]K[t].

Now define

U~[t]=[(u~[t]1)T,,(u~[t]C)T]T𝐑C×dv.

Also define the matrix Γ[t]𝐑C×C by

(Γ[t])r,i={γ[t]rγ[t]i,i<r,0,ir.

To see how the recurrence for u~[t]r leads to a matrix equation, write the recurrence again:

u~[t]r=β[t]r(v[t]ri=1r1Γ[t]r,i((k[t]i)Tk[t]r)u~[t]i).

As above, the r-th row of U~[t] is

(U~[t])r,:=(u~[t]r)T,

and the r-th row of V[t] is

(V[t])r,:=(v[t]r)T.

Moreover,

(K[t]K[t]T)r,i=(k[t]r)Tk[t]i=(k[t]i)Tk[t]r.

Hence the recurrence is equivalent, row by row, to

(U~[t])r,:=β[t]r(V[t])r,:β[t]ri=1r1Γ[t]r,i(K[t]K[t]T)r,i(U~[t])i,:.

Now define

L~[t]:=tril(B[t](Γ[t]K[t]K[t]T),1)𝐑C×C.

Its entries are

(L~[t])r,i={β[t]rΓ[t]r,i(k[t]i)Tk[t]r,i<r,0,ir.

Therefore

β[t]ri=1r1Γ[t]r,i(K[t]K[t]T)r,i(U~[t])i,:=i=1C(L~[t])r,i(U~[t])i,:=(L~[t]U~[t])r,:.

Thus

(U~[t])r,:+(L~[t]U~[t])r,:=(B[t]V[t])r,:

for every r=1,,C. Hence

(I+L~[t])U~[t]=B[t]V[t].

Solving for U~[t] gives

U~[t]=(I+L~[t])1B[t]V[t]=(I+tril(B[t](Γ[t]K[t]K[t]T),1))1B[t]V[t]𝐑C×dv.

Matrix State and Output Form

Following the paper on Gated Delta Net, define the rescaled quantities

q[t]r=γ[t]rq[t]r,w[t]r=γ[t]rw[t]r,k[t]r=γ[t]Cγ[t]rk[t]r,S[t]=γ[t]CS[t].

Let Q[t], W[t], and K[t] be the row-wise matrix forms of these vectors. Then the hardware-efficient chunkwise state update is

S[t+1]=S[t]+(U~[t]W[t]S[t]T)TK[t]𝐑dv×dk.

For the output we obtain

O[t]=Q[t]S[t]T+(Q[t]K[t]TM)(U~[t]W[t]S[t]T)𝐑C×dv.

Compare this with the DeltaNet equations

S[t+1]=S[t]+(U[t]W[t]S[t]T)TK[t]

and

O[t]=Q[t]S[t]T+(Q[t]K[t]TM)(U[t]W[t]S[t]T).

We see that Gated DeltaNet keeps the same chunkwise structure, but replaces the un-gated quantities by the gate-rescaled forms S[t], Q[t], W[t], K[t], and the UT-transformed matrix U~[t].

When αt=1 for all t, we have γ[t]r=1, hence

Q[t]=Q[t],W[t]=W[t],K[t]=K[t],S[t]=S[t],

and Γ[t] reduces to the strictly lower-triangular causal pattern, so the equations reduce to the DeltaNet chunkwise formulation.

Conclusion

I hope this blogpost could make the calculations involved in chunked wise formulation of the various variants of Linear Attention more accessible. Please check out the paper that introduces Gated Delta Net for more information on the Gated Delta Rule. If you like to connect or exchange ideas you can reach me on Linkedin or X.