CuTe partitions
In CuTe
we have multiple ways to tile our tensors, i.e. to divide the work onto the threads on our GPU.
The CuTe
docs describe three modes of tiling
Inner Partition
Outer Partition
Thread Value Partition
In this blogpost we will examine all three of them and implement a Copy kernel corresponding to each concept.
Pure Inner
Partition
We can launch and verify the kernel like so:
if __name__ == "__main__":
cutlass.cuda.initialize_cuda_context()
tensor_shape = (8192, 8192)
block_shape = (1, 16)
num_threads = 256
S = torch.randn(8192, 8192, device="cuda", dtype=torch.bfloat16)
D = torch.zeros(8192, 8192, device="cuda", dtype=torch.bfloat16)
tensor_S = from_dlpack(S, assumed_align=16)
tensor_D = from_dlpack(D, assumed_align=16)
launch_copy(tensor_S, tensor_D, block_shape, num_threads)
torch.testing.assert_close(S, D)
print(S)
print(D)
We see that we provide a block_shape
along with the number of threads.
Our launch function looks as follows:
@cute.jit
def launch_copy(
tensor_S: cute.Tensor,
tensor_D: cute.Tensor,
block_shape: cute.Shape,
num_threads: cutlass.Constexpr[cutlass.Int32],
):
print("Tensors:")
print(f"tensor_S = {tensor_S}")
print(f"tensor_D = {tensor_D}")
# Tile (m, n) by (M, N) to obtain ((M, N), m', n')
# , where M' and N' are the number of block tiles
tiled_tensor_S = cute.tiled_divide(tensor_S, block_shape) # (M, N), m', n')
tiled_tensor_D = cute.tiled_divide(tensor_D, block_shape) # (M, N), m', n')
print("Block Tile Tensor:")
print(f"tiled_tensor_S = {tiled_tensor_S}")
print(f"tiled_tensor_D = {tiled_tensor_D}")
grid_dim = (
(cute.size(tiled_tensor_D, mode=[1]) * cute.size(tiled_tensor_D, mode=[2]))
// num_threads,
1,
1,
)
block_dim = (num_threads, 1, 1)
print("Grid and Block Configuration:")
print(f"grid_dim = {grid_dim}")
print(f"block_dim = {block_dim}")
copy_kernel(tiled_tensor_S, tiled_tensor_D).launch(grid=grid_dim, block=block_dim)
return
Note that we use tiled_divide
here and provide it with the block shape. What this function does is to tile the tensor.
Assume the Tensor has a Layout of (m,n):(n,1)
and the block shape is (m, n)
. Application of tiled divide will than tile up the Layout associated to the Tensor to a shape ((m,n), M',N')
where M'
and N'
are simple the original modes divided by the corresponding mode of the block shape. Note that it's not randomly that we choose the block shape to be (1, 16)
. A bfloat16
number has 16 bits
and therefore we need to schedule 8
coalesced loads for the compiler to be able to turn that into 128Bit
load which is oftentimes more efficient. We chose the 16
along the coalesced column of the K-Major
input tensor.
This will give us
tiled_tensor_S = tensor<ptr<bf16, gmem, align<16>> o ((1,16),8192,512):((0,1),8192,16)>
tiled_tensor_D = tensor<ptr<bf16, gmem, align<16>> o ((1,16),8192,512):((0,1),8192,16)>
Note that we can nicely see that the first mode is (1,16):(0,1)
which means we will be able to schedule 2 LD128
instructions.
We than divide the product of the Rest
modes 1
and 2
onto the threads and launch the kernel with the appropriate number of blocks to do that.
grid_dim = (
(cute.size(tiled_tensor_D, mode=[1]) * cute.size(tiled_tensor_D, mode=[2]))
// num_threads,
1,
1,
)
block_dim = (num_threads, 1, 1)
Let's now look at the kernel:
@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdimx, _, _ = cute.arch.block_dim()
num = cute.size(S, mode=[2])
block_start = bidx * bdimx + tidx
x = block_start // num
y = block_start % num
# Slice into the tiled tensors
block_coordinate = ((None, None), x, y)
tile_S = S[block_coordinate]
tile_D = D[block_coordinate]
print("Block Tile:")
print(f"tile_S = {tile_S}")
print(f"tile_D = {tile_D}")
fragment = cute.make_fragment_like(tile_S)
print("Fragment:")
print(f"fragment = {fragment}")
fragment.store(tile_S.load())
tile_D.store(fragment.load())
return
We recover index for first
and second
mode and slice the tensor appropriately. We take the whole zeroth mode (which is composed of two submodes) as this corresponds to the tile size we want to process with one thread.
This will give us something like this:
tile_S = tensor<ptr<bf16, gmem, align<16>> o (1,16):(0,1)>
tile_D = tensor<ptr<bf16, gmem, align<16>> o (1,16):(0,1)>
Fragment:
fragment = tensor<ptr<bf16, rmem, align<16>> o (1,16):(0,1)>
Here we used make_fragment
to register a fragment. We than finish off by transferring GMEM -> RMEM -> GMEM
.
In ncu
we can verify the compiler is doing as we expected:
00007f79 07c71f20 LDG.E.128 R12, [R2.64+0x10]
00007f79 07c71f30 LDG.E.128 R8, [R2.64]
00007f79 07c71f40 IMAD.WIDE R4, R4, R5, c[0x0][0x168]
00007f79 07c71f50 STG.E.128 [R4.64+0x10], R12
00007f79 07c71f60 STG.E.128 [R4.64], R8
This shows precisely that 128
LD/STG
instructions were used.
This kernel archives a Memory Throughput of 93.32%
on my consumer GPU (NVIDIA GeForce RTX 3060 Laptop GPU
) .
Inner
and Outer
Partition
In the kernel above we had to do some index gymnastic to choose the appropriate tile. It would be nice if we could somehow automate this process of assigning threads to tiles.
It turns out CuTe
offers a concept just for that: Outer Partition
. In this approach we will first tile the matrix using a block_shape
(i.e. perform inner partition) and than use a thread_layout
to assign the subtiles of this tile to each thread (this is called outer partition).
@cute.jit
def launch_copy(
tensor_S: cute.Tensor, # Pointer to Source
tensor_D: cute.Tensor, # Pointer to Destination
block_shape: cute.Shape, # (M, N)
thread_shape: cutlass.Shape,
):
print("Tensors:")
print(f"tensor_S = {tensor_S}")
print(f"tensor_D = {tensor_D}")
# Tile (m, n) by (M, N) to obtain ((M, N), (m', n'))
# , where M' and N' are the number of block tiles
tiled_tensor_S = cute.zipped_divide(tensor_S, block_shape) # (M, N), (m', n'))
tiled_tensor_D = cute.zipped_divide(tensor_D, block_shape) # (M, N), (m', n'))
print("Block Tile Tensor:")
print(f"tiled_tensor_S = {tiled_tensor_S}")
print(f"tiled_tensor_D = {tiled_tensor_D}")
thr_layout = cute.make_layout(thread_shape, stride=(thread_shape[1], 1))
print("Thread Layout:")
print(f"thr_layout = {thr_layout}")
grid_dim = (
cute.size(tiled_tensor_D, mode=[1]),
1,
1,
)
block_dim = (cute.size(thr_layout), 1, 1)
print("Grid and Block Configuration:")
print(f"grid_dim = {grid_dim}")
print(f"block_dim = {block_dim}")
copy_kernel(tiled_tensor_S, tiled_tensor_D, thr_layout).launch(
grid=grid_dim, block=block_dim
)
return
Note that here we chose
block_shape = (32, 256)
thread_shape = (8, 32)
The block shape is a tunable parameter of the kernel.
Here we perform zipped_divide
which groups the Rest mode into the first mode and turned out to be a little bit more convenient (in my opinion).
We chose the Thread Layout to be (8,32):(32,1)
, i.e. we have a K-Major
Layout for the threads which organizes 2 warp group of threads and follows the input tensor in its majorness.
Here we define the grid
and blockdim
directly via the layouts.
The kernel looks as follows:
@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor, ThreadLayout: cute.Layout):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
x = bidx
# Slice into the tiled tensors
block_coordinate = ((None, None), x)
tile_S = S[block_coordinate]
tile_D = D[block_coordinate]
print("Block Tile:")
print(f"tile_S = {tile_S}")
print(f"tile_D = {tile_D}")
thr_tile_S = cute.local_partition(tile_S, ThreadLayout, tidx)
thr_tile_D = cute.local_partition(tile_D, ThreadLayout, tidx)
print("Thread Tile:")
print(f"thr_tile_S = {thr_tile_S}")
print(f"thr_tile_D = {thr_tile_D}")
fragment = cute.make_fragment_like(thr_tile_S)
print("Fragment:")
print(f"fragment = {fragment}")
fragment.store(thr_tile_S.load())
thr_tile_D.store(fragment.load())
return
Note that we only have 2 modes because we used zipped_divide
and we can directly index into that via the Block Index
.
We than use local_partition
to figure out how we distribute the tile onto the subtiles that will than be processed by individual threads.
In our example:
thr_tile_S = tensor<ptr<bf16, gmem> o (4,8):(65536,32)>
thr_tile_D = tensor<ptr<bf16, gmem> o (4,8):(65536,32)>
Which is clear because 32/8 = 4
and 256/32 = 8
. Local partitioning needs to be provided with the tidx
to figure out which sub tile to take.
We'll than copy again from GMEM -> RMEM -> GMEM
.
This kernel archives 92,07%
efficiency. Note that this is archived without performing vectorized load per thread as can be verified in the profiler. However we have a higher workload per thread because each thread will take care of 32
elements in the matrix.
Thread Value Partitioning
I wrote a blogpost on thread value layouts in the past. Please read it if you are not familiar with the concept and a similar kernel as well here. That's why I will go quickly over this section.
We choose:
thread_shape = (32, 8)
value_shape = (4, 8)
Note that we have chosen the value shape such that we will later be able to perform vectorized load and store like in the first kernel.
@cute.jit
def launch_copy(
tensor_S: cute.Tensor,
tensor_D: cute.Tensor,
thread_shape: cute.Shape,
value_shape: cute.Shape,
):
print("Tensors:")
print(f"tensor_S = {tensor_S}")
print(f"tensor_D = {tensor_D}")
# Obtain Tiler and TV Layout
thr_layout = cute.make_layout(thread_shape, stride=(thread_shape[1], 1))
val_layout = cute.make_layout(value_shape, stride=(value_shape[1], 1))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
print("Block Tile Tensor:")
print(f"thr_layout = {thr_layout}")
print(f"val_layout = {val_layout}")
print(f"tiler_mn = {tiler_mn}")
print(f"tv_layout = {tv_layout}")
tiled_tensor_S = cute.zipped_divide(tensor_S, tiler_mn) # (M, N), (m', n'))
tiled_tensor_D = cute.zipped_divide(tensor_D, tiler_mn) # (M, N), (m', n'))
print("Block Tile Tensor:")
print(f"tiled_tensor_S = {tiled_tensor_S}")
print(f"tiled_tensor_D = {tiled_tensor_D}")
grid_dim = (
cute.size(tiled_tensor_D, mode=[1]),
1,
1,
)
block_dim = (cute.size(tv_layout, mode=[0]), 1, 1)
print("Grid and Block Configuration:")
print(f"grid_dim = {grid_dim}")
print(f"block_dim = {block_dim}")
copy_kernel(tiled_tensor_S, tiled_tensor_D, tv_layout).launch(
grid=grid_dim, block=block_dim
)
return
We'll let make_layout_tv
to figure out tiler and layout we will need in the kernel to perform composition of the tiled layout with the TV Layout.
Note that
L o TV-Layout
will take Thread, Value
pair as input and map this via L
to corresponding location in our matrix.
@cute.kernel
def copy_kernel(S: cute.Tensor, D: cute.Tensor, tv_layout: cute.Layout):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
x = bidx
# Slice into the tiled tensors
block_coordinate = ((None, None), x)
tile_S = S[block_coordinate]
tile_D = D[block_coordinate]
print("Block Tile:")
print(f"tile_S = {tile_S}")
print(f"tile_D = {tile_D}")
thr_val_tile_S = cute.composition(tile_S, tv_layout)
thr_val_tile_D = cute.composition(tile_D, tv_layout)
thr_coord = (tidx, None)
val_tile_S = thr_val_tile_S[thr_coord]
val_tile_D = thr_val_tile_D[thr_coord]
print("thr_val_tile_S:")
print(f"thr_val_tile_S = {thr_val_tile_S}")
print(f"thr_val_tile_D = {thr_val_tile_D}")
print(f"val_tile_S = {val_tile_S}")
print(f"val_tile_D = {val_tile_D}")
fragment = cute.make_fragment_like(val_tile_S)
print("Fragment:")
print(f"fragment = {fragment}")
fragment.store(val_tile_S.load())
val_tile_D.store(fragment.load())
return
In the kernel we compose our tile with the tv_layout
.
This will look as follows:
Block Tile:
tile_S = tensor<ptr<bf16, gmem, align<16>> o (128,64):(4096,1)>
tile_D = tensor<ptr<bf16, gmem, align<16>> o (128,64):(4096,1)>
thr_val_tile_S:
thr_val_tile_S = tensor<ptr<bf16, gmem, align<16>> o ((8,32),(8,4)):((8,16384),(1,4096))>
thr_val_tile_D = tensor<ptr<bf16, gmem, align<16>> o ((8,32),(8,4)):((8,16384),(1,4096))>
val_tile_S = tensor<ptr<bf16, gmem, align<16>> o ((8,4)):((1,4096))>
val_tile_D = tensor<ptr<bf16, gmem, align<16>> o ((8,4)):((1,4096))>
Fragment:
fragment = tensor<ptr<bf16, rmem, align<16>> o ((8,4)):((1,8))>
Note how nicely we obtain the val_tile_S: ((8,4)):((1,4096))
for each thread. We see that this will give us the ability to perform vectorized load. We can confirm this in the profiler and indeed:
LDG.E.128 R8, [R2.64]
LDG.E.128 R12, [R2.64+0x2000]
LDG.E.128 R16, [R2.64+0x4000]
LDG.E.128 R20, [R2.64+0x6000]
IADD3 R4, P0, R4, c[0x0][0x168], RZ
IADD3.X R5, R0, c[0x0][0x16c], RZ, P0, !PT
STG.E.128 [R4.64], R8
STG.E.128 [R4.64+0x2000], R12
STG.E.128 [R4.64+0x4000], R16
STG.E.128 [R4.64+0x6000], R20
which is what we expect and want: 4 times a vectorized load. Each of them differ by 0x2000 = 16^3 = 4096
which corresponds to the second stride for the value tiles. This kernel archives 89,53%
utilization.
Conclusion
I hope this blogpost showed that we can use multiple ways of partitioning our goal in the CuTeDSL
. For a given situation we can pick what is most convenient and choose the corresponding partition pattern.
This blogpost was inspired by multiple examples given in the CUTLASS library. Please consider giving the library a star.
You can connect with me via Linkedin to exchange ideas.