Grouped Blockscaled Gemm - Host code
This is the next blogpost where I aim to explain grouped blockscaled GEMM for B200 GPUs in a top down manner. Before studying the internals of Grouped Blockscaled GEMM Kernel we should understand the setup of MMA and TMA, tile scheduler and other parts which are important during kernel execution. This blogs aims to explain the parts that are different from the usual persistent blockscaled dense gemm approach.
call
@cute.jit
def __call__(
self,
initial_a: cute.Tensor,
initial_b: cute.Tensor,
initial_c: cute.Tensor,
initial_sfa: cute.Tensor,
initial_sfb: cute.Tensor,
group_count: cutlass.Constexpr[int],
problem_shape_mnkl: cute.Tensor,
strides_abc: cute.Tensor,
tensor_address_abc: cute.Tensor,
tensor_address_sfasfb: cute.Tensor,
total_num_clusters: cutlass.Constexpr[int],
tensormap_cute_tensor: cute.Tensor,
max_active_clusters: cutlass.Constexpr[int],
stream: cuda.CUstream,
):
Let's quickly recap the meaning of each of this arguments:
initial_{a|b|c}: These are the smallesta,b,cacross all groups where by smallest we mean smallest "Area" of matrix, i.e. for and similar for the other matrices.initial_{sfa|sfb}: Corresponding scale factors foraandbabove.group_count: Number of groups. Known at compile time.problem_shape_mnkl: Tensor containing the problem shape for each groupstrides_abc: Tensor with Layout where first mode is the group index, second mode is wether we deal with , or and last index are the strides in mode and mode for , mode and mode for and mode and mode for . Note the strides for each tensor are 2 dimensional because the mode is always chosen to be one and therefore trivially giventensor_address_abc: Tensor with Layout . First mode is group index, second mode is wether we deal with , or . Each entry is a pointer to the the corresponding tensor within the group.tensor_address_sfasfb: Corresponding pointers to the scale factors for each group.total_num_clusters: The total number of tiles we need to cover within each group summed up (taken into account potential cluster we use). Known at compile time.tensormap_cute_tensor: Buffers of tensormaps, each buffer corresponds to oneSMand contains 5 tensormaps. Created from an empty torch tensor (i.e. initially contains garbage).max_active_clusters: Determined by hardware. User chosen number of clusters but capped with the maximum available clusters of our hardware. Known at compile time.stream: The stream we launch our kernel within.
In the following we will briefly go over the "usual" Blackwell GPU stuff. Please take a look at my other Blackwell blogs if you feel lost at any point. If there is something new I will explain it in more detail.
self.a_dtype = initial_a.element_type
self.b_dtype = initial_b.element_type
self.sf_dtype = initial_sfa.element_type
self.c_dtype = initial_c.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(initial_c)
if cutlass.const_expr(self.a_dtype != self.b_dtype):
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
# Setup attributes that dependent on gemm inputs
self._setup_attributes()
The _setup_attributes works largely as usual.
The only major deviation is here. Please read the original code for details, as this is basically just a bookkeeping step to ensure we don't get into troubles with the SMEM.
mbar_smem_bytes = self._get_mbar_smem_bytes(
num_acc_stage=self.num_acc_stage,
num_ab_stage=self.num_ab_stage,
num_c_stage=self.num_c_stage,
)
# Use utils.TensorMapUpdateMode.SMEM by default
tensormap_smem_bytes = (
Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap
* Sm100GroupedBlockScaledGemmKernel.num_tensormaps
)
if (
mbar_smem_bytes
+ tensormap_smem_bytes
+ Sm100GroupedBlockScaledGemmKernel.tensor_memory_management_bytes
> self.reserved_smem_bytes
):
raise ValueError(
f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the "
f"reserved smem bytes {self.reserved_smem_bytes}"
)
Afterwards we perform usual setup of tiled MMA and TMA with the initial tensors. Note that this is only relevant for the TMA. That is because the tiled_mma will not use the shape of the initial tensors but only their majorness and datatypes. The SMEM layouts are also the same because they only depend on the tile size we use. So the only difference across different elements is in the GMEM tensor we provide to the helper functions that setup the TMA. We will later see how we assign correct elements to TMA during transfer from GMEM to SMEM later.
Note we could also just read the docstring here
For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided
by different tensors in global memory. The "initial" tensors only carry data type and
majorness information.
# Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
# ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(
initial_a.shape, self.sf_vec_size
)
initial_sfa = cute.make_tensor(initial_sfa.iterator, sfa_layout)
# ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(
initial_b.shape, self.sf_vec_size
)
initial_sfb = cute.make_tensor(initial_sfb.iterator, sfb_layout)
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.sf_dtype,
self.sf_vec_size,
self.cta_group,
self.mma_inst_shape_mn,
)
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype,
self.a_major_mode,
self.b_major_mode,
self.sf_dtype,
self.sf_vec_size,
cute.nvgpu.tcgen05.CtaGroup.ONE,
self.mma_inst_shape_mn_sfb,
)
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Setup TMA load for A
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
initial_a,
a_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
)
# Setup TMA load for B
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
initial_b,
b_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
)
# Setup TMA load for SFA
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
sfa_smem_layout = cute.slice_(
self.sfa_smem_layout_staged, (None, None, None, 0)
)
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op,
initial_sfa,
sfa_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
internal_type=cutlass.Int16,
)
# Setup TMA load for SFB
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
self.cluster_shape_mn, tiled_mma.thr_id
)
sfb_smem_layout = cute.slice_(
self.sfb_smem_layout_staged, (None, None, None, 0)
)
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op,
initial_sfb,
sfb_smem_layout,
self.mma_tiler_sfb,
tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape,
internal_type=cutlass.Int16,
)
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
self.num_tma_load_bytes = (
a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size
) * atom_thr_size
# Setup TMA store for C
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
initial_c,
epi_smem_layout,
self.epi_tile,
)
Afterwards we compute the grid.
# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
total_num_clusters, self.cluster_shape_mn, max_active_clusters
)
Lets take a closer look at how this is done
@staticmethod
def _compute_grid(
total_num_clusters: int,
cluster_shape_mn: tuple[int, int],
max_active_clusters: cutlass.Constexpr[int],
) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]:
# Create problem shape with M, N dimensions from cluster shape
# and L dimension representing the total number of clusters.
problem_shape_ntile_mnl = (
cluster_shape_mn[0],
cluster_shape_mn[1],
cutlass.Int32(total_num_clusters),
)
tile_sched_params = utils.PersistentTileSchedulerParams(
problem_shape_ntile_mnl, (*cluster_shape_mn, 1)
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
As we can see the setup of the persistent tile scheduler is slightly different from the usual setup which looks like this:
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
Let's quickly recap what the code in the persistent blockscaled dense gemm does. We extract (bM, bN) into the c_shape. We'll than use the zipped_divide to obtain layout gc with shape ((bM, bN), (RestM, RestN, RestL)) and use the first mode as the num_ctas_mnl parameter.
Contrast that to the group blockscaled kernel: Here we will use the cluster shape as first and second mode of problem_shape_ntile_mnl and in the last mode we put total_num_clusters. Note that total_num_clusters is simply the total number of tiles we need to cover all the (bM, bN) "cluster scaled tiles" across all groups.
The next step is to setup the SMEM. Note that we have one buffer per SM. And because we have a persistent kernel we will launch a grid where each SM corresponds to one block. So we allocate num_tensor_map tensormaps per block in SMEM.
The rest is same as for usual persistent kernel.
self.buffer_align_bytes = 1024
self.size_tensormap_in_i64 = (
Sm100GroupedBlockScaledGemmKernel.num_tensormaps
* Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap
// 8
)
# Define shared storage for kernel
@cute.struct
class SharedStorage:
tensormap_buffer: cute.struct.MemRange[
cutlass.Int64, self.size_tensormap_in_i64
]
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC: cute.struct.Align[
cute.struct.MemRange[
self.c_dtype,
cute.cosize(self.c_smem_layout_staged.outer),
],
self.buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
sA: cute.struct.Align[
cute.struct.MemRange[
self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
],
self.buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
sB: cute.struct.Align[
cute.struct.MemRange[
self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
],
self.buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
sSFA: cute.struct.Align[
cute.struct.MemRange[
self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)
],
self.buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
sSFB: cute.struct.Align[
cute.struct.MemRange[
self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)
],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
Conclusion
I hope this brief note could successfully explain which steps on the host side are different from the ones for a usual persistent kernel. The next blogpost will continue with the logic within the kernel. Thanks to Verda which enabled me to do these experiments. If you want to exchange ideas I am happy to connect on Linkedin.