simons blog

CuTeDSL on Hopper - WGMMA and TMA intro

To write performant Kernels for Hopper in CuTeDSL we need the concepts of WGMMA and TMA.

This will be the first part of a multi-part Blog Series for CuTeDSL on Hopperthat analyses the dense_gemm.py example given in CuTeDSL examples. Here we don't analyse the Kernel. Instead we will focus on the parts of the kernel setup that are specific for Hopper, i.e. setup of WGMMA and TMA atoms that are essential to understand the operations used in the kernel.

The second part will explain the logic of the kernel. Note that for simplicity I will for now consider the case where we don't use Clustering as this is another concept that should be explained separately.

The example we refer to here can be found in the CUTLASS repo.

call

In this jitted method we prepare the setup for the Kernel. Below we will describe the most important steps.

MMA atom

Here we initialise the MMA atom which will be used to prepare our inputs for the WGMMA instruction.

        tiled_mma = sm90_utils.make_trivial_tiled_mma(
            self.a_dtype,
            self.b_dtype,
            self.a_layout.sm90_mma_major_mode(),
            self.b_layout.sm90_mma_major_mode(),
            self.acc_dtype,
            self.atom_layout_mnk,
            tiler_mn=(64, self.tile_shape_mnk[1]),
        )

We can print out the object afterwards.

{} Tiled MMA
  Thr Layout VMNK: (128,1,1,1):(1,0,0,0)
  Permutation MNK: (_,_,_)
MMA Atom
  ThrID:           128:1
  Shape MNK:       (64,128,16)
  TV Layout A:     (128,(64,16)):(0,(1,64))
  TV Layout B:     (128,(128,16)):(0,(1,128))
  TV Layout C:     ((4,8,4),(2,2,16)):((128,1,16),(64,8,512))

We see that we have 128 = 32 * 4 threads. That is because we use WGMMA instruction. WGMMA stands for warpgroup level matrix multiply and accumulate operation. A warp group consists of 4 warps, i.e. 32 threads.

A and B Layout are of similar form. Both have a Thread Layout of 128:0 which simply means the threads get broadcasted as they all will refer to the same location in memory. All 128 threads than have access to the whole Matrix tile.

C has a more interesting Thread Layout because C will be passed as register. It can be understood looking at the following picture:

Screenshot 2025-06-30 at 10

Note that we consider the elements to be layed out with stride of 1 in the first mode.

For the Thread Layout only the position of first entry for each thread (i.e. d0) is important. The rest can be inferred via the Value Layout.

We can group the Threads as follows:

 T0 -  T3
   ...
T28 - T31
----------
T32 - T35
   ...
T60 - T63
----------
   ...
----------
T96 - T99
   ...
T124 - T127
----------

Note that we only want to describe the d0 entries for each thread as explained above.

From here it follows that the Thread Layout is (4, 8, 4) : (128, 1, 16).

For the value layout we need to describe all the values d0, d1, ..., dZ, dW for a thread.

We see from above that for N = 128 we have 128 / 8 = 16 groups that always have the same pattern. Therefore we can describe the first one and from there obtain the remaining mode by simply calculating the corresponding stride to go from one group to another.

Obviously the shape for one group is (2, 2). To go from d0 to d1 we need to make 64 steps (one for each row). To go from d0 to d2 we need to make 8 steps, so the layout is (2, 2) : (64, 8). The groups are as described above 16 and the stride is 8 * 64 = 512 because we need to traverse 8 rows to get from d0 to d4.

Therefore the Value Layout is (2, 2, 16) : (64, 8, 512).

Thanks to my friend Zining Zhang for the helpful discussion on WGMMA Layouts.

TMA atom

TMA is a unit on Hopper GPUs to perform fast memory transfer from GMEM to SMEM.

We can create the TMA atoms as follows:

        tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
            a,
            self.a_smem_layout_staged,
            (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
            self.cluster_shape_mnk[1],
        )

        tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
            b,
            self.b_smem_layout_staged,
            (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
            self.cluster_shape_mnk[0],
        )

        tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
            c,
            self.epi_smem_layout_staged,
            self.epi_tile,
        )

Note that a_smem_layout and b_smem_layout where created before. They look as follows:

A SMEM Layout Staged:
{} S<3,4,3> o 0 o ((8,16),(64,1),(1,7)):((64,512),(1,0),(0,8192))
B SMEM Layout Staged:
{} S<3,4,3> o 0 o ((8,16),(64,1),(1,7)):((64,512),(1,0),(0,8192))

S<3,4,3> is a Swizzling Layout. Swizzling is a technique to avoid bank conflicts and I wrote a blogpost on this topic which you may find interesting. From the PTX docs we can find that S<3,4,3> is the canonical 128B swizzling.

S<3,4,3> o 0 o L means the following: Apply the inner transformation S<3,4,3> to the outer layout L with an offset of 0 as we can read in make_composed_layout inside cutlass.cute.core.

The Staged Layouts are calculated as follows:

In a first step we obtain a Layout Atom:

a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
            sm90_utils.get_smem_layout_atom(
                a_layout,
                a_dtype,
                a_major_mode_size,
            ),
            a_dtype,
        )
