simons blog

Simple reduction in CuTeDSL

Not long ago the Quack library showed that CuTeDSL can not only be used for GEMM kernels but to implement highly efficient memory bound kernels. In this blogpost I want to show a very simple RMSNorm kernel to showcase how to perform reduction in CuTeDSL. I take RMSNorm as an example.

RMSNorm

RMSNorm performs following task:

Given a matrix XNt×Nh where Nt is number_of_tokens and Nh the hidden dimension we want to perform row wise normalization. We do that by taking the row wise mean of the squared of each elements and use that to normalize each element in the row by that factor. We'll further multiply each entry with a weighting factor Nh.

In formula:

yi=xiRMS(x)*wi,whereRMS(x)=ϵ+1ni=1nxi2

Clearly this is a memory bound kernel that involves reduction, one of the fundamental algorithms in GPU programming. Below we will show how to implement a baseline in the CuTeDSL.

The kernel

Before implementing the kernel we should calculate the RMSNorm in Pytorch aswell to check our implementation for correctness:

    x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float32)    
    w = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)
    y = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32)

    rms_norm = nn.RMSNorm(hidden_dim, eps=eps)
    rms_norm.weight.data = w 
    y_ref = rms_norm(x)

The CuTe kernel can than be invoked as follows:

    mX = from_dlpack(x, assumed_align=16)
    mW = from_dlpack(w, assumed_align=16)
    mY = from_dlpack(y, assumed_align=16)

    rms_norm_(mX, mW, mY, num_tokens, hidden_dim, eps)

and we can check for correctness as follows:

torch.testing.assert_close(y_ref, y)

We invoke the kernel as follows:

@cute.jit
def rms_norm_(mX: cute.Tensor, mW: cute.Tensor, mY: cute.Tensor, 
              num_tokens: cutlass.Constexpr, hidden_dim: cutlass.Constexpr, epsilon: cutlass.Constexpr
):
  threads_per_block = 256
  rms_norm_kernel(mX, mW, mY, threads_per_block, num_tokens, hidden_dim, epsilon).launch(
    grid=(num_tokens,1,1),
    block=(threads_per_block,1,1)
  )

Note that here I chose a value for threads_in_block that was obtained by experimenting a little bit. We launch one block for each token we have. Below you will understand on why we do that and how we can cover the full hidden dimension even when it gets large.

@cute.kernel
def rms_norm_kernel(mX: cute.Tensor, mW: cute.Tensor, mY: cute.tensor, threads_per_block: cutlass.Constexpr, 
                    num_tokens: cutlass.Constexpr, hidden_dim: cutlass.Constexpr, epsilon: cutlass.Constexpr
):
  allocator = cutlass.utils.SmemAllocator()
  layout = cute.make_layout((threads_per_block))
  scalar_layout = cute.make_layout((1))

  sdata = allocator.allocate_tensor(cutlass.Float32, layout=layout, byte_alignment=16, swizzle=None)
  squared_reduce = allocator.allocate_tensor(cutlass.Float32, layout=scalar_layout)

The first step in the kernel is to allocate tensors in shared memory dynamically.

  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
   
  block_sum = 0.0
  for i in range(tidx, hidden_dim, threads_per_block, unroll_full=True):
    x_ = mX[(bidx, i)]
    block_sum += x_ * x_ 
  
  sdata[tidx] = block_sum
  cute.arch.sync_threads()

Here we see how we cover the whole hidden_dim. Each thread handles multiple values and we read the values out in a grid stride fashion. For example if the hidden_dim = 1024 thread 0 will read out the entries at 0, 256, 512, 768. We'll than accumulate the squared entries into block_sum and save that in the sdata tensor which was initialized above to contain exactly threads_per_block entries. After that we will always have a tensor with 256 entries where each entry contains the above calculated sum. For example in the setting I described above sdata[0] would contain x02+x2562+x5122+x7682.

Obviously we now need to reduce sdata to obtain the full expression we want to use to normalize. We'll use usual algorithm to perform parallel reduction:

In code:

  if tidx < 128:
    sdata[tidx] += sdata[tidx + 128]
  cute.arch.sync_threads()
  
  if tidx < 64:
    sdata[tidx] += sdata[tidx + 64]
  cute.arch.sync_threads()
  
  if tidx < 32:
    sdata[tidx] += sdata[tidx + 32]
    res = cute.arch.warp_reduction_sum(sdata[tidx], threads_in_group=32)
    if tidx == 0:
      squared_reduce[0] = cute.math.rsqrt(res/hidden_dim + epsilon, fastmath=True)

  cute.arch.sync_threads()

Note that here we first have 128 active threads, than 64 and than only one warp which contains 32 threads. Note that within a warp we don't need any thread synchronization because all threads in a warp a scheduled at the same time. The warp_reduction_sum can be used to efficiently reduce the final 32 values via warp shuffle instructions. You can read the code here to understand it better. After we are done with the reduction we save the reciprocate square root via efficient PTX wrapper into our "scalar" shared memory tensor.

  rms = squared_reduce[0]
  for i in range(tidx, hidden_dim, threads_per_block, unroll_full=True):
    mY[(bidx, i)] = mX[(bidx, i)] * rms * mW[i] 

