simons blog

CuTe partitions

In CuTe we have multiple ways to tile our tensors, i.e. to divide the work onto the threads on our GPU.

The CuTe docs describe three modes of tiling

In this blogpost we will examine all three of them and implement a Copy kernel corresponding to each concept.

Pure Inner Partition

We can launch and verify the kernel like so:

if __name__ == "__main__":
    cutlass.cuda.initialize_cuda_context()

    tensor_shape = (8192, 8192)
    block_shape = (1, 16)
    num_threads = 256

    S = torch.randn(8192, 8192, device="cuda", dtype=torch.bfloat16)
    D = torch.zeros(8192, 8192, device="cuda", dtype=torch.bfloat16)

    tensor_S = from_dlpack(S, assumed_align=16)
    tensor_D = from_dlpack(D, assumed_align=16)

    launch_copy(tensor_S, tensor_D, block_shape, num_threads)

    torch.testing.assert_close(S, D)

    print(S)
    print(D)

We see that we provide a block_shape along with the number of threads.

Our launch function looks as follows:

@cute.jit
def launch_copy(
    tensor_S: cute.Tensor, 
    tensor_D: cute.Tensor, 
    block_shape: cute.Shape,  
    num_threads: cutlass.Constexpr[cutlass.Int32],
):
    print("Tensors:")
    print(f"tensor_S = {tensor_S}")
    print(f"tensor_D = {tensor_D}")

    # Tile (m, n) by (M, N) to obtain ((M, N), m', n')
    # , where M' and N' are the number of block tiles
    tiled_tensor_S = cute.tiled_divide(tensor_S, block_shape)  # (M, N), m', n')
    tiled_tensor_D = cute.tiled_divide(tensor_D, block_shape)  # (M, N), m', n')

    print("Block Tile Tensor:")
    print(f"tiled_tensor_S = {tiled_tensor_S}")
    print(f"tiled_tensor_D = {tiled_tensor_D}")

    grid_dim = (
        (cute.size(tiled_tensor_D, mode=[1]) * cute.size(tiled_tensor_D, mode=[2]))
        // num_threads,
        1,
        1,
    )
    block_dim = (num_threads, 1, 1)

    print("Grid and Block Configuration:")
    print(f"grid_dim = {grid_dim}")
    print(f"block_dim = {block_dim}")

    copy_kernel(tiled_tensor_S, tiled_tensor_D).launch(grid=grid_dim, block=block_dim)
    return

Note that we use tiled_divide here and provide it with the block shape. What this function does is to tile the tensor.

Assume the Tensor has a Layout of (m,n):(n,1) and the block shape is (m, n). Application of tiled divide will than tile up the Layout associated to the Tensor to a shape ((m,n), M',N') where M' and N' are simple the original modes divided by the corresponding mode of the block shape. Note that it's not randomly that we choose the block shape to be (1, 16). A bfloat16 number has 16 bits and therefore we need to schedule 8 coalesced loads for the compiler to be able to turn that into 128Bit load which is oftentimes more efficient. We chose the 16 along the coalesced column of the K-Major input tensor.

This will give us

tiled_tensor_S = tensor<ptr<bf16, gmem, align<16>> o ((1,16),8192,512):((0,1),8192,16)>
tiled_tensor_D = tensor<ptr<bf16, gmem, align<16>> o ((1,16),8192,512):((0,1),8192,16)>

Note that we can nicely see that the first mode is (1,16):(0,1) which means we will be able to schedule 2 LD128 instructions.

We than divide the product of the Rest modes 1 and 2 onto the threads and launch the kernel with the appropriate number of blocks to do that.

    grid_dim = (
        (cute.size(tiled_tensor_D, mode=[1]) * cute.size(tiled_tensor_D, mode=[2]))
        // num_threads,
        1,
        1,
    )
    block_dim = (num_threads, 1, 1)

Let's now look at the kernel:

