simons blog

Persistent GEMM in CuTeDSL on Hopper

In kernel design we can employ warp specialisation which means Producer and Consumer(s) will be processed in different warp(groups). The CuTeDSL offers so called TileScheduler Abstraction to archive this behavior. In this post I show you how to write a simple but performant GEMM kernel that uses warp specialisation and archives >650 TFLOPs.

In my explanation I will focus on the parts that differ from the dense_gemm example on Hopper here which I explained in detail here and here.

Call

We can run and benchmark the Kernel as follows:

def run():
    M, N, K = 8192, 8192, 8192
    a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
    b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
    c = torch.zeros(M, N, device="cuda", dtype=torch.float32)

    a_tensor = from_dlpack(a, assumed_align=16)  # (M, K) : (K, 1) - K-Major
    b_tensor = from_dlpack(b, assumed_align=16)  # (N, K) : (K, 1) - K-Major
    c_tensor = from_dlpack(c, assumed_align=16)  # (M, N) : (N, 1) - N-Major

    kernel = Kernel()

    # Get current CUDA stream from PyTorch
    torch_stream = torch.cuda.current_stream()
    # Get the raw stream pointer as a CUstream
    current_stream = cuda.CUstream(torch_stream.cuda_stream)

    wgmma = cute.compile(kernel, a_tensor, b_tensor, c_tensor, stream=current_stream)

    wgmma(a_tensor, b_tensor, c_tensor, current_stream)
    torch.cuda.synchronize()
    c_ref = torch.matmul(a, b.t())
    torch.testing.assert_close(c.to(torch.bfloat16), c_ref, atol=1e-03, rtol=1e-03)

    wgmma_callable = partial(wgmma, a_tensor, b_tensor, c_tensor, current_stream)

    avg_time = do_bench(wgmma_callable, warmup=500, rep=10000)
    print(f"Time in ms = {avg_time}")
    print(f"TFLOPs = {(2 * M * N * K / 1e12) / (avg_time / 1000)}")

Kernel Setup

We have tiling along M, N and K dimension. We furthermore define the number of consumers as 2 and the atom layout accordingly. Here we will accumulate in Float32 and multiply two BFloat16 values.

The op is used for the TMA transfer which we will use for efficient transfer of Aand B from GMEM -> SMEM.

    def __init__(self):
        # Config
        ######################################################
        self.bM = 128
        self.bN = 256
        self.bK = 64

        self.num_consumer = 2
        self.num_producer = 1

        self.atom_layout_mnk = (self.num_consumer, 1, 1)
        self.acc_dtype = cutlass.Float32

        self.op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()

        self.stage = 4

        self.shared_storage = None
        ######################################################

Before we call the kernel we setup our MMA unit. We consider both inputs to be K-Major, i.e. with a stride of 1 along the K dimension.

        tiled_mma = sm90_utils.make_trivial_tiled_mma(
            mA.element_type,
            mB.element_type,
            cute.nvgpu.warpgroup.OperandMajorMode.K,
            cute.nvgpu.warpgroup.OperandMajorMode.K,
            self.acc_dtype,
            self.atom_layout_mnk,
            tiler_mn=(64, self.bN),
        )

We need to setup the TMA for A and B aswell:

        tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tiled_tma_atom(
            self.op,
            mA,
            smem_layout,
            (self.bM, self.bK),
            num_multicast=1,
        )
        tma_atom_b, tma_tensor_b = cute.nvgpu.cpasync.make_tiled_tma_atom(
            self.op,
            mB,
            smem_layout,
            (self.bN, self.bK),
            num_multicast=1,
        )

We'll than setup the TileScheduler and Grid as follows:

tile_sched_params, grid = self._compute_grid(M, N, self.bM, self.bN)

Note that we could use max_active_clusters up to 132 but in my experiment I found that the performance was better or equal for 128. This trick is borrowed from here.

    @staticmethod
    def _compute_grid(M: int, N: int, bM: int, bN: int):
        num_ctas_mnl = (M // bM, N // bN, 1)  # Number of tiles in each dimension
        cluster_shape_mnl = (1, 1, 1)  # No Cluster
        max_active_clusters = cutlass.const_expr(128)  # Hardware

        tile_sched_params = cutlass.utils.PersistentTileSchedulerParams(
            num_ctas_mnl, cluster_shape_mnl
        )
        grid = cutlass.utils.StaticPersistentTileScheduler.get_grid_shape(
            tile_sched_params, max_active_clusters
        )
        return tile_sched_params, grid

We'll than call the kernel:

        # Launch the kernel synchronously
        self.kernel(
            tma_atom_a,
            tma_tensor_a,
            tma_atom_b,
            tma_tensor_b,
            mC,
            tiled_mma,
            a_smem_layout_staged,
            b_smem_layout_staged,
            tile_sched_params,
        ).launch(
            grid=grid,
            block=[num_threads, 1, 1],
            cluster=(1, 1, 1),
            smem=self.shared_storage.size_in_bytes(),
            stream=stream,
        )

Kernel

Inside the kernel we define the pipeline as follows.

        # mbar arrays
        mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr()

        # Threads/warps participating in the pipelines
        mainloop_pipeline_producer_group = pipeline.CooperativeGroup(
            pipeline.Agent.Thread
        )
        num_comsumers = 4 * self.num_consumer
        mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
            pipeline.Agent.Thread, size=num_comsumers, alignment=num_comsumers
        )

        # States
        mainloop_consumer_state = pipeline.make_pipeline_state(
            pipeline.PipelineUserType.Consumer, self.stage
        )
        mainloop_producer_state = pipeline.make_pipeline_state(
            pipeline.PipelineUserType.Producer, self.stage
        )

