simons blog

Grouped Block scaled Gemm - Intro

This is the first blogpost on a series which will be about Grouped Blockscaled GEMM on Blackwell of which the code can be found here.

In ordinary english that means we can calculate a group of GEMMS, i.e. multiple blockscaled GEMMS with potentially distinct problem sizes. Before we study the kernel itself it is useful exercise to understand the setup within the run function because we can learn about the expected parameters we will use during our computation. Here we will focus on that part before studying more interesting aspects of the example.

run

# Create tensor and return the pointer, tensor, and stride
def create_tensor_and_stride(
    l: int,
    mode0: int,
    mode1: int,
    is_mode0_major: bool,
    dtype: type[cutlass.Numeric],
    is_dynamic_layout: bool = True,
) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]:
    """Create GPU tensor from either a new or existing CPU tensor.

    :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one.
    :type torch_tensor_cpu: torch.Tensor, optional
    """

    # Create new CPU tensor
    torch_tensor_cpu = cutlass_torch.matrix(
        l,
        mode0,
        mode1,
        is_mode0_major,
        cutlass.Float32,
    )

    # Create GPU tensor from CPU tensor (new or existing)
    cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like(
        torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16
    )

    # Mark tensor with element divisibility for 16B alignment
    cute_tensor.mark_compact_shape_dynamic(
        mode=0 if is_mode0_major else 1,
        stride_order=(2, 1, 0) if is_mode0_major else (2, 0, 1),
        divisibility=32 if dtype == cutlass.Float4E2M1FN else 16,
    )

    # omit stride for L mode as it is always 1
    stride = (1, mode0) if is_mode0_major else (mode1, 1)

    return (
        torch_tensor.data_ptr(),
        torch_tensor,
        cute_tensor,
        torch_tensor_cpu,
        stride,
    )

This is very simple function that will create a tensor with given shape, stride and divisibility for alignment. It will return

def create_tensors_abc_for_all_groups(
    problem_sizes_mnkl: List[tuple[int, int, int, int]],
    ab_dtype: Type[cutlass.Numeric],
    c_dtype: Type[cutlass.Numeric],
    a_major: str,
    b_major: str,
    c_major: str,
) -> tuple[
    List[List[int]],
    List[List[torch.Tensor]],
    List[tuple],
    List[List[tuple]],
    List[List[torch.Tensor]],
]:
    ref_torch_fp32_tensors_abc = []
    torch_tensors_abc = []
    cute_tensors_abc = []
    strides_abc = []
    ptrs_abc = []

    # Iterate through all groups and create tensors for each group
    for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
        # Create tensors  A, B, C
        (
            ptr_a,
            torch_tensor_a,
            cute_tensor_a,
            ref_torch_fp32_tensor_a,
            stride_mk_a,
        ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype)

        (
            ptr_b,
            torch_tensor_b,
            cute_tensor_b,
            ref_torch_fp32_tensor_b,
            stride_nk_b,
        ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype)

        (
            ptr_c,
            torch_tensor_c,
            cute_tensor_c,
            ref_torch_fp32_tensor_c,
            stride_mn_c,
        ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype)

        ref_torch_fp32_tensors_abc.append(
            [ref_torch_fp32_tensor_a, ref_torch_fp32_tensor_b, ref_torch_fp32_tensor_c]
        )

        ptrs_abc.append([ptr_a, ptr_b, ptr_c])
        torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c])
        strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c])
        cute_tensors_abc.append(
            (
                cute_tensor_a,
                cute_tensor_b,
                cute_tensor_c,
            )
        )

    return (
        ptrs_abc,
        torch_tensors_abc,
        cute_tensors_abc,
        strides_abc,
        ref_torch_fp32_tensors_abc,
    )

Essentially we will have nested lists because we have different groups. So the length of each list is the number of groups and each list entry contains 3 entries. One for a, b and c each.

Same for the scale factors for (a,b) for each group.