@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdimx, _, _ = cute.arch.block_dim()

    num = cute.size(S, mode=[2])
    block_start = bidx * bdimx + tidx
    x = block_start // num
    y = block_start % num
    # Slice into the tiled tensors
    block_coordinate = ((None, None), x, y)
    tile_S = S[block_coordinate]
    tile_D = D[block_coordinate]

    print("Block Tile:")
    print(f"tile_S = {tile_S}")
    print(f"tile_D = {tile_D}")

    fragment = cute.make_fragment_like(tile_S)

    print("Fragment:")
    print(f"fragment = {fragment}")

    fragment.store(tile_S.load())
    tile_D.store(fragment.load())
    return

We recover index for first and second mode and slice the tensor appropriately. We take the whole zeroth mode (which is composed of two submodes) as this corresponds to the tile size we want to process with one thread.

This will give us something like this:

tile_S = tensor<ptr<bf16, gmem, align<16>> o (1,16):(0,1)>
tile_D = tensor<ptr<bf16, gmem, align<16>> o (1,16):(0,1)>
Fragment:
fragment = tensor<ptr<bf16, rmem, align<16>> o (1,16):(0,1)>

Here we used make_fragment to register a fragment. We than finish off by transferring GMEM -> RMEM -> GMEM.

In ncu we can verify the compiler is doing as we expected:

00007f79 07c71f20	      LDG.E.128 R12, [R2.64+0x10]
00007f79 07c71f30	      LDG.E.128 R8, [R2.64]
00007f79 07c71f40	      IMAD.WIDE R4, R4, R5, c[0x0][0x168]
00007f79 07c71f50	      STG.E.128 [R4.64+0x10], R12
00007f79 07c71f60	      STG.E.128 [R4.64], R8

This shows precisely that 128 LD/STG instructions were used.

This kernel archives a Memory Throughput of 93.32% on my consumer GPU (NVIDIA GeForce RTX 3060 Laptop GPU) .

Inner and Outer Partition

In the kernel above we had to do some index gymnastic to choose the appropriate tile. It would be nice if we could somehow automate this process of assigning threads to tiles.

It turns out CuTe offers a concept just for that: Outer Partition. In this approach we will first tile the matrix using a block_shape (i.e. perform inner partition) and than use a thread_layout to assign the subtiles of this tile to each thread (this is called outer partition).

@cute.jit
def launch_copy(
    tensor_S: cute.Tensor,  # Pointer to Source
    tensor_D: cute.Tensor,  # Pointer to Destination
    block_shape: cute.Shape,  # (M, N)
    thread_shape: cutlass.Shape,
):
    print("Tensors:")
    print(f"tensor_S = {tensor_S}")
    print(f"tensor_D = {tensor_D}")

    # Tile (m, n) by (M, N) to obtain ((M, N), (m', n'))
    # , where M' and N' are the number of block tiles
    tiled_tensor_S = cute.zipped_divide(tensor_S, block_shape)  # (M, N), (m', n'))
    tiled_tensor_D = cute.zipped_divide(tensor_D, block_shape)  # (M, N), (m', n'))

    print("Block Tile Tensor:")
    print(f"tiled_tensor_S = {tiled_tensor_S}")
    print(f"tiled_tensor_D = {tiled_tensor_D}")

    thr_layout = cute.make_layout(thread_shape, stride=(thread_shape[1], 1))

    print("Thread Layout:")
    print(f"thr_layout = {thr_layout}")

    grid_dim = (
        cute.size(tiled_tensor_D, mode=[1]),
        1,
        1,
    )
    block_dim = (cute.size(thr_layout), 1, 1)

    print("Grid and Block Configuration:")
    print(f"grid_dim = {grid_dim}")
    print(f"block_dim = {block_dim}")

    copy_kernel(tiled_tensor_S, tiled_tensor_D, thr_layout).launch(
        grid=grid_dim, block=block_dim
    )
    return

Note that here we chose

