CuTeDSL on Hopper - WGMMA and TMA intro
To write performant Kernels for Hopper
in CuTeDSL
we need the concepts of WGMMA
and TMA
.
This will be the first part of a multi-part Blog Series for CuTeDSL
on Hopper
that analyses the dense_gemm.py
example given in CuTeDSL
examples. Here we don't analyse the Kernel. Instead we will focus on the parts of the kernel setup that are specific for Hopper, i.e. setup of WGMMA
and TMA
atoms that are essential to understand the operations used in the kernel.
The second part will explain the logic of the kernel. Note that for simplicity I will for now consider the case where we don't use Clustering
as this is another concept that should be explained separately.
The example we refer to here can be found in the CUTLASS repo.
call
In this jitted method we prepare the setup for the Kernel. Below we will describe the most important steps.
MMA atom
Here we initialise the MMA atom
which will be used to prepare our inputs for the WGMMA
instruction.
tiled_mma = sm90_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_layout.sm90_mma_major_mode(),
self.b_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.tile_shape_mnk[1]),
)
We can print out the object afterwards.
{} Tiled MMA
Thr Layout VMNK: (128,1,1,1):(1,0,0,0)
Permutation MNK: (_,_,_)
MMA Atom
ThrID: 128:1
Shape MNK: (64,128,16)
TV Layout A: (128,(64,16)):(0,(1,64))
TV Layout B: (128,(128,16)):(0,(1,128))
TV Layout C: ((4,8,4),(2,2,16)):((128,1,16),(64,8,512))
We see that we have 128 = 32 * 4
threads. That is because we use WGMMA
instruction. WGMMA
stands for warpgroup level matrix multiply and accumulate operation
. A warp group
consists of 4 warps
, i.e. 32 threads
.
A
and B
Layout
are of similar form. Both have a Thread Layout
of 128:0
which simply means the threads get broadcasted as they all will refer to the same location in memory. All 128
threads than have access to the whole Matrix tile
.
C
has a more interesting Thread Layout
because C
will be passed as register. It can be understood looking at the following picture:
Note that we consider the elements to be layed out with stride of 1
in the first mode.
For the Thread Layout
only the position of first entry for each thread (i.e. d0
) is important. The rest can be inferred via the Value Layout
.
We can group the Threads as follows:
T0 - T3
...
T28 - T31
----------
T32 - T35
...
T60 - T63
----------
...
----------
T96 - T99
...
T124 - T127
----------
Note that we only want to describe the d0
entries for each thread as explained above.
- To go from
T0, d0 -> T1, d0
we need to traverse all the rows two times, i.e. the corresponding stride is128
. To the right we need to go four steps, i.e. the corresponding mode is(4, 128)
- To go from
T0, d0 -> T4, d0
we need to traverse one row, i.e. the corresponding stride is1
. To down we need to go eight steps, i.e. the corresponding mode is(8, 1)
. - To go from
T0, d0 -> T32, d0
we need to traverse sixteen rows, i.e. the corresponding stride is16
. There are four colours, so the corresponding mode is(4, 16)
.
From here it follows that the Thread Layout
is (4, 8, 4) : (128, 1, 16)
.
For the value layout we need to describe all the values d0, d1, ..., dZ, dW
for a thread.
We see from above that for N = 128
we have 128 / 8 = 16
groups that always have the same pattern. Therefore we can describe the first one and from there obtain the remaining mode by simply calculating the corresponding stride to go from one group to another.
Obviously the shape for one group is (2, 2)
. To go from d0
to d1
we need to make 64
steps (one for each row). To go from d0
to d2
we need to make 8
steps, so the layout is (2, 2) : (64, 8)
. The groups are as described above 16
and the stride is 8 * 64 = 512
because we need to traverse 8 rows to get from d0
to d4
.
Therefore the Value Layout
is (2, 2, 16) : (64, 8, 512)
.
Thanks to my friend Zining Zhang for the helpful discussion on WGMMA Layouts
.
TMA atom
TMA
is a unit on Hopper
GPUs to perform fast memory transfer from GMEM
to SMEM
.
We can create the TMA
atoms as follows:
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
a,
self.a_smem_layout_staged,
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
b,
self.b_smem_layout_staged,
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[0],
)
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
c,
self.epi_smem_layout_staged,
self.epi_tile,
)
Note that a_smem_layout
and b_smem_layout
where created before. They look as follows:
A SMEM Layout Staged:
{} S<3,4,3> o 0 o ((8,16),(64,1),(1,7)):((64,512),(1,0),(0,8192))
B SMEM Layout Staged:
{} S<3,4,3> o 0 o ((8,16),(64,1),(1,7)):((64,512),(1,0),(0,8192))
S<3,4,3>
is a Swizzling Layout
. Swizzling
is a technique to avoid bank conflicts and I wrote a blogpost on this topic which you may find interesting. From the PTX docs we can find that S<3,4,3>
is the canonical 128B swizzling
.
S<3,4,3> o 0 o L
means the following: Apply the inner transformation S<3,4,3>
to the outer layout L
with an offset of 0
as we can read in make_composed_layout
inside cutlass.cute.core
.
The Staged Layouts are calculated as follows:
In a first step we obtain a Layout Atom:
a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
sm90_utils.get_smem_layout_atom(
a_layout,
a_dtype,
a_major_mode_size,
),
a_dtype,
)
A SMEM Layout Atom:
{} S<3,4,3> o 0 o (8,64):(64,1)
The details can be found in cute.nvgpu.warpgroup.make_smem_layout_atom
. For K-Major
we do something like the below shortened code:
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
num_contiguous_bits = 1024
sw = core.make_swizzle(3, 4, 3)
...
num_contiguous_elems = num_contiguous_bits // element_type.width
...
return core.make_composed_layout(
sw,
0,
core.make_layout(
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
),
loc=loc,
ip=ip,
)
Again I refer to the PTX docs for more details.
The staged Layout is than calculated via tile_to_shape
which repeats the SMEM Layout atom to tile the whole tensor shape. Here the Tensor shape is the tiler shape appended with the number of stages.
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, ab_stage),
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)
The TMA
atoms are than generated via
tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
a,
self.a_smem_layout_staged,
(self.tile_shape_mnk[0], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
b,
self.b_smem_layout_staged,
(self.tile_shape_mnk[1], self.tile_shape_mnk[2]),
self.cluster_shape_mnk[0],
)
tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors(
c,
self.epi_smem_layout_staged,
self.epi_tile,
)
Note that for A
and B
we will copy from GMEM -> SMEM
and for C
we will copy from SMEM -> GMEM
.
We can checkout _make_tma_atoms_and_tensors
to get a better understanding:
@staticmethod
def _make_tma_atoms_and_tensors(
tensor: cute.Tensor,
smem_layout_staged: cute.ComposedLayout,
smem_tile: tuple[int, int],
mcast_dim: int,
) -> tuple[cute.CopyAtom, cute.Tensor]:
"""Create TMA atoms and tensors for input tensors.
:param tensor: Input tensor (A or B)
:type tensor: cute.Tensor
:param smem_layout_staged: Shared memory layout for the tensor
:type smem_layout_staged: cute.ComposedLayout
:param smem_tile: Shared memory tile shape
:type smem_tile: Tuple[int, int]
:param mcast_dim: Multicast dimension
:type mcast_dim: int
:return: TMA atom and tensor
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
"""
op = (
cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
if mcast_dim == 1
else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
)
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
op,
tensor,
smem_layout,
smem_tile,
num_multicast=mcast_dim,
)
return tma_atom, tma_tensor
We see that op
determines the "direction" we copy in. For A
and B
as we see above it is G2S
, i.e. GMEM
to SMEM
.
By the docs make_tma_tile_atom
figures out the bulk tensor asynchronous copy instruction to use with the maximum "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided layout and consistent with the provided Tiler.
For C
we will do a SMEM -> GMEM
transfer via TMA during the epilogue.
The corresponding atom is calculated as follows:
@staticmethod
def _make_tma_store_atoms_and_tensors(
tensor_c: cute.Tensor,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: tuple[int, int],
) -> tuple[cute.CopyAtom, cute.Tensor]:
"""Create TMA atoms and tensors for C tensor storage.
:param tensor_c: Output tensor C
:type tensor_c: cute.Tensor
:param epi_smem_layout_staged: Shared memory layout for epilogue
:type epi_smem_layout_staged: cute.ComposedLayout
:param epi_tile: Epilogue tile shape
:type epi_tile: Tuple[int, int]
:return: TMA atom and tensor for C
:rtype: Tuple[cute.CopyAtom, cute.Tensor]
"""
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
c_cta_v_layout = cute.composition(
cute.make_identity_layout(tensor_c.shape), epi_tile
)
tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
tensor_c,
epi_smem_layout,
c_cta_v_layout,
)
return tma_atom_c, tma_tensor_c
The process is similar. Note that here we have cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp
to copy SMEM -> GMEM
.
SharedStorage
In CuTeDSL
we can use SharedStorage
to allocate and handle shared memory.
@cute.struct
class SharedStorage:
mainloop_pipeline_array_ptr: cute.struct.MemRange[
cutlass.Int64, self.ab_stage * 2
]
sa: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged)
],
self.buffer_align_bytes,
]
sb: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged)
],
self.buffer_align_bytes,
]
Process in Batches
Note that we process a batch
of GEMMs
in the kernel. The blocks in z
direction index into the batch dimension, so each block plane performs an ordinary GEMM
.
Conclusion
I hope this kernel made WGMMA
and TMA
and their usage in CuTeDSL
more accesible. You may want to checkout the corresponding blogposts I wrote in the past to get a different perspective:
Also I recommend to checkout the PTX
docs mentioned above and other CuTeDSL
blogposts I wrote.