The final step is to write away the values via the same approach as before.

Code

The full code is really small so I share it here instead of creating a repo for it. You can see that it includes a small section that benchmarks performance using cutlass testing utility.

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from math import ceil
import torch
import torch.nn as nn

def benchmark(callable, mX, mW, mY,  
              num_tokens, hidden_dim, epsilon):
    avg_time_us = cute.testing.benchmark(
        callable,
        kernel_arguments=cute.testing.JitArguments(mX, mW, mY),
        warmup_iterations=100,
        iterations=1000,
    )

    # Calculate metrics
    # ----------------
    dtype = mX.element_type

    # Calculate total bytes transferred:
    # - 1 Read and 1 Write
    # - Each element is dtype.width bits
    bytes_per_element = dtype.width // 8
    total_bytes = hidden_dim * num_tokens * 2 * bytes_per_element

    # Calculate achieved bandwidth
    achieved_bandwidth = total_bytes / (avg_time_us * 1000)  # GB/s

    # Print results
    # ------------
    print(f"Performance Metrics:")
    print(f"-------------------")
    print(f"Kernel execution time: {avg_time_us:.4f} us")
    print(f"Memory throughput: {achieved_bandwidth:.2f} GB/s")

@cute.kernel
def rms_norm_kernel(mX: cute.Tensor, mW: cute.Tensor, mY: cute.tensor, threads_per_block: cutlass.Constexpr, 
                    num_tokens: cutlass.Constexpr, hidden_dim: cutlass.Constexpr, epsilon: cutlass.Constexpr
):
  allocator = cutlass.utils.SmemAllocator()
  layout = cute.make_layout((threads_per_block))
  scalar_layout = cute.make_layout((1))

  sdata = allocator.allocate_tensor(cutlass.Float32, layout=layout, byte_alignment=16, swizzle=None)
  squared_reduce = allocator.allocate_tensor(cutlass.Float32, layout=scalar_layout)

  tidx, _, _ = cute.arch.thread_idx()
  bidx, _, _ = cute.arch.block_idx()
   
  block_sum = 0.0
  for i in range(tidx, hidden_dim, threads_per_block, unroll_full=True):
    x_ = mX[(bidx, i)]
    block_sum += x_ * x_ 
  
  sdata[tidx] = block_sum
  cute.arch.sync_threads()

  if tidx < 128:
    sdata[tidx] += sdata[tidx + 128]
  cute.arch.sync_threads()
  
  if tidx < 64:
    sdata[tidx] += sdata[tidx + 64]
  cute.arch.sync_threads()
  
  if tidx < 32:
    sdata[tidx] += sdata[tidx + 32]
    res = cute.arch.warp_reduction_sum(sdata[tidx], threads_in_group=32)
    if tidx == 0:
      squared_reduce[0] = cute.math.rsqrt(res/hidden_dim + epsilon, fastmath=True)

  cute.arch.sync_threads()
  
  rms = squared_reduce[0]
  for i in range(tidx, hidden_dim, threads_per_block, unroll_full=True):
    mY[(bidx, i)] = mX[(bidx, i)] * rms * mW[i] 

@cute.jit
def rms_norm_(mX: cute.Tensor, mW: cute.Tensor, mY: cute.Tensor, 
              num_tokens: cutlass.Constexpr, hidden_dim: cutlass.Constexpr, epsilon: cutlass.Constexpr
):
  threads_per_block = 256
  rms_norm_kernel(mX, mW, mY, threads_per_block, num_tokens, hidden_dim, epsilon).launch(
    grid=(num_tokens,1,1),
    block=(threads_per_block,1,1)
  )

if __name__ == "__main__":
    num_tokens = 65536
    hidden_dim = 1024 
    eps = 1e-5

    x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float32)    
    w = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)
    y = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32)

    rms_norm = nn.RMSNorm(hidden_dim, eps=eps)
    rms_norm.weight.data = w 
    y_ref = rms_norm(x)

    mX = from_dlpack(x, assumed_align=16)
    mW = from_dlpack(w, assumed_align=16)
    mY = from_dlpack(y, assumed_align=16)

    rms_norm_(mX, mW, mY, num_tokens, hidden_dim, eps)
    torch.testing.assert_close(y_ref, y)

    compiled = cute.compile(rms_norm_, mX, mW, mY, num_tokens, hidden_dim, eps)
    benchmark(compiled, mX, mW, mY, num_tokens, hidden_dim, eps)

Note that I didn't perform full set of "tricks" to obtain maximum performance as this serves more of an educational example. For example we might additionally add vectorized loads or obtain other tricks that I implemented in CUDA. For best performance you should anyway use the Quack library.

However even the above fairly optimized archives pretty good performance for some test cases on my consumer GPU. For example for hidden_dim = 1024 it archives 87% of GPU utilization. To archive peak performance you can take a look at further optimisations the Quack library employed.

Conclusion

I hope this blogpost showed that low level control CuTeDSL offers gives us the ability to also implement tasks like reduction or potentially scan etc. in it. If you want to exchange further ideas feel free to connect with me on LinkedIn.