simons blog

Simple math to speed up GDN prefill

In this short note we will briefly derive a helpful identity to speedup GDN prefill algorithm. In my simple Torch implementation of the algorithm that gave already a good speedup of about 18%. I assume for custom CUDA C++ kernel the gains can even be more pronounced.

Please read my previous post on GDN for background.

Reminder

The state transition for GDN is as follows:

St=St1(αt(IβtktktT))+βtvtktT𝐑dv×dk.

Define the cumulative gate

γ[t]r=j=1rα[t]j

and for 1ir

Γ[t]r,i=γ[t]rγ[t]i=j=i+1rα[t]j.

We obtained the following chunkwise transition rule

S[t+1]=S[t]+(U~[t]W[t]S[t]T)TK[t]𝐑dv×dk.

We had (up to a decay factor)

W[t]=(I+tril(B[t]K[t]K[t]T,1))1B[t]K[t].

and

U~[t]=(I+L~[t])1B[t]V[t]=(I+tril(B[t](Γ[t]K[t]K[t]T),1))1B[t]V[t]𝐑C×dv.

We see that there are two inverts involved these computations. Invert is potentially one of the bottlenecks during computation of the chunkwise transition so it would be helpful if we could save one of these.

Save one invert

Consider

M=I+tril(B[t]K[t]K[t]T,1)

and

N=I+tril(B[t](Γ[t]K[t]K[t]T),1).

Let K[t]C×dk and denote by kμ the μ-th row of K[t]. Then

(K[t]K[t]T)μν=kμTkν.

Since B[t] is diagonal, we write B[t]=diag(β1,,βC).
For μ>ν the strict lower triangular part of M therefore has entries

Mμν=βμkμTkν,Mμμ=1.

Similarly, for N we obtain

Nμν=βμΓ[t],μνkμTkν,Nμμ=1.

We see that these look already highly similar and will make this more explicit below.

Plugging in the definition of Γ we obtain

Γ[t],μν=γ[t],μγ[t],ν

and therefore

Nμν=βμγ[t],μγ[t],νkμTkν.

Let us factor out the γ factors. Define

G=diag(γ[t],1,,γ[t],C),G1=diag(γ[t],11,,γ[t],C1).

For any matrix A we have

(GAG1)μν=γ[t],μAμνγ[t],ν1.

Thus multiplying a matrix by G on the left and G1 on the right multiplies the (μ,ν) entry by γ[t],μγ[t],ν1.

Applying this observation to A=tril(B[t]K[t]K[t]T,1) yields

tril(B[t](Γ[t]K[t]K[t]T),1)=Gtril(B[t]K[t]K[t]T,1)G1.

Therefore we can write

N=I+Gtril(B[t]K[t]K[t]T,1)G1.

Since I=GIG1, we obtain

N=G(I+tril(B[t]K[t]K[t]T,1))G1=GMG1.

Taking the inverse gives

N1=(GMG1)1=GM1G1.

Thus the factor N appearing in U~[t] can be expressed using the same inverse M1 together with two diagonal matrix multiplications. This allows us to reuse the computed inverse and avoid performing a second matrix inversion.

Conclusion

In this short note we have seen how to derive simple math to speed up parallel algorithms in deep learning. Sometimes it is good to look carefully at the equations and bring them into their simplest form.

The observation how to "save the invert" was first made in Comba. Please take a look at their paper for an alternative derivation.