simons blog

NVFP4 GEMV improved

Introduction

A few days ago I explained how the reference kernel for GEMV competition works. In this blogpost I will show various ways to parallelize the reduction over the K-Mode.

Recap

Let's make a quick recap of the basic reference kernel for GEMV

# The CuTe reference implementation for NVFP4 block-scaled GEMV
@cute.kernel
def kernel(
    mA_mkl: cute.Tensor,
    mB_nkl: cute.Tensor,
    mSFA_mkl: cute.Tensor,
    mSFB_nkl: cute.Tensor,
    mC_mnl: cute.Tensor,
):
    # Get CUDA block and thread indices
    bidx, bidy, bidz = cute.arch.block_idx()
    tidx, _, _ = cute.arch.thread_idx()

    # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L])
    gA_mkl = cute.local_tile(
        mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for A (same shape as gA_mkl)
    # Here, block_M = (32, 4); block_K = (16, 4)
    gSFA_mkl = cute.local_tile(
        mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L])
    gB_nkl = cute.local_tile(
        mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for B (same shape as gB_nkl)
    gSFB_nkl = cute.local_tile(
        mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L])
    gC_mnl = cute.local_tile(
        mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None)
    )

    # Select output element corresponding to this thread and block indices
    tCgC = gC_mnl[tidx, None, bidx, bidy, bidz]
    tCgC = cute.make_tensor(tCgC.iterator, 1)
    res = cute.zeros_like(tCgC, cutlass.Float32)

    # Get the number of k tiles (depth dimension) for the reduction loop
    k_tile_cnt = gA_mkl.layout[3].shape
    for k_tile in range(k_tile_cnt):
        tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgB = gB_nkl[0, None, bidy, k_tile, bidz]
        tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz]

        tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float32)
        tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float32)
        tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32)
        tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32)

        # Load NVFP4 or FP8 values from global memory
        a_val_nvfp4 = tAgA.load()
        b_val_nvfp4 = tBgB.load()
        sfa_val_fp8 = tAgSFA.load()
        sfb_val_fp8 = tBgSFB.load()

        # Convert loaded values to float32 for computation (FFMA)
        a_val = a_val_nvfp4.to(cutlass.Float32)
        b_val = b_val_nvfp4.to(cutlass.Float32)
        sfa_val = sfa_val_fp8.to(cutlass.Float32)
        sfb_val = sfb_val_fp8.to(cutlass.Float32)

        # Store the converted values to RMEM CuTe tensors
        tArA.store(a_val)
        tBrB.store(b_val)
        tArSFA.store(sfa_val)
        tBrSFB.store(sfb_val)

        # Iterate over SF vector tiles and compute the scale&matmul accumulation
        for i in cutlass.range_constexpr(mma_tiler_mnk[2]):
            res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i]

    # Store the final float16 result back to global memory
    tCgC.store(res.to(cutlass.Float16))
    return

Here we launched before a grid with as many threads as rows in a tile. Than each thread calculates one dot product. After thinking about it we can find a simple way to optimize performance by increasing parallelism. We can do that by parallelize along the K-mode.

Let us look at a few ways how this can be done.

When I run the benchmark from GPU Mode Repo on this it gives following performance:

benchmark-count: 3
benchmark.0.spec: m: 7168; k: 16384; l: 1; seed: 1111
benchmark.0.runs: 3
benchmark.0.mean: 234495.99742889404
benchmark.0.std: 0.0
benchmark.0.err: 0.0
benchmark.0.best: 234495.99742889404
benchmark.0.worst: 234495.99742889404
benchmark.1.spec: m: 4096; k: 7168; l: 8; seed: 1111
benchmark.1.runs: 31
benchmark.1.mean: 119713.03532200475
benchmark.1.std: 657.4408274935079
benchmark.1.err: 118.0798583867043
benchmark.1.best: 117824.0031003952
benchmark.1.worst: 121855.99654912949
benchmark.2.spec: m: 7168; k: 2048; l: 4; seed: 1111
benchmark.2.runs: 3
benchmark.2.mean: 38911.99827194214
benchmark.2.std: 0.0
benchmark.2.err: 0.0
benchmark.2.best: 38911.99827194214
benchmark.2.worst: 38911.99827194214
check: pass

Parallelize via extra blocks

We launch a grid where each block in the y direction corresponds to one K tile.

    grid = (
        cute.ceil_div(c_tensor.shape[0], mma_tiler_mnk[0]),
        cute.ceil_div(a_tensor.shape[1], mma_tiler_mnk[2]),
        c_tensor.shape[2],
    )

