simons blog

Persistent Float8 Dense Gemm on Hopper

In the past I wrote a persistent gemm using the CuTeDSL. This blogpost will generalize the approach taken there and show how to modify the dense_gemm example for Hopper and turn it into a persistent kernel. It is a generalisation in the sense that here we perform a batched Gemm, i.e. we multiply a batch of matrices MxKxL and NxKxL. Furthermore I made a modification in the epilogue that allows for arbitrary output types, i.e. we can multiply two matrices in Float8 precision and output in Float8 or Float16. The previous version only allowed Float32 outputs.

Please check out my past posts on CuTeDSL if you are not familiar with the concepts used below. You can find them here.

Adjustments to be made

The first adjustment to be made compared to the ordinary dense gemm is that we need the TileScheduler abstraction:

We see that we provide the output tensor as well as the cta_tile_shape_mnk and cluster_shape_mn. This function is taken from the corresponding Blackwell example and initialises the grid as well as the parameters for the tile scheduler which will use to get the correct tiling.

    @staticmethod
    def _compute_grid(
        c: cute.Tensor,
        cta_tile_shape_mnk: Tuple[int, int, int],
        cluster_shape_mn: Tuple[int, int],
    ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
        """Use persistent tile scheduler to compute the grid size for the output tensor C.

        :param c: The output tensor C
        :type c: cute.Tensor
        :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
        :type cta_tile_shape_mnk: tuple[int, int, int]
        :param cluster_shape_mn: Shape of each cluster in M, N dimensions.
        :type cluster_shape_mn: tuple[int, int]

        :return: A tuple containing:
            - tile_sched_params: Parameters for the persistent tile scheduler.
            - grid: Grid shape for kernel launch.
        :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
        """
        c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
        gc = cute.zipped_divide(c, tiler=c_shape)
        num_ctas_mnl = gc[(0, (None, None, None))].shape
        cluster_shape_mnl = (*cluster_shape_mn, 1)
        max_active_clusters = cutlass.const_expr(128)  # Hardware

        tile_sched_params = utils.PersistentTileSchedulerParams(
            num_ctas_mnl, cluster_shape_mnl
        )
        grid = utils.StaticPersistentTileScheduler.get_grid_shape(
            tile_sched_params, max_active_clusters
        )

        return tile_sched_params, grid

The second adjustment we need to make is in the init function. We will use 2 consumers and 1 producer. This is similar to the gemm example I implemented. We will use a total of 3 warp groups, two of which will be used for consumer (i.e. computing) and one which will be used for producer (i.e. data load via TMA).

        self.num_consumer = 2
        self.num_producer = 1
		...
        self.num_threads_per_warp_group = 128
        self.threads_per_cta = (
            self.num_consumer + self.num_producer
        ) * self.num_threads_per_warp_group

The third adjustment to be made is the mainloop and epilogue stage:

The below is the mainloop. Note the similarity to the gemm example I implemented. The only adjustment to be made here is that now we use the third tiler as well to get the correct batch.

        # /////////////////////////////////////////////////////////////////////////////
        #  Producer
        # /////////////////////////////////////////////////////////////////////////////
        if warp_idx < 4:
            cute.arch.warpgroup_reg_dealloc(24)

            if warp_idx == 0:
                tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
                    tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
                )
                work_tile = tile_sched.initial_work_tile_info()

                while work_tile.is_valid_tile:
                    tile_m, tile_n, tile_l = work_tile.tile_idx

                    for tile_k in cutlass.range(k_tile_cnt):
                        mainloop_pipeline.producer_acquire(mainloop_producer_state)

                        tAgA_index = (None, tile_m, tile_k, tile_l)
                        tAsA_index = (None, mainloop_producer_state.index)
                        tBgB_index = (None, tile_n, tile_k, tile_l)
                        tBsB_index = (None, mainloop_producer_state.index)

                        tAgA_k = tAgA_mkl[tAgA_index]
                        tAsA_pipe = tAsA[tAsA_index]
                        tBgB_k = tBgB_nkl[tBgB_index]
                        tBsB_pipe = tBsB[tBsB_index]

                        cute.copy(
                            tma_atom_a,
                            tAgA_k,
                            tAsA_pipe,
                            tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
                                mainloop_producer_state
                            ),
                            mcast_mask=a_mcast_mask,
                        )
                        cute.copy(
                            tma_atom_b,
                            tBgB_k,
                            tBsB_pipe,
                            tma_bar_ptr=mainloop_pipeline.producer_get_barrier(
                                mainloop_producer_state
                            ),
                            mcast_mask=b_mcast_mask,
                        )
                        # Mainloop pipeline's producer commit is a NOP
                        mainloop_pipeline.producer_commit(mainloop_producer_state)
                        mainloop_producer_state.advance()
                    # Advance Tile
                    tile_sched.advance_to_next_work()
                    work_tile = tile_sched.get_current_work()
        # Consumer
        else:
            cute.arch.warpgroup_reg_alloc(240)

            tile_sched = cutlass.utils.StaticPersistentTileScheduler.create(
                tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
            )

            work_tile = tile_sched.initial_work_tile_info()
            num_k_blocks = cute.size(tCrA, mode=[2])

            accumulators = cute.make_fragment(
                tCgC[None, None, None, 0, 0, 0].shape, self.acc_dtype
            )
            while work_tile.is_valid_tile:
                tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)

                tile_m, tile_n, tile_l = work_tile.tile_idx

                for k_tile in cutlass.range(k_tile_cnt):
                    # /////////////////////////////////////////////////////////////////////////////
                    #  Wait for TMA copies to complete
                    # /////////////////////////////////////////////////////////////////////////////
                    mainloop_pipeline.consumer_wait(mainloop_consumer_state)
                    # /////////////////////////////////////////////////////////////////////////////
                    #  WGMMA
                    # /////////////////////////////////////////////////////////////////////////////
                    cute.nvgpu.warpgroup.fence()
                    for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
                        k_block_coord = (
                            None,
                            None,
                            k_block_idx,
                            mainloop_consumer_state.index,
                        )
                        tCrA_1phase = tCrA[k_block_coord]
                        tCrB_1phase = tCrB[k_block_coord]

                        cute.gemm(
                            tiled_mma,
                            accumulators,
                            tCrA_1phase,
                            tCrB_1phase,
                            accumulators,
                        )
                        tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)

                    cute.nvgpu.warpgroup.commit_group()
                    # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
                    cute.nvgpu.warpgroup.wait_group(0)

                    mainloop_pipeline.consumer_release(mainloop_consumer_state)
                    mainloop_consumer_state.advance()

                store_copy = cute.make_copy_atom(
                    cute.nvgpu.CopyUniversalOp(), self.c_dtype
                )
                accumulators_vector = cute.make_fragment_like(
                    accumulators, self.c_dtype
                )
                accumulators_vector.store(accumulators.load().to(self.c_dtype))
                cute.copy(
                    store_copy,
                    accumulators_vector,
                    tCgC[None, None, None, tile_m, tile_n, tile_l],
                )

                tile_sched.advance_to_next_work()
                work_tile = tile_sched.get_current_work()

