Persistent GEMM in CuTeDSL on Hopper
In kernel design we can employ warp specialisation which means Producer and Consumer(s) will be processed in different warp(groups).
The CuTeDSL
offers so called TileScheduler
Abstraction to archive this behavior.
In this post I show you how to write a simple but performant GEMM
kernel that uses warp specialisation and archives >650 TFLOPs
.
In my explanation I will focus on the parts that differ from the dense_gemm
example on Hopper here which I explained in detail here and here.
Call
We can run and benchmark the Kernel as follows:
def run():
M, N, K = 8192, 8192, 8192
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float32)
a_tensor = from_dlpack(a, assumed_align=16) # (M, K) : (K, 1) - K-Major
b_tensor = from_dlpack(b, assumed_align=16) # (N, K) : (K, 1) - K-Major
c_tensor = from_dlpack(c, assumed_align=16) # (M, N) : (N, 1) - N-Major
kernel = Kernel()
# Get current CUDA stream from PyTorch
torch_stream = torch.cuda.current_stream()
# Get the raw stream pointer as a CUstream
current_stream = cuda.CUstream(torch_stream.cuda_stream)
wgmma = cute.compile(kernel, a_tensor, b_tensor, c_tensor, stream=current_stream)
wgmma(a_tensor, b_tensor, c_tensor, current_stream)
torch.cuda.synchronize()
c_ref = torch.matmul(a, b.t())
torch.testing.assert_close(c.to(torch.bfloat16), c_ref, atol=1e-03, rtol=1e-03)
wgmma_callable = partial(wgmma, a_tensor, b_tensor, c_tensor, current_stream)
avg_time = do_bench(wgmma_callable, warmup=500, rep=10000)
print(f"Time in ms = {avg_time}")
print(f"TFLOPs = {(2 * M * N * K / 1e12) / (avg_time / 1000)}")
Kernel Setup
We have tiling along M
, N
and K
dimension. We furthermore define the number of consumers as 2 and the atom layout accordingly. Here we will accumulate in Float32
and multiply two BFloat16
values.
The op
is used for the TMA
transfer which we will use for efficient transfer of A
and B
from GMEM -> SMEM
.
def __init__(self):
# Config
######################################################
self.bM = 128
self.bN = 256
self.bK = 64
self.num_consumer = 2
self.num_producer = 1
self.atom_layout_mnk = (self.num_consumer, 1, 1)
self.acc_dtype = cutlass.Float32
self.op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
self.stage = 4
self.shared_storage = None
######################################################
Before we call the kernel we setup our MMA
unit. We consider both inputs to be K-Major
, i.e. with a stride of 1
along the K
dimension.
tiled_mma = sm90_utils.make_trivial_tiled_mma(
mA.element_type,
mB.element_type,
cute.nvgpu.warpgroup.OperandMajorMode.K,
cute.nvgpu.warpgroup.OperandMajorMode.K,
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.bN),
)
We need to setup the TMA
for A
and B
aswell:
tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tiled_tma_atom(
self.op,
mA,
smem_layout,
(self.bM, self.bK),
num_multicast=1,
)
tma_atom_b, tma_tensor_b = cute.nvgpu.cpasync.make_tiled_tma_atom(
self.op,
mB,
smem_layout,
(self.bN, self.bK),
num_multicast=1,
)
We'll than setup the TileScheduler
and Grid
as follows:
tile_sched_params, grid = self._compute_grid(M, N, self.bM, self.bN)
Note that we could use max_active_clusters
up to 132
but in my experiment I found that the performance was better or equal for 128
. This trick is borrowed from here.
@staticmethod
def _compute_grid(M: int, N: int, bM: int, bN: int):
num_ctas_mnl = (M // bM, N // bN, 1) # Number of tiles in each dimension
cluster_shape_mnl = (1, 1, 1) # No Cluster
max_active_clusters = cutlass.const_expr(128) # Hardware
tile_sched_params = cutlass.utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = cutlass.utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
We'll than call the kernel:
# Launch the kernel synchronously
self.kernel(
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
mC,
tiled_mma,
a_smem_layout_staged,
b_smem_layout_staged,
tile_sched_params,
).launch(
grid=grid,
block=[num_threads, 1, 1],
cluster=(1, 1, 1),
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
Kernel
Inside the kernel we define the pipeline as follows.
# mbar arrays
mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()
# Threads/warps participating in the pipelines
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread
)
num_comsumers = 4 * self.num_consumer
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, size=num_comsumers, alignment=num_comsumers
)
# States
mainloop_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.stage
)
mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.stage
)
Here we need 4
warps per consumer because one consumer will operate within a warp group and a warp group consists of 4
warps.
We'll than initialise the Pipeline:
cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=mainloop_pipeline_array_ptr,
num_stages=self.stage,
producer_group=mainloop_pipeline_producer_group,
consumer_group=mainloop_pipeline_consumer_group,
tx_count=tma_copy_bytes,
cta_layout_vmnk=cta_layout_vmnk,
)
Note that here we choose cta_layout_mnk = cute.make_layout((1,1,1))
because we didn't use clusters.
We initialise our tensors as such:
gA_mk = cute.local_tile(mA_mk, (self.bM, self.bK), (None, None))
gB_nk = cute.local_tile(mB_nk, (self.bN, self.bK), (None, None))
gC_mn = cute.local_tile(mC_mn, (self.bM, self.bN), (None, None))
Note that in we don't employ projection here. That is to keep the number of tiles in M
, N
and K
direction because we will need them to index into the correct tile later on.
We'll initialise our TMA
tensors as follows:
a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
a_cta_crd = 1
sa_for_tma_partition = cute.group_modes(sa, 0, 2)
gA_for_tma_partition = cute.group_modes(gA_mk, 0, 2)
tAsA, tAgA_mk = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
a_cta_crd,
a_cta_layout,
sa_for_tma_partition,
gA_for_tma_partition,
)
And similar for B
.
After partitioning using tiled_mma
we are ready to define our producer consumer pattern.
tCgC = thr_mma.partition_C(gC_mn)
tCsA = thr_mma.partition_A(sa)
tCsB = thr_mma.partition_B(sb)
tCrA = tiled_mma.make_fragment_A(tCsA)
tCrB = tiled_mma.make_fragment_B(tCsB)
Producer
Below you see the code for the producer. The first warp group will be the producer and load the tiles from GMEM -> SMEM
via TMA
. Note we use tAgA_k_index = (None, tile_m, tile_k)
and similar for B
to index into the appropriate tile in GMEM
.
The TileScheduler
will be used to provide these tiles and we use cute.arch.warpgroup_reg_dealloc(24)
following the approach taken here to take into account the fact that we won't need many registers during the Loading phase.
# Producer
if warp_group_idx < 1:
cute.arch.warpgroup_reg_dealloc(24)
if warp_idx_in_warpgroup == 0:
tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
while work_tile.is_valid_tile:
tile_m, tile_n, _ = work_tile.tile_idx
for tile_k in cutlass.range(k_tile_cnt):
mainloop_pipeline.producer_acquire(mainloop_producer_state)
tAgA_k_index = (None, tile_m, tile_k)
tAsA_stage_index = (None, mainloop_producer_state.index)
tBgB_k_index = (None, tile_n, tile_k)
tBsB_stage_index = (None, mainloop_producer_state.index)
tAgA_k = tAgA_mk[tAgA_k_index]
tAsA_pipe = tAsA[tAsA_stage_index]
tBgB_k = tBgB_nk[tBgB_k_index]
tBsB_pipe = tBsB[tBsB_stage_index]
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=0,
)
cute.copy(
tma_atom_b,
tBgB_k,
tBsB_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=0,
)
mainloop_pipeline.producer_commit(mainloop_producer_state)
mainloop_producer_state.advance()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
Consumer
The rest of the warpgroups (in our case 2) will be consumers, i.e. they perform the GEMM
operation.
cute.arch.warpgroup_reg_alloc(240)
will allocate more registers for the consumers. tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
will clear the accumulators before we process a tile. We than perform accumulation for one tile (thats why we choose accumulators in shape which doesn't have tile dimensions) and transfer these tile to the appropriate destination in GMEM
# Consumer
else:
cute.arch.warpgroup_reg_alloc(240)
tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
num_k_blocks = cute.size(tCrA, mode=[2])
accumulators = cute.make_fragment(
tCgC[None, None, None, 0, 0].shape, self.acc_dtype
)
while work_tile.is_valid_tile:
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
tile_m, tile_n, _ = work_tile.tile_idx
for k_tile in cutlass.range(k_tile_cnt):
# /////////////////////////////////////////////////////////////////////////////
# Wait for TMA copies to complete
# /////////////////////////////////////////////////////////////////////////////
mainloop_pipeline.consumer_wait(mainloop_consumer_state)
# /////////////////////////////////////////////////////////////////////////////
# WGMMA
# /////////////////////////////////////////////////////////////////////////////
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
mainloop_consumer_state.index,
)
tCrA_1phase = tCrA[k_block_coord]
tCrB_1phase = tCrB[k_block_coord]
cute.gemm(
tiled_mma,
accumulators,
tCrA_1phase,
tCrB_1phase,
accumulators,
)
tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
cute.nvgpu.warpgroup.wait_group(0)
mainloop_pipeline.consumer_release(mainloop_consumer_state)
mainloop_consumer_state.advance()
store_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), self.acc_dtype
)
cute.copy(
store_copy, accumulators, tCgC[None, None, None, tile_m, tile_n]
)
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
This relatively simple code archives already 643 TFLOPs
.
Use clusters
There is a simple modification we can make to archive 658 TFLOPs
: Clusters. Clusters can be leveraged by the TMA
unit to perform a multicast. Multicast means that we load for example 2
tiles from A
which belong to one cluster and one tile from B
. The one tile from B
will be multicasted, i.e. can be combined with both tiles belonging to the same cluster.
Within the init we define an additional operation:
self.op_cluster = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
which will be used for the tiles that need to be multicasted and the cluster shape
self.cluster_shape_mnk = (2, 1, 1)
Within the __call__
function we add:
# Cluster
self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
and modify the setup of TMA
as follows:
tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tiled_tma_atom(
self.op if self.cluster_shape_mnk[1] == 1 else self.op_cluster,
mA,
smem_layout,
(self.bM, self.bK),
num_multicast=self.cluster_shape_mnk[1],
)
and analogously for B
.
We'll than provide self.cta_layout_mnk
as an additional argument to the kernel and the function that calculates grid and tile scheduler and launch the kernel with cluster=self.cluster_shape_mnk
as an argument.
Within the kernel we need to define the multicast_mask
which will be passed to the copy operations of the TMA
s.
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
# ///////////////////////////////////////////////////////////////////////////////
# Get mcast mask
# ///////////////////////////////////////////////////////////////////////////////
a_mcast_mask = cute.make_layout_image_mask(
cta_layout_mnk, cluster_coord_mnk, mode=1
)
b_mcast_mask = cute.make_layout_image_mask(
cta_layout_mnk, cluster_coord_mnk, mode=0
)
a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
We need to adjust the consumer arrive count accordingly for the corresponding pipeline:
# Each warp will constribute to the arrive count with the number of mcast size
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
num_warps = 8
consumer_arrive_cnt = mcast_size * num_warps
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt, alignment=consumer_arrive_cnt
)
and supply the appropriate cluster coordinate a_cta_crd = cluster_coord_mnk[1]
to A
when setting up the TMA
partition and similar for B
.
In the Producer
loop we need to provide the appropriate masks to the copy operation
cute.copy(
tma_atom_a,
tAgA_k,
tAsA_pipe,
tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
mainloop_producer_state
),
mcast_mask=a_mcast_mask,
)
and apart from replacing thread with cluster sync where necessary that's it.
Conclusion
I hope this blogpost made warp specialised kernel design more approachable. Note that there are still further optimisations that we can employ but I chose to stop at this point to keep the blog concise and focused on the topic at hand. You can contact me on Linkedin to discuss GPU programming
or MLSys
in general. The full code can be found in my Github repo.