simons blog

PingPong 🏓 in the CuTeDSL with QuACK 🦆

Introduction

PingPong is an important pattern in GEMM design we didn't cover so far. The idea is that we have a persistent kernel but differing from the approach we employed before the two Consumer warp groups will perform their MMA calculation as well as the epilogue separately and in an overlapping manner. This pattern is different from the simple persistent implementation because we need intra-consumer communication so only consumer warpgroup performs MMA/Epilogue respectively. The QuACK library implements a variation of PingPong and we will analyse it here.

To reproduce the setup we took launch the kernel like this.

CUDA_VISIBLE_DEVICES=0 uv run dense_gemm_sm90.py                                   \
      --mnkl 4096,4096,4096,1 --tile_shape_mnk 128,208,64                  \
      --cluster_shape_mn 1,1 --a_dtype BFloat16 --b_dtype BFloat16           \
      --d_dtype BFloat16 --acc_dtype Float32                                \
      --a_major k --b_major k --d_major n \
      --warmup_iterations 100 --iterations 1000 \
      --persistent --pingpong

__init__

We set the layout as follows:

atom_layout_m, atom_layout_n = 1, 1
self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1)

Set the number of warp groups of MMA. Note that for PingPong we have a total of 2 warpgroups, each of them operating on a (1,1,1) atom.

self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)

For PingPong we will have each MMA warp group separated, thats why we don't multiply by self.mma_warp_groups which we would do for an ordinary persistent kernel where the consumers cooperate.

self.num_mma_threads = (
	self.mma_warp_groups if not self.pingpong else 1
) * self.num_threads_per_warp_group
self.num_epi_threads = (
	self.mma_warp_groups if not self.pingpong else 1
) * self.num_threads_per_warp_group

By default we don't use load_A_cpasync, so for the pingpong kernel we have one warp loading A and B. We use one further warp for the epilogue load. We assign the warps

0-8 -> MMA
9 -> LOAD A/B
10 -> LOAD EPI
self.num_ab_load_warps = 1 if not self.load_A_cpasync else 2
self.num_ab_load_threads = cute.arch.WARP_SIZE * self.num_ab_load_warps
self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
self.ab_load_warp_id = self.mma_warp_groups * 4
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // self.num_mma_threads
if self.fp8_slow_accum:
	regs_per_thread *= 2
if self.mma_warp_groups == 3:
	self.num_regs_load, self.num_regs_mma = 32, 160
else:
	heavy_register_pressure = regs_per_thread >= 208
	self.num_regs_load, self.num_regs_mma = (
		(40, 232) if not heavy_register_pressure else (24, 240)
	)

Note that we have

tile_shape_mnk =  (128, 208)
num_mma_threads =  128
regs_per_thread =  208

i.e. we put heavy pressure on the registers belonging to the MMA part. Note that we can find in the PTX docs that we the restriction is that self.num_regs_load and self.num_regs_mma must be in the range 24 to 256 (both inclusive) and must be a multiple of 8. The setting of (24, 240) is common to archive good performance.

_setup_attributes

A

self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages(
	self.tile_shape_mnk,
	self.epi_tile,
	self.a_dtype,
	self.b_dtype,
	self.d_dtype,
	self.c_dtype,
	self.smem_capacity,
	self.occupancy,
	# epi_smem will reuse smem ab if not persistent.
	overlap_sD_sA=not self.is_persistent,
)

A PingPong kernel is persistent and hence we don't overlap sD_sA. Let's see in _compute_stages what that means:

@staticmethod
def _compute_stages(
	...
) -> Tuple[int, int]:

	epi_stage = 2
	if overlap_sD_sA:
		epi_bytes = 0
	else:
		d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8
		epi_bytes = d_bytes_per_stage * epi_stage
	...
	# Refine epilogue stages:
	# Calculate remaining smem after allocating for A/B stages and reserved bytes
	# Add remaining unused smem to epilogue
	if not overlap_sD_sA:
		epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // d_bytes_per_stage
	return ab_stage, epi_stage, epi_c_stage

