simons blog

Indexing in CUDA

Intro

In this blogpost I want to explain what it means for a matrix to be in row major format. This is essential to understand CUDA kernels and their methods of indexing into the matrices they process.

Let's consider a 2d array A with shape (M, N). In CUDA such an array by default is linearised in row major format to take into account the flat structure of memory space in the computer. What that means in practice is simply that the matrix coordinate (i,j) gets mapped to i * N + j. Let's call this function f.

Looking at this formula we can see why this is called row major. Let's take the difference of the memory mapping two distinct coordinates:

d = f(i2, j2)-f(i1,j1) = (i2-i1) * N + (j2-j1)

We see that that for adjacent columns d = 1 and for adjacent rows d = N.

We can generalize this further to a 3d array of shape (M1, M2, M3). Here the coordinate (i, j, l) get mapped to l + M3 * (j + i * M2)= i * M2 * M3 + j * M3 + l.

Code Analysis

Let's now understand the indexing of a CUDA kernel with this paradigm. For full explanation of the idea of 2d block tiling which we use as a kernel example I refer to this excellent blogpost that explains the concept of 2d tiling in detail. Here we will focus on the indexing part. The full code can be found on github and to understand the below explanation you should read and try to understand it before continuing.

In a first step we allocate shared memory. Each matrix can be interpreted as a 2d matrix with shape (BM, BK) for As and shape (BK, BN) for Bs.

__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];

This shared memory is than populated by the corresponding entries from the matrix in global memory.

The code we want to analyze closer is the following part:

// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
  // block into registers
  for (uint i = 0; i < TM; ++i) {
    regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
  }
  for (uint i = 0; i < TN; ++i) {
    regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
  }
  for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
    for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
      threadResults[resIdxM * TN + resIdxN] +=
          regM[resIdxM] * regN[resIdxN];
    }
  }

When I first looked at the indices I was not sure how to derive them myself but actually using the explanation from above it is not very difficult.

Analyze indexing into As

(threadRow * TM + i) * BK + dotIdx
= threadRow * TM * BK + i * BK + dotIdx

From here we can read of that the As is interpreted as a 3d array with shape (..., BK, TM). From above we know that the original matrix was of shape (BM, BK) and because the matrix is still the same it means our interpretation as a 3d array has shape (BM/TM, TM, BK) because number of elements in our array stays constant.

From here it is clear why we index like that: We want to tile the blocks of memory further and do this precisely by transforming (BM, BK)->(BM/TM, TM, BK).

threadRow * TM * BK + i * BK + dotIdx -> (threadRow, i, dotIdx).

Analyze indexing into Bs

dotIdx * BN + threadCol * TN + i = dotIdx * BN/TN * TN + threadCol * TN + i. Using similar technique as above we see that we interpret Bs as a 3d array with shape (BK, BN/TN, TN). dotIdx * BN/TN * TN + threadCol * TN + i -> (dotIdx, threadCol, i)

Putting it all together

We are now capable of understanding the full loop: float threadResults[TM * TN] = {0.0}; so the result is initially a (TM, TN) 2d array.

threadResults[resIdxM * TN + resIdxN] +=
              regM[resIdxM] * regN[resIdxN];

The algorithm works as follows:

After we performed this operation the result is written to the result matrix as follows:

for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
  for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
    C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
        alpha * threadResults[resIdxM * TN + resIdxN] +
        beta * C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
  }
}

(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN = (threadRow * TM + resIdxM) * N/TN * TN + threadCol * TN + resIdxN. C is initially thought of as a matrix of shape (M, N) and now interpreted as a matrix (M/TM, TM, N/TN, TN), that means we write back to memory in a tiled way as well.

(threadRow * TM + resIdxM) * N/TN * TN + threadCol * TN + resIdxN -> (threadRow, resIdxM, threadCol, resIdxN) so each warp writes one row and one column.

I hope this blogpost helped you too understand indexing in CUDA better.