simons blog

SGEMM in CuTeDSL

SGEMM is one of the fundamental operations we aim to optimise on GPUs. In this blogpost I will explain the corresponding example from the CUTLASS repo. I chose SGEMM because it is the most simple non trivial example given in the examples and therefore a good starting point to learn about CuTeDSL.

We will analyse the code in a Top Down approach, i.e. work gradually down in complexity.

The whole blogpost explains the case that we obtain when executing the program with the following arguments that I explain below.:

--mnk 256,128,64 --a_major m --b_major n --c_major m

main function

This is where the problem is set up.

The main function has following signature

def main(
    a_major: str,
    b_major: str,
    c_major: str,
    problem_shape: Tuple[int, int, int],
    warmup_iterations: int = 2,
    iterations: int = 100,
    skip_ref_check: bool = False,
):

By command like we passed the majorness for the three matrices of interest.

We will consider the case

A - M-Major
B - N-Major
C - M-Major

That simply means that the stride along M mode for A and C and N mode for B is considered to be 1.

The problem_shape is passed as (M, N, K) = (256, 128, 64).

This defines the layouts of A, B and C as follows.

A:(M,K):(1,M)B:(N,K):(1,N)C:(M,N):(1,M)

The layouts in CuTe define the way our matrix elements are layed out in memory.

To understand this concept deeper let's consider M = 2, K = 3.

For example we see for M-Major that the matrix coordinate (0,2) mapped to the physical coordinate 4. Physical coordinates are 1 dimensional because memory is a one dimensional concept.

M-Major
(2,3):(1,2)
 0 2 4
 1 3 5
K-Major
(2,3):(3,1)
 0 1 2
 3 4 5

The below code simply initialises the tensors with appropriate shape and stride.

    torch.manual_seed(1024)
    M, N, K = problem_shape

    # Create and permute tensor A/B/C
    def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
        # is_mode0_major: (mode1, mode0) -> (mode0, mode1)
        # else: (mode0, mode1) -> (mode0, mode1)
        shape = (mode1, mode0) if is_mode0_major else (mode0, mode1)
        permute_order = (1, 0) if is_mode0_major else (0, 1)

        return (
            torch.empty(*shape, dtype=torch.int32)
            .random_(-5, 5)
            .to(dtype=dtype)
            .permute(permute_order)
            .cuda()
        )

    a = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
    b = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
    c = create_and_permute_tensor(M, N, c_major == "m", torch.float32)

    divisibility_a = a.shape[1] if a_major == "k" else a.shape[0]
    divisibility_b = b.shape[1] if b_major == "k" else b.shape[0]
    divisibility_c = c.shape[1] if c_major == "n" else c.shape[0]

    a_tensor = (
        from_dlpack(a, assumed_align=16)
        .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
        .mark_compact_shape_dynamic(
            mode=(1 if a_major == "k" else 0),
            divisibility=divisibility_a,
        )
    )

    b_tensor = (
        from_dlpack(b, assumed_align=16)
        .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
        .mark_compact_shape_dynamic(
            mode=(1 if b_major == "k" else 0),
            divisibility=divisibility_b,
        )
    )

    c_tensor = (
        from_dlpack(c, assumed_align=16)
        .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
        .mark_compact_shape_dynamic(
            mode=(1 if c_major == "n" else 0),
            divisibility=divisibility_c,
        )
    )

In our case that means

mA Layout: (256,64):(1,256)
mB Layout: (256,64):(1,256)
mC Layout: (256,64):(1,256)

After that we simply compile our kernel:

    sgemm = SGemm()

    print("Compiling kernel with cute.compile ...")
    start_time = time.time()
    gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
    compilation_time = time.time() - start_time
    print(f"Compilation time: {compilation_time:.4f} seconds")

The kernel can than be called like so:

gemm(a_tensor, b_tensor, c_tensor)

SGemm class

The SGemm class has three methods: init, call and kernel.

init method

The below code initialises

class SGemm:
    def __init__(
        self,
        cta_tiler: Tuple[int, int, int] = (128, 128, 8),
        num_stages: int = 3,
        num_threads: int = 256,
    ):
        self._cta_tiler = cta_tiler
        self._num_stages = num_stages
        self._num_threads = num_threads
        assert num_threads > 0, "needs at least one thread"
        assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout"

        self._bM, self._bN, self._bK = self._cta_tiler
        assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M"
        assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N"
        assert self._num_stages >= 3, "num_stages must be greater than or equal to 3"

