PingPong 🏓 in the CuTeDSL with QuACK 🦆
Introduction
PingPong
is an important pattern in GEMM
design we didn't cover so far. The idea is that we have a persistent kernel but differing from the approach we employed before the two Consumer warp groups will perform their MMA
calculation as well as the epilogue separately and in an overlapping manner. This pattern is different from the simple persistent implementation because we need intra-consumer communication so only consumer warpgroup performs MMA/Epilogue respectively. The QuACK library implements a variation of PingPong and we will analyse it here.
To reproduce the setup we took launch the kernel like this.
CUDA_VISIBLE_DEVICES=0 uv run dense_gemm_sm90.py \
--mnkl 4096,4096,4096,1 --tile_shape_mnk 128,208,64 \
--cluster_shape_mn 1,1 --a_dtype BFloat16 --b_dtype BFloat16 \
--d_dtype BFloat16 --acc_dtype Float32 \
--a_major k --b_major k --d_major n \
--warmup_iterations 100 --iterations 1000 \
--persistent --pingpong
__init__
We set the layout as follows:
atom_layout_m, atom_layout_n = 1, 1
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)
Set the number of warp groups of MMA
. Note that for PingPong
we have a total of 2
warpgroups, each of them operating on a (1,1,1)
atom.
self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
For PingPong
we will have each MMA
warp group separated, thats why we don't multiply by self.mma_warp_groups
which we would do for an ordinary persistent kernel where the consumers cooperate.
self.num_mma_threads = (
self.mma_warp_groups if not self.pingpong else 1
) * self.num_threads_per_warp_group
self.num_epi_threads = (
self.mma_warp_groups if not self.pingpong else 1
) * self.num_threads_per_warp_group
By default we don't use load_A_cpasync
, so for the pingpong kernel we have one warp loading A
and B
. We use one further warp for the epilogue load. We assign the warps
0-8 -> MMA
9 -> LOAD A/B
10 -> LOAD EPI
self.num_ab_load_warps = 1 if not self.load_A_cpasync else 2
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
self.ab_load_warp_id = self.mma_warp_groups * 4
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // self.num_mma_threads
if self.fp8_slow_accum:
regs_per_thread *= 2
if self.mma_warp_groups == 3:
self.num_regs_load, self.num_regs_mma = 32, 160
else:
heavy_register_pressure = regs_per_thread >= 208
self.num_regs_load, self.num_regs_mma = (
(40, 232) if not heavy_register_pressure else (24, 240)
)
Note that we have
tile_shape_mnk = (128, 208)
num_mma_threads = 128
regs_per_thread = 208
i.e. we put heavy pressure on the registers belonging to the MMA
part. Note that we can find in the PTX docs that we the restriction is that self.num_regs_load
and self.num_regs_mma
must be in the range 24 to 256 (both inclusive) and must be a multiple of 8. The setting of (24, 240)
is common to archive good performance.
_setup_attributes
A
self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
self.tile_shape_mnk,
self.epi_tile,
self.a_dtype,
self.b_dtype,
self.d_dtype,
self.c_dtype,
self.smem_capacity,
self.occupancy,
# epi_smem will reuse smem ab if not persistent.
overlap_sD_sA=not self.is_persistent,
)
A PingPong
kernel is persistent and hence we don't
overlap sD_sA
. Let's see in _compute_stages
what that means:
@staticmethod
def _compute_stages(
...
) -> Tuple[int, int]:
epi_stage = 2
if overlap_sD_sA:
epi_bytes = 0
else:
d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
epi_bytes = d_bytes_per_stage * epi_stage
...
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
if not overlap_sD_sA:
epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
return ab_stage, epi_stage, epi_c_stage
Note that the approach taken in the dense_gemm.py
example from the CUTLASS
repo (which is not persistent) is to reuse the SMEM
allocated to A
for the Epilogue
. From the code above it's obvious that for the epilogue we allocate now it's on part of the SMEM
. Additionally we don't hardcode the number of stages taken in the epilogue but adjust them dynamically based on the number of bytes
we need for one Tile
of the epilogue.
__call__
As always this function will be responsible for preparing the inputs and everything else that is used in the kernel. We will focus here on the explanation of differences to the previous GEMM
kernels we've seen in the CuTeDSL
.
The QuACK
library implements its own TileScheduler
. I will probably explain it in more detail in a later blogpost. Here we note that we provide it with is_persistent
for persistent tile scheduling.
problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
mD.shape[2],
)
TileScheduler = StaticTileScheduler
tile_sched_args = TileSchedulerArguments(
problem_shape_ntile_mnl=problem_shape_ntile_mnl,
raster_order=RasterOrderOption.Heuristic,
group_size=8,
cluster_shape_mnk=self.cluster_shape_mnk,
is_persistent=self.is_persistent,
)
Note that we allocate space for sD
. This is a difference to the previous examples where we reused the SMEM
from A
for the epilogue as explained above.
@cute.struct
class SharedStorage:
ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
sD: cute.struct.Align[
cute.struct.MemRange[self.d_dtype, epi_smem_size],
self.buffer_align_bytes,
]
sC: cute.struct.Align[
cute.struct.MemRange[
self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
],
self.buffer_align_bytes,
]
sA: cute.struct.Align[
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
self.buffer_align_bytes,
]
sB: cute.struct.Align[
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
kernel
We do the main setup logic as usual. I will not go into each piece specifically as I covered these parts (or slight variations of it) in past blogposts.
As mentioned above we don't reuse A
for the persistent PingPong
schedule.
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
if const_expr(not self.is_persistent):
sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
else:
sD = storage.sD.get_tensor(
epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
)
Here we select the correct warps for our load stage. Note warpgroup_reg_dealloc
which will restrict our registers for the load warps. TMA
doesn't need heavy register usage.
if warp_idx >= self.ab_load_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
if const_expr(mC_mnl is not None):
epi_load_barrier = pipeline.NamedBarrier(
barrier_id=int(NamedBarrierGemm.EpilogueLoad),
num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
)
else:
epi_load_barrier = None
if (
warp_idx >= self.ab_load_warp_id
and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
):
Here is the producer logic (which is similar to what we know already). The tile scheduler is responsible for scheduling the coordinates
to our tile.
while work_tile.is_valid_tile:
tile_coord_mnkl = work_tile.tile_idx
batch_idx = tile_coord_mnkl[3]
...
tile_scheduler.prefetch_next_work()
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
# End of persistent scheduler loop
ab_pipeline.producer_tail(ab_producer_state)
This if clause selects the two warpgroups responsible for the MMA
. We will now analyse this part in depth.
if warp_idx < self.ab_load_warp_id:
Similar to above we adjust the register count. Note that here of course we increase
register count instead of decreasing
it. For the pingpong
schedule both the first and second consumer warpgroups get one warp marked as tma_warp
.
cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
is_tma_warp = cutlass.Boolean(
(not self.pingpong and warp_idx == 0)
or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
)
We'll setup our warpgroup layout
.
tidx, _, _ = cute.arch.thread_idx()
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
if const_expr(self.pingpong):
tidx = tidx % self.num_threads_per_warp_group
warp_group_thread_layout = cute.make_layout(
self.mma_warp_groups if not self.pingpong else 1,
stride=self.num_threads_per_warp_group,
)
thr_mma = tiled_mma.get_slice(
warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
)
Note that for the pingpong we have
tidx -> tidx % 128
warp_group_thread_layout = 1 : 128
which emphasis that we thread each warp group separately from the other one.
tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
acc = cute.make_fragment(acc_shape, self.acc_dtype)
if const_expr(self.fp8_slow_accum):
acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
else:
acc_slow = None
if const_expr(self.pingpong):
if warp_group_idx == 0:
# WG0 needs a start signal at the very beginning
self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
Let's take a look a pingpong_barrier_arrive
:
def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
assert stage in ["mma", "epi"]
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
cute.arch.barrier_arrive(
barrier_id=int(barrier) + warp_group_idx,
number_of_threads=2 * self.num_threads_per_warp_group,
)
self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
We use a simple Enumeration
here to obtain the appropriate barrier id
class NamedBarrierGemm(enum.IntEnum):
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
# For mainloop load warps to signal that the epilogue load warp can start.
# This is to avoid loading C too early, interfering with loading A and B.
EpilogueLoad = enum.auto()
MmaWG0 = enum.auto()
MmaWG1 = enum.auto()
EpiWG0 = enum.auto()
EpiWG1 = enum.auto()
We need the barrier_arrive
for synchronisation between the two warp groups, i.e. we signal the arrival of the executing thread to quote the PTX docs.
Note that the first warp group will operate on a different tile than the second. That's why we initially set off by calling advance_to_next_work
to get the next tile for this warp group and additionally advance the ab_read_state
by k_tile_cnt
(i.e. the number of k_tiles
a full tile is consumed by the WGMMA
instruction which can handle only up to 16
for the K
dimension).
ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
tile_scheduler = TileSchedulerCls()
if const_expr(self.pingpong):
if warp_idx >= 4:
# Advance 2nd Math WG to the next work tile for the startup
tile_scheduler.advance_to_next_work()
# Advance 2nd Math WG pipeline states to the end of 1st Math WG
ab_read_state.advance_iters(k_tile_cnt)
work_tile = tile_scheduler.initial_work_tile_info()
Similar to above we schedule our tiles with the tile scheduler.
while work_tile.is_valid_tile:
We'll than execute the mma
.
ab_read_state, tiled_mma = self.mma(
ab_pipeline,
ab_read_state,
tiled_mma,
tCrA,
tCrB,
acc,
acc_slow,
k_tile_cnt,
warp_group_idx,
)
Let's take a closer look at the underlying logic:
@cute.jit
def mma(
...
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
# /////////////////////////////////////////////////////////////////////////////
# Prologue MMAs
# /////////////////////////////////////////////////////////////////////////////
k_pipe_mmas = 1
ab_release_state = ab_read_state.clone()
num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
if const_expr(self.pingpong):
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
peek_ab_full_status = cutlass.Boolean(True)
if 0 < k_tile_cnt:
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrA, mode=[2])
# TODO: this is probably not correct if k_tile_cnt == 0
for k_tile in cutlass.range(num_prologue_mma):
# Wait for A/B buffer to be ready
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
warpgroup.fence()
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
warpgroup.commit_group()
ab_read_state.advance()
peek_ab_full_status = cutlass.Boolean(True)
if k_tile + 1 < k_tile_cnt:
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
if const_expr(self.fp8_slow_accum):
warpgroup.wait_group(0)
acc_slow.store(acc.load())
# /////////////////////////////////////////////////////////////////////////////
# MAINLOOP
# /////////////////////////////////////////////////////////////////////////////
for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
# Wait for TMA copies to complete
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
# WGMMA
warpgroup.fence()
if const_expr(self.fp8_slow_accum):
tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
warpgroup.commit_group()
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
if const_expr(not self.fp8_slow_accum):
warpgroup.wait_group(k_pipe_mmas)
else:
warpgroup.wait_group(0)
acc_slow.store(acc_slow.load() + acc.load())
ab_pipeline.consumer_release(ab_release_state)
ab_read_state.advance()
ab_release_state.advance()
peek_ab_full_status = cutlass.Boolean(True)
if k_tile + 1 < k_tile_cnt:
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
if const_expr(self.pingpong):
# Cue for next WG's MMA to start
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
if const_expr(not self.fp8_slow_accum):
# fp8_slow_accum would already called wait_group(0) inside the loop
warpgroup.wait_group(0)
for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
ab_pipeline.consumer_release(ab_release_state)
ab_release_state.advance()
if const_expr(self.fp8_slow_accum):
acc.store(acc_slow.load())
# If we don't return the tiled_mma, we get compiler error
# "operand #0 does not dominate this use"
return ab_read_state, tiled_mma
Note that the above is highly similar to the logic we normally employ. The difference for the ping pong is that we need to take care of the synchronisation. This is quiet similar to the usual Producer/Consumer
pattern we also employ in "normal" persistent kernels.
In persistent kernels we use wait
and arrive
on full
and empty
barriers to synchronise work between consumer and producer.
Here, we first create a barrier. The barrier id will be 0
for WG1
and 1
for WG2
.
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
After performing the MMA
we signal via
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
that we are done with the calculation. For WG1
we signal to barrier 1
and for WG2
we signal to barrier 0
. This ensures there are not multiple MMAs
going on at the same time. Only one of the two warpgroups will be busy calculating at any point in time.
Note that one might argue that to be precise
warpgroup.wait_group(0)
should be executed before the arrive
to be sure we don't have any idle WGMMAs
when arriving on the barrier of the other warpgroup.
After executing the mainloop we perform
if const_expr(self.pingpong):
# Update starting mainloop pipeline state for the next tile
ab_read_state.advance_iters(k_tile_cnt)
That reflects the nature of pingpong where we have a pattern of WG1 -> WG1 -> WG0 -> ...
for the calculation. One warp group executes MMA
on k_tile_cnt
K
tiles and so we update the ab_read_state
as above.
After the mainloop comes the epilogue.
Once done with the MMA
for a certain warp group we will perform the epilogue (i.e. transfer from RMEM
to GMEM
via SMEM
using TMA
) while the other warp group is busy computing MMA
.
We initially create the appropriate barrier.
# /////////////////////////////////////////////////////////////////////////////
# EPILOGUE
# /////////////////////////////////////////////////////////////////////////////
if const_expr(self.pingpong):
self.pingpong_barrier_sync(warp_group_idx, "epi")
What comes than is very similar to the usual epilogue for Hopper so I will not go over it line by line.
However for PingPong
it is important that we do the following.
if const_expr(self.pingpong):
# Update starting load/store pipeline states for the next tile
epi_read_state.advance_iters(c_tile_cnt)
# With pingpong, 2 WGs write two different output tiles to the same smem,
# so we have to make sure the smem content is done reading before signalling
# the next WG's epilogue.
if warp_idx == 0 or warp_idx == 4:
cute.arch.cp_async_bulk_wait_group(0, read=True)
self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
- We need to update the
epi_read_state
by the tiles we execute in one epilogue. - We have either
warp_idx == 0
orwarp_idx == 4
forcp_async_bulk_wait_group
(i.e.TMA sync
). - Finally we arrive on the epilogue barrier to let the other warp group know that it can start with it's epilogue.
That concludes the analysis of the PingPong
kernel.
Conclusion
I hope this blogpost helped to demystify PingPong
scheduling in the CuTeDSL
. If you like to exchange ideas I would be happy to connect on Linkedin. The kernel I analysed here can be found in the QuACK
repo here.