Here we need 4 warps per consumer because one consumer will operate within a warp group and a warp group consists of 4 warps.

We'll than initialise the Pipeline:

        cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape))
        mainloop_pipeline = pipeline.PipelineTmaAsync.create(
            barrier_storage=mainloop_pipeline_array_ptr,
            num_stages=self.stage,
            producer_group=mainloop_pipeline_producer_group,
            consumer_group=mainloop_pipeline_consumer_group,
            tx_count=tma_copy_bytes,
            cta_layout_vmnk=cta_layout_vmnk,
        )

Note that here we choose cta_layout_mnk = cute.make_layout((1,1,1)) because we didn't use clusters.

We initialise our tensors as such:

        gA_mk = cute.local_tile(mA_mk, (self.bM, self.bK), (None, None))
        gB_nk = cute.local_tile(mB_nk, (self.bN, self.bK), (None, None))
        gC_mn = cute.local_tile(mC_mn, (self.bM, self.bN), (None, None))

Note that in we don't employ projection here. That is to keep the number of tiles in M, N and K direction because we will need them to index into the correct tile later on.

We'll initialise our TMA tensors as follows:

        a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
        a_cta_crd = 1
        sa_for_tma_partition = cute.group_modes(sa, 0, 2)
        gA_for_tma_partition = cute.group_modes(gA_mk, 0, 2)
        tAsA, tAgA_mk = cute.nvgpu.cpasync.tma_partition(
            tma_atom_a,
            a_cta_crd,
            a_cta_layout,
            sa_for_tma_partition,
            gA_for_tma_partition,
        )

And similar for B.

After partitioning using tiled_mma we are ready to define our producer consumer pattern.

        tCgC = thr_mma.partition_C(gC_mn)
		tCsA = thr_mma.partition_A(sa)
        tCsB = thr_mma.partition_B(sb)
        tCrA = tiled_mma.make_fragment_A(tCsA)
        tCrB = tiled_mma.make_fragment_B(tCsB)

Producer

Below you see the code for the producer. The first warp group will be the producer and load the tiles from GMEM -> SMEM via TMA. Note we use tAgA_k_index = (None, tile_m, tile_k) and similar for B to index into the appropriate tile in GMEM. The TileScheduler will be used to provide these tiles and we use cute.arch.warpgroup_reg_dealloc(24) following the approach taken here to take into account the fact that we won't need many registers during the Loading phase.

        # Producer
        if warp_group_idx < 1:
            cute.arch.warpgroup_reg_dealloc(24)
            if warp_idx_in_warpgroup == 0:
                tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
                    tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
                )

                work_tile = tile_sched.initial_work_tile_info()

                while work_tile.is_valid_tile:
                    tile_m, tile_n, _ = work_tile.tile_idx

                    for tile_k in cutlass.range(k_tile_cnt):
                        mainloop_pipeline.producer_acquire(mainloop_producer_state)

                        tAgA_k_index = (None, tile_m, tile_k)
                        tAsA_stage_index = (None, mainloop_producer_state.index)
                        tBgB_k_index = (None, tile_n, tile_k)
                        tBsB_stage_index = (None, mainloop_producer_state.index)

                        tAgA_k = tAgA_mk[tAgA_k_index]
                        tAsA_pipe = tAsA[tAsA_stage_index]
                        tBgB_k = tBgB_nk[tBgB_k_index]
                        tBsB_pipe = tBsB[tBsB_stage_index]

                        cute.copy(
                            tma_atom_a,
                            tAgA_k,
                            tAsA_pipe,
                            tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
                                mainloop_producer_state
                            ),
                            mcast_mask=0,
                        )
                        cute.copy(
                            tma_atom_b,
                            tBgB_k,
                            tBsB_pipe,
                            tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
                                mainloop_producer_state
                            ),
                            mcast_mask=0,
                        )

                        mainloop_pipeline.producer_commit(mainloop_producer_state)
                        mainloop_producer_state.advance()

                    tile_sched.advance_to_next_work()
                    work_tile = tile_sched.get_current_work()

Consumer

