simons blog

Backprop through RMSNorm

yi=RMSNorm(xi)=1i=1Nxi2+ϵxiwi

Backprob

a=1Ni=1Nxi2, xi=1a+ϵxi, yi=xi·wi, yiL given.

First back propagation:

xiL=yiL·xiyi=yiL·wi

Second back propagation: Need to sum over all possible paths from x to a.

aL=jxjL·axjaxj=xja(a+ϵ)1/2=12xj(a+ϵ)3/2

Combine the two expressions gives

aL=12(a+ϵ)3/2jyjL·wj·xj

Third backpropagation: Incoming from two nodes, sum over these:

xiL=aL·xia+xiL·xixi

First term

xia=2N·xi

Second term

xixi=1(a+ϵ)1/2

From here it follows

xiL=xi(a+ϵ)3/2·1NjyjL·wj·xj+1(a+ϵ)1/2yiL·wi.

Factoring out common terms gives:

1(a+ϵ)1/2(yiL·wixi(a+ϵ)1/2·1NjyjL·wj·xj(a+ϵ)1/2).

We can simplify further by using definition of xi:

1(a+ϵ)1/2(yiL·wixi·1NjyjL·wj·xj).

This agrees with reference here

def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
    """Reference implementation for RMSNorm backward pass."""
    x_f32 = x.float()
    x_hat = x_f32 * rstd.unsqueeze(1)
    wdy = dout * w
    c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
    dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)

    # dL/dW
    dw = (dout * x_hat).sum(dim=0)
    return dx.to(x.dtype), dw.to(w.dtype)