simons blog

Making vector sum really fast

In this blogpost we want to briefly describe how to archive SOTA performance for the task of reduction on a vector, i.e. our program should do the following: Given a vector v return the sum of all elements in v. We will assume that the vector is large, i.e. it contains N = 1 << 30 = 2^30 entries.

Baseline

template <unsigned int threadsPerBlock>
__global__ void kernel_0(const int *d_in, int *d_out, size_t N) {
  extern __shared__ int sums[threadsPerBlock];
  int sum = 0;
  const int tid = threadIdx.x;
  const int global_tid = blockIdx.x * threadsPerBlock + tid;
  const int threads_in_grid = threadsPerBlock * gridDim.x;

  for (int i = global_tid; i < N; i += threads_in_grid) {
    sum += d_in[i];
  }
  sums[tid] = sum;
  __syncthreads();

  for (int activeThreads = threadsPerBlock >> 1; activeThreads;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      sums[tid] += sums[tid + activeThreads];
    }
    __syncthreads();
  }

  if (tid == 0) {
    d_out[blockIdx.x] = sums[tid];
  }
}

template <int threadsPerBlock>
void kernel_0_launch(const int *d_in, int *d_first, int *d_out, size_t N) {
  const int numBlocks = (N + threadsPerBlock - 1) / threadsPerBlock;
  kernel_0<threadsPerBlock><<<numBlocks, threadsPerBlock>>>(
      d_in, d_first, N);
  kernel_0<threadsPerBlock><<<1, threadsPerBlock>>>(
      d_first, d_out, numBlocks);
}

Our baseline is a simple two pass approach. We launch kernel_0 two times. First time we reduce to the sum in each block and them sum these up. The algorithm works such that

  1. We store the vector entry corresponding to the current thread in the shared memory.
  2. We'll than divide our block of threads into two halfs. We accumulate the result of the first thread in left half with the result of the thread in the right half etc.
  3. We synchronize, i.e. we wait that operation to have finished for all threads, i.e. the left block is contains accumulation results for each thread in it.
  4. We ignore the right half from above and continue the above procedure until we arrive at the very first thread on the left side. In this way we obtain the sum in each block. We'll than treat this as another vector and reduce this vector in a single block to obtain the total sum.

This approach archives a bandwidth of 639.103 GB/s corresponding to 19.3668 % of possible bandwidth in an H100 GPU.

Using warps

In a GPU one warp consists of 32 threads. We can process the first warp in each block separately and use the more efficent __syncwarp(); to enforce synchronisation. Note that in a first version of the blogpost I assumed all threads in a warp execute in a synchronised fashion. That turns out to be wrong and all though we will get the correct result most of the times it can lead to race conditions which can be discovered by running compute-sanitizer --tool racecheck on the compiled kernel. Luckily the usage of syncwarp only costs ~1GB/s in bandwidth. Thanks to Pauleonix for pointing that out!

template <unsigned int threadsPerBlock>
__global__ void kernel_1(const int *d_in, int *d_out, size_t N) {
  extern __shared__ int sums[threadsPerBlock];
  int sum = 0;
  const int tid = threadIdx.x;
  const int global_tid = blockIdx.x * threadsPerBlock + tid;
  const int threads_in_grid = threadsPerBlock * gridDim.x;

  for (int i = global_tid; i < N; i += threads_in_grid) {
    sum += d_in[i];
  }
  sums[tid] = sum;
  __syncthreads();

#pragma unroll
  for (int activeThreads = threadsPerBlock >> 1; activeThreads > 32;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      sums[tid] += sums[tid + activeThreads];
    }
    __syncthreads();
  }

  volatile int *volatile_sums = sums;
#pragma unroll
  for (int activeThreads = 32; activeThreads; activeThreads >>= 1) {
    if (tid < activeThreads) {
      volatile_sums[tid] += volatile_sums[tid + activeThreads];
    }
    __syncwarp();
  }

  if (tid == 0) {
    d_out[blockIdx.x] = volatile_sums[tid];
  }
}

