simons blog

Backprob through Layernorm

Equation for Layernorm:

yi=ximiv+ϵwi=xiwi

m=1Ni=1Nxi, v=1Ni=1N(xm)2, xi=ximv+ϵ. As a computational graph: Screenshot 2025-07-13 at 14

yiL given.

xiL=yiL·wi

Sum over all paths from x to v.

vL=jxjL·vxjvxj=(xjm)v(v+ϵ)1/2=12(xjm)(v+ϵ)3/2

Plugging in:

vL=12(v+ϵ)3/2jyjL·wj·(xjm)

m has incoming paths from both x and v. Over the paths from x we need to sum and sum these with the path from v.

mL=vL·mv+jxjL·mxjmv=1Nmj(xjm)2=2Nj(xjm)=2N(jxjN·1Njxj)=0.mxj=mxmv+ϵ=1v+ϵ

Plugging in:

mL=1(v+ϵ)1/2jyjL·wj

Final backprop: Incoming from m, v and x.

xiL=mL·xim+vL·xiv+xiL·xixi

Calculate the terms separately.

First term:

xim=xi1Njxj=1NmL·xim=1(v+ϵ)1/21NjyjL·wj

Second term:

xiv=1Nxij(xjm)2=2N(xim)vL·xiv=1(v+ϵ)3/21N(xim)jyjL·wj·(xjm)

We can simplify this using definition of x.

vL·xiv=1(v+ϵ)1/21NxijyjL·wj·xj

Last term

xixi=xiximv+ϵ=1v+ϵxiL·xixi=yiL·wi·1(v+ϵ)1/2

To summarise we have the three terms:

mL·xim=1(v+ϵ)1/21NjyjL·wjvL·xiv=1(v+ϵ)1/21NxijyjL·wj·xjxiL·xixi=yiL·wi·1(v+ϵ)1/2

Now sum these up to obtain final result:

xiL=1(v+ϵ)1/2[yiL·wi1NjyjL·wjxi·1NjyjL·wj·xj]

This is the formula given here.

    @staticmethod
    def backward(dout, cache):
        x, w, mean, rstd = cache
        # recompute the norm (save memory at the cost of compute)
        norm = (x - mean) * rstd
        # gradients for weights, bias
        db = dout.sum((0, 1))
        dw = (dout * norm).sum((0, 1))
        # gradients for input
        dnorm = dout * w
        dx = dnorm - dnorm.mean(-1, keepdim=True) - norm * (dnorm * norm).mean(-1, keepdim=True)
        dx *= rstd
        return dx, dw, db