Predication in Cutlass
The cutlass documentation on CuTe touches the topic of predication briefly but doesn't give a full code example. In this blogpost I will explain how to use predication to perform appropriate boundary checking in a CuTe program.
Introduction
The kernel we start from is taken from the CuTe tutorial and performs an efficient tiled copy. Before we start with the topic of predication let's briefly focus on the non vectorised version of tiled copy.
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
# SEE BELOW
}
/// Main function
int main(int argc, char** argv) {
//
// Given a 2D shape, perform an efficient copy
//
using namespace cute;
using Element = float;
int M = 32768;
int N = 16384;
auto tensor_shape = make_shape(M, N);
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
Tensor tensor_S =
make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())),
make_layout(tensor_shape));
Tensor tensor_D =
make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())),
make_layout(tensor_shape));
auto block_shape = make_shape(Int<256>{}, Int<128>{});
Tensor tiled_tensor_S =
tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
Tensor tiled_tensor_D =
tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Thread arrangement
Layout thr_layout =
make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx
dim3 gridDim(
size<1>(tiled_tensor_D),
size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(thr_layout));
copy_kernel<<<gridDim, blockDim>>>(tiled_tensor_S, tiled_tensor_D,
thr_layout);
cudaError result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result)
<< std::endl;
return -1;
}
h_D = d_D;
int32_t errors = 0;
int32_t const kErrorLimit = 10;
for (size_t i = 0; i < h_D.size(); ++i) {
if (h_S[i] != h_D[i]) {
std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i
<< "]: " << h_D[i] << std::endl;
if (++errors >= kErrorLimit) {
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
return -1;
}
}
}
std::cout << "Success." << std::endl;
return 0;
}
This example is adopted from the CuTe tutorial from the cutlass repo. These are the steps taken before we call the kernel in our main function. We will step by step explain them below.
Tensor tensor_S =
make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())),
make_layout(tensor_shape));
Tensor tensor_D =
make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())),
make_layout(tensor_shape));
simply initialises the tensors.
auto block_shape = make_shape(Int<256>{}, Int<128>{});
if ((size<0>(tensor_shape) % size<0>(block_shape)) ||
(size<1>(tensor_shape) % size<1>(block_shape))) {
std::cerr << "The tensor shape must be divisible by the block shape."
<< std::endl;
return -1;
}
Tensor tiled_tensor_S =
tiled_divide(tensor_S, block_shape);
Tensor tiled_tensor_D =
tiled_divide(tensor_D, block_shape);
Here we simply tile tensor. This will transform (M,N) -> ((blkM, blkN), ceil(M/blkM), ceil(N/blkN)
, i.e. we tile our initial matrix into smaller matrices with shape (blkM, blkN)
. The last two dimensions of the shape correspond to the number of blocks we create in x
and y
dimension of the grid.
For the example above we will have
(32768, 16384) -> ((256, 128), 128, 128))
Layout thr_layout =
make_layout(make_shape(Int<32>{}, Int<8>{}));
dim3 gridDim(
size<1>(tiled_tensor_D),
size<2>(tiled_tensor_D));
dim3 blockDim(size(thr_layout));
copy_kernel<<<gridDim, blockDim>>>(tiled_tensor_S, tiled_tensor_D,
thr_layout);
Here we make a thread layout.
This will further tile our block tiles onto the threads inside the kernel as we see below.
(256, 128) -> (256/32, 128/8) = (8, 16)
We launch the kernel than with the number of blocks given by the tiled layout (i.e. 256
blocks in x and 128
blocks in y direction) and the number of threads we need for the thread tiling (i.e. 32 * 8 = 256
) per block.
The kernel without predication looks like this:
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
using namespace cute;
Tensor tile_S = S(make_coord(_, _), blockIdx.x,
blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x,
blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{},
threadIdx.x); // (ThrValM, ThrValN)
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{},
threadIdx.x); // (ThrValM, ThrValN)
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(thr_tile_S, fragment);
copy(fragment, thr_tile_D);
}
It simply takes the whole matrix tile, than creates a local partition of it as described above. Each thread than copies one of these elements from GMEM -> RMEM -> GMEM
.
This processes is highly efficient and archives bandwidth of ~3 TB/s
on H100
. We could further increase this by tuning the different tile sizes but this is not our focus in this blogpost.
Why we need predication?
Let's imagine we want to process a matrix of dimension (M, N) = (32768 + 1, 16384 + 1)
, the above kernel will not work than.
Why? Because our tiling will result in a Layout of (32768 + 1, 16384 + 1) -> ((256, 128), 129, 129))
. The problem here is that in the last blocks we will attempt to copy data that we should not copy. That is because in the last block in x direction there will be only 1 element we need to process and similar in the last block in y direction. We don't want to copy the whole thread tile for these blocks!
You can give it a try by trying to run the program with adjusted M
and N
and you will get an error like:
CUDA Runtime error: an illegal memory access was encountered
terminate called after throwing an instance of 'thrust::THRUST_200700_900_NS::system::system_error'
what(): CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)
This should in fact be not surprising for anyone who worked with CUDA kernels. We quiet frequently observe the need to do proper boundary checking.
Predication with CuTe
In this part of the blogpost we will give the kernel that solves the above problem.
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel_predicate(TensorS S, TensorD D, ThreadLayout, int M,
int N) {
using namespace cute;
Tensor tile_S = S(make_coord(_, _), blockIdx.x,
blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x,
blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{},
threadIdx.x); // (ThrValM, ThrValN)
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{},
threadIdx.x); // (ThrValM, ThrValN)
auto identity_tensor = make_identity_tensor(make_shape(
size<0>(tile_S), size<1>(tile_S))); // (BlockShape_M, BlockShape_N)
auto thread_identity_tensor = local_partition(
identity_tensor, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
auto predicator = make_tensor<bool>(
make_shape(size<0>(fragment), size<1>(fragment))); // (ThrValM, ThrValN)
CUTE_UNROLL
for (int i = 0; i < size<0>(predicator); ++i) {
CUTE_UNROLL
for (int j = 0; j < size<1>(predicator); ++j) {
auto thread_identity = thread_identity_tensor(i, j);
int global_row = blockIdx.x * size<0>(tile_S) + get<0>(thread_identity);
int global_col = blockIdx.y * size<1>(tile_S) + get<1>(thread_identity);
predicator(i, j) = (global_row < M) && (global_col < N);
}
}
// Copy from GMEM to RMEM and from RMEM to GMEM with predicate
copy_if(predicator, thr_tile_S, fragment);
copy_if(predicator, fragment, thr_tile_D);
}
We see that the kernel is very similar to the version without predication. The logic for predication was adopted from Lei Mao who used similar technique for the task of matrix transpose. Check out his blogposts, they are very nice to read! Now we will explain what we needed to change in our kernel to make it work for matrices which are not exactly dividable by the block tile dimensions.
auto identity_tensor = make_identity_tensor(make_shape(
size<0>(tile_S), size<1>(tile_S))); // (BlockShape_M, BlockShape_N)
auto thread_identity_tensor = local_partition(
identity_tensor, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
We create an identity tensor with exactly the same tiling like the tensors we want to copy. The identity tensor will simply map (x,y)->(x,y)
.
auto predicator = make_tensor<bool>(
make_shape(size<0>(fragment), size<1>(fragment))); // (ThrValM, ThrValN)
We initialise a predicator matrix. This matrix will be 1 for all tuples (x,y)
which are within the bound [0, M] x [0, N]
, i.e. it will be a simple indicator variable which is 1 when the element in question lies within our matrix.
CUTE_UNROLL
for (int i = 0; i < size<0>(predicator); ++i) {
CUTE_UNROLL
for (int j = 0; j < size<1>(predicator); ++j) {
auto thread_identity = thread_identity_tensor(i, j);
int global_row = blockIdx.x * size<0>(tile_S) + get<0>(thread_identity);
int global_col = blockIdx.y * size<1>(tile_S) + get<1>(thread_identity);
predicator(i, j) = (global_row < M) && (global_col < N);
}
}
We go over all thread tiles. We calculate the corresponding global row and col for each pair (i, j)
. We can simply do that by multiplying blockIdx
with the corresponding block tile length in dimension and adding the corresponding offset due to the thread tiling.
copy_if(predicator, thr_tile_S, fragment);
copy_if(predicator, fragment, thr_tile_D);
This will simply copy the entries that are within the bounds of our matrix.
You can convince yourself that the kernel will happily copy over the matrix with shape (M, N) = (32768 + 1, 16384 + 1)
without complaining and gives the correct result.
The performance is on pair with the copy kernel from above for matrix shapes that are evenly divisible by the block tile dimension. I guess that is due to the fact that compiler recognises he can optimise out the predication. For matrices the kernel will perform slightly less optimal due to the warp divergence.
I hope this blogpost helped you to understand predication with CuTe better.