simons blog

Making RMSNorm really fast

RMS Norm is a common operation used in modern LLMs. Given a vector v it's RMS Norm is calculated as vi=viRMS(v)wi where wi is a weight and RMS(v)=ϵ+1Ni=1,...,Nvi2 . In this blogpost we want to calculate the RMS Norm of each row in the matrix V=[v1,...,vnumToken] where vi=[x1,...,xhiddenDim] with a given weight w=[w1,...,whiddenDim].

Sequential solution

It's important to check the correctness of our kernels with a basic sequential implementation. See below for the simple version we use.

template <int numTokens, int hiddenDim>
void launchRmsNormCpu(float *x, float *w, float eps, float *y) {
  float rms;
  for (int token = 0; token < numTokens; token++) {
    rms = 0;
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      rms += x[token * hiddenDim + hidden] * x[token * hiddenDim + hidden];
    }
    rms = sqrt(rms / hiddenDim + eps);
    for (int hidden = 0; hidden < hiddenDim; hidden++) {
      y[token * hiddenDim + hidden] =
          x[token * hiddenDim + hidden] / rms * w[hidden];
    }
  }
}

How to parallelise?

Our attempt to parallelisation is pretty simple. Each block will handle one token. If the threads in the block are less than the number of hidden dimension each thread will need to handle multiple elements. We than perform a simple reduction, calculate the RMS Norm and write the output. Please see my previous blogpost on reduction if you are not familiar.

Naive kernel

A naive solution in CUDA is as follows.