A SMEM Layout Atom:
{} S<3,4,3> o 0 o (8,64):(64,1)

The details can be found in cute.nvgpu.warpgroup.make_smem_layout_atom. For K-Major we do something like the below shortened code:

    elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
        num_contiguous_bits = 1024
        sw = core.make_swizzle(3, 4, 3)
	    ...
	    num_contiguous_elems = num_contiguous_bits // element_type.width
	    ...
		return core.make_composed_layout(
            sw,
            0,
            core.make_layout(
                (8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
            ),
            loc=loc,
            ip=ip,
        )

Again I refer to the PTX docs for more details.

The staged Layout is than calculated via tile_to_shape which repeats the SMEM Layout atom to tile the whole tensor shape. Here the Tensor shape is the tiler shape appended with the number of stages.

        a_smem_layout_staged = cute.tile_to_shape(
            a_smem_layout_atom,
            cute.append(a_smem_shape, ab_stage),
            order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
        )

The TMA atoms are than generated via

        tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
            a,
            self.a_smem_layout_staged,
            (self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
            self.cluster_shape_mnk[1],
        )

        tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
            b,
            self.b_smem_layout_staged,
            (self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
            self.cluster_shape_mnk[0],
        )

        tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
            c,
            self.epi_smem_layout_staged,
            self.epi_tile,
        )

Note that for A and B we will copy from GMEM -> SMEM and for C we will copy from SMEM -> GMEM.

We can checkout _make_tma_atoms_and_tensors to get a better understanding:

    @staticmethod
    def _make_tma_atoms_and_tensors(
        tensor: cute.Tensor,
        smem_layout_staged: cute.ComposedLayout,
        smem_tile: tuple[int, int],
        mcast_dim: int,
    ) -> tuple[cute.CopyAtom, cute.Tensor]:
        """Create TMA atoms and tensors for input tensors.

        :param tensor: Input tensor (A or B)
        :type tensor: cute.Tensor
        :param smem_layout_staged: Shared memory layout for the tensor
        :type smem_layout_staged: cute.ComposedLayout
        :param smem_tile: Shared memory tile shape
        :type smem_tile: Tuple[int, int]
        :param mcast_dim: Multicast dimension
        :type mcast_dim: int

        :return: TMA atom and tensor
        :rtype: Tuple[cute.CopyAtom, cute.Tensor]
        """
        op = (
            cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
            if mcast_dim == 1
            else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
        )

        smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
        tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
            op,
            tensor,
            smem_layout,
            smem_tile,
            num_multicast=mcast_dim,
        )
        return tma_atom, tma_tensor

We see that op determines the "direction" we copy in. For A and B as we see above it is G2S, i.e. GMEM to SMEM.

By the docs make_tma_tile_atom figures out the bulk tensor asynchronous copy instruction to use with the maximum "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided layout and consistent with the provided Tiler.

For C we will do a SMEM -> GMEM transfer via TMA during the epilogue.

The corresponding atom is calculated as follows:

    @staticmethod
    def _make_tma_store_atoms_and_tensors(
        tensor_c: cute.Tensor,
        epi_smem_layout_staged: cute.ComposedLayout,
        epi_tile: tuple[int, int],
    ) -> tuple[cute.CopyAtom, cute.Tensor]:
        """Create TMA atoms and tensors for C tensor storage.

        :param tensor_c: Output tensor C
        :type tensor_c: cute.Tensor
        :param epi_smem_layout_staged: Shared memory layout for epilogue
        :type epi_smem_layout_staged: cute.ComposedLayout
        :param epi_tile: Epilogue tile shape
        :type epi_tile: Tuple[int, int]

        :return: TMA atom and tensor for C
        :rtype: Tuple[cute.CopyAtom, cute.Tensor]
        """
        epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
        c_cta_v_layout = cute.composition(
            cute.make_identity_layout(tensor_c.shape), epi_tile
        )
        tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom(
            cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
            tensor_c,
            epi_smem_layout,
            c_cta_v_layout,
        )

        return tma_atom_c, tma_tensor_c

The process is similar. Note that here we have cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp to copy SMEM -> GMEM.

SharedStorage

In CuTeDSL we can use SharedStorage to allocate and handle shared memory.

        @cute.struct
        class SharedStorage:
            mainloop_pipeline_array_ptr: cute.struct.MemRange[
                cutlass.Int64, self.ab_stage * 2
            ]
            sa: cute.struct.Align[
                cute.struct.MemRange[
                    self.a_dtype, cute.cosize(self.a_smem_layout_staged)
                ],
                self.buffer_align_bytes,
            ]
            sb: cute.struct.Align[
                cute.struct.MemRange[
                    self.b_dtype, cute.cosize(self.b_smem_layout_staged)
                ],
                self.buffer_align_bytes,
            ]

Process in Batches

Note that we process a batch of GEMMs in the kernel. The blocks in z direction index into the batch dimension, so each block plane performs an ordinary GEMM.

Conclusion

I hope this kernel made WGMMA and TMA and their usage in CuTeDSL more accesible. You may want to checkout the corresponding blogposts I wrote in the past to get a different perspective:

Also I recommend to checkout the PTX docs mentioned above and other CuTeDSL blogposts I wrote.