call method

The call method is the analog to the host code in a CUDA program. Here we will prepare all the parameters which will be later on handed down to the kernel.

Note that the __call__ method is jitted. It takes the tensors we prepared above as well as an epilogue_op. The epilogue_op is the identity by default and can be used to modify our result after the main computation is over. For example we could set lambda x: 2 * x to multiply the elements of the matrix multiplication by 2 after computing the GEMM.

    @cute.jit
    def __call__(
        self,
        mA: cute.Tensor,
        mB: cute.Tensor,
        mC: cute.Tensor,
        epilogue_op: cutlass.Constexpr = lambda x: x,
    ):

We determine the major modes. utils.LayoutEnum.from_tensor will return utils.LayoutEnum.ROW_MAJOR or utils.LayoutEnum.COL_MAJOR major. ROW_MAJOR would mean that the stride along the second dimension is 1 and vice versa.

We'll than initialise the SMEM layouts. If we deal with row major tensors in A or B we perform padding a common technique (alternative to swizzling) to reduce memory bank conflicts. The concept of memory bank conflicts is explained here.

We furthermore determine the size of SMEM we need to allocate. The maximum size is determined by the physical properties by the chip we deal with.

        self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
        self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
        self.c_major_mode = utils.LayoutEnum.from_tensor(mC)

        padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
        padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
        sA_layout = cute.make_layout(
            (self._bM, self._bK, self._num_stages),
            stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)),
        )
        sB_layout = cute.make_layout(
            (self._bN, self._bK, self._num_stages),
            stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
        )

        smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
            mB.element_type, sB_layout
        )

We furthermore initialise layout tA and tB which will be used for the Thread layout. We can see easily that both of them have a size of num_threads and therefore define a map with domain [0,Nthreads). We see that they are both major in the first dimension. We scale bM=bM4 to accommodate for vectorisation in case of 16 byte alignment. In what follows we assume alignment.

The vectorisation is also reflected in the Value layouts. By default a Layout is major in first dimension if no stride is provided, i.e. we have (4, 1) : (1, 0) as a layout for the values.

The copy atoms are needed for copy operation from GMEM to SMEM and are self explanatory in their arguments.

        if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
            num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
            atom_async_copy_A = cute.make_copy_atom(
                cute.nvgpu.cpasync.CopyG2SOp(),
                mA.element_type,
                num_bits_per_copy=mA.element_type.width * num_vectorized,
            )
            major_mode_size = self._bM // num_vectorized
            tA = cute.make_layout(
                (major_mode_size, self._num_threads // major_mode_size),
                stride=(1, major_mode_size),
            )
            vA = cute.make_layout((num_vectorized, 1))

        if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
            num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
            atom_async_copy_B = cute.make_copy_atom(
                cute.nvgpu.cpasync.CopyG2SOp(),
                mA.element_type,
                num_bits_per_copy=mB.element_type.width * num_vectorized,
            )
            major_mode_size = self._bN // num_vectorized
            tB = cute.make_layout(
                (major_mode_size, self._num_threads // major_mode_size),
                stride=(1, major_mode_size),
            )
            vB = cute.make_layout((num_vectorized, 1))

make_tiled_copy_tv is a convenience wrapper.

        tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA)
        tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB)

We can get a better understanding of it by printing out its content. We do that for tiled_copy_A. We see that it holds the the Tiler, TV Layout and Copy Atom. We will later use it for slicing our tensors for each thread and copying.

Tiled Copy
  Tiler MN:        (128:1,8:1)
  TV Layout tiled: (256,4):(4,1)
Copy Atom
  ThrID:           1:0
  TV Layout Src:   (1,4):(0,1)
  TV Layout Dst:   (1,4):(0,1)
  Value type:      f32