Note that the approach taken in the dense_gemm.py example from the CUTLASS repo (which is not persistent) is to reuse the SMEM allocated to A for the Epilogue. From the code above it's obvious that for the epilogue we allocate now it's on part of the SMEM. Additionally we don't hardcode the number of stages taken in the epilogue but adjust them dynamically based on the number of bytes we need for one Tile of the epilogue.

__call__

As always this function will be responsible for preparing the inputs and everything else that is used in the kernel. We will focus here on the explanation of differences to the previous GEMM kernels we've seen in the CuTeDSL.

The QuACK library implements its own TileScheduler. I will probably explain it in more detail in a later blogpost. Here we note that we provide it with is_persistent for persistent tile scheduling.

problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (
	mD.shape[2],
)
TileScheduler = StaticTileScheduler
tile_sched_args = TileSchedulerArguments(
	problem_shape_ntile_mnl=problem_shape_ntile_mnl,
	raster_order=RasterOrderOption.Heuristic,
	group_size=8,
	cluster_shape_mnk=self.cluster_shape_mnk,
	is_persistent=self.is_persistent,
)

Note that we allocate space for sD. This is a difference to the previous examples where we reused the SMEM from A for the epilogue as explained above.

@cute.struct
class SharedStorage:
	ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
	epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
	sD: cute.struct.Align[
		cute.struct.MemRange[self.d_dtype, epi_smem_size],
		self.buffer_align_bytes,
	]
	sC: cute.struct.Align[
		cute.struct.MemRange[
			self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size
		],
		self.buffer_align_bytes,
	]
	sA: cute.struct.Align[
		cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
		self.buffer_align_bytes,
	]
	sB: cute.struct.Align[
		cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
		self.buffer_align_bytes,
	]

self.shared_storage = SharedStorage

kernel

We do the main setup logic as usual. I will not go into each piece specifically as I covered these parts (or slight variations of it) in past blogposts.

As mentioned above we don't reuse A for the persistent PingPong schedule.

sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
if const_expr(not self.is_persistent):
	sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout_staged.inner, dtype=self.d_dtype)
	sD = cute.make_tensor(sD_ptr, epi_smem_layout_staged.outer)
else:
	sD = storage.sD.get_tensor(
		epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner
	)

Here we select the correct warps for our load stage. Note warpgroup_reg_dealloc which will restrict our registers for the load warps. TMA doesn't need heavy register usage.

if warp_idx >= self.ab_load_warp_id:
	cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
	if const_expr(mC_mnl is not None):
		epi_load_barrier = pipeline.NamedBarrier(
			barrier_id=int(NamedBarrierGemm.EpilogueLoad),
			num_threads=self.num_ab_load_threads + self.num_epi_load_threads,
		)
	else:
		epi_load_barrier = None
	if (
		warp_idx >= self.ab_load_warp_id
		and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
	):

Here is the producer logic (which is similar to what we know already). The tile scheduler is responsible for scheduling the coordinates to our tile.

while work_tile.is_valid_tile:
	tile_coord_mnkl = work_tile.tile_idx
	batch_idx = tile_coord_mnkl[3]
	...
	tile_scheduler.prefetch_next_work()
	tile_scheduler.advance_to_next_work()
	work_tile = tile_scheduler.get_current_work()
	# End of persistent scheduler loop
ab_pipeline.producer_tail(ab_producer_state)

This if clause selects the two warpgroups responsible for the MMA. We will now analyse this part in depth.

if warp_idx < self.ab_load_warp_id:

Similar to above we adjust the register count. Note that here of course we increase register count instead of decreasing it. For the pingpong schedule both the first and second consumer warpgroups get one warp marked as tma_warp.

cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
is_tma_warp = cutlass.Boolean(
	(not self.pingpong and warp_idx == 0)
	or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
)

We'll setup our warpgroup layout.

tidx, _, _ = cute.arch.thread_idx()
warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
if const_expr(self.pingpong):
	tidx = tidx % self.num_threads_per_warp_group