The corresponding kernel than looks as follows:

# The CuTe reference implementation for NVFP4 block-scaled GEMV
@cute.kernel
def kernel(
    mA_mkl: cute.Tensor,
    mB_nkl: cute.Tensor,
    mSFA_mkl: cute.Tensor,
    mSFB_nkl: cute.Tensor,
    mC_mnl: cute.Tensor,  # Now float32 accumulation buffer
):
    # Get CUDA block and thread indices
    bidx, bidy, bidz = cute.arch.block_idx()
    tidx, _, _ = cute.arch.thread_idx()

    # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L])
    gA_mkl = cute.local_tile(
        mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for A (same shape as gA_mkl)
    # Here, block_M = (32, 4); block_K = (16, 4)
    gSFA_mkl = cute.local_tile(
        mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L])
    gB_nkl = cute.local_tile(
        mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for B (same shape as gB_nkl)
    gSFB_nkl = cute.local_tile(
        mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L])
    gC_mnl = cute.local_tile(
        mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None)
    )

    # Select output element corresponding to this thread and block indices
    tCgC = gC_mnl[tidx, None, bidx, 0, bidz]
    tCgC = cute.make_tensor(tCgC.iterator, 1)
    res = cute.zeros_like(tCgC, cutlass.Float32)

    tAgA = gA_mkl[tidx, None, bidx, bidy, bidz]
    tBgB = gB_nkl[0, None, 0, bidy, bidz]
    tAgSFA = gSFA_mkl[tidx, None, bidx, bidy, bidz]
    tBgSFB = gSFB_nkl[0, None, 0, bidy, bidz]

    tArA = cute.make_rmem_tensor_like(tAgA, cutlass.Float16)
    tBrB = cute.make_rmem_tensor_like(tBgB, cutlass.Float16)
    tArSFA = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32)
    tBrSFB = cute.make_rmem_tensor_like(tBgSFB, cutlass.Float32)
    
    tABrAB = cute.make_rmem_tensor_like(tAgA, cutlass.Float16)
    tSFrSF = cute.make_rmem_tensor_like(tAgSFA, cutlass.Float32) 

    # Load NVFP4 or FP8 values from global memory
    a_val_nvfp4 = tAgA.load()
    b_val_nvfp4 = tBgB.load()
    sfa_val_fp8 = tAgSFA.load()
    sfb_val_fp8 = tBgSFB.load()

    # Convert loaded values to float32 for computation (FFMA)
    a_val = a_val_nvfp4.to(cutlass.Float16)
    b_val = b_val_nvfp4.to(cutlass.Float16)
    sfa_val = sfa_val_fp8.to(cutlass.Float32)
    sfb_val = sfb_val_fp8.to(cutlass.Float32)

    # Store the converted values to RMEM CuTe tensors
    tArA.store(a_val)
    tBrB.store(b_val)
    tArSFA.store(sfa_val)
    tBrSFB.store(sfb_val)

    tABrAB.store(tArA.load() * tBrB.load()) 
    tSFrSF.store(tArSFA.load() * tBrSFB.load()) 
    # Iterate over SF vector tiles and compute the scale&matmul accumulation
    for i in cutlass.range_constexpr(mma_tiler_mnk[2]):
        res += tABrAB[i] * tSFrSF[i]

    # Atomic add to float32 buffer
    atomic_add_fp32(res[0], tCgC.iterator) 
    return

We see that we don't need to loop any longer over the tiles. Only the loop within each tile will stay. Note that at the end we perform atomic add to the correct output. However atomic add intrinsic

from cutlass import Float32
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm

@dsl_user_op
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
    nvvm.atomicrmw(
        res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
    )

is only available in F32 precision (this wrapper is similar to one existing in CuTeDSL for Int32 and can also be found in flash attention repo). We therefore will need to create an additional tensor in F32 when launching the kernel and than copy over there (in Pytorch) to the F16 output tensor. Running benchmark on this will give

