Let the compiler do the work in CuTeDSL
To archive peak performance on H100
on the task of matrix transpose we need to prefetch matrix tiles when we employ a non persistent way of writing our kernels.
One way that I explained in the past and that is taken in the official CuTeDSL example for Hopper is to manually write the prefetching logic in the Prologue of the Mainloop.
There is another way which uses an experimental feature of the CuTeDSL
and let's the compiler do this work for us. This results in less code for the User to write. However it should be taken into account that the performance is slightly inferior to the manual approach and we have less control this way.
Here I will explain the experimental pipeline
argument to cutlass.range
.
The usual way
This is the traditional way that I described already in the above mentioned blogpost. We'll define a number that determines the number of tiles to prefetch, schedule the corresponding tiles to be transferred. Afterwards we perform a prologue MMA
operation which serves to clear the registers we will write our result to. We'll than enter the Mainloop which performs computation and memory transfer in a cyclical way, synchronising via usual Pipeline operation.
mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.ab_stage
)
if warp_idx == 0:
# /////////////////////////////////////////////////////////////////////////////
# Prefetch TMA load
# /////////////////////////////////////////////////////////////////////////////
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
# /////////////////////////////////////////////////////////////////////////////
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.producer_acquire(mainloop_producer_state)
# /////////////////////////////////////////////////////////////////////////////
# Slice to global/shared memref to current k_tile
# /////////////////////////////////////////////////////////////////////////////
tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)]
tAsA_pipe = tAsA[(None, mainloop_producer_state.index)]
tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)]
tBsB_pipe = tBsB[(None, mainloop_producer_state.index)]
# /////////////////////////////////////////////////////////////////////////////
# TMA load A/B
# /////////////////////////////////////////////////////////////////////////////
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=b_mcast_mask,
)
# Mainloop pipeline's producer commit is a NOP
mainloop_pipeline.producer_commit(mainloop_producer_state)
mainloop_producer_state.advance()
# /////////////////////////////////////////////////////////////////////////////
# Prologue MMAs
# /////////////////////////////////////////////////////////////////////////////
k_pipe_mmas = 1
mainloop_consumer_read_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
mainloop_consumer_release_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
peek_ab_full_status = cutlass.Boolean(1)
if mainloop_consumer_read_state.count < k_tile_cnt:
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
mainloop_consumer_read_state
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_tile in cutlass.range(k_pipe_mmas, unroll_full=True):
# Wait for A/B buffer to be ready
mainloop_pipeline.consumer_wait(
mainloop_consumer_read_state, peek_ab_full_status
)
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_read_state.index,
)
tCrA_1phase = tCrA[k_block_coord]
tCrB_1phase = tCrB[k_block_coord]
cute.gemm(
tiled_mma,
accumulators,
tCrA_1phase,
tCrB_1phase,
accumulators,
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
mainloop_consumer_read_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if mainloop_consumer_read_state.count < k_tile_cnt:
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
mainloop_consumer_read_state
)
# /////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# /////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
# /////////////////////////////////////////////////////////////////////////////
# Wait for TMA copies to complete
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.consumer_wait(
mainloop_consumer_read_state, peek_ab_full_status
)
# /////////////////////////////////////////////////////////////////////////////
# WGMMA
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_read_state.index,
)
tCrA_1phase = tCrA[k_block_coord]
tCrB_1phase = tCrB[k_block_coord]
cute.gemm(
tiled_mma,
accumulators,
tCrA_1phase,
tCrB_1phase,
accumulators,
)
cute.nvgpu.warpgroup.commit_group()
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
mainloop_pipeline.consumer_release(mainloop_consumer_release_state)
mainloop_consumer_read_state.advance()
mainloop_consumer_release_state.advance()
peek_ab_full_status = cutlass.Boolean(1)
if mainloop_consumer_read_state.count < k_tile_cnt:
peek_ab_full_status = mainloop_pipeline.consumer_try_wait(
mainloop_consumer_read_state
)
# /////////////////////////////////////////////////////////////////////////////
# TMA load
# /////////////////////////////////////////////////////////////////////////////
if warp_idx == 0 and mainloop_producer_state.count < k_tile_cnt:
# /////////////////////////////////////////////////////////////////////////////
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.producer_acquire(mainloop_producer_state)
# /////////////////////////////////////////////////////////////////////////////
# Slice to global/shared memref to current k_tile
# /////////////////////////////////////////////////////////////////////////////
tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)]
tAsA_pipe = tAsA[(None, mainloop_producer_state.index)]
tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)]
tBsB_pipe = tBsB[(None, mainloop_producer_state.index)]
# /////////////////////////////////////////////////////////////////////////////
# TMA load A/B
# /////////////////////////////////////////////////////////////////////////////
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=b_mcast_mask,
)
# Mainloop pipeline's producer commit is a NOP
mainloop_pipeline.producer_commit(mainloop_producer_state)
mainloop_producer_state.advance()
One step back: GEMM without prefetch
It is nice the CUTLASS
people provide already a baseline for us that employs prefetching. However we could think how we could write a naive GEMM
. By naive I mean simply to perform Copy -> Compute -> Copy -> ...
without any prefetch.
We can write that as follows:
# /////////////////////////////////////////////////////////////////////////////
# Setup
# /////////////////////////////////////////////////////////////////////////////
k_tile_cnt = cute.size(gA_mkl, mode=[2])
mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.ab_stage
)
mainloop_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.ab_stage
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
# /////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# /////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
# /////////////////////////////////////////////////////////////////////////////
# TMA load
# /////////////////////////////////////////////////////////////////////////////
if warp_idx == 0:
# /////////////////////////////////////////////////////////////////////////////
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.producer_acquire(mainloop_producer_state)
# /////////////////////////////////////////////////////////////////////////////
# Slice to global/shared memref to current k_tile
# /////////////////////////////////////////////////////////////////////////////
tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)]
tAsA_pipe = tAsA[(None, mainloop_producer_state.index)]
tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)]
tBsB_pipe = tBsB[(None, mainloop_producer_state.index)]
# /////////////////////////////////////////////////////////////////////////////
# TMA load A/B
# /////////////////////////////////////////////////////////////////////////////
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=b_mcast_mask,
)
# Mainloop pipeline's producer commit is a NOP
mainloop_pipeline.producer_commit(mainloop_producer_state)
mainloop_producer_state.advance()
# /////////////////////////////////////////////////////////////////////////////
# Wait for TMA copies to complete
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.consumer_wait(mainloop_consumer_state)
# /////////////////////////////////////////////////////////////////////////////
# WGMMA
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_state.index,
)
tCrA_1phase = tCrA[k_block_coord]
tCrB_1phase = tCrB[k_block_coord]
cute.gemm(
tiled_mma,
accumulators,
tCrA_1phase,
tCrB_1phase,
accumulators,
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
cute.nvgpu.warpgroup.wait_group(0)
mainloop_pipeline.consumer_release(mainloop_consumer_state)
mainloop_consumer_state.advance()
mainloop_pipeline.consumer_try_wait(mainloop_consumer_state)
Note that this code is simpler than the above. However the performance is far worse.
The kernel with prefetching archives 792 TFLOPs
while the naive version only archives 308 TFLOPs
.
Compiler to the rescue
How if we could give the compiler a hint to generate the prefetching for us? It turns out there is an experimental feature in the CuTeDSL
that allows us to do that.
For that we first need to implement a minimal PipelineState
# /////////////////////////////////////////////////////////////////////////////
# Simple Pipeline
# /////////////////////////////////////////////////////////////////////////////
class PipelineStateMinimal:
"""
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
"""
def __init__(self, count, index, phase):
self.count = count
self.index = index
self.phase = phase
We'll than need to make a minimal adjustment to our naive code as follows:
# /////////////////////////////////////////////////////////////////////////////
# Setup
# /////////////////////////////////////////////////////////////////////////////
k_tile_cnt = cute.size(gA_mkl, mode=[2])
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
# /////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# /////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range(
k_tile_cnt,
pipelining=self.ab_stage - 1,
):
mainloop_consumer_state = PipelineStateMinimal(
k_tile,
k_tile % self.ab_stage,
cutlass.Int32((k_tile // self.ab_stage) % 2),
)
# /////////////////////////////////////////////////////////////////////////////
# TMA load
# /////////////////////////////////////////////////////////////////////////////
if warp_idx == 0:
mainloop_producer_state = PipelineStateMinimal(
k_tile,
k_tile % self.ab_stage,
cutlass.Int32((k_tile // self.ab_stage) % 2) ^ 1,
)
# /////////////////////////////////////////////////////////////////////////////
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.producer_acquire(mainloop_producer_state)
# /////////////////////////////////////////////////////////////////////////////
# Slice to global/shared memref to current k_tile
# /////////////////////////////////////////////////////////////////////////////
tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)]
tAsA_pipe = tAsA[(None, mainloop_producer_state.index)]
tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)]
tBsB_pipe = tBsB[(None, mainloop_producer_state.index)]
# /////////////////////////////////////////////////////////////////////////////
# TMA load A/B
# /////////////////////////////////////////////////////////////////////////////
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=b_mcast_mask,
)
# Mainloop pipeline's producer commit is a NOP
mainloop_pipeline.producer_commit(mainloop_producer_state)
# /////////////////////////////////////////////////////////////////////////////
# Wait for TMA copies to complete
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.consumer_wait(mainloop_consumer_state)
# /////////////////////////////////////////////////////////////////////////////
# WGMMA
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_state.index,
)
tCrA_1phase = tCrA[k_block_coord]
tCrB_1phase = tCrB[k_block_coord]
cute.gemm(
tiled_mma,
accumulators,
tCrA_1phase,
tCrB_1phase,
accumulators,
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
cute.nvgpu.warpgroup.wait_group(0)
mainloop_pipeline.consumer_release(mainloop_consumer_state)
mainloop_pipeline.consumer_try_wait(mainloop_consumer_state)
Here we manually keep track of the compiler state at each iteration. We additionally provide the pipelining
argument to the cutlass.range
. That gives the compiler a hint that we want to perform prefetching.
The corresponding performance is very close to the one above: 784 TFLOPs
that means we more than doubled performance compared to the naive version.
Conclusion
I hope this blogpost showed the flexibility one has when writing CuTeDSL
kernels. You can find all the code in my Github repo. I am happy to discuss topic on GPU programming and MLSys in general on my Linkedin.