warp_group_thread_layout = cute.make_layout(
	self.mma_warp_groups if not self.pingpong else 1,
	stride=self.num_threads_per_warp_group,
)
thr_mma = tiled_mma.get_slice(
	warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
)

Note that for the pingpong we have

tidx -> tidx % 128
warp_group_thread_layout = 1 : 128

which emphasis that we thread each warp group separately from the other one.

tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))

acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
acc = cute.make_fragment(acc_shape, self.acc_dtype)
if const_expr(self.fp8_slow_accum):
	acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
else:
	acc_slow = None

if const_expr(self.pingpong):
	if warp_group_idx == 0:
		# WG0 needs a start signal at the very beginning
		self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
		self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")

Let's take a look a pingpong_barrier_arrive:

def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str):
	assert stage in ["mma", "epi"]
	barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
	cute.arch.barrier_arrive(
		barrier_id=int(barrier) + warp_group_idx,
		number_of_threads=2 * self.num_threads_per_warp_group,
	)
	self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")

We use a simple Enumeration here to obtain the appropriate barrier id

class NamedBarrierGemm(enum.IntEnum):
    Epilogue = enum.auto()  # starts from 1 as barrier 0 is reserved for sync_threads()
    # For mainloop load warps to signal that the epilogue load warp can start.
    # This is to avoid loading C too early, interfering with loading A and B.
    EpilogueLoad = enum.auto()
    MmaWG0 = enum.auto()
    MmaWG1 = enum.auto()
    EpiWG0 = enum.auto()
    EpiWG1 = enum.auto()

We need the barrier_arrive for synchronisation between the two warp groups, i.e. we signal the arrival of the executing thread to quote the PTX docs.

Note that the first warp group will operate on a different tile than the second. That's why we initially set off by calling advance_to_next_work to get the next tile for this warp group and additionally advance the ab_read_state by k_tile_cnt (i.e. the number of k_tiles a full tile is consumed by the WGMMA instruction which can handle only up to 16 for the K dimension).

ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)

tile_scheduler = TileSchedulerCls()
if const_expr(self.pingpong):
	if warp_idx >= 4:
		# Advance 2nd Math WG to the next work tile for the startup
		tile_scheduler.advance_to_next_work()
		# Advance 2nd Math WG pipeline states to the end of 1st Math WG
		ab_read_state.advance_iters(k_tile_cnt)
work_tile = tile_scheduler.initial_work_tile_info()

Similar to above we schedule our tiles with the tile scheduler.

while work_tile.is_valid_tile:

We'll than execute the mma.

ab_read_state, tiled_mma = self.mma(
	ab_pipeline,
	ab_read_state,
	tiled_mma,
	tCrA,
	tCrB,
	acc,
	acc_slow,
	k_tile_cnt,
	warp_group_idx,
)

Let's take a closer look at the underlying logic:

@cute.jit
def mma(
	...
) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
	# /////////////////////////////////////////////////////////////////////////////
	#  Prologue MMAs
	# /////////////////////////////////////////////////////////////////////////////
	k_pipe_mmas = 1
	ab_release_state = ab_read_state.clone()
	num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
	if const_expr(self.pingpong):
		self.pingpong_barrier_sync(warp_group_idx, stage="mma")
	peek_ab_full_status = cutlass.Boolean(True)
	if 0 < k_tile_cnt:
		peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
	tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
	num_k_blocks = cute.size(tCrA, mode=[2])
	# TODO: this is probably not correct if k_tile_cnt == 0
	for k_tile in cutlass.range(num_prologue_mma):
		# Wait for A/B buffer to be ready
		ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
		warpgroup.fence()
		for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
			k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
			cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
			tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
		warpgroup.commit_group()
		ab_read_state.advance()
		peek_ab_full_status = cutlass.Boolean(True)
		if k_tile + 1 < k_tile_cnt:
			peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
	if const_expr(self.fp8_slow_accum):
		warpgroup.wait_group(0)
		acc_slow.store(acc.load())

	# /////////////////////////////////////////////////////////////////////////////
	#  MAINLOOP
	# /////////////////////////////////////////////////////////////////////////////
	for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
		# Wait for TMA copies to complete
		ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
		# WGMMA
		warpgroup.fence()
		if const_expr(self.fp8_slow_accum):
			tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
		for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
			k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
			cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
			tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
		warpgroup.commit_group()
		# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
		if const_expr(not self.fp8_slow_accum):
			warpgroup.wait_group(k_pipe_mmas)
		else:
			warpgroup.wait_group(0)
			acc_slow.store(acc_slow.load() + acc.load())
		ab_pipeline.consumer_release(ab_release_state)
		ab_read_state.advance()
		ab_release_state.advance()
		peek_ab_full_status = cutlass.Boolean(True)
		if k_tile + 1 < k_tile_cnt:
			peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
	if const_expr(self.pingpong):
		# Cue for next WG's MMA to start
		self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
	if const_expr(not self.fp8_slow_accum):
		# fp8_slow_accum would already called wait_group(0) inside the loop
		warpgroup.wait_group(0)
	for k_tile in cutlass.range(k_pipe_mmas, unroll=1):
		ab_pipeline.consumer_release(ab_release_state)
		ab_release_state.advance()
	if const_expr(self.fp8_slow_accum):
		acc.store(acc_slow.load())
	# If we don't return the tiled_mma, we get compiler error
	# "operand #0 does not dominate this use"
	return ab_read_state, tiled_mma

Note that the above is highly similar to the logic we normally employ. The difference for the ping pong is that we need to take care of the synchronisation. This is quiet similar to the usual Producer/Consumer pattern we also employ in "normal" persistent kernels.

In persistent kernels we use wait and arrive on full and empty barriers to synchronise work between consumer and producer.

Here, we first create a barrier. The barrier id will be 0 for WG1 and 1 for WG2.

self.pingpong_barrier_sync(warp_group_idx, stage="mma")

After performing the MMA we signal via

self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")

that we are done with the calculation. For WG1 we signal to barrier 1 and for WG2 we signal to barrier 0. This ensures there are not multiple MMAs going on at the same time. Only one of the two warpgroups will be busy calculating at any point in time.

Note that one might argue that to be precise

warpgroup.wait_group(0)

should be executed before the arrive to be sure we don't have any idle WGMMAs when arriving on the barrier of the other warpgroup.

After executing the mainloop we perform

if const_expr(self.pingpong):
	# Update starting mainloop pipeline state for the next tile
	ab_read_state.advance_iters(k_tile_cnt)

That reflects the nature of pingpong where we have a pattern of WG1 -> WG1 -> WG0 -> ... for the calculation. One warp group executes MMA on k_tile_cnt K tiles and so we update the ab_read_state as above.

After the mainloop comes the epilogue. Once done with the MMA for a certain warp group we will perform the epilogue (i.e. transfer from RMEM to GMEM via SMEM using TMA) while the other warp group is busy computing MMA.

We initially create the appropriate barrier.

# /////////////////////////////////////////////////////////////////////////////
#  EPILOGUE
# /////////////////////////////////////////////////////////////////////////////
if const_expr(self.pingpong):
	self.pingpong_barrier_sync(warp_group_idx, "epi")

What comes than is very similar to the usual epilogue for Hopper so I will not go over it line by line.

However for PingPong it is important that we do the following.

if const_expr(self.pingpong):
	# Update starting load/store pipeline states for the next tile
	epi_read_state.advance_iters(c_tile_cnt)
	# With pingpong, 2 WGs write two different output tiles to the same smem,
	# so we have to make sure the smem content is done reading before signalling
	# the next WG's epilogue.
	if warp_idx == 0 or warp_idx == 4:
		cute.arch.cp_async_bulk_wait_group(0, read=True)
	self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")

That concludes the analysis of the PingPong kernel.

Conclusion

I hope this blogpost helped to demystify PingPong scheduling in the CuTeDSL. If you like to exchange ideas I would be happy to connect on Linkedin. The kernel I analysed here can be found in the QuACK repo here.