benchmark-count: 3
benchmark.0.spec: m: 7168; k: 16384; l: 1; seed: 1111
benchmark.0.runs: 3
benchmark.0.mean: 36864.00130391121
benchmark.0.std: 0.0
benchmark.0.err: 0.0
benchmark.0.best: 36864.00130391121
benchmark.0.worst: 36864.00130391121
benchmark.1.spec: m: 4096; k: 7168; l: 8; seed: 1111
benchmark.1.runs: 97
benchmark.1.mean: 55399.91764217308
benchmark.1.std: 541.918369861518
benchmark.1.err: 55.023473865435456
benchmark.1.best: 55231.99960589409
benchmark.1.worst: 59392.00147986412
benchmark.2.spec: m: 7168; k: 2048; l: 4; seed: 1111
benchmark.2.runs: 3
benchmark.2.mean: 24576.00086927414
benchmark.2.std: 0.0
benchmark.2.err: 0.0
benchmark.2.best: 24576.00086927414
benchmark.2.worst: 24576.00086927414
check: pass

We see that this is much better than the reference kernel.

Use different Threads and atomic add

In the first parallelization strategy we used parallelization via thread blocks. Now we want to do following: Launch some extra threads and use them to "work" on one row together.

We launch a grid like so:

    grid = (
        cute.ceil_div(c_tensor.shape[0], threads_per_m),
        1,
        c_tensor.shape[2],
    )

    # Launch the CUDA kernel
    kernel(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor).launch(
        grid=grid,
        block=[threads_per_m, threads_per_k, 1],
        cluster=(1, 1, 1),
    )

where we defined above:

threads_per_m = 32  # Number of threads per CUDA thread block
threads_per_k  = 32
mma_tiler_mnk = (threads_per_m, 1, 256)  # Tile sizes for M, N, K dimensions
# The CuTe reference implementation for NVFP4 block-scaled GEMV
@cute.kernel
def kernel(
    mA_mkl: cute.Tensor,
    mB_nkl: cute.Tensor,
    mSFA_mkl: cute.Tensor,
    mSFB_nkl: cute.Tensor,
    mC_mnl: cute.Tensor,
):
    # Get CUDA block and thread indices
    bidx, bidy, bidz = cute.arch.block_idx()
    tidx, tidy, _ = cute.arch.thread_idx()

    # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L])
    gA_mkl = cute.local_tile(
        mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for A (same shape as gA_mkl)
    # Here, block_M = (32, 4); block_K = (16, 4)
    gSFA_mkl = cute.local_tile(
        mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L])
    gB_nkl = cute.local_tile(
        mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for B (same shape as gB_nkl)
    gSFB_nkl = cute.local_tile(
        mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L])
    gC_mnl = cute.local_tile(
        mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None)
    )

    # Select output element corresponding to this thread and block indices
    tCgC = gC_mnl[tidx, None, bidx, bidy, bidz]
    tCgC = cute.make_tensor(tCgC.iterator, 1)
    res = cute.zeros_like(tCgC, accum_dtype)

    # Shared Memory
    allocator = cutlass.utils.SmemAllocator()
    smem_layout = cute.make_layout(mma_tiler_mnk[0])
    shared_res = allocator.allocate_tensor(element_type=cutlass.Float32, layout=smem_layout)

    if tidy == 0:
        shared_res[tidx] = 0.0
    cute.arch.sync_threads()
    # Get the number of k tiles (depth dimension) for the reduction loop
    k_tile_cnt = gA_mkl.layout[3].shape
    for k_tile in range(tidy, k_tile_cnt, threads_per_k, unroll_full=True):
        tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgB = gB_nkl[0, None, bidy, k_tile, bidz]
        tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz]

        tArA = cute.make_rmem_tensor_like(tAgA, c_dtype)
        tBrB = cute.make_rmem_tensor_like(tBgB, c_dtype)
        tABrAB = cute.make_rmem_tensor_like(tAgA, c_dtype)
        tArSFA = cute.make_rmem_tensor_like(tAgSFA, accum_dtype)
        tBrSFB = cute.make_rmem_tensor_like(tBgSFB, accum_dtype)
        tSFrSF = cute.make_rmem_tensor_like(tAgSFA, accum_dtype)

        # Load NVFP4 or FP8 values from global memory
        a_val_nvfp4 = tAgA.load()
        b_val_nvfp4 = tBgB.load()
        sfa_val_fp8 = tAgSFA.load()
        sfb_val_fp8 = tBgSFB.load()

        # Convert loaded values to float32 for computation (FFMA)
        a_val = a_val_nvfp4.to(c_dtype)
        b_val = b_val_nvfp4.to(c_dtype)
        sfa_val = sfa_val_fp8.to(sf_dtype)
        sfb_val = sfb_val_fp8.to(sf_dtype)

        # Store the converted values to RMEM CuTe tensors
        tArA.store(a_val)
        tBrB.store(b_val)
        tArSFA.store(sfa_val)
        tBrSFB.store(sfb_val)

        tABrAB.store(tArA.load() * tBrB.load())
        tSFrSF.store(tArSFA.load() * tBrSFB.load())

        # Iterate over SF vector tiles and compute the scale&matmul accumulation
        for i in cutlass.range_constexpr(mma_tiler_mnk[2]):
            res += tArA[i] * tArSFA[i] * tBrB[i] * tBrSFB[i]
    
    atomic_add_fp32(res[0], elem_pointer(shared_res, tidx))
    cute.arch.sync_threads()
    if tidy == 0:
        out = scalar_to_ssa(shared_res[tidx], cutlass.Float32)
        # Store the final float16 result back to global memory
        tCgC.store(out.to(cutlass.Float16))
    return