For the tiling to work we need:

        # (bM, bK, RestM, RestK, RestL)
        gA_mkl = cute.local_tile(
            mA_mkl,
            cute.slice_(self.tile_shape_mnk, (None, 0, None)),
            (None, None, None),
        )
        # (bN, bK, RestN, RestK, RestL)
        gB_nkl = cute.local_tile(
            mB_nkl,
            cute.slice_(self.tile_shape_mnk, (0, None, None)),
            (None, None, None),
        )
        # (bM, bN, RestM, RestN, RestL)
        gC_mnl = cute.local_tile(
            mC_mnl,
            cute.slice_(self.tile_shape_mnk, (None, None, 0)),
            (None, None, None),
        )

the dense_gemm example uses projection here which will result in Rest_M, Rest_N and Rest_L to be projected out. We obviously don't want that here because we need these modes to index into the appropriate tile.

Compared to the persistent gemm I made a further small adjustment in the epilogue:

                store_copy = cute.make_copy_atom(
                    cute.nvgpu.CopyUniversalOp(), self.c_dtype
                )
                accumulators_vector = cute.make_fragment_like(
                    accumulators, self.c_dtype
                )
                accumulators_vector.store(accumulators.load().to(self.c_dtype))
                cute.copy(
                    store_copy,
                    accumulators_vector,
                    tCgC[None, None, None, tile_m, tile_n, tile_l],
                )

Note that we initialise another fragment here to cast the accumulator to the output datatype. This is necessary because cute.copy will only work if src and dst have same datatype.

Performance and Outlook on further improvement

The current performance is as follows compared to the dense_gemm example:

A, B are in Float8E4M3FN

C in Float8E4M3FN: 1056 TFLOPs for persistent kernel, 1389 TFLOPs for dense_gemm example, i.e. 76% of the performance.

C in Float16: 1147 TFLOPs for persistent kernel, 1345 TFLOPs for dense_gemm example, i.e. 85% of the performance.

We should be able to at least match the performance of the baseline by using TMA stores for the persistent version in the Epilogue. I intend to implement that in the future. There are also further optimisations for Float8 which we can imply which I also intend to implement.

Conclusion

I hope this blogpost helped to understand better how to write a general persistent version for a dense gemm in the CuTeDSL. Feel free to connect with me via Linkedin to exchange ideas on GPU programming.

All the code can be found in my github repo.