template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelNaive(float *x, float *w, float eps, float *y) {
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float rms_;

  const int tid = threadIdx.x;
  const int bid = blockIdx.x;
  float sum = 0.0f;

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

  for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

  if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = x[bid * hiddenDim + i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
void launchRmsNormNaive(float *x, float *w, float eps, float *y) {
  rmsNormKernelNaive<hiddenDim, threadsPerBlock>
      <<<numTokens, threadsPerBlock>>>(x, w, eps, y);
}

x crosses memory one time, w crosses memory one time, y crosses memory one time. For numTokens = 1 << 18 and hiddenDim = 1 << 12 w is negligible and we can calculate the bandwidth as

const size_t size = numTokens * hiddenDim * sizeof(float);
size_t numCrossMemoryBound = 2 * size;
float latency = time / numRounds;
float bandwidth = (numCrossMemoryBound / latency) / 1e6;

The result for the above kernel is

Latency = 2.84878 ms
Bandwidth = 3015.3 GB/s
% of max = 91.3727 %

Using shared memory

As we can see above we access the elements in x frequently. We can use shared memory to make memory accesses quicker.

template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelSmem(float *x, float *w, float eps, float *y) {
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float rms_;

  const int tid = threadIdx.x;
  const int bid = blockIdx.x;

  float sum = 0.0f;

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    int index = bid * hiddenDim + i;
    float x_ = x[index];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

  for (int activeThreads = threadsPerBlock / 2; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

  if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float val = xShared[i] * rms_ * w[i];
    y[bid * hiddenDim + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
void launchRmsNormSmem(float *x, float *w, float eps, float *y) {
  rmsNormKernelSmem<hiddenDim, threadsPerBlock>
      <<<numTokens, threadsPerBlock>>>(x, w, eps, y);
}

This gives a speedup of

Latency = 2.82101 ms
Bandwidth = 3044.99 GB/s
% of max = 92.2723 %

Using warps

Similar to the technique we applied in prefix sum we can also do the following:

  1. Reduce within each warp
  2. With one warp reduce this array to get the final reduction The code for this process looks as follows:
#define WARP_SIZE 32

__device__ float warpReduce(float x) {
  float val = x;
  for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
  return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelWarp(float *x, float *w, float eps, float *y) {
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float xShared[hiddenDim];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

  const int tid = threadIdx.x;
  const int laneId = tid & 31;
  const int warpId = tid >> 5;
  const int warpsPerBlock = threadsPerBlock >> 5;

  const int bid = blockIdx.x;
  float sum = 0.0f;

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    float x_ = x[bid * hiddenDim + i];
    xShared[i] = x_;
    sum += x_ * x_;
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

  float warpSum = warpReduce(squaredPerThread[tid]);
  if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

  if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

  for (int i = tid; i < hiddenDim; i += threadsPerBlock) {
    y[bid * hiddenDim + i] = xShared[i] * rms_ * w[i];
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
void launchRmsNormWarp(float *x, float *w, float eps, float *y) {
  rmsNormKernelWarp<hiddenDim, threadsPerBlock>
      <<<numTokens, threadsPerBlock>>>(x, w, eps, y);
}

The result for this kernel are as follows:

Latency = 2.82263 ms
Bandwidth = 3043.23 GB/s
% of max = 92.2192 %

Initially I expected this to be faster, but it turns out this is not the case.

Vectorise Load and Store

If we profile our kernels above we can see that the memory load and store costs us the most instruction. We can optimize that by vectorizing load and store using CUDAs float4 datatype.

For the smem approach this looks as follows:

template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelSmemFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y) {
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float rms_;

  const int tid = threadIdx.x;
  const int bid = blockIdx.x;

  float sum = 0.0f;

  for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

  for (int activeThreads = threadsPerBlock >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      squaredPerThread[tid] += squaredPerThread[tid + activeThreads];
    }
    __syncthreads();
  }

  if (tid == 0) {
    rms_ = rsqrtf(squaredPerThread[tid] / hiddenDim + eps);
  }
  __syncthreads();

  for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
void launchRmsNormSmemFloat4(float *x, float *w, float eps, float *y) {
  float4 *x_ = reinterpret_cast<float4 *>(x);
  float4 *w_ = reinterpret_cast<float4 *>(w);
  float4 *y_ = reinterpret_cast<float4 *>(y);
  rmsNormKernelSmemFloat4<hiddenDim, threadsPerBlock>
      <<<numTokens, threadsPerBlock>>>(x_, w_, eps, y_);
}

This archieves the following bandwidth:

Latency = 2.80455 ms
Bandwidth = 3062.86 GB/s
% of max = 92.8139 %

Similar we can do it for the warp kernel

#define WARP_SIZE 32

__device__ float warpReduce(float x) {
  float val = x;
  for (int activeThreads = WARP_SIZE >> 1; activeThreads > 0;
       activeThreads >>= 1) {
    val += __shfl_down_sync(0xffffffff, val, activeThreads);
  }
  return val;
}

template <int hiddenDim, int threadsPerBlock>
__global__ void rmsNormKernelWarpFloat4(float4 *x, float4 *w, float eps,
                                        float4 *y) {
  __shared__ float squaredPerThread[threadsPerBlock];
  __shared__ float4 xShared[hiddenDim >> 2];
  __shared__ float sumPerWarp[WARP_SIZE];
  __shared__ float rms_;

  const int tid = threadIdx.x;
  const int laneId = tid & 31;
  const int warpId = tid >> 5;
  const int warpsPerBlock = threadsPerBlock >> 5;

  const int bid = blockIdx.x;
  float sum = 0.0f;

  for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    int index = bid * (hiddenDim >> 2) + i;
    float4 x_ = x[index];
    xShared[i] = x_;
    sum += (x_.x * x_.x) + (x_.y * x_.y) + (x_.z * x_.z) + (x_.w * x_.w);
  }
  squaredPerThread[tid] = sum;
  __syncthreads();

  float warpSum = warpReduce(squaredPerThread[tid]);
  if (laneId == 0) {
    sumPerWarp[warpId] = warpSum;
  }
  __syncthreads();

  if (tid < WARP_SIZE) {
    sumPerWarp[tid] = warpReduce(tid < warpsPerBlock ? sumPerWarp[tid] : 0);
    if (tid == 0) {
      rms_ = rsqrtf(sumPerWarp[tid] / hiddenDim + eps);
    }
  }
  __syncthreads();

  for (int i = tid; i < hiddenDim >> 2; i += threadsPerBlock) {
    float4 w_ = w[i];
    float4 x_ = xShared[i];
    float4 val = make_float4(x_.x * rms_ * w_.x, x_.y * rms_ * w_.y,
                             x_.z * rms_ * w_.z, x_.w * rms_ * w_.w);
    y[bid * (hiddenDim >> 2) + i] = val;
  }
}

template <int numTokens, int hiddenDim, int threadsPerBlock>
void launchRmsNormWarpFloat4(float *x, float *w, float eps, float *y) {
  float4 *x_ = reinterpret_cast<float4 *>(x);
  float4 *w_ = reinterpret_cast<float4 *>(w);
  float4 *y_ = reinterpret_cast<float4 *>(y);

  rmsNormKernelWarpFloat4<hiddenDim, threadsPerBlock>
      <<<numTokens, threadsPerBlock>>>(x_, w_, eps, y_);
}

This archives the following performance:

Latency = 2.80475 ms
Bandwidth = 3062.63 GB/s
% of max = 92.8071 %

Conclusion

We saw that if we now how Reduction works it is not difficult to implement highly performant kernels for RMSNorm operation If you see further optimisation opportunities I would be happy to hear about them. One thing that suprised me was that using #pragma unroll didn't have a positive effect on the performance. If you liked this blogpost I would be happy to connect on Linkedin and exchange ideas about CUDA or other MLSys topics. All the code to reproduce the results above can be found on my Github.