Consumer-Producer pattern on H100 in CuTeDSL
GPUs
are bound by memory transfer. As I explained in a past post this makes it oftentimes necessary to overlap memory transfer and computation to archive peak performance.
In this blogpost I will focus on the PTX instructions around mbarrier which is the concept needed to implement Consumer Producer Patterns from scratch on Hopper GPUs. I take a an efficient RMSNorm backwards kernel and analyse the code related to the communication between Producer and Consumer.
I assume the code is known so make sure to take a look at RMSNormBackward
before continuing to read.
Analysis
We initialise the mbar_ptr
calling
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
smem, tv_layout, is_persistent=True
)
this will reserve the appropriate section in memory on the SMEM
. Here stage = 2
, so we have 4
elements in mbar_ptr
.
mbar_ptr = smem.allocate_array(
cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
)
We'll than identify the full_barrier
and the empty_barrier
using pointer arithmetics
mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
Two elements will be used for each of the above barriers.
We will use these in the following way:
empty_barrier[s].wait()
will be used by the producer to decide if the index s
in the buffer is ready to be filled. If that is the case producer will transfer from GMEM -> SMEM
and signal via full_barrier[s].arrive()
that it is done with the processing.
Similar the consumer uses full_barrier[s].wait()
to see if it can perform the computation for the corresponding element in the buffer (in our case reduction
) and will let producer know via empty_barrier[s].arrive()
that it's done doing so. We will see that the instruction additionally expects a phase
variable which will be a bit
of information that keeps track of parity of how many times we called wait
on a barrier.
Let us see how this is implemented in QuACK
which uses the CuTeDSL
.
Before entering the loop we perform a prefetch of the first batch and commit
this operation.
# Prefetch the first batch
row = tXcX[None, None, None, bidx_start][0][0]
if row < M:
...
elif tiler_mn[0] > 1:
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
# later into registers, causing wrong dW.
...
cute.arch.cp_async_commit_group()
if cutlass.const_expr(self.cluster_n > 1):
cute.arch.cluster_wait()
We'll than initialise the stage and the phases. The stage tells us which tile is currently being processed. Note that the producer phase get's initialised to 1
because we performed the prefetch above.
stage = cutlass.Int32(0)
producer_phase = cutlass.Int32(1)
consumer_phase = cutlass.Int32(0)
We'll than loop in a fashion that allows for one block to process multiple tiles if the total number of blocks would not cover the whole row dimension of the input matrix. In CUDA
this is sometimes called a Grid-Stride Loop.
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
At the beginning of each loop we prefetch the next batch (stage ^ 1
):
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
...
elif tiler_mn[0] > 1:
...
cute.arch.cp_async_commit_group()
if row < M or tiler_mn[0] == 1:
rstd = mRstd[row]
cute.arch.cp_async_wait_group(1)
Note that we call cute.arch.cp_async_wait_group(1)
. The argument 1
means we wait until 1
cp_async
instruction is pending. This makes sense because we don't need the copy of the next batch to be finished at this point.
We perform wait
on the empty[stage]
barrier for the current stage
. I.e we check with the producer
if the current tile is ready to be copied.
if cutlass.const_expr(self.cluster_n > 1):
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
We'll than call the row_reduce
function and provide it with the full_barrier
.
mean_xhat_wdy = (
utils.row_reduce(
...
(mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
phase=consumer_phase,
...
)
/ shape[1]
)
Inside it we will call cluster_reduce
(in the case we consider here, i.e. we use cluster).
if warp_idx == 0:
with cute.arch.elect_one():
num_warps = rows_per_block * warps_per_row
cute.arch.mbarrier_arrive_and_expect_tx(
mbar_ptr,
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
)
i.e. producer arrives
on full_barrier[s]
and expects a specified number of bytes (which is determined by the shape of the reduction buffer).
We'll than call wait
on full_barrier[s]
, i.e. consumer checks if the tile is ready to be processed with it's phase
and than if that is the case performs the reduction.
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
block_reduce_val = init_val
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
for i in cutlass.range_constexpr(num_iter):
idx = lane_idx + i * cute.arch.WARP_SIZE
if idx < cute.size(reduction_buffer, mode=[1]):
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
return warp_reduce(block_reduce_val, op)
After being done we arrive on the empty_barrier[stage]
to signal completion of the reduction op
if cutlass.const_expr(self.cluster_n > 1):
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
# Requires adjusting the thread_count when initializing the mbar
cute.arch.sync_warp()
lane_idx = cute.arch.lane_idx()
if lane_idx < self.cluster_n:
cute.arch.mbarrier_arrive(
mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
)
We'll than update the stage. If one stage is finished (i.e. we went one time through a cycle of length 2
in our case) we update the phases as well.
stage ^= 1
if stage == 0:
consumer_phase ^= 1
producer_phase ^= 1
To prevent early exit of cluster we perform another wait
on the empty_barrier[stage]
if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
Conclusion
I hope this blogpost made the mbarrier
instructions more accessible and helps in understanding the Backward Kernel for RMSNorm
in QuACK
.