Analyze CUDA programs by looking at GPU assembly.
This is a short note on how to archieve higher performance on memory bound CUDA programs and analyze SASS code.
Simple vector copy
Let's consider the following two programs to copy a vector.
#define threadsPerBlock 1024
__global__ void vectorCopy(float *input, float *output, int N) {
const int i = threadIdx.x + blockIdx.x * threadsPerBlock;
if (i < N) {
output[i] = input[i];
}
}
__global__ void vectorCopyVectorized(float4 *input, float4 *output, int N) {
const int i = threadIdx.x + blockIdx.x * threadsPerBlock;
if (i < (N >> 2)) {
output[i] = input[i];
}
}
It turns out that the vectorized version performs is faster for N = 1 << 30
elements in the factor.
Analysis
To understand why that is it is helpful to look at the SASS code of each kernel. We can obtain the SASS code by using godbolt or by profiling the kernels using the NVIDIA NCU tool.
On godbolt we should select the appropriate nvcc version and furthermore configure command arguments for the compiler like so: -arch sm_90 -use_fast_math -O3
. It is important to select the appropriate architecture.
We can than look at the SASS code. NVIDIA docs describe SASS as the low-level assembly language that compiles to binary microcode, which executes natively on NVIDIA GPU hardware.
Below we give the SASS code for the two kernels on H100:
vectorCopy(float*, float*, int):
LDC R1, c[0x0][0x28]
S2R R7, SR_TID.X
ULDC UR4, c[0x0][0x220]
S2R R0, SR_CTAID.X
LEA R7, R0, R7, 0xa
ISETP.GE.AND P0, PT, R7, UR4, PT
@P0 EXIT
LDC.64 R2, c[0x0][0x210]
ULDC.64 UR4, c[0x0][0x208]
LDC.64 R4, c[0x0][0x218]
IMAD.WIDE R2, R7, 0x4, R2
LDG.E R3, desc[UR4][R2.64]
IMAD.WIDE R4, R7, 0x4, R4
STG.E desc[UR4][R4.64], R3
EXIT
vectorCopyVectorized(float4*, float4*, int):
LDC R1, c[0x0][0x28]
S2R R7, SR_TID.X
ULDC UR4, c[0x0][0x220]
USHF.R.S32.HI UR4, URZ, 0x2, UR4
S2R R0, SR_CTAID.X
LEA R7, R0, R7, 0xa
ISETP.GE.AND P0, PT, R7, UR4, PT
@P0 EXIT
LDC.64 R4, c[0x0][0x210]
ULDC.64 UR4, c[0x0][0x208]
LDC.64 R2, c[0x0][0x218]
IMAD.WIDE R4, R7, 0x10, R4
LDG.E.128 R8, desc[UR4][R4.64]
IMAD.WIDE R2, R7, 0x10, R2
STG.E.128 desc[UR4][R2.64], R8
EXIT
We see that the vectorised version has one instruction more which is due to the fact that we perform a bitshift to calculate N / 4 = N >> 2
. We could optimize that out by passing N/4 to the kernel and than have one instruction less (USHF.R.S32.HI UR4, URZ, 0x2, UR4
) but this doesn't make much of a difference. Let's neglect the bitshift in the following analysis therefore.
The interesting part of the logic (i.e. where we perform the copy) is here:
LDC.64 R2, c[0x0][0x210]
ULDC.64 UR4, c[0x0][0x208]
LDC.64 R4, c[0x0][0x218]
IMAD.WIDE R2, R7, 0x4, R2
LDG.E R3, desc[UR4][R2.64]
IMAD.WIDE R4, R7, 0x4, R4
STG.E desc[UR4][R4.64], R3
vs the vectorized version
LDC.64 R4, c[0x0][0x210]
ULDC.64 UR4, c[0x0][0x208]
LDC.64 R2, c[0x0][0x218]
IMAD.WIDE R4, R7, 0x10, R4
LDG.E.128 R8, desc[UR4][R4.64]
IMAD.WIDE R2, R7, 0x10, R2
STG.E.128 desc[UR4][R2.64], R8
We see that by using LDG.E.128/STG.E.128
instead of LDG.E/STG.E
we load/store 128 bits instead of 32 bits!
That means we need the same number of instructions but need significantly less blocks!
To undestand let's compare:
template <int threadsPerBlock>
__global__ void vectorCopy(float *input, float *output, int N) {
const int i = threadIdx.x + blockIdx.x * threadsPerBlock;
if (i < N) {
output[i] = input[i];
}
}
template <int threadsPerBlock>
void launchVectorCopy(float *input, float *output, int N) {
const int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
vectorCopy<threadsPerBlock>
<<<blocksPerGrid, threadsPerBlock>>>(input, output, N);
}
with
template <int threadsPerBlock>
__global__ void vectorCopyVectorized(float4 *input, float4 *output, int N) {
const int i = threadIdx.x + blockIdx.x * threadsPerBlock;
if (i < (N >> 2)) {
output[i] = input[i];
}
}
template <int threadsPerBlock>
void launchVectorCopyVectorized(float *input, float *output, int N) {
const int blocksPerGrid = (N / 4 + threadsPerBlock - 1) / threadsPerBlock;
vectorCopyVectorized<threadsPerBlock><<<blocksPerGrid, threadsPerBlock>>>(
reinterpret_cast<float4 *>(input), reinterpret_cast<float4 *>(output), N);
}
If we take N = 1 << 30
and threadsPerBlock = 1 << 10
we launch 1048576
blocks in the first version and 262144
blocks in the second version.
If we neglect the cost of the bitshift instruction (or simply eliminate it as described above) we can understand why the second kernel is much faster: We carry out much less instructions because we launch only a fraction of the blocks we load in the unvectorized version.
I hope this blogpost could help you to understand vectorized load and store much better.
The code can be found on my github. The makefile includes also commands to profile the kernels for usage in NVIDIA Nsight Compute
.
If you like this blogpost you can connect to me on Linkedin. I like to exchange ideas on CUDA and MLSys in general.