As we can read in the PTX docs MMA operation expects a specific layout. We obtain it here.

        atoms_layout = cute.make_layout(
            (self._num_threads // 16, 16, 1), stride=(16, 1, 0)
        )
        if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
            atoms_layout = cute.make_layout(
                (16, self._num_threads // 16, 1), stride=(1, 16, 0)
            )
        op = cute.nvgpu.MmaUniversalOp(cutlass.Float32)
        permutation_tiler_M = cute.make_layout(
            (atoms_layout.shape[0], 4), stride=(4, 1)
        )
        permutation_tiler_N = cute.make_layout(
            (atoms_layout.shape[1], 4), stride=(4, 1)
        )
        tiled_mma = cute.make_tiled_mma(
            op,
            atoms_layout,
            permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None),
        )

We'll than are ready to launch our kernel. Note that from above we can immediately see that the size S=Nthreads for our atoms_layout. We furthermore tile the C matrix using bM and bN along the corresponding dimensions to calculate the number of blocks in each dimension. Note that we ceil up, i.e. if bM doesn't divide M evenly we will round up.

        # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1)
        grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1

        self.kernel(
            mA,
            mB,
            mC,
            sA_layout,
            sB_layout,
            tiled_copy_A,
            tiled_copy_B,
            tiled_mma,
            epilogue_op,
        ).launch(
            grid=grid_dim,
            block=[cute.size(atoms_layout), 1, 1],
            smem=smem_size,
        )

kernel method

Here is where our main logic resides.

We annotate the kernel with cute.kernel. The arguments where explained above.

    @cute.kernel
    def kernel(
        self,
        mA: cute.Tensor,
        mB: cute.Tensor,
        mC: cute.Tensor,
        sA_layout: cute.Layout,
        sB_layout: cute.Layout,
        tiled_copy_A: cute.TiledCopy,
        tiled_copy_B: cute.TiledCopy,
        tiled_mma: cute.TiledMma,
        epilogue_op: cutlass.Constexpr = lambda x: x,
    ):

This is usual CUDA stuff. thr_mma slices obtains the correct thread slice of the tiled MMA.

        tidx, tidy, tidz = cute.arch.thread_idx()
        bidx, bidy, bidz = cute.arch.block_idx()
        tiler_coord = (bidx, bidy, None)
        thr_mma = tiled_mma.get_slice(tidx)

Note the g in front of the variables below. This indicates we are dealing with tensors residing in GMEM. Each pair of (bidx, bidy) will process a tile.

        gA = cute.local_tile(
            mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
        )
        gB = cute.local_tile(
            mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
        )
        gC = cute.local_tile(
            mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None)
        )

We can print the layouts out:

gA Layout: (128,8,8):(1,256,2048)
gB Layout: (128,8,8):(1,128,1024)
gC Layout: (128,128):(1,256)

We see that the shapes are (bM,bK,k), (bN,bK,k) and (bM,bN). Here k=KbK.

This code offsets the tensors. This is needed for the case that bK doesn't evenly divide K. Note that in the case that K divides bK the residue is simply 0.

        residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2]
        gA = cute.domain_offset((0, residue_k, 0), gA)
        gB = cute.domain_offset((0, residue_k, 0), gB)

We allocate the shared memory. We'll than get the slice for the current thread. We'll get

        smem = cutlass.utils.SmemAllocator()
        sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
        sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
        thr_copy_A = tiled_copy_A.get_slice(tidx)
        thr_copy_B = tiled_copy_B.get_slice(tidx)
        tAgA = thr_copy_A.partition_S(gA)
        tAsA = thr_copy_A.partition_D(sA)
        tBgB = thr_copy_B.partition_S(gB)
        tBsB = thr_copy_B.partition_D(sB)

We can print out the shapes to understand better:

tAgA Shape: ((4,1),1,1,8)
tAsA Shape: ((4,1),1,1,3)
tBgB Shape: ((4,1),1,1,8)
tBsB Shape: ((4,1),1,1,3)

We see here that we have (4, 1) in the first mode because of the vectorisation. The last mode corresponds to k and num_stages from above.

Note the CuTe notation: We have a g to indicate GMEM and s to indicate SMEM. We have _S to indicate Source and _D to indicate Destination.

Predication is needed to handle the case where the tiles don't evenly divide the corresponding dimension.

We can handle predication with the use of identity_tensor. An identity tensor in CuTe simply maps (x,y)(x,y). We'll than replicate the tiling and GMEM partitioning from above.

        mcA = cute.make_identity_tensor(mA.shape)
        mcB = cute.make_identity_tensor(mB.shape)
        cA = cute.local_tile(
            mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
        )
        cB = cute.local_tile(
            mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
        )
        cA = cute.domain_offset((0, residue_k, 0), cA)
        cB = cute.domain_offset((0, residue_k, 0), cB)
        # Repeat the partitioning with identity layouts
        tAcA = thr_copy_A.partition_S(cA)
        tBcB = thr_copy_B.partition_S(cB)

