Grouped Blockscaled Gemm - Kernel
In this blogpost we will discuss the kernel of grouped blockscaled GEMM on B200. This rounds up our recent series on grouped blockscaled GEMM on B200 GPUs. In this blogpost we focus on analysing on the difference between a persistent grouped GEMM in CuTeDSL vs the ordinary persistent GEMM.
kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tiled_mma_sfb: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_sfa: cute.CopyAtom,
mSFA_mkl: cute.Tensor,
tma_atom_sfb: cute.CopyAtom,
mSFB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
cluster_layout_vmnk: cute.Layout,
cluster_layout_sfb_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
sfa_smem_layout_staged: cute.Layout,
sfb_smem_layout_staged: cute.Layout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
epi_tile: cute.Tile,
tile_sched_params: utils.PersistentTileSchedulerParams,
group_count: cutlass.Constexpr,
problem_sizes_mnkl: cute.Tensor,
strides_abc: cute.Tensor,
ptrs_abc: cute.Tensor,
ptrs_sfasfb: cute.Tensor,
tensormaps: cute.Tensor,
):
Kernel takes the mostly the usual arguments. However we provide it now with
ptrs_abc: ACuTetensor with the pointers toA/B/Cfor each groupptrs_sfasfb: ACuTetensor with the pointers toSFA/SFBfor each groupstrides_abc: ACuTetensor with the strides toA/B/Cfor each grouptensormaps: ACuTetensor with a tensormap for each ofA/B/C/SFA/SFB/Cfor each SMproblem_sizes_mnkl: ACuTetensor with the problem shapeM/N/Kfor each group (Lis assumed to be 1)group_count: The number of groups, assumed to be known at compile time.
The first distinction to the usual GEMM is setup of the tensormaps:
tensormap_smem_ptr = storage.tensormap_buffer.data_ptr()
tensormap_a_smem_ptr = tensormap_smem_ptr
tensormap_b_smem_ptr = (
tensormap_a_smem_ptr
+ Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8
)
tensormap_sfa_smem_ptr = (
tensormap_b_smem_ptr
+ Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8
)
tensormap_sfb_smem_ptr = (
tensormap_sfa_smem_ptr
+ Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8
)
tensormap_c_smem_ptr = (
tensormap_sfb_smem_ptr
+ Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8
)
Note that we allocate consecutive space in SMEM for each of the tensormaps.
Afterwards we setup Pipelines, multicast mask and perform partitioning on a block and thread level at usual. The partitioning is performed on the initial tensors we initialized on the host. In the official these are just the smallest matrices across all groups but we may as well choose a small fixed problem shape and initialise them and associate these Layouts with null pointers.
grid_dim = cute.arch.grid_dim()
tensormap_workspace_idx = (
bidz * grid_dim[1] * grid_dim[0] + bidy * grid_dim[0] + bidx
)
tensormap_manager = utils.TensorMapManager(
utils.TensorMapUpdateMode.SMEM,
Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap,
)
tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(tensormap_workspace_idx, 0, None)].iterator
)
tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(tensormap_workspace_idx, 1, None)].iterator
)
tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(tensormap_workspace_idx, 2, None)].iterator
)
tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(tensormap_workspace_idx, 3, None)].iterator
)
tensormap_c_gmem_ptr = tensormap_manager.get_tensormap_ptr(
tensormaps[(tensormap_workspace_idx, 4, None)].iterator
)
Note that here we select the current position in the grid. Because we have a persistent kernel each position in the grid has a one to one correspondence with an SM which is the first mode of the tensormaps as we have seen in a previous blogpost. The next mode simply corresponds to the number of tensormaps. We have 5 tensormaps one for A/B/C/SFA/SFB each. The last mode is simply the bytes per tensor map scaled by 8 because the elements are of Int64 and the size of that datatype will be handled by CuTe when offsetting the pointer with the layout.
TMA
Within the TMA loop we can analyse following code:
if warp_idx == self.tma_warp_id:
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), grid_dim
)
# grouped gemm tile scheduler helper will compute the group index for the tile we're working on
group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper(
group_count,
tile_sched_params,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
)
tensormap_init_done = cutlass.Boolean(False)
# group index of last tile
last_group_idx = cutlass.Int32(-1)
work_tile = tile_sched.initial_work_tile_info()
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z(
cur_tile_coord,
problem_sizes_mnkl,
)
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
cur_group_idx = grouped_gemm_cta_tile_info.group_idx
is_group_changed = cur_group_idx != last_group_idx
# skip tensormap update if we're working on the same group
if is_group_changed:
real_tensor_a = self.make_tensor_abc_for_tensormap_update(
cur_group_idx,
self.a_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
0, # 0 for tensor A
)
real_tensor_b = self.make_tensor_abc_for_tensormap_update(
cur_group_idx,
self.b_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
1, # 1 for tensor B
)
real_tensor_sfa = self.make_tensor_sfasfb_for_tensormap_update(
cur_group_idx,
self.sf_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
ptrs_sfasfb,
0, # 0 for tensor SFA
)
real_tensor_sfb = self.make_tensor_sfasfb_for_tensormap_update(
cur_group_idx,
self.sf_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
ptrs_sfasfb,
1, # 1 for tensor SFB
)
if tensormap_init_done == False:
# wait tensormap initialization complete
self.tensormap_ab_init_barrier.arrive_and_wait()
tensormap_init_done = True
tensormap_manager.update_tensormap(
(
real_tensor_a,
real_tensor_b,
real_tensor_sfa,
real_tensor_sfb,
),
(tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb),
(
tensormap_a_gmem_ptr,
tensormap_b_gmem_ptr,
tensormap_sfa_gmem_ptr,
tensormap_sfb_gmem_ptr,
),
self.tma_warp_id,
(
tensormap_a_smem_ptr,
tensormap_b_smem_ptr,
tensormap_sfa_smem_ptr,
tensormap_sfb_smem_ptr,
),
)
Let's unpack step by step:
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), grid_dim
)
# grouped gemm tile scheduler helper will compute the group index for the tile we're working on
group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper(
group_count,
tile_sched_params,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
)
tensormap_init_done = cutlass.Boolean(False)
# group index of last tile
last_group_idx = cutlass.Int32(-1)
work_tile = tile_sched.initial_work_tile_info()
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
Before we will loop over tile for the current SM we initialise a GroupedGemmTileSchedulerHelper.
class GroupedGemmTileSchedulerHelper:
"""
A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.
:param group_count: Number of groups in current grouped gemm problem
:type group_count: int
:param tile_sched_params: Parameter used to create the tile scheduler this helper works with
:type tile_sched_params: PersistentTileSchedulerParams
:param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k)
:type cluster_tile_shape_mnk: tuple[int, int, int]
:param search_state: The initial search state
:type search_state: GroupedGemmGroupSearchState
"""
def __init__(
self,
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
) -> None:
self.tile_sched_params = tile_sched_params
self.group_count = group_count
self.lane_idx = cute.arch.lane_idx()
self.cluster_tile_shape_mnk = cluster_tile_shape_mnk
self.search_state = search_state
where we pass the initial state as
def create_initial_search_state() -> GroupedGemmGroupSearchState:
"""
Create an initial search state for grouped gemm.
:return: A new search state with initial values
:rtype: GroupedGemmGroupSearchState
"""
return GroupedGemmGroupSearchState(
start_group_idx=Int32(0),
tile_count_prev_group=Int32(0),
tile_count_searched=Int32(0),
)
We'll than loop over the tiles and do the following:
cur_tile_coord = work_tile.tile_idx
grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z(
cur_tile_coord,
problem_sizes_mnkl,
)
cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k
cur_group_idx = grouped_gemm_cta_tile_info.group_idx
is_group_changed = cur_group_idx != last_group_idx
delinearize_z will return GroupedWorkTileInfo(cta_tile_coord, is_valid, group_search_result).
# Grouped Work Tile Information
class GroupedWorkTileInfo(WorkTileInfo):
"""A class to represent information about a work tile.
:ivar tile_idx: The index of the tile.
:type tile_idx: cute.Coord
:ivar is_valid_tile: Whether the tile is valid.
:type is_valid_tile: Boolean
:ivar group_search_result: Group work tile information.
:type group_search_result: GroupSearchResult
"""
We'll than extract the number of k tiles for the current group and the group index from this object. We flag if the group index has changed. That is because if the group index has changed we need to update the tensor maps. That is because each group has its distinct tensor map.
# skip tensormap update if we're working on the same group
if is_group_changed:
real_tensor_a = self.make_tensor_abc_for_tensormap_update(
cur_group_idx,
self.a_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
0, # 0 for tensor A
)
real_tensor_b = self.make_tensor_abc_for_tensormap_update(
cur_group_idx,
self.b_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
1, # 1 for tensor B
)
real_tensor_sfa = self.make_tensor_sfasfb_for_tensormap_update(
cur_group_idx,
self.sf_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
ptrs_sfasfb,
0, # 0 for tensor SFA
)
real_tensor_sfb = self.make_tensor_sfasfb_for_tensormap_update(
cur_group_idx,
self.sf_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
ptrs_sfasfb,
1, # 1 for tensor SFB
)
if tensormap_init_done == False:
# wait tensormap initialization complete
self.tensormap_ab_init_barrier.arrive_and_wait()
tensormap_init_done = True
tensormap_manager.update_tensormap(
(
real_tensor_a,
real_tensor_b,
real_tensor_sfa,
real_tensor_sfb,
),
(tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb),
(
tensormap_a_gmem_ptr,
tensormap_b_gmem_ptr,
tensormap_sfa_gmem_ptr,
tensormap_sfb_gmem_ptr,
),
self.tma_warp_id,
(
tensormap_a_smem_ptr,
tensormap_b_smem_ptr,
tensormap_sfa_smem_ptr,
tensormap_sfb_smem_ptr,
),
)
We see that is done here. Note that we use the tensors of pointers for convenient access of the underlying data pointers to the distinct pytorch tensors for each group. The tensormap_manager handles our updating after we wait on the barrier to ensure all tensormaps are properly initialized.
Before continuing with next tile we safe the current group index
last_group_idx = cur_group_idx
MMA
#
# Initialize tensormaps for A, B, SFA and SFB
#
tensormap_manager.init_tensormap_from_atom(
tma_atom_a, tensormap_a_smem_ptr, self.mma_warp_id
)
tensormap_manager.init_tensormap_from_atom(
tma_atom_b, tensormap_b_smem_ptr, self.mma_warp_id
)
tensormap_manager.init_tensormap_from_atom(
tma_atom_sfa, tensormap_sfa_smem_ptr, self.mma_warp_id
)
tensormap_manager.init_tensormap_from_atom(
tma_atom_sfb, tensormap_sfb_smem_ptr, self.mma_warp_id
)
# indicate tensormap initialization has finished
self.tensormap_ab_init_barrier.arrive_and_wait()
Tensormap Manager calls tiny wrapper to initialize the tensormaps within the MMA warp
# init tensormap pointed by dst_ptr with the one inside copy_atom.
# dst_ptr should be pointing to a global memory location or a smem location
# warp_id specifies which warp to perform the initialization
@dsl_user_op
@cute.jit
def init_tensormap_from_atom(
self,
copy_atom: cute.CopyAtom,
dst_ptr: cute.Pointer,
warp_id: int,
*,
loc=None,
ip=None,
) -> None:
warp_idx = cute.arch.warp_idx(loc=loc, ip=ip)
warp_idx = cute.arch.make_warp_uniform(warp_idx, loc=loc, ip=ip)
if warp_idx == warp_id:
with cute.arch.elect_one(loc=loc, ip=ip):
cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr, loc=loc, ip=ip)
cute.arch.sync_warp(loc=loc, ip=ip)
return
Note that these where previously updated within the the TMA warp (if a change in the group occurred).
Afterwards we perform our usual setup steps for the MMA.
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), grid_dim
)
# grouped gemm tile scheduler helper will compute the group index for the tile we're working on
group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper(
group_count,
tile_sched_params,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
)
Similar to above we create a helper to map from tile scheduler to the helper function.
cur_tile_coord = work_tile.tile_idx
# MMA warp is only interested in number of tiles along K dimension
(
cur_k_tile_cnt,
cur_group_idx,
) = group_gemm_ts_helper.search_cluster_tile_count_k(
cur_tile_coord,
problem_sizes_mnkl,
)
Within the loops over all tiles we get the number of k tiles we need to loop over MMA. Note that this is important because different groups might have different K Mode and therefore different number of K Tiles to loop over.
Epilogue
In the Epilogue the pattern is similar to the one we use for the TMA.
# initialize tensorap for C
tensormap_manager.init_tensormap_from_atom(
tma_atom_c,
tensormap_c_smem_ptr,
self.epilog_warp_id[0],
)
Again we use the helper before iterating over the tiles.
#
# Persistent tile scheduling loop
#
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), grid_dim
)
# grouped gemm tile scheduler helper will compute the group index for the tile we're working on
group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper(
group_count,
tile_sched_params,
self.cluster_tile_shape_mnk,
utils.create_initial_search_state(),
)
Within the tile loop we perform similar update rule as for the TMA, but of course here we won't need to update A/B/SFA/SFB but rather C.
cur_tile_coord = work_tile.tile_idx
grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z(
cur_tile_coord,
problem_sizes_mnkl,
)
cur_group_idx = grouped_gemm_cta_tile_info.group_idx
is_group_changed = cur_group_idx != last_group_idx
if is_group_changed:
# construct tensor c based on real shape, stride information
real_tensor_c = self.make_tensor_abc_for_tensormap_update(
cur_group_idx,
self.c_dtype,
(
grouped_gemm_cta_tile_info.problem_shape_m,
grouped_gemm_cta_tile_info.problem_shape_n,
grouped_gemm_cta_tile_info.problem_shape_k,
),
strides_abc,
ptrs_abc,
2, # 2 for tensor C
)
tensormap_manager.update_tensormap(
((real_tensor_c),),
((tma_atom_c),),
((tensormap_c_gmem_ptr),),
self.epilog_warp_id[0],
(tensormap_c_smem_ptr,),
)
If the group has changed we perform a fence before making the TMEM -> RMEM -> SMEM -> GMEM store operation in subtiles.
if is_group_changed:
if warp_idx == self.epilog_warp_id[0]:
tensormap_manager.fence_tensormap_update(tensormap_c_gmem_ptr)
at the end of each loop we update the group index
last_group_idx = cur_group_idx
Conclusion
I hope this blogpost explained how to perform groupwise blockscaled gemm operation using the CuTeDSL to a better level. Please contact me on Linkedin if you want to exchange ideas.