block_shape = (32, 256)
thread_shape = (8, 32)

The block shape is a tunable parameter of the kernel.

Here we perform zipped_divide which groups the Rest mode into the first mode and turned out to be a little bit more convenient (in my opinion).

We chose the Thread Layout to be (8,32):(32,1), i.e. we have a K-Major Layout for the threads which organizes 2 warp group of threads and follows the input tensor in its majorness.

Here we define the grid and blockdim directly via the layouts.

The kernel looks as follows:

@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor, ThreadLayout: cute.Layout):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    x = bidx
    # Slice into the tiled tensors
    block_coordinate = ((None, None), x)
    tile_S = S[block_coordinate]
    tile_D = D[block_coordinate]

    print("Block Tile:")
    print(f"tile_S = {tile_S}")
    print(f"tile_D = {tile_D}")

    thr_tile_S = cute.local_partition(tile_S, ThreadLayout, tidx)
    thr_tile_D = cute.local_partition(tile_D, ThreadLayout, tidx)

    print("Thread Tile:")
    print(f"thr_tile_S = {thr_tile_S}")
    print(f"thr_tile_D = {thr_tile_D}")

    fragment = cute.make_fragment_like(thr_tile_S)

    print("Fragment:")
    print(f"fragment = {fragment}")

    fragment.store(thr_tile_S.load())
    thr_tile_D.store(fragment.load())
    return

Note that we only have 2 modes because we used zipped_divide and we can directly index into that via the Block Index.

We than use local_partition to figure out how we distribute the tile onto the subtiles that will than be processed by individual threads.

In our example:

thr_tile_S = tensor<ptr<bf16, gmem> o (4,8):(65536,32)>
thr_tile_D = tensor<ptr<bf16, gmem> o (4,8):(65536,32)>

Which is clear because 32/8 = 4 and 256/32 = 8. Local partitioning needs to be provided with the tidx to figure out which sub tile to take.

We'll than copy again from GMEM -> RMEM -> GMEM.

This kernel archives 92,07% efficiency. Note that this is archived without performing vectorized load per thread as can be verified in the profiler. However we have a higher workload per thread because each thread will take care of 32 elements in the matrix.

Thread Value Partitioning

I wrote a blogpost on thread value layouts in the past. Please read it if you are not familiar with the concept and a similar kernel as well here. That's why I will go quickly over this section.

We choose:

thread_shape = (32, 8)
value_shape = (4, 8)

Note that we have chosen the value shape such that we will later be able to perform vectorized load and store like in the first kernel.

@cute.jit
def launch_copy(
    tensor_S: cute.Tensor,
    tensor_D: cute.Tensor,
    thread_shape: cute.Shape,
    value_shape: cute.Shape,
):
    print("Tensors:")
    print(f"tensor_S = {tensor_S}")
    print(f"tensor_D = {tensor_D}")

    # Obtain Tiler and TV Layout
    thr_layout = cute.make_layout(thread_shape, stride=(thread_shape[1], 1))
    val_layout = cute.make_layout(value_shape, stride=(value_shape[1], 1))
    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)

    print("Block Tile Tensor:")
    print(f"thr_layout = {thr_layout}")
    print(f"val_layout = {val_layout}")
    print(f"tiler_mn = {tiler_mn}")
    print(f"tv_layout = {tv_layout}")

    tiled_tensor_S = cute.zipped_divide(tensor_S, tiler_mn)  # (M, N), (m', n'))
    tiled_tensor_D = cute.zipped_divide(tensor_D, tiler_mn)  # (M, N), (m', n'))

    print("Block Tile Tensor:")
    print(f"tiled_tensor_S = {tiled_tensor_S}")
    print(f"tiled_tensor_D = {tiled_tensor_D}")

    grid_dim = (
        cute.size(tiled_tensor_D, mode=[1]),
        1,
        1,
    )
    block_dim = (cute.size(tv_layout, mode=[0]), 1, 1)

    print("Grid and Block Configuration:")
    print(f"grid_dim = {grid_dim}")
    print(f"block_dim = {block_dim}")

    copy_kernel(tiled_tensor_S, tiled_tensor_D, tv_layout).launch(
        grid=grid_dim, block=block_dim
    )
    return

