SGEMM in CuTeDSL
SGEMM is one of the fundamental operations we aim to optimise on GPUs.
In this blogpost I will explain the corresponding example from the CUTLASS repo.
I chose SGEMM because it is the most simple non trivial example given in the examples and therefore a good starting point to learn about CuTeDSL.
We will analyse the code in a Top Down approach, i.e. work gradually down in complexity.
The whole blogpost explains the case that we obtain when executing the program with the following arguments that I explain below.:
--mnk 256,128,64 --a_major m --b_major n --c_major m
main function
This is where the problem is set up.
The main function has following signature
def main(
a_major: str,
b_major: str,
c_major: str,
problem_shape: Tuple[int, int, int],
warmup_iterations: int = 2,
iterations: int = 100,
skip_ref_check: bool = False,
):
By command like we passed the majorness for the three matrices of interest.
We will consider the case
A - M-Major
B - N-Major
C - M-Major
That simply means that the stride along M mode for A and C and N mode for B is considered to be 1.
The problem_shape is passed as (M, N, K) = (256, 128, 64).
This defines the layouts of A, B and C as follows.
The layouts in CuTe define the way our matrix elements are layed out in memory.
To understand this concept deeper let's consider M = 2, K = 3.
For example we see for M-Major that the matrix coordinate mapped to the physical coordinate . Physical coordinates are 1 dimensional because memory is a one dimensional concept.
M-Major
(2,3):(1,2)
0 2 4
1 3 5
K-Major
(2,3):(3,1)
0 1 2
3 4 5
The below code simply initialises the tensors with appropriate shape and stride.
torch.manual_seed(1024)
M, N, K = problem_shape
# Create and permute tensor A/B/C
def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype):
# is_mode0_major: (mode1, mode0) -> (mode0, mode1)
# else: (mode0, mode1) -> (mode0, mode1)
shape = (mode1, mode0) if is_mode0_major else (mode0, mode1)
permute_order = (1, 0) if is_mode0_major else (0, 1)
return (
torch.empty(*shape, dtype=torch.int32)
.random_(-5, 5)
.to(dtype=dtype)
.permute(permute_order)
.cuda()
)
a = create_and_permute_tensor(M, K, a_major == "m", torch.float32)
b = create_and_permute_tensor(N, K, b_major == "n", torch.float32)
c = create_and_permute_tensor(M, N, c_major == "m", torch.float32)
divisibility_a = a.shape[1] if a_major == "k" else a.shape[0]
divisibility_b = b.shape[1] if b_major == "k" else b.shape[0]
divisibility_c = c.shape[1] if c_major == "n" else c.shape[0]
a_tensor = (
from_dlpack(a, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if a_major == "k" else 0),
divisibility=divisibility_a,
)
)
b_tensor = (
from_dlpack(b, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0))
.mark_compact_shape_dynamic(
mode=(1 if b_major == "k" else 0),
divisibility=divisibility_b,
)
)
c_tensor = (
from_dlpack(c, assumed_align=16)
.mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0))
.mark_compact_shape_dynamic(
mode=(1 if c_major == "n" else 0),
divisibility=divisibility_c,
)
)
In our case that means
mA Layout: (256,64):(1,256)
mB Layout: (256,64):(1,256)
mC Layout: (256,64):(1,256)
After that we simply compile our kernel:
sgemm = SGemm()
print("Compiling kernel with cute.compile ...")
start_time = time.time()
gemm = cute.compile(sgemm, a_tensor, b_tensor, c_tensor)
compilation_time = time.time() - start_time
print(f"Compilation time: {compilation_time:.4f} seconds")
The kernel can than be called like so:
gemm(a_tensor, b_tensor, c_tensor)
SGemm class
The SGemm class has three methods: init, call and kernel.
init method
The below code initialises
cta_tiler: Consists of , and each of these is used to tile along the corresponding dimension.- We use multiple stages to overlap memory transfer and compute.
GPUsare fundamentally memory bound so we want to keep the tensor cores busy while transferring memory fromGMEM -> SMEMorGMEM -> RMEMorSMEM -> RMEM.num_stagesdefines the number of these stages. num_threadsis clear.
class SGemm:
def __init__(
self,
cta_tiler: Tuple[int, int, int] = (128, 128, 8),
num_stages: int = 3,
num_threads: int = 256,
):
self._cta_tiler = cta_tiler
self._num_stages = num_stages
self._num_threads = num_threads
assert num_threads > 0, "needs at least one thread"
assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout"
self._bM, self._bN, self._bK = self._cta_tiler
assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M"
assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N"
assert self._num_stages >= 3, "num_stages must be greater than or equal to 3"
call method
The call method is the analog to the host code in a CUDA program. Here we will prepare all the parameters which will be later on handed down to the kernel.
Note that the __call__ method is jitted. It takes the tensors we prepared above as well as an epilogue_op. The epilogue_op is the identity by default and can be used to modify our result after the main computation is over. For example we could set lambda x: 2 * x to multiply the elements of the matrix multiplication by 2 after computing the GEMM.
@cute.jit
def __call__(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mC: cute.Tensor,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
We determine the major modes. utils.LayoutEnum.from_tensor will return utils.LayoutEnum.ROW_MAJOR or utils.LayoutEnum.COL_MAJOR major. ROW_MAJOR would mean that the stride along the second dimension is 1 and vice versa.
We'll than initialise the SMEM layouts. If we deal with row major tensors in A or B we perform padding a common technique (alternative to swizzling) to reduce memory bank conflicts. The concept of memory bank conflicts is explained here.
We furthermore determine the size of SMEM we need to allocate. The maximum size is determined by the physical properties by the chip we deal with.
self.a_major_mode = utils.LayoutEnum.from_tensor(mA)
self.b_major_mode = utils.LayoutEnum.from_tensor(mB)
self.c_major_mode = utils.LayoutEnum.from_tensor(mC)
padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0
sA_layout = cute.make_layout(
(self._bM, self._bK, self._num_stages),
stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)),
)
sB_layout = cute.make_layout(
(self._bN, self._bK, self._num_stages),
stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)),
)
smem_size = cute.size_in_bytes(mA.element_type, sA_layout) + cute.size_in_bytes(
mB.element_type, sB_layout
)
We furthermore initialise layout tA and tB which will be used for the Thread layout. We can see easily that both of them have a size of num_threads and therefore define a map with domain . We see that they are both major in the first dimension. We scale to accommodate for vectorisation in case of 16 byte alignment. In what follows we assume alignment.
The vectorisation is also reflected in the Value layouts. By default a Layout is major in first dimension if no stride is provided, i.e. we have (4, 1) : (1, 0) as a layout for the values.
The copy atoms are needed for copy operation from GMEM to SMEM and are self explanatory in their arguments.
if self.a_major_mode == utils.LayoutEnum.COL_MAJOR:
num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1
atom_async_copy_A = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
mA.element_type,
num_bits_per_copy=mA.element_type.width * num_vectorized,
)
major_mode_size = self._bM // num_vectorized
tA = cute.make_layout(
(major_mode_size, self._num_threads // major_mode_size),
stride=(1, major_mode_size),
)
vA = cute.make_layout((num_vectorized, 1))
if self.b_major_mode == utils.LayoutEnum.COL_MAJOR:
num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1
atom_async_copy_B = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
mA.element_type,
num_bits_per_copy=mB.element_type.width * num_vectorized,
)
major_mode_size = self._bN // num_vectorized
tB = cute.make_layout(
(major_mode_size, self._num_threads // major_mode_size),
stride=(1, major_mode_size),
)
vB = cute.make_layout((num_vectorized, 1))
make_tiled_copy_tv is a convenience wrapper.
tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA)
tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB)
We can get a better understanding of it by printing out its content. We do that for tiled_copy_A. We see that it holds the the Tiler, TV Layout and Copy Atom. We will later use it for slicing our tensors for each thread and copying.
Tiled Copy
Tiler MN: (128:1,8:1)
TV Layout tiled: (256,4):(4,1)
Copy Atom
ThrID: 1:0
TV Layout Src: (1,4):(0,1)
TV Layout Dst: (1,4):(0,1)
Value type: f32
As we can read in the PTX docs MMA operation expects a specific layout. We obtain it here.
atoms_layout = cute.make_layout(
(self._num_threads // 16, 16, 1), stride=(16, 1, 0)
)
if self.c_major_mode == utils.LayoutEnum.COL_MAJOR:
atoms_layout = cute.make_layout(
(16, self._num_threads // 16, 1), stride=(1, 16, 0)
)
op = cute.nvgpu.MmaUniversalOp(cutlass.Float32)
permutation_tiler_M = cute.make_layout(
(atoms_layout.shape[0], 4), stride=(4, 1)
)
permutation_tiler_N = cute.make_layout(
(atoms_layout.shape[1], 4), stride=(4, 1)
)
tiled_mma = cute.make_tiled_mma(
op,
atoms_layout,
permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None),
)
We'll than are ready to launch our kernel. Note that from above we can immediately see that the size for our atoms_layout. We furthermore tile the C matrix using and along the corresponding dimensions to calculate the number of blocks in each dimension. Note that we ceil up, i.e. if doesn't divide evenly we will round up.
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1)
grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1
self.kernel(
mA,
mB,
mC,
sA_layout,
sB_layout,
tiled_copy_A,
tiled_copy_B,
tiled_mma,
epilogue_op,
).launch(
grid=grid_dim,
block=[cute.size(atoms_layout), 1, 1],
smem=smem_size,
)
kernel method
Here is where our main logic resides.
We annotate the kernel with cute.kernel. The arguments where explained above.
@cute.kernel
def kernel(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mC: cute.Tensor,
sA_layout: cute.Layout,
sB_layout: cute.Layout,
tiled_copy_A: cute.TiledCopy,
tiled_copy_B: cute.TiledCopy,
tiled_mma: cute.TiledMma,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
This is usual CUDA stuff. thr_mma slices obtains the correct thread slice of the tiled MMA.
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
tiler_coord = (bidx, bidy, None)
thr_mma = tiled_mma.get_slice(tidx)
Note the g in front of the variables below. This indicates we are dealing with tensors residing in GMEM. Each pair of (bidx, bidy) will process a tile.
gA = cute.local_tile(
mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
)
gB = cute.local_tile(
mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
)
gC = cute.local_tile(
mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None)
)
We can print the layouts out:
gA Layout: (128,8,8):(1,256,2048)
gB Layout: (128,8,8):(1,128,1024)
gC Layout: (128,128):(1,256)
We see that the shapes are , and . Here .
This code offsets the tensors. This is needed for the case that doesn't evenly divide . Note that in the case that divides the residue is simply .
residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2]
gA = cute.domain_offset((0, residue_k, 0), gA)
gB = cute.domain_offset((0, residue_k, 0), gB)
We allocate the shared memory. We'll than get the slice for the current thread. We'll get
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
tAgA = thr_copy_A.partition_S(gA)
tAsA = thr_copy_A.partition_D(sA)
tBgB = thr_copy_B.partition_S(gB)
tBsB = thr_copy_B.partition_D(sB)
We can print out the shapes to understand better:
tAgA Shape: ((4,1),1,1,8)
tAsA Shape: ((4,1),1,1,3)
tBgB Shape: ((4,1),1,1,8)
tBsB Shape: ((4,1),1,1,3)
We see here that we have (4, 1) in the first mode because of the vectorisation. The last mode corresponds to k and num_stages from above.
Note the CuTe notation: We have a g to indicate GMEM and s to indicate SMEM. We have _S to indicate Source and _D to indicate Destination.
Predication is needed to handle the case where the tiles don't evenly divide the corresponding dimension.
We can handle predication with the use of identity_tensor. An identity tensor in CuTe simply maps . We'll than replicate the tiling and GMEM partitioning from above.
mcA = cute.make_identity_tensor(mA.shape)
mcB = cute.make_identity_tensor(mB.shape)
cA = cute.local_tile(
mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1)
)
cB = cute.local_tile(
mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1)
)
cA = cute.domain_offset((0, residue_k, 0), cA)
cB = cute.domain_offset((0, residue_k, 0), cB)
# Repeat the partitioning with identity layouts
tAcA = thr_copy_A.partition_S(cA)
tBcB = thr_copy_B.partition_S(cB)
The below tensors will be the storage for our predication result. They will contain a 1 if the corresponding element should be copied and a 0 otherwise. I will not go too deeply into predication to keep the blog concise but maybe this can be handled in another blogpost.
# Allocate predicate tensors for m and n
tApA = cute.make_fragment(
cute.make_layout(
(
tAsA.shape[0][1],
cute.size(tAsA, mode=[1]),
cute.size(tAsA, mode=[2]),
),
stride=(cute.size(tAsA, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
tBpB = cute.make_fragment(
cute.make_layout(
(
tBsB.shape[0][1],
cute.size(tBsB, mode=[1]),
cute.size(tBsB, mode=[2]),
),
stride=(cute.size(tBsB, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
# Allocate predicate tensors for m, n and k for residue k-tile
tApA_residue_k = cute.make_fragment(
cute.make_layout(
(
tAsA.shape[0][1],
cute.size(tAsA, mode=[1]),
cute.size(tAsA, mode=[2]),
),
stride=(
cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]),
cute.size(tAsA, mode=[2]),
1,
),
),
cutlass.Boolean,
)
tBpB_residue_k = cute.make_fragment(
cute.make_layout(
(
tBsB.shape[0][1],
cute.size(tBsB, mode=[1]),
cute.size(tBsB, mode=[2]),
),
stride=(
cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]),
cute.size(tBsB, mode=[2]),
1,
),
),
cutlass.Boolean,
)
We perform predication here. From CUDA we know that not every thread should participate in a copy if the blocksize does not evenly divide the dimension of the problem and that is what is reflected here.
# Set predicates for m/n bounds for mainloop
for rest_v in range(tApA.shape[0]):
for m in range(tApA.shape[1]):
tApA[rest_v, m, 0] = cute.elem_less(
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
)
for rest_v in range(tBpB.shape[0]):
for n in range(tBpB.shape[1]):
tBpB[rest_v, n, 0] = cute.elem_less(
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
)
# Set predicates for m/n/k bounds for residue k tile
for rest_v in range(tApA_residue_k.shape[0]):
for m in range(tApA_residue_k.shape[1]):
for k in range(tApA_residue_k.shape[2]):
coord_A = tAcA[(0, rest_v), m, k, 0]
tApA_residue_k[rest_v, m, k] = cute.elem_less(
(coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1])
)
for rest_v in range(tBpB_residue_k.shape[0]):
for n in range(tBpB_residue_k.shape[1]):
for k in range(tBpB_residue_k.shape[2]):
coord_B = tBcB[(0, rest_v), n, k, 0]
tBpB_residue_k[rest_v, n, k] = cute.elem_less(
(coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1])
)
We issue the first asynchronous copy operation. We commit it. Note that this is done for the zeroth state because the last mode of tAsA and tBsB corresponds to the stages. After done we increase gmem_pipe_read which corresponds to the last mode of tAgA and tBgB which in return correspond to the number of tiles we divide the K dimension into by bK. Note that we use the predicators defined above to copy only the relevant elements if necessary.
k_pipe_max = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
gmem_pipe_read = cutlass.Int32(0)
cute.copy(
tiled_copy_A,
tAgA[None, None, None, gmem_pipe_read],
tAsA[None, None, None, 0],
pred=tApA_residue_k,
)
cute.copy(
tiled_copy_B,
tBgB[None, None, None, gmem_pipe_read],
tBsB[None, None, None, 0],
pred=tBpB_residue_k,
)
cute.arch.cp_async_commit_group()
gmem_pipe_read = (
gmem_pipe_read + 1
if gmem_pipe_read + 1 < k_tile_count
else cutlass.Int32(0)
)
In what follows we copy the remaining tiles from GMEM to SMEM.
# Start async loads for 1st k-tile onwards, no k-residue handling needed
for k_tile in range(1, k_pipe_max - 1):
if k_tile < k_tile_count:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, gmem_pipe_read],
tAsA[None, None, None, k_tile],
pred=tApA,
)
cute.copy(
tiled_copy_B,
tBgB[None, None, None, gmem_pipe_read],
tBsB[None, None, None, k_tile],
pred=tBpB,
)
gmem_pipe_read = (
gmem_pipe_read + 1
if gmem_pipe_read + 1 < k_tile_count
else cutlass.Int32(0)
)
cute.arch.cp_async_commit_group()
# all tiles have been copied from global memory, so clear the
# predicate tensor
if k_tile_count < k_pipe_max:
for rest_v in range(tApA.shape[0]):
for m in range(tApA.shape[1]):
tApA[rest_v, m, 0] = cutlass.Boolean(0)
for rest_v in range(tBpB.shape[0]):
for n in range(tBpB.shape[1]):
tBpB[rest_v, n, 0] = cutlass.Boolean(0)
We'll define the partitioning for A, B and C needed for the MMA operation. We allocate the fragments for the first stage because the MMA operation expects it arguments to live in RMEM.
# ///////////////////////////////////////////////////////////////////////////////
# Define A/B partitioning and C accumulators.
# ///////////////////////////////////////////////////////////////////////////////
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCgC = thr_mma.partition_C(gC)
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
tCrC = tiled_mma.make_fragment_C(tCgC)
# Clear the accumulator
tCrC.fill(0.0)
# Current pipe index in smem to read from / write to
smem_pipe_read = cutlass.Int32(0)
smem_pipe_write = cutlass.Int32(k_pipe_max - 1)
tCsA_p = tCsA[None, None, None, smem_pipe_read]
tCsB_p = tCsB[None, None, None, smem_pipe_read]
cp_async_wait_groupwaits until k_pipe_max - 2 copy operations are pending. If this barrier is reached we continue.
k_block_max = cute.size(tCrA, mode=[2])
if k_block_max > 1:
# Wait until our first prefetched tile is loaded in
cute.arch.cp_async_wait_group(k_pipe_max - 2)
cute.arch.barrier()
# Prefetch the first rmem from the first k-tile
cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0])
cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0])
Note that by the docs:
@dsl_user_op
def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
"""
Auto-vectorizing SIMT copy policy.
Given a source and destination tensors that are statically shaped, this policy figures out the
largest safe vector width that the copy instruction can take and performs the copy.
"""
Now comes our main loop.
for _ in cutlass.range_dynamic(k_tile_count, unroll=1):
for k_block in range(k_block_max):
if k_block == k_block_max - 1:
tCsA_p = tCsA[None, None, None, smem_pipe_read]
tCsB_p = tCsB[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(k_pipe_max - 2)
cute.arch.barrier()
# Load A, B from shared memory to registers for k_block + 1
k_block_next = (k_block + 1) % k_block_max # static
cute.autovec_copy(
tCsA_p[None, None, k_block_next],
tCrA[None, None, k_block_next],
)
cute.autovec_copy(
tCsB_p[None, None, k_block_next],
tCrB[None, None, k_block_next],
)
# Fetch next A: To better interleave global memory access and
# compute instructions, we intentionally use the sequence:
# copy A, perform GEMM, then copy B.
if k_block == 0:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, gmem_pipe_read],
tAsA[None, None, None, smem_pipe_write],
# Use predicates because the m-mode may be irregular
pred=tApA,
)
# Thread-level register gemm for k_block
cute.gemm(
tiled_mma,
tCrC,
tCrA[None, None, k_block],
tCrB[None, None, k_block],
tCrC,
)
# Fetch next B and update smem pipeline read/write
if k_block == 0:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, gmem_pipe_read],
tBsB[None, None, None, smem_pipe_write],
# Use predicates because the n-mode may be irregular
pred=tBpB,
)
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == k_pipe_max:
smem_pipe_read = cutlass.Int32(0)
# After copying all tiles, we avoid clearing the predicate
# tensor in the `mainloop` to prevent increasing its
# instruction count. Instead, we continue copying the
# first tile, though it won't be used. The 0-th tile is not
# copied due to its irregular shape, which could lead to
# illegal memory accesses.
gmem_pipe_read = (
gmem_pipe_read + 1
if gmem_pipe_read + 1 < k_tile_count
else cutlass.Int32(1)
)
Note that we process all tiles in stages.
If necessary we prefetch the next tile (i.e. at the end of current stage). We load the
if k_block == k_block_max - 1:
tCsA_p = tCsA[None, None, None, smem_pipe_read]
tCsB_p = tCsB[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(k_pipe_max - 2)
cute.arch.barrier()
# Load A, B from shared memory to registers for k_block + 1
k_block_next = (k_block + 1) % k_block_max # static
cute.autovec_copy(
tCsA_p[None, None, k_block_next],
tCrA[None, None, k_block_next],
)
cute.autovec_copy(
tCsB_p[None, None, k_block_next],
tCrB[None, None, k_block_next],
)
Here we perform a copy from GMEM to SMEM under predication. Remember that the last mode of tAgA corresponds to the current Tile and the last mode of tAsA corresponds to the current stage.
if k_block == 0:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, gmem_pipe_read],
tAsA[None, None, None, smem_pipe_write],
# Use predicates because the m-mode may be irregular
pred=tApA,
)
The gemm we want to compute is computed blockwise and accumulated into the dedicated register tCrC.
# Thread-level register gemm for k_block
cute.gemm(
tiled_mma,
tCrC,
tCrA[None, None, k_block],
tCrB[None, None, k_block],
tCrC,
)
We fetch next B. We than update the smem_pipe_read and gmem_pipe_read parameters.
# Fetch next B and update smem pipeline read/write
if k_block == 0:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, gmem_pipe_read],
tBsB[None, None, None, smem_pipe_write],
# Use predicates because the n-mode may be irregular
pred=tBpB,
)
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == k_pipe_max:
smem_pipe_read = cutlass.Int32(0)
# After copying all tiles, we avoid clearing the predicate
# tensor in the `mainloop` to prevent increasing its
# instruction count. Instead, we continue copying the
# first tile, though it won't be used. The 0-th tile is not
# copied due to its irregular shape, which could lead to
# illegal memory accesses.
gmem_pipe_read = (
gmem_pipe_read + 1
if gmem_pipe_read + 1 < k_tile_count
else cutlass.Int32(1)
)
We make sure nothing is pending and than store the result to tCrC. We copy from GMEM to RMEM under predication.
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
tCrC.store(epilogue_op(tCrC.load()))
# predicate
cC = cute.make_identity_tensor(gC.shape)
tCpC = thr_mma.partition_C(cC)
predC = cute.make_fragment(tCrC.layout, cutlass.Boolean)
residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx
residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy
for i in range(cute.size(tCrC.shape)):
predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n))
numIterM = cute.size(tCrC, mode=[1])
numIterN = cute.size(tCrC, mode=[2])
atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type)
cute.copy(atom, tCrC, tCgC, pred=predC)
return
Conclusion
CuTeDSL is very performant but not easy to understand. I hope this blogpost could shed some light on the SGemm example. Potentially this equips us to understand and design more difficult kernels in the future.