Backprop through RMSNorm
, , , given.
First back propagation:
Second back propagation: Need to sum over all possible paths from to .
Combine the two expressions gives
Third backpropagation: Incoming from two nodes, sum over these:
First term
Second term
From here it follows
Factoring out common terms gives:
We can simplify further by using definition of :
This agrees with reference here
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
"""Reference implementation for RMSNorm backward pass."""
x_f32 = x.float()
x_hat = x_f32 * rstd.unsqueeze(1)
wdy = dout * w
c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
# dL/dW
dw = (dout * x_hat).sum(dim=0)
return dx.to(x.dtype), dw.to(w.dtype)