simons blog

Cuda streams

Introduction

In CUDA a Stream is a sequence of commands that are executed in order. Multiple streams can be used to overlap different tasks. For each stream the tasks will be executed in order and for different streams we can potentially overlap the tasks if resources allow it. Here I will explain how to split up a task over multiple streams. Each single stream has the following tasks Transfer from Host to Device -> Kernel -> Transfer from Device to Host and the different streams will be overlapped.

Ordinary computation

In our example we will use a "dummy kernel"

__global__ void kernel(float *a, float *b, const int offset) {
  const auto i = offset + blockIdx.x * blockDim.x + threadIdx.x;
  const auto x = static_cast<float>(i);
  const auto s = sinf(x);
  const auto c = cosf(x);

  b[i] = a[i] + sqrtf(s * s + c * c);
}

Readers familiar with elementary math might notice that we essentially perform b = a + 1. This is only done to put some kind of computation into the kernel. Note we may provide an offset as "starting point" which we will be clear later on.

Before starting with computation we do the following:

CHECK_CUDA_ERROR( cudaMallocHost((void**) &a, bytes) ); 
CHECK_CUDA_ERROR( cudaMallocHost((void**) &b, bytes) ); 
CHECK_CUDA_ERROR( cudaMalloc(&d_a, bytes) );
CHECK_CUDA_ERROR( cudaMalloc(&d_b, bytes) );

cudaMalloc should be familiar, cudaMallocHost allocates size bytes of host memory that is page-locked and accessible to the device...Since the memory can be accessed directly by the device, it can be read or written with much higher bandwidth than pageable memory obtained with functions such as malloc()...Page-locking excessive amounts of memory with cudaMallocHost may degrade system performance, since it reduces the amount of memory available to the system for paging. As a result, this function is best used sparingly to allocate staging areas for data exchange between host and device as we can learn from the CUDA Driver docs. The cudaMallocHost is needed to see the effect of cudaMemcpyAsync which we will use later on.

We'll than benchmark the following sequence of tasks:

CHECK_CUDA_ERROR( cudaMemcpy(d_a, a, bytes, cudaMemcpyHostToDevice) );
kernel<<<blocksPerGrid, threadsPerBlock>>>(d_a, d_b, 0);
CHECK_CUDA_ERROR( cudaMemcpy(b, d_b, bytes, cudaMemcpyDeviceToHost) );

In NSYS this looks as follows:

Sync

I benchmarked the result after 100 warmup runs on 1000 iterations on an array of size 1 << 25 on an NVIDIA H100 80GB HBM3 and got the following result:

Time for sequential transfer and execution in ms = 6.56524
Max error = 1.19209e-07

Overlap

We may use cudaMemcpyAsync to perform the computation on different parts of the data (from the kernel it is clear that this is possible because they are independent) and overlap the workloads of different streams.

Do be able to do this we first need to create the streams:

cudaStream_t stream[numStreams];
for(auto i = 0; i < numStreams; i++)
	CHECK_CUDA_ERROR( cudaStreamCreate(&stream[i]) );

Note that when benchmarking performance we might use

cudaEventRecord(start, 0);
# TASKS FOR ALL STREAMS
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
float elapsedTime;
cudaEventElapsedTime(&elapsedTime, start, stop);

that is because 0 is the default stream and by default nvcc will compile with --default-stream legacy. When we use this option the default stream will block until all non default streams arrive so at the time we measure the elapsed time we can be sure all work is done.

We implement our overlapping strategy as follows:

for(auto s = 0; s < numStreams; s++) {
  auto offset = s * elementsPerStream;
  CHECK_CUDA_ERROR( cudaMemcpyAsync(d_a + offset, a + offset, 
							bytesPerStream, cudaMemcpyHostToDevice, 
							stream[s]) );
  kernel<<<blocksPerGridStream, threadsPerBlock, 0, stream[s]>>>(d_a, d_b, offset);
  CHECK_CUDA_ERROR( cudaMemcpyAsync(b + offset, d_b + offset, 
							bytesPerStream, cudaMemcpyDeviceToHost, 
							stream[s]) );
}

Note that we shift our original pointers by the offset and pass the offset as well to our kernel. The kernel also takes now the signature <<<blocksPerGridStream, threadsPerBlock, 0, stream[s]>>> to indicate which stream it belongs to. We'll also adjust the number of bytes to be transferred according to the number of streams we use. In NSYS this looks as follows for four streams.:

Async

We nicely see the overlapping behavior: When one Memcpy from Host/Device -> Device/Host is finished, the next one starts right away while the kernel is still busy.

The benchmarking result is as follows:

Time for asynchronous 1 transfer and execution in ms = 4.86347
Max error = 1.19209e-07

We see that this gives a relatively large speedup . I imagine that being due to the fact that our kernel computation is relatively lightweight and the more efficient memcpy to be dominating that.

Overlap II

On devices that do not support concurrent data transfers we may rewrite the above logic as follows:

for(auto s = 0; s < numStreams; s++) {
  auto offset = s * elementsPerStream;
  CHECK_CUDA_ERROR( cudaMemcpyAsync(d_a + offset, a + offset, 
							bytesPerStream, cudaMemcpyHostToDevice, 
							stream[s]) );
}
for(auto s = 0; s < numStreams; s++) {
  auto offset = s * elementsPerStream;
  kernel<<<blocksPerGridStream, threadsPerBlock, 0, stream[s]>>>(d_a, d_b, offset);
}
for(auto s = 0; s < numStreams; s++) {
  auto offset = s * elementsPerStream;
  CHECK_CUDA_ERROR( cudaMemcpyAsync(b + offset, d_b + offset, 
							bytesPerStream, cudaMemcpyDeviceToHost, 
							stream[s]) );
}

to enforce overlapping behaviour. On the H100 this doesn't have any effect. The timing is the same as above.

Conclusion

I hope this gave a good motivation on why Streams are a useful concept in CUDA. You may find the accompanying code on my github useful and may use it to reproduce the benchmark on your own device. This was adapted from this example.

Feel free to connect to me on Linkedin to exchange ideas.