We'll let make_layout_tv to figure out tiler and layout we will need in the kernel to perform composition of the tiled layout with the TV Layout.

Note that

L o TV-Layout will take Thread, Value pair as input and map this via L to corresponding location in our matrix.

@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor, tv_layout: cute.Layout):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    x = bidx
    # Slice into the tiled tensors
    block_coordinate = ((None, None), x)
    tile_S = S[block_coordinate]
    tile_D = D[block_coordinate]

    print("Block Tile:")
    print(f"tile_S = {tile_S}")
    print(f"tile_D = {tile_D}")

    thr_val_tile_S = cute.composition(tile_S, tv_layout)
    thr_val_tile_D = cute.composition(tile_D, tv_layout)
    thr_coord = (tidx, None)
    val_tile_S = thr_val_tile_S[thr_coord]
    val_tile_D = thr_val_tile_D[thr_coord]

    print("thr_val_tile_S:")
    print(f"thr_val_tile_S = {thr_val_tile_S}")
    print(f"thr_val_tile_D = {thr_val_tile_D}")
    print(f"val_tile_S = {val_tile_S}")
    print(f"val_tile_D = {val_tile_D}")

    fragment = cute.make_fragment_like(val_tile_S)

    print("Fragment:")
    print(f"fragment = {fragment}")

    fragment.store(val_tile_S.load())
    val_tile_D.store(fragment.load())
    return

In the kernel we compose our tile with the tv_layout.

This will look as follows:

Block Tile:
tile_S = tensor<ptr<bf16, gmem, align<16>> o (128,64):(4096,1)>
tile_D = tensor<ptr<bf16, gmem, align<16>> o (128,64):(4096,1)>
thr_val_tile_S:
thr_val_tile_S = tensor<ptr<bf16, gmem, align<16>> o ((8,32),(8,4)):((8,16384),(1,4096))>
thr_val_tile_D = tensor<ptr<bf16, gmem, align<16>> o ((8,32),(8,4)):((8,16384),(1,4096))>
val_tile_S = tensor<ptr<bf16, gmem, align<16>> o ((8,4)):((1,4096))>
val_tile_D = tensor<ptr<bf16, gmem, align<16>> o ((8,4)):((1,4096))>
Fragment:
fragment = tensor<ptr<bf16, rmem, align<16>> o ((8,4)):((1,8))>

Note how nicely we obtain the val_tile_S: ((8,4)):((1,4096)) for each thread. We see that this will give us the ability to perform vectorized load. We can confirm this in the profiler and indeed:

LDG.E.128 R8, [R2.64]
LDG.E.128 R12, [R2.64+0x2000]
LDG.E.128 R16, [R2.64+0x4000]
LDG.E.128 R20, [R2.64+0x6000]
IADD3 R4, P0, R4, c[0x0][0x168], RZ
IADD3.X R5, R0, c[0x0][0x16c], RZ, P0, !PT
STG.E.128 [R4.64], R8
STG.E.128 [R4.64+0x2000], R12
STG.E.128 [R4.64+0x4000], R16
STG.E.128 [R4.64+0x6000], R20

which is what we expect and want: 4 times a vectorized load. Each of them differ by 0x2000 = 16^3 = 4096 which corresponds to the second stride for the value tiles. This kernel archives 89,53% utilization.

Conclusion

I hope this blogpost showed that we can use multiple ways of partitioning our goal in the CuTeDSL. For a given situation we can pick what is most convenient and choose the corresponding partition pattern. This blogpost was inspired by multiple examples given in the CUTLASS library. Please consider giving the library a star. You can connect with me via Linkedin to exchange ideas.