The below tensors will be the storage for our predication result. They will contain a 1 if the corresponding element should be copied and a 0 otherwise. I will not go too deeply into predication to keep the blog concise but maybe this can be handled in another blogpost.

        # Allocate predicate tensors for m and n
        tApA = cute.make_fragment(
            cute.make_layout(
                (
                    tAsA.shape[0][1],
                    cute.size(tAsA, mode=[1]),
                    cute.size(tAsA, mode=[2]),
                ),
                stride=(cute.size(tAsA, mode=[1]), 1, 0),
            ),
            cutlass.Boolean,
        )
        tBpB = cute.make_fragment(
            cute.make_layout(
                (
                    tBsB.shape[0][1],
                    cute.size(tBsB, mode=[1]),
                    cute.size(tBsB, mode=[2]),
                ),
                stride=(cute.size(tBsB, mode=[1]), 1, 0),
            ),
            cutlass.Boolean,
        )
        # Allocate predicate tensors for m, n and k for residue k-tile
        tApA_residue_k = cute.make_fragment(
            cute.make_layout(
                (
                    tAsA.shape[0][1],
                    cute.size(tAsA, mode=[1]),
                    cute.size(tAsA, mode=[2]),
                ),
                stride=(
                    cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]),
                    cute.size(tAsA, mode=[2]),
                    1,
                ),
            ),
            cutlass.Boolean,
        )
        tBpB_residue_k = cute.make_fragment(
            cute.make_layout(
                (
                    tBsB.shape[0][1],
                    cute.size(tBsB, mode=[1]),
                    cute.size(tBsB, mode=[2]),
                ),
                stride=(
                    cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]),
                    cute.size(tBsB, mode=[2]),
                    1,
                ),
            ),
            cutlass.Boolean,
        )

We perform predication here. From CUDA we know that not every thread should participate in a copy if the blocksize does not evenly divide the dimension of the problem and that is what is reflected here.

        # Set predicates for m/n bounds for mainloop
        for rest_v in range(tApA.shape[0]):
            for m in range(tApA.shape[1]):
                tApA[rest_v, m, 0] = cute.elem_less(
                    tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
                )
        for rest_v in range(tBpB.shape[0]):
            for n in range(tBpB.shape[1]):
                tBpB[rest_v, n, 0] = cute.elem_less(
                    tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
                )

        # Set predicates for m/n/k bounds for residue k tile
        for rest_v in range(tApA_residue_k.shape[0]):
            for m in range(tApA_residue_k.shape[1]):
                for k in range(tApA_residue_k.shape[2]):
                    coord_A = tAcA[(0, rest_v), m, k, 0]
                    tApA_residue_k[rest_v, m, k] = cute.elem_less(
                        (coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1])
                    )
        for rest_v in range(tBpB_residue_k.shape[0]):
            for n in range(tBpB_residue_k.shape[1]):
                for k in range(tBpB_residue_k.shape[2]):
                    coord_B = tBcB[(0, rest_v), n, k, 0]
                    tBpB_residue_k[rest_v, n, k] = cute.elem_less(
                        (coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1])
                    )

We issue the first asynchronous copy operation. We commit it. Note that this is done for the zeroth state because the last mode of tAsA and tBsB corresponds to the stages. After done we increase gmem_pipe_read which corresponds to the last mode of tAgA and tBgB which in return correspond to the number of tiles we divide the K dimension into by bK. Note that we use the predicators defined above to copy only the relevant elements if necessary.

        k_pipe_max = cute.size(tAsA, mode=[3])
        k_tile_count = cute.size(tAgA, mode=[3])
        gmem_pipe_read = cutlass.Int32(0)
        cute.copy(
            tiled_copy_A,
            tAgA[None, None, None, gmem_pipe_read],
            tAsA[None, None, None, 0],
            pred=tApA_residue_k,
        )
        cute.copy(
            tiled_copy_B,
            tBgB[None, None, None, gmem_pipe_read],
            tBsB[None, None, None, 0],
            pred=tBpB_residue_k,
        )
        cute.arch.cp_async_commit_group()
        gmem_pipe_read = (
            gmem_pipe_read + 1
            if gmem_pipe_read + 1 < k_tile_count
            else cutlass.Int32(0)
        )