Note that here we again collectively calculate one result instead of letting it do a single thread (we use the additional threads in y direction to do that). We than accumulate the result for one row in the tile (which is indexed via tidx). Here we don't have to do the additional allocation for F32 because we do the atomic add onto the shared memory and we convert it from there to a fragment. The helper function scalar_to_ssa and elem_pointer can be found in Flash Attention Repo aswell.

The performance is:

benchmark-count: 3
benchmark.0.spec: m: 7168; k: 16384; l: 1; seed: 1111
benchmark.0.runs: 3
benchmark.0.mean: 38911.99827194214
benchmark.0.std: 0.0
benchmark.0.err: 0.0
benchmark.0.best: 38911.99827194214
benchmark.0.worst: 38911.99827194214
benchmark.1.spec: m: 4096; k: 7168; l: 8; seed: 1111
benchmark.1.runs: 200
benchmark.1.mean: 67258.71954113245
benchmark.1.std: 43544.87626238569
benchmark.1.err: 3079.0877291062043
benchmark.1.best: 63423.9986538887
benchmark.1.worst: 679840.0282859802
benchmark.2.spec: m: 7168; k: 2048; l: 4; seed: 1111
benchmark.2.runs: 3
benchmark.2.mean: 26602.66620417436
benchmark.2.std: 36.95069858684305
benchmark.2.err: 21.333495775858562
benchmark.2.best: 26559.999212622643
benchmark.2.worst: 26623.99969995022
check: pass

We see that performance here is worse than for the first "extra kernel" but better than reference kernel. I suspect we can tune hyperparameters further to archive better performance but I didn't find the time to do that (yet). Naively thinking I expected this approach to be better because we perform atomic add onto the SMEM and not GMEM. Furthermore we don't need to allocate a new F32 tensor in torch everytime to launch the kernel and copy over.

Use different threads and no atomic add

We can write a version of the second kernel which doesn't use atomic add. We do that by initializing a 2D Tensor in SMEM. This gives us the ability to save the result for each pair of (tidx, tidy). We can than perform reduction on the second mode before writing away the result. Note that the tensor in SMEM should be K-Major in order to have better memory access pattern in reduction step. The kernel looks like so:

# The CuTe reference implementation for NVFP4 block-scaled GEMV
@cute.kernel
def kernel(
    mA_mkl: cute.Tensor,
    mB_nkl: cute.Tensor,
    mSFA_mkl: cute.Tensor,
    mSFB_nkl: cute.Tensor,
    mC_mnl: cute.Tensor,
):
    # Get CUDA block and thread indices
    bidx, bidy, bidz = cute.arch.block_idx()
    tidx, tidy, _ = cute.arch.thread_idx()

    # Extract the local tile for input matrix A (shape: [block_M, block_K, rest_M, rest_K, rest_L])
    gA_mkl = cute.local_tile(
        mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for A (same shape as gA_mkl)
    # Here, block_M = (32, 4); block_K = (16, 4)
    gSFA_mkl = cute.local_tile(
        mSFA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None, None)
    )
    # Extract the local tile for input matrix B (shape: [block_N, block_K, rest_N, rest_K, rest_L])
    gB_nkl = cute.local_tile(
        mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for scale factor tensor for B (same shape as gB_nkl)
    gSFB_nkl = cute.local_tile(
        mSFB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None, None)
    )
    # Extract the local tile for output matrix C (shape: [block_M, block_N, rest_M, rest_N, rest_L])
    gC_mnl = cute.local_tile(
        mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None, None)
    )

    # Select output element corresponding to this thread and block indices
    tCgC = gC_mnl[tidx, None, bidx, bidy, bidz]
    tCgC = cute.make_tensor(tCgC.iterator, 1)
    res = cute.zeros_like(tCgC, accum_dtype)

    # Shared Memory
    allocator = cutlass.utils.SmemAllocator()
    smem_layout = cute.make_layout((threads_per_m, threads_per_k), stride = (threads_per_k, 1))
    shared_res = allocator.allocate_tensor(element_type=cutlass.Float32, layout=smem_layout)

    # Get the number of k tiles (depth dimension) for the reduction loop
    k_tile_cnt = gA_mkl.layout[3].shape
    for k_tile in range(tidy, k_tile_cnt, threads_per_k, unroll_full=True):
        tAgA = gA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgB = gB_nkl[0, None, bidy, k_tile, bidz]
        tAgSFA = gSFA_mkl[tidx, None, bidx, k_tile, bidz]
        tBgSFB = gSFB_nkl[0, None, bidy, k_tile, bidz]

        tArA = cute.make_rmem_tensor_like(tAgA, c_dtype)
        tBrB = cute.make_rmem_tensor_like(tBgB, c_dtype)
        tABrAB = cute.make_rmem_tensor_like(tAgA, c_dtype)
        tArSFA = cute.make_rmem_tensor_like(tAgSFA, accum_dtype)
        tBrSFB = cute.make_rmem_tensor_like(tBgSFB, accum_dtype)
        tSFrSF = cute.make_rmem_tensor_like(tAgSFA, accum_dtype)

        # Load NVFP4 or FP8 values from global memory
        a_val_nvfp4 = tAgA.load()
        b_val_nvfp4 = tBgB.load()
        sfa_val_fp8 = tAgSFA.load()
        sfb_val_fp8 = tBgSFB.load()

        # Convert loaded values to float32 for computation (FFMA)
        a_val = a_val_nvfp4.to(c_dtype)
        b_val = b_val_nvfp4.to(c_dtype)
        sfa_val = sfa_val_fp8.to(sf_dtype)
        sfb_val = sfb_val_fp8.to(sf_dtype)

        # Store the converted values to RMEM CuTe tensors
        tArA.store(a_val)
        tBrB.store(b_val)
        tArSFA.store(sfa_val)
        tBrSFB.store(sfb_val)

        tABrAB.store(tArA.load() * tBrB.load())
        tSFrSF.store(tArSFA.load() * tBrSFB.load())

        # Iterate over SF vector tiles and compute the scale&matmul accumulation
        for i in cutlass.range_constexpr(mma_tiler_mnk[2]):
            res += tABrAB[i] * tSFrSF[i]
   
    shared_res[(tidx, tidy)] = res[0]
    cute.arch.sync_threads()
    
    if tidy == 0:
        out = cute.zeros_like(tCgC, accum_dtype)
        for i in cutlass.range_constexpr(threads_per_k):
            out += shared_res[(tidx, i)]

        # Store the final float16 result back to global memory
        tCgC.store(out.to(cutlass.Float16))
    return

The performance is as follows when launched with same parameters like the atomic version above:

benchmark-count: 3
benchmark.0.spec: m: 7168; k: 16384; l: 1; seed: 1111
benchmark.0.runs: 3
benchmark.0.mean: 38911.99827194214
benchmark.0.std: 0.0
benchmark.0.err: 0.0
benchmark.0.best: 38911.99827194214
benchmark.0.worst: 38911.99827194214
benchmark.1.spec: m: 4096; k: 7168; l: 8; seed: 1111
benchmark.1.runs: 32
benchmark.1.mean: 65599.9998562038
benchmark.1.std: 362.1300239794551
benchmark.1.err: 64.01614890677993
benchmark.1.best: 65503.99959087372
benchmark.1.worst: 67584.00052785873
benchmark.2.spec: m: 7168; k: 2048; l: 4; seed: 1111
benchmark.2.runs: 3
benchmark.2.mean: 30719.999223947525
benchmark.2.std: 0.0
benchmark.2.err: 0.0
benchmark.2.best: 30719.999223947525
benchmark.2.worst: 30719.999223947525
check: pass

We see that it outperforms the reference kernel as well. As above potentially we can improve performance by tuning hyperparameters.

Next Steps

We can think of further improvements to the baseline. The first step could involve to finetune the above hyperparameters but we can also try to make better memory patterns or pipeline the loop. I hope this blogpost can help others to improve their kernels. I would be happy to connect via Linkedin to exchange ideas.