Grouped Block scaled Gemm - Intro
This is the first blogpost on a series which will be about Grouped Blockscaled GEMM on Blackwell of which the code can be found here.
In ordinary english that means we can calculate a group of GEMMS, i.e. multiple blockscaled GEMMS with potentially distinct problem sizes. Before we study the kernel itself it is useful exercise to understand the setup within the run function because we can learn about the expected parameters we will use during our computation. Here we will focus on that part before studying more interesting aspects of the example.
run
# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
l: int,
mode0: int,
mode1: int,
is_mode0_major: bool,
dtype: type[cutlass.Numeric],
is_dynamic_layout: bool = True,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
"""Create GPU tensor from either a new or existing CPU tensor.
:param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
:type torch_tensor_cpu: torch.Tensor, optional
"""
# Create new CPU tensor
torch_tensor_cpu = cutlass_torch.matrix(
l,
mode0,
mode1,
is_mode0_major,
cutlass.Float32,
)
# Create GPU tensor from CPU tensor (new or existing)
cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
)
# Mark tensor with element divisibility for 16B alignment
cute_tensor.mark_compact_shape_dynamic(
mode=0 if is_mode0_major else 1,
stride_order=(2, 1, 0) if is_mode0_major else (2, 0, 1),
divisibility=32 if dtype == cutlass.Float4E2M1FN else 16,
)
# omit stride for L mode as it is always 1
stride = (1, mode0) if is_mode0_major else (mode1, 1)
return (
torch_tensor.data_ptr(),
torch_tensor,
cute_tensor,
torch_tensor_cpu,
stride,
)
This is very simple function that will create a tensor with given shape, stride and divisibility for alignment. It will return
- Pointer to the torch tensor
- The torch tensor
- The cute tensor
- The torch tensor on the CPU (in
Float32) - The stride- simply indicates if we have
MorKmajor tensor
def create_tensors_abc_for_all_groups(
problem_sizes_mnkl: List[tuple[int, int, int, int]],
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
a_major: str,
b_major: str,
c_major: str,
) -> tuple[
List[List[int]],
List[List[torch.Tensor]],
List[tuple],
List[List[tuple]],
List[List[torch.Tensor]],
]:
ref_torch_fp32_tensors_abc = []
torch_tensors_abc = []
cute_tensors_abc = []
strides_abc = []
ptrs_abc = []
# Iterate through all groups and create tensors for each group
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
# Create tensors A, B, C
(
ptr_a,
torch_tensor_a,
cute_tensor_a,
ref_torch_fp32_tensor_a,
stride_mk_a,
) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)
(
ptr_b,
torch_tensor_b,
cute_tensor_b,
ref_torch_fp32_tensor_b,
stride_nk_b,
) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)
(
ptr_c,
torch_tensor_c,
cute_tensor_c,
ref_torch_fp32_tensor_c,
stride_mn_c,
) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)
ref_torch_fp32_tensors_abc.append(
[ref_torch_fp32_tensor_a, ref_torch_fp32_tensor_b, ref_torch_fp32_tensor_c]
)
ptrs_abc.append([ptr_a, ptr_b, ptr_c])
torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
cute_tensors_abc.append(
(
cute_tensor_a,
cute_tensor_b,
cute_tensor_c,
)
)
return (
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
ref_torch_fp32_tensors_abc,
)
Essentially we will have nested lists because we have different groups. So the length of each list is the number of groups and each list entry contains 3 entries. One for a, b and c each.
- List of pointers for
(a, b, c)for each group - List of corresponding torch tensors for
(a, b, c)for each group - List of cute tensors for
(a, b, c)for each group - List of strides for
(a, b, c)for each group - List of torch tensors on CPU in
Float32data format for(a, b, c)for each group
Same for the scale factors for (a,b) for each group.
def create_tensors_sfasfb_for_all_groups(
problem_sizes_mnkl: List[tuple[int, int, int, int]],
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
) -> tuple[
List[List[int]],
List[List[torch.Tensor]],
List[tuple],
List[List[torch.Tensor]],
]:
ptrs_sfasfb = []
torch_tensors_sfasfb = []
cute_tensors_sfasfb = []
refs_sfasfb = []
# Iterate through all groups and create tensors for each group
for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
sfa_ref, ptr_sfa, sfa_tensor, sfa_torch = create_scale_factor_tensor(
l, m, k, sf_vec_size, sf_dtype
)
sfb_ref, ptr_sfb, sfb_tensor, sfb_torch = create_scale_factor_tensor(
l, n, k, sf_vec_size, sf_dtype
)
ptrs_sfasfb.append([ptr_sfa, ptr_sfb])
torch_tensors_sfasfb.append([sfa_torch, sfb_torch])
cute_tensors_sfasfb.append(
(
sfa_tensor,
sfb_tensor,
)
)
refs_sfasfb.append([sfa_ref, sfb_ref])
return (
ptrs_sfasfb,
torch_tensors_sfasfb,
cute_tensors_sfasfb,
refs_sfasfb,
)
See the example for the helper functions which I will omit here.
# Create tensors A, B, C for all groups
(
ptrs_abc,
torch_tensors_abc,
cute_tensors_abc,
strides_abc,
ref_f32_torch_tensors_abc,
) = create_tensors_abc_for_all_groups(
problem_sizes_mnkl,
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
)
# Create tensors SFA, SFB for all groups
(
ptrs_sfasfb,
torch_tensors_sfasfb,
cute_tensors_sfasfb,
refs_f32_torch_tensors_sfasfb,
) = create_tensors_sfasfb_for_all_groups(
problem_sizes_mnkl,
sf_dtype,
sf_vec_size,
)
# Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps
key_size_a = lambda item: item[1][0] * item[1][2]
key_size_b = lambda item: item[1][1] * item[1][2]
key_size_c = lambda item: item[1][0] * item[1][1]
# Find the indices of the groups with the smallest tensor sizes
min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a)
min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b)
min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c)
initial_cute_tensors_abc = [
cute_tensors_abc[min_a_idx][0], # A with smallest (m, k)
cute_tensors_abc[min_b_idx][1], # B with smallest (n, k)
cute_tensors_abc[min_c_idx][2], # C with smallest (m, n)
]
initial_cute_tensors_sfasfb = [
cute_tensors_sfasfb[min_a_idx][0], # SFA with smallest (m, k)'s group
cute_tensors_sfasfb[min_b_idx][1], # SFB with smallest (n, k)'s group
]
Here we use convenience lambdas to setup the initial tensormaps.
The run signature tells us that problem_size is List[Tuple[int, int, int, int]] so it will contain the problem sizes for each group.
Assume we have two groups which are two GEMMs with problem sizes:
- 256 x 256 x 256 x 1
- 512 x 512 x 512 x 1
Then
problem_sizeis[(256, 256, 256, 1), (512, 512, 512, 1)]. In the lambdas we always haveitem[1][X]because we apply it to the enumerate, i.e. the first item will just be the group index (which we will return). If this is understood the whole thing is clear: item[1][0] * item[1][2]will give us smallestM * K, i.e. smallest "area" of matrix A across the groupsitem[1][1] * item[1][2]will give us smallestN * K, i.e. smallest "area" of matrix B across the groupsitem[1][0] * item[1][1]will give us smallestM * N, i.e. smallest "area" of matrix C across the groups So we will smallest areas separately forA,BandCacross all groups and use this as theinitial_cute_tensors_abc. The scale factors are choses correspondingly toA/B.
hardware_info = cutlass.utils.HardwareInfo()
sm_count = hardware_info.get_max_active_clusters(1)
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
# Prepare tensormap buffer for each SM
num_tensormap_buffers = sm_count
tensormap_shape = (
num_tensormap_buffers,
Sm100GroupedBlockScaledGemmKernel.num_tensormaps,
Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8,
)
tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like(
torch.empty(tensormap_shape, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
)
note that:
bytes_per_tensormap = 128
num_tensormaps = 5
are hardcoded and independent of our problem config. The sm_count on B200 is 148 and thus we will have for the shape . Note we scaled bytes_per_tensormap by a factor of 8- the number of bytes in int64.
The tensor_of_tensormap will have a layout . Note that means we will offset underlying pointer such that one tensor map will get an offset of exactly the number of bytes in one tensor map. The Layout can be thought of as multiple planes. Each plane corresponds to one SM. Each plane has layed out the num_tensormaps tensor maps layed out after one another.
grouped_blockscaled_gemm = Sm100GroupedBlockScaledGemmKernel(
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
# layout (num_groups, 4):(4, 1)
(
tensor_of_dim_size_mnkl,
tensor_of_dim_size_mnkl_torch,
) = cutlass_torch.cute_tensor_like(
torch.tensor(problem_sizes_mnkl, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
# layout (num_groups, 3, 2):(6, 2, 1)
tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(strides_abc, dtype=torch.int32),
cutlass.Int32,
is_dynamic_layout=False,
assumed_align=16,
)
# layout (num_groups,3):(3, 1)
tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_abc, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
# layout (num_groups,2):(2, 1)
tensor_of_ptrs_sfasfb, tensor_of_ptrs_sfasfb_torch = cutlass_torch.cute_tensor_like(
torch.tensor(ptrs_sfasfb, dtype=torch.int64),
cutlass.Int64,
is_dynamic_layout=False,
assumed_align=16,
)
Let's unpack what is happening here:
tensor_of_dim_size_mnklhas Layout which is clear because each group has a distinct problem sizetensor_of_strides_abchas Layout because we specify two strides per tensora,b,cper group.tensor_of_ptrs_abchas Layout because we have three pointers toa,b,cper group.tensor_of_ptrs_sfasfbhas Layout because we have two scale factor pointers per group.
# Compute total number of cluster tiles we need to compute for given grouped GEMM problem
def compute_total_num_clusters(
problem_sizes_mnkl: List[tuple[int, int, int, int]],
cluster_tile_shape_mn: tuple[int, int],
) -> int:
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
(x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn: tuple[int, int],
cluster_shape_mn: tuple[int, int],
) -> tuple[int, int]:
cta_tile_shape_mn = [128, mma_tiler_mn[1]]
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(mma_tiler_mn, cluster_shape_mn)
total_num_clusters = compute_total_num_clusters(
problem_sizes_mnkl, cluster_tile_shape_mn
)
compute_cluster_tile_shape will calculate the cluster tile shape, which is essentially , afterwards we calculate for each group how many of these tiles we need to cover the whole matrix (M, N) and this is total_num_clusters.
# Initialize Stream
current_stream = cutlass_torch.default_stream()
# Compile grouped GEMM kernel
compiled_grouped_gemm = cute.compile(
grouped_blockscaled_gemm,
initial_cute_tensors_abc[0],
initial_cute_tensors_abc[1],
initial_cute_tensors_abc[2],
initial_cute_tensors_sfasfb[0],
initial_cute_tensors_sfasfb[1],
num_groups,
tensor_of_dim_size_mnkl,
tensor_of_strides_abc,
tensor_of_ptrs_abc,
tensor_of_ptrs_sfasfb,
total_num_clusters,
tensor_of_tensormap,
max_active_clusters,
current_stream,
options=f"--opt-level 2",
)
We'll finally pass all these parameters to the kernel and that's it with the basic setup.
Conclusion
I hope this blogpost serves as a good intro to the basic setup before launching a CuTeDSL grouped blockscaled GEMM kernel. The experimentation was enabled by Verda, please check out their website if you want to perform your own experiments on a B200 GPU. I am happy to connect on Linkedin to exchange ideas.