The rest of the warpgroups (in our case 2) will be consumers, i.e. they perform the GEMM operation. cute.arch.warpgroup_reg_alloc(240) will allocate more registers for the consumers. tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) will clear the accumulators before we process a tile. We than perform accumulation for one tile (thats why we choose accumulators in shape which doesn't have tile dimensions) and transfer these tile to the appropriate destination in GMEM

        # Consumer
        else:
            cute.arch.warpgroup_reg_alloc(240)

            tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
                tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
            )

            work_tile = tile_sched.initial_work_tile_info()
            num_k_blocks = cute.size(tCrA, mode=[2])
            accumulators = cute.make_fragment(
                tCgC[None, None, None, 0, 0].shape, self.acc_dtype
            )
            while work_tile.is_valid_tile:
                tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)

                tile_m, tile_n, _ = work_tile.tile_idx

                for k_tile in cutlass.range(k_tile_cnt):
                    # /////////////////////////////////////////////////////////////////////////////
                    #  Wait for TMA copies to complete
                    # /////////////////////////////////////////////////////////////////////////////
                    mainloop_pipeline.consumer_wait(mainloop_consumer_state)
                    # /////////////////////////////////////////////////////////////////////////////
                    #  WGMMA
                    # /////////////////////////////////////////////////////////////////////////////
                    cute.nvgpu.warpgroup.fence()
                    for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
                        k_block_coord = (
                            None,
                            None,
                            k_block_idx,
                            mainloop_consumer_state.index,
                        )
                        tCrA_1phase = tCrA[k_block_coord]
                        tCrB_1phase = tCrB[k_block_coord]

                        cute.gemm(
                            tiled_mma,
                            accumulators,
                            tCrA_1phase,
                            tCrB_1phase,
                            accumulators,
                        )
                        tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)

                    cute.nvgpu.warpgroup.commit_group()
                    # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
                    cute.nvgpu.warpgroup.wait_group(0)

                    mainloop_pipeline.consumer_release(mainloop_consumer_state)
                    mainloop_consumer_state.advance()

                store_copy = cute.make_copy_atom(
                    cute.nvgpu.CopyUniversalOp(), self.acc_dtype
                )
                cute.copy(
                    store_copy, accumulators, tCgC[None, None, None, tile_m, tile_n]
                )

                tile_sched.advance_to_next_work()
                work_tile = tile_sched.get_current_work()

This relatively simple code archives already 643 TFLOPs.

Use clusters

There is a simple modification we can make to archive 658 TFLOPs: Clusters. Clusters can be leveraged by the TMA unit to perform a multicast. Multicast means that we load for example 2 tiles from A which belong to one cluster and one tile from B. The one tile from Bwill be multicasted, i.e. can be combined with both tiles belonging to the same cluster.

Within the init we define an additional operation:

self.op_cluster = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()

which will be used for the tiles that need to be multicasted and the cluster shape

self.cluster_shape_mnk = (2, 1, 1)

Within the __call__ function we add:

        # Cluster
        self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
        self.num_mcast_ctas_a = self.cluster_shape_mnk[1]
        self.num_mcast_ctas_b = self.cluster_shape_mnk[0]
        self.is_a_mcast = self.num_mcast_ctas_a > 1
        self.is_b_mcast = self.num_mcast_ctas_b > 1

and modify the setup of TMA as follows:

        tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tiled_tma_atom(
            self.op if self.cluster_shape_mnk[1] == 1 else self.op_cluster,
            mA,
            smem_layout,
            (self.bM, self.bK),
            num_multicast=self.cluster_shape_mnk[1],
        )

and analogously for B.

We'll than provide self.cta_layout_mnk as an additional argument to the kernel and the function that calculates grid and tile scheduler and launch the kernel with cluster=self.cluster_shape_mnk as an argument.

Within the kernel we need to define the multicast_mask which will be passed to the copy operations of the TMAs.

        cta_rank_in_cluster = cute.arch.make_warp_uniform(
            cute.arch.block_idx_in_cluster()
        )
        cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)

        # ///////////////////////////////////////////////////////////////////////////////
        # Get mcast mask
        # ///////////////////////////////////////////////////////////////////////////////
        a_mcast_mask = cute.make_layout_image_mask(
            cta_layout_mnk, cluster_coord_mnk, mode=1
        )
        b_mcast_mask = cute.make_layout_image_mask(
            cta_layout_mnk, cluster_coord_mnk, mode=0
        )

        a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
        b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0

We need to adjust the consumer arrive count accordingly for the corresponding pipeline:

        # Each warp will constribute to the arrive count with the number of mcast size
        mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
        num_warps = 8
        consumer_arrive_cnt = mcast_size * num_warps
        mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
            pipeline.Agent.Thread, consumer_arrive_cnt, alignment=consumer_arrive_cnt
        )

and supply the appropriate cluster coordinate a_cta_crd = cluster_coord_mnk[1] to A when setting up the TMApartition and similar for B.

In the Producer loop we need to provide the appropriate masks to the copy operation

                        cute.copy(
                            tma_atom_a,
                            tAgA_k,
                            tAsA_pipe,
                            tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
                                mainloop_producer_state
                            ),
                            mcast_mask=a_mcast_mask,
                        )

and apart from replacing thread with cluster sync where necessary that's it.

Conclusion

I hope this blogpost made warp specialised kernel design more approachable. Note that there are still further optimisations that we can employ but I chose to stop at this point to keep the blog concise and focused on the topic at hand. You can contact me on Linkedin to discuss GPU programming or MLSys in general. The full code can be found in my Github repo.