template <int threadsPerBlock>
void kernel_1_launch(const int *d_in, int *d_first, int *d_out, size_t N) {
  const int numBlocks = (N + threadsPerBlock - 1) / threadsPerBlock;
  kernel_1<threadsPerBlock><<<numBlocks, threadsPerBlock>>>(d_in, d_first, N);
  kernel_1<threadsPerBlock><<<1, threadsPerBlock>>>(d_first, d_out, numBlocks);
}

This boosts our performance a little bit onto 661.203 GB/s corresponding to 20.0365% utilization.

One Pass

We can use atomicAdd to add across blocks to a memory location in CUDA. We can use this to implement a simple one pass kernel

template <unsigned int threadsPerBlock>
__global__ void kernel_2(const int *d_in, int *d_out, size_t N) {
  extern __shared__ int sums[threadsPerBlock];
  int sum = 0;
  const int tid = threadIdx.x;
  const int global_tid = blockIdx.x * threadsPerBlock + tid;

  if (global_tid == 0) {
    *d_out = 0;
  }

  if (global_tid < N) {
    sum += d_in[global_tid];
  }
  sums[tid] = sum;
  __syncthreads();

#pragma unroll
  for (int activeThreads = threadsPerBlock >> 1; activeThreads > 32;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      sums[tid] += sums[tid + activeThreads];
    }
    __syncthreads();
  }

  volatile int *volatile_sums = sums;
#pragma unroll
  for (int activeThreads = 32; activeThreads; activeThreads >>= 1) {
    if (tid < activeThreads) {
      volatile_sums[tid] += volatile_sums[tid + activeThreads];
    }
    __syncwarp();
  }

  if (tid == 0) {
    atomicAdd(d_out, volatile_sums[tid]);
  }
}

template <int threadsPerBlock>
void kernel_2_launch(const int *d_in, int *d_out, size_t N) {
  const int numBlocks = (N + threadsPerBlock - 1) / threadsPerBlock;
  cudaMemset(d_out, 0, sizeof(int));
  kernel_2<threadsPerBlock><<<numBlocks, threadsPerBlock>>>(d_in, d_out, N);
}

This boosts our performance to 882.823 GB/s corresponding to 26.7522% utilisation.

Please Note that an earlier version of the kernel initialized d_out for global_tid = 0, this could potentially lead to a race condition as pointed out by Learninmou, lucky that has very small effect on the bandwidth and even with the proper cudaMemset we remain superior to the NVIDIA library. I updated the numbers accordingly.

Increasing arithmetic intensity

If you look for one moment at the kernel above it is obvious that each thread simply accesses the corresponding entry in the vector and writes it into shared memory. We can do better than that by letting each thread process a Batch of elements.

template <unsigned int threadsPerBlock, unsigned int batchSize>
__global__ void kernel_3(const int *d_in, int *d_out, size_t N) {
  extern __shared__ int sums[threadsPerBlock];
  int sum = 0;
  const int tid = threadIdx.x;  
  const int global_tid = blockIdx.x * threadsPerBlock + tid;
  const int threads_in_grid = threadsPerBlock * gridDim.x;

  if (global_tid < N) {
#pragma unroll
    for (int j = 0; j < batchSize; j++) {
      if (global_tid * batchSize + j < N) {
        sum += d_in[global_tid * batchSize + j];
      }
    }
  }
  sums[tid] = sum;
  __syncthreads();

#pragma unroll
  for (int activeThreads = threadsPerBlock >> 1; activeThreads > 32;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      sums[tid] += sums[tid + activeThreads];
    }
    __syncthreads();
  }

  volatile int *volatile_sums = sums;
#pragma unroll
  for (int activeThreads = 32; activeThreads; activeThreads >>= 1) {
    if (tid < activeThreads) {
      volatile_sums[tid] += volatile_sums[tid + activeThreads];
    }
    __syncwarp();
  }

  if (tid == 0) {
    atomicAdd(d_out, volatile_sums[tid]);
  }
}