def create_tensors_sfasfb_for_all_groups(
    problem_sizes_mnkl: List[tuple[int, int, int, int]],
    sf_dtype: Type[cutlass.Numeric],
    sf_vec_size: int,
) -> tuple[
    List[List[int]],
    List[List[torch.Tensor]],
    List[tuple],
    List[List[torch.Tensor]],
]:
    ptrs_sfasfb = []
    torch_tensors_sfasfb = []
    cute_tensors_sfasfb = []
    refs_sfasfb = []

    # Iterate through all groups and create tensors for each group
    for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl):
        sfa_ref, ptr_sfa, sfa_tensor, sfa_torch = create_scale_factor_tensor(
            l, m, k, sf_vec_size, sf_dtype
        )
        sfb_ref, ptr_sfb, sfb_tensor, sfb_torch = create_scale_factor_tensor(
            l, n, k, sf_vec_size, sf_dtype
        )
        ptrs_sfasfb.append([ptr_sfa, ptr_sfb])
        torch_tensors_sfasfb.append([sfa_torch, sfb_torch])
        cute_tensors_sfasfb.append(
            (
                sfa_tensor,
                sfb_tensor,
            )
        )
        refs_sfasfb.append([sfa_ref, sfb_ref])

    return (
        ptrs_sfasfb,
        torch_tensors_sfasfb,
        cute_tensors_sfasfb,
        refs_sfasfb,
    )

See the example for the helper functions which I will omit here.

    # Create tensors A, B, C for all groups
    (
        ptrs_abc,
        torch_tensors_abc,
        cute_tensors_abc,
        strides_abc,
        ref_f32_torch_tensors_abc,
    ) = create_tensors_abc_for_all_groups(
        problem_sizes_mnkl,
        ab_dtype,
        c_dtype,
        a_major,
        b_major,
        c_major,
    )
    # Create tensors SFA, SFB for all groups
    (
        ptrs_sfasfb,
        torch_tensors_sfasfb,
        cute_tensors_sfasfb,
        refs_f32_torch_tensors_sfasfb,
    ) = create_tensors_sfasfb_for_all_groups(
        problem_sizes_mnkl,
        sf_dtype,
        sf_vec_size,
    )

    # Choose A, B, C, SFA, SFB with the smallest size to create initial tensormaps
    key_size_a = lambda item: item[1][0] * item[1][2]
    key_size_b = lambda item: item[1][1] * item[1][2]
    key_size_c = lambda item: item[1][0] * item[1][1]
    # Find the indices of the groups with the smallest tensor sizes
    min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a)
    min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b)
    min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c)
    initial_cute_tensors_abc = [
        cute_tensors_abc[min_a_idx][0],  # A with smallest (m, k)
        cute_tensors_abc[min_b_idx][1],  # B with smallest (n, k)
        cute_tensors_abc[min_c_idx][2],  # C with smallest (m, n)
    ]
    initial_cute_tensors_sfasfb = [
        cute_tensors_sfasfb[min_a_idx][0],  # SFA with smallest (m, k)'s group
        cute_tensors_sfasfb[min_b_idx][1],  # SFB with smallest (n, k)'s group
    ]

Here we use convenience lambdas to setup the initial tensormaps. The run signature tells us that problem_size is List[Tuple[int, int, int, int]] so it will contain the problem sizes for each group.

Assume we have two groups which are two GEMMs with problem sizes:

    hardware_info = cutlass.utils.HardwareInfo()
    sm_count = hardware_info.get_max_active_clusters(1)
    max_active_clusters = hardware_info.get_max_active_clusters(
        cluster_shape_mn[0] * cluster_shape_mn[1]
    )
    # Prepare tensormap buffer for each SM
    num_tensormap_buffers = sm_count
    tensormap_shape = (
        num_tensormap_buffers,
        Sm100GroupedBlockScaledGemmKernel.num_tensormaps,
        Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8,
    )
    tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like(
        torch.empty(tensormap_shape, dtype=torch.int64),
        cutlass.Int64,
        is_dynamic_layout=False,
    )

note that:

    bytes_per_tensormap = 128
    num_tensormaps = 5

are hardcoded and independent of our problem config. The sm_count on B200 is 148 and thus we will have for the shape (148,5,16). Note we scaled bytes_per_tensormap by a factor of 8- the number of bytes in int64.