In what follows we copy the remaining tiles from GMEM to SMEM.

        # Start async loads for 1st k-tile onwards, no k-residue handling needed
        for k_tile in range(1, k_pipe_max - 1):
            if k_tile < k_tile_count:
                cute.copy(
                    tiled_copy_A,
                    tAgA[None, None, None, gmem_pipe_read],
                    tAsA[None, None, None, k_tile],
                    pred=tApA,
                )
                cute.copy(
                    tiled_copy_B,
                    tBgB[None, None, None, gmem_pipe_read],
                    tBsB[None, None, None, k_tile],
                    pred=tBpB,
                )

            gmem_pipe_read = (
                gmem_pipe_read + 1
                if gmem_pipe_read + 1 < k_tile_count
                else cutlass.Int32(0)
            )
            cute.arch.cp_async_commit_group()

        # all tiles have been copied from global memory, so clear the
        # predicate tensor
        if k_tile_count < k_pipe_max:
            for rest_v in range(tApA.shape[0]):
                for m in range(tApA.shape[1]):
                    tApA[rest_v, m, 0] = cutlass.Boolean(0)
            for rest_v in range(tBpB.shape[0]):
                for n in range(tBpB.shape[1]):
                    tBpB[rest_v, n, 0] = cutlass.Boolean(0)

We'll define the partitioning for A, B and C needed for the MMA operation. We allocate the fragments for the first stage because the MMA operation expects it arguments to live in RMEM.

        # ///////////////////////////////////////////////////////////////////////////////
        # Define A/B partitioning and C accumulators.
        # ///////////////////////////////////////////////////////////////////////////////
        tCsA = thr_mma.partition_A(sA)
        tCsB = thr_mma.partition_B(sB)
        tCgC = thr_mma.partition_C(gC)
        tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
        tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
        tCrC = tiled_mma.make_fragment_C(tCgC)
        # Clear the accumulator
        tCrC.fill(0.0)

        # Current pipe index in smem to read from / write to
        smem_pipe_read = cutlass.Int32(0)
        smem_pipe_write = cutlass.Int32(k_pipe_max - 1)

        tCsA_p = tCsA[None, None, None, smem_pipe_read]
        tCsB_p = tCsB[None, None, None, smem_pipe_read]

cp_async_wait_groupwaits until k_pipe_max - 2 copy operations are pending. If this barrier is reached we continue.

        k_block_max = cute.size(tCrA, mode=[2])

        if k_block_max > 1:
            # Wait until our first prefetched tile is loaded in
            cute.arch.cp_async_wait_group(k_pipe_max - 2)
            cute.arch.barrier()
            # Prefetch the first rmem from the first k-tile
            cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0])
            cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0])

Note that by the docs:

@dsl_user_op
def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
    """
    Auto-vectorizing SIMT copy policy.

    Given a source and destination tensors that are statically shaped, this policy figures out the
    largest safe vector width that the copy instruction can take and performs the copy.
    """

Now comes our main loop.

        for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
            for k_block in range(k_block_max):
                if k_block == k_block_max - 1:
                    tCsA_p = tCsA[None, None, None, smem_pipe_read]
                    tCsB_p = tCsB[None, None, None, smem_pipe_read]
                    cute.arch.cp_async_wait_group(k_pipe_max - 2)
                    cute.arch.barrier()

                # Load A, B from shared memory to registers for k_block + 1
                k_block_next = (k_block + 1) % k_block_max  # static
                cute.autovec_copy(
                    tCsA_p[None, None, k_block_next],
                    tCrA[None, None, k_block_next],
                )
                cute.autovec_copy(
                    tCsB_p[None, None, k_block_next],
                    tCrB[None, None, k_block_next],
                )

                # Fetch next A: To better interleave global memory access and
                # compute instructions, we intentionally use the sequence:
                # copy A, perform GEMM, then copy B.
                if k_block == 0:
                    cute.copy(
                        tiled_copy_A,
                        tAgA[None, None, None, gmem_pipe_read],
                        tAsA[None, None, None, smem_pipe_write],
                        # Use predicates because the m-mode may be irregular
                        pred=tApA,
                    )

                # Thread-level register gemm for k_block
                cute.gemm(
                    tiled_mma,
                    tCrC,
                    tCrA[None, None, k_block],
                    tCrB[None, None, k_block],
                    tCrC,
                )

                # Fetch next B and update smem pipeline read/write
                if k_block == 0:
                    cute.copy(
                        tiled_copy_B,
                        tBgB[None, None, None, gmem_pipe_read],
                        tBsB[None, None, None, smem_pipe_write],
                        # Use predicates because the n-mode may be irregular
                        pred=tBpB,
                    )
                    cute.arch.cp_async_commit_group()
                    smem_pipe_write = smem_pipe_read
                    smem_pipe_read = smem_pipe_read + 1
                    if smem_pipe_read == k_pipe_max:
                        smem_pipe_read = cutlass.Int32(0)
                    # After copying all tiles, we avoid clearing the predicate
                    # tensor in the `mainloop` to prevent increasing its
                    # instruction count. Instead, we continue copying the
                    # first tile, though it won't be used. The 0-th tile is not
                    # copied due to its irregular shape, which could lead to
                    # illegal memory accesses.
                    gmem_pipe_read = (
                        gmem_pipe_read + 1
                        if gmem_pipe_read + 1 < k_tile_count
                        else cutlass.Int32(1)
                    )

Note that we process all tiles in stages.

If necessary we prefetch the next tile (i.e. at the end of current stage). We load the

                if k_block == k_block_max - 1:
                    tCsA_p = tCsA[None, None, None, smem_pipe_read]
                    tCsB_p = tCsB[None, None, None, smem_pipe_read]
                    cute.arch.cp_async_wait_group(k_pipe_max - 2)
                    cute.arch.barrier()

                # Load A, B from shared memory to registers for k_block + 1
                k_block_next = (k_block + 1) % k_block_max  # static
                cute.autovec_copy(
                    tCsA_p[None, None, k_block_next],
                    tCrA[None, None, k_block_next],
                )
                cute.autovec_copy(
                    tCsB_p[None, None, k_block_next],
                    tCrB[None, None, k_block_next],
                )

Here we perform a copy from GMEM to SMEM under predication. Remember that the last mode of tAgA corresponds to the current Tile and the last mode of tAsA corresponds to the current stage.

                if k_block == 0:
                    cute.copy(
                        tiled_copy_A,
                        tAgA[None, None, None, gmem_pipe_read],
                        tAsA[None, None, None, smem_pipe_write],
                        # Use predicates because the m-mode may be irregular
                        pred=tApA,
                    )

The gemm we want to compute is computed blockwise and accumulated into the dedicated register tCrC.

                # Thread-level register gemm for k_block
                cute.gemm(
                    tiled_mma,
                    tCrC,
                    tCrA[None, None, k_block],
                    tCrB[None, None, k_block],
                    tCrC,
                )

We fetch next B. We than update the smem_pipe_read and gmem_pipe_read parameters.

                # Fetch next B and update smem pipeline read/write
                if k_block == 0:
                    cute.copy(
                        tiled_copy_B,
                        tBgB[None, None, None, gmem_pipe_read],
                        tBsB[None, None, None, smem_pipe_write],
                        # Use predicates because the n-mode may be irregular
                        pred=tBpB,
                    )
                    cute.arch.cp_async_commit_group()
                    smem_pipe_write = smem_pipe_read
                    smem_pipe_read = smem_pipe_read + 1
                    if smem_pipe_read == k_pipe_max:
                        smem_pipe_read = cutlass.Int32(0)
                    # After copying all tiles, we avoid clearing the predicate
                    # tensor in the `mainloop` to prevent increasing its
                    # instruction count. Instead, we continue copying the
                    # first tile, though it won't be used. The 0-th tile is not
                    # copied due to its irregular shape, which could lead to
                    # illegal memory accesses.
                    gmem_pipe_read = (
                        gmem_pipe_read + 1
                        if gmem_pipe_read + 1 < k_tile_count
                        else cutlass.Int32(1)
                    )

We make sure nothing is pending and than store the result to tCrC. We copy from GMEM to RMEM under predication.

        cute.arch.cp_async_wait_group(0)
        cute.arch.barrier()
        tCrC.store(epilogue_op(tCrC.load()))

        # predicate
        cC = cute.make_identity_tensor(gC.shape)
        tCpC = thr_mma.partition_C(cC)
        predC = cute.make_fragment(tCrC.layout, cutlass.Boolean)
        residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx
        residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy
        for i in range(cute.size(tCrC.shape)):
            predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n))
        numIterM = cute.size(tCrC, mode=[1])
        numIterN = cute.size(tCrC, mode=[2])
        atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type)
        cute.copy(atom, tCrC, tCgC, pred=predC)
        return

Conclusion

CuTeDSL is very performant but not easy to understand. I hope this blogpost could shed some light on the SGemm example. Potentially this equips us to understand and design more difficult kernels in the future.