template <int threadsPerBlock, int batchSize>
void kernel_3_launch(const int *d_in, int *d_out, size_t N) {
  const int numBlocks = (N + threadsPerBlock * batchSize - 1) /
                        (threadsPerBlock * batchSize);
  cudaMemset(d_out, 0, sizeof(int));
  kernel_3<threadsPerBlock, batchSize><<<numBlocks, threadsPerBlock>>>(d_in,
                                                                       d_out, N);
}

As we see we now launch less blocks, that is because each thread now processes Batchsize elements. This increases the workload of each batch and gives a huge boost in performance! With this approach we can get 3226.86 GB/s which is very close to the physical maximum with 97.7838% utilisation.

Vectorise load

CUDA offers user vectorised data type int4. We can use that to load the data more efficently.

template <unsigned int threadsPerBlock, unsigned int batchSize>
__global__ void kernel_4(const int4 *d_in, int *d_out, size_t N) {
  extern __shared__ int sums[threadsPerBlock];
  int sum = 0;
  const int tid = threadIdx.x;  
  const int global_tid = blockIdx.x * threadsPerBlock + tid;
  const int threads_in_grid = threadsPerBlock * gridDim.x;


  if (global_tid < N) {
#pragma unroll
    for (int i = 0; i < batchSize >> 2; i++) {
      const int4 val = d_in[global_tid * (batchSize >> 2) + i];
      if (global_tid * batchSize + i * 4 < N) {
        sum += val.x + val.y + val.z + val.w;
      }
    }
  }
  sums[tid] = sum;
  __syncthreads();

#pragma unroll
  for (int activeThreads = threadsPerBlock >> 1; activeThreads > 32;
       activeThreads >>= 1) {
    if (tid < activeThreads) {
      sums[tid] += sums[tid + activeThreads];
    }
    __syncthreads();
  }

  volatile int *volatile_sums = sums;
#pragma unroll
  for (int activeThreads = 32; activeThreads; activeThreads >>= 1) {
    if (tid < activeThreads) {
      volatile_sums[tid] += volatile_sums[tid + activeThreads];
    }
    __syncwarp();
  }

  if (tid == 0) {
    atomicAdd(d_out, volatile_sums[tid]);
  }
}

template <int threadsPerBlock, int batchSize>
void kernel_4_launch(const int *d_in, int *d_out, size_t N) {
  const int numBlocks = (N + threadsPerBlock * batchSize - 1) /
                        (threadsPerBlock * batchSize);
  const int4 *d_in_cast = reinterpret_cast<const int4 *>(d_in);
  cudaMemset(d_out, 0, sizeof(int));
  kernel_4<threadsPerBlock, batchSize><<<numBlocks, threadsPerBlock>>>(d_in_cast,
                                                                       d_out, N);
}

This gives a tiny improvement over the version from above to 3229.98 GB/s which corresponds to 97.8783%.

Benchmark nvidia library

We can benchmark the nvidia native implementation of the above operation as follows:

void kernel_5_launch(const int *d_in, int *d_out, size_t N) {
  void* d_temp = nullptr;
  size_t temp_storage = 0;

  // First call to determine temporary storage size
  cub::DeviceReduce::Sum(d_temp, temp_storage, d_in, d_out, N);
  
  // Allocate temporary storage
  assert(temp_storage > 0);
  cudaMalloc(&d_temp, temp_storage);

  cub::DeviceReduce::Sum(d_temp, temp_storage, d_in, d_out, N);
}

This gives us 3191.42 GB/s and 96.7097% utilisation. That means using our approach we outperformed NVIDIA implementation for the problem size (N = 1 << 30) and hardware (H100) we choose.

References

This blogposts was inspired by the discussion on reduction given in The CUDA Handbook. The idea for batching is from fast.cu repository as well as the code for benchmarking cub library. Some of the approaches taken there should be able to improve the performance of our kernel even further but I chose to stop at a point where it's still easy to follow for a beginner. I highly recommend to check out this repo and the blogpost of it's author on valuable insights on writing performant CUDA kernels.

You can reproduce the experiments and find my code in this repo. I ran the experiments on H100 and a docker image for CUDA 12.8.

Update: I integrated two new kernels using warp level intrinsics. For the code and benchmark you can check out my repo.