The tensor_of_tensormap will have a layout (148,5,16):(80,16,1). Note that means we will offset underlying pointer such that one tensor map will get an offset of exactly the number of bytes in one tensor map. The Layout can be thought of as multiple planes. Each plane corresponds to one SM. Each plane has layed out the num_tensormaps tensor maps layed out after one another.

    grouped_blockscaled_gemm = Sm100GroupedBlockScaledGemmKernel(
        sf_vec_size,
        mma_tiler_mn,
        cluster_shape_mn,
    )

    # layout (num_groups, 4):(4, 1)
    (
        tensor_of_dim_size_mnkl,
        tensor_of_dim_size_mnkl_torch,
    ) = cutlass_torch.cute_tensor_like(
        torch.tensor(problem_sizes_mnkl, dtype=torch.int32),
        cutlass.Int32,
        is_dynamic_layout=False,
        assumed_align=16,
    )

    # layout (num_groups, 3, 2):(6, 2, 1)
    tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like(
        torch.tensor(strides_abc, dtype=torch.int32),
        cutlass.Int32,
        is_dynamic_layout=False,
        assumed_align=16,
    )

    # layout (num_groups,3):(3, 1)
    tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like(
        torch.tensor(ptrs_abc, dtype=torch.int64),
        cutlass.Int64,
        is_dynamic_layout=False,
        assumed_align=16,
    )

    # layout (num_groups,2):(2, 1)
    tensor_of_ptrs_sfasfb, tensor_of_ptrs_sfasfb_torch = cutlass_torch.cute_tensor_like(
        torch.tensor(ptrs_sfasfb, dtype=torch.int64),
        cutlass.Int64,
        is_dynamic_layout=False,
        assumed_align=16,
    )

Let's unpack what is happening here:

    # Compute total number of cluster tiles we need to compute for given grouped GEMM problem
    def compute_total_num_clusters(
        problem_sizes_mnkl: List[tuple[int, int, int, int]],
        cluster_tile_shape_mn: tuple[int, int],
    ) -> int:
        total_num_clusters = 0
        for m, n, _, _ in problem_sizes_mnkl:
            num_clusters_mn = tuple(
                (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn)
            )
            total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
        return total_num_clusters

    # Compute cluster tile shape
    def compute_cluster_tile_shape(
        mma_tiler_mn: tuple[int, int],
        cluster_shape_mn: tuple[int, int],
    ) -> tuple[int, int]:
        cta_tile_shape_mn = [128, mma_tiler_mn[1]]
        return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))

    cluster_tile_shape_mn = compute_cluster_tile_shape(mma_tiler_mn, cluster_shape_mn)
    total_num_clusters = compute_total_num_clusters(
        problem_sizes_mnkl, cluster_tile_shape_mn
    )

compute_cluster_tile_shape will calculate the cluster tile shape, which is essentially (128·Nc,x,bN·Nc,y), afterwards we calculate for each group how many of these tiles we need to cover the whole matrix (M, N) and this is total_num_clusters.

    # Initialize Stream
    current_stream = cutlass_torch.default_stream()

    # Compile grouped GEMM kernel
    compiled_grouped_gemm = cute.compile(
        grouped_blockscaled_gemm,
        initial_cute_tensors_abc[0],
        initial_cute_tensors_abc[1],
        initial_cute_tensors_abc[2],
        initial_cute_tensors_sfasfb[0],
        initial_cute_tensors_sfasfb[1],
        num_groups,
        tensor_of_dim_size_mnkl,
        tensor_of_strides_abc,
        tensor_of_ptrs_abc,
        tensor_of_ptrs_sfasfb,
        total_num_clusters,
        tensor_of_tensormap,
        max_active_clusters,
        current_stream,
        options=f"--opt-level 2",
    )

We'll finally pass all these parameters to the kernel and that's it with the basic setup.

Conclusion

I hope this blogpost serves as a good intro to the basic setup before launching a CuTeDSL grouped blockscaled GEMM kernel. The experimentation was enabled by Verda, please check out their website if you want to perform your own experiments on a B200 GPU. I am happy to connect on Linkedin to exchange ideas.