Understanding CuTe Swizzling - The Math Behind 32B, 64B, and 128B Patterns
Swizzling is a critical technique used in state-of-the-art GEMM kernels to avoid costly shared memory bank conflicts. For a quick refresher on what shared memory bank conflicts are and why they matter for performance, I recommend this excellent explanation.
In this post, we'll dive deep into the canonical swizzling patterns: Swizzle<0,4,3>
, Swizzle<1,4,3>
, Swizzle<2,4,3>
, and Swizzle<3,4,3>
. Through step-by-step derivations, we'll uncover why these are known as No-Swizzling, 32B-Swizzling, 64B-Swizzling, and 128B-Swizzling respectively.
Note: Throughout this analysis, we'll use the convention that the 0th least significant bit is the rightmost bit, the 1st LSB is second from right, etc.
Swizzle API Syntax
To understand the API better we can take a look at swizzle.hpp in CUTLASS
repo:
// A generic Swizzle functor
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
* ^--^ MBase is the number of least-sig bits to keep constant
* ^-^ ^-^ BBits is the number of bits in the mask
* ^---------^ SShift is the distance to shift the YYY mask
* (pos shifts YYY to the right, neg shifts YYY to the left)
*
* e.g. Given
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
* the result is
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
*/
In Code that application of Swizzle to a function looks such:
template <int BBits, int MBase, int SShift = BBits>
struct Swizzle
{
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
static_assert(num_base >= 0, "MBase must be positive.");
static_assert(num_bits >= 0, "BBits must be positive.");
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");
// using 'int' type here to avoid unintentially casting to unsigned... unsure.
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;
static constexpr uint32_t swizzle_code = uint32_t(yyy_msk::value | zzz_msk::value);
template <class Offset>
CUTE_HOST_DEVICE constexpr static
auto
apply(Offset const& offset)
{
return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY
}
template <class Offset>
CUTE_HOST_DEVICE constexpr
auto
operator()(Offset const& offset) const
{
return apply(offset);
}
template <int B, int M, int S>
CUTE_HOST_DEVICE constexpr
auto
operator==(Swizzle<B,M,S> const&) const
{
return B == BBits && M == MBase && S == SShift;
}
};
Let's calculate a the canonical Swizzles by hand to understand better:
Remember:
// Swizzle<BBits, MBase, SShift>
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
...
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;
Swizzle<0,4,3>
offset = 0b0000000001111111 = 127
// Swizzle<0, 4, 3>
-> num_bits = 0
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 0) - 1 = 0b0000000000000001 - 0b0000000000000001 = 0b0000000000000000
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000011111 << (4 + 3) = 0b0000000000000000 << 7 = 0b0000000000000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000001111111 & 0b0000000000000000 = 0b0000000000000000
// (offset & yyy_mask) >> msk_shift
0b0000000000000000 >> 3 = 0b0000000000000000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000001111111 ^ 0b0000000000000000 = 0b0000000001111111 = 127
We see that Swizzle<0,4,3>
will always act as the identity. That is because the bit_msk
is always a sequence of 0
bits. This will always result in a yyy_msk
of zero bits which will than be shifted by a factor of 3
to the right which will result again in a sequence of 0
s. XOR with a sequence of 0
s is equal to the identity because XOR(0,0) = 0
and XOR(1,0) = 1
.
Swizzle<1,4,3>
offset = 0b0000000011111111 = 255
// Swizzle<1, 4, 3>
-> num_bits = 1
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 1) - 1 = 0b0000000000000010 - 0b0000000000000001 = 0b0000000000000001
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000001 << (4 + 3) = 0b0000000000000001 << 7 = 0b0000000010000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000010000000 = 0b0000000010000000
// (offset & yyy_mask) >> msk_shift
0b0000000010000000 >> 3 = 0b0000000000010000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000011111111 ^ 0b0000000000010000 = 0b0000000011101111 = 239
We see that Swizzle<1,4,3>
does the following:
- Calculate a
1
bit mask. Shift that1
bit 7 bits to the left. This will result in a mask that has everywhere a0
except at the8th
least significant bit. We call this maskyyy_msk
. AND
offset
withyyy_msk
. This will leave us with an expression that is0b00000000X0000000
whereX={0,1}
is the8th
least significant bit of the offset.- We shift that expression by
3
to the right. Note that this will result in an expression0b00000000000X0000
, i.e. everything being equal to0
except the8-3=5th
least significant bit. And the5th
least significant bit has the8th
least significant bit of the offset as entry. - We than
XOR
this expression with the original offset. We can therefore conclude that theSwizzle<1,4,3>
will act on theoffset
as follows: Let the offset be
0bxxxxxxxxXxxYxxxx
Let us called call the "flipped offset"
0bxxxxxxxxXxxZxxxx
where Z = ~Y
.
We will than have
Swizzle<1,4,3>(offset) == offset, if X = 0, Y = 0 or X = 0, Y = 1
Swizzle<1,4,3>(offset) == flipped_offset, if X = 1, Y = 0 or X = 1, Y = 1
I.e. we flip the 5th
least significant bit if the 8th
least significant bit is 1
.
Note that this implies that all the numbers of the form
0bxxxxxxxx0xxZxxxx
won't get affected at all by the swizzle.
Swizzle<2,4,3>
offset = 0b0000000111111111 = 511
// Swizzle<2, 4, 3>
-> num_bits = 2
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 2) - 1 = 0b0000000000000100 - 0b0000000000000001 = 0b0000000000000011
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000011 << (4 + 3) = 0b0000000000000011 << 7 = 0b0000000110000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000110000000 = 0b0000000110000000
// (offset & yyy_mask) >> msk_shift
0b0000000110000000 >> 3 = 0b0000000000110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000000110000 = 0b0000000111001111 = 463
We see that Swizzle<2,4,3>
does the following:
- Calculate a
2
bit mask. Shift that2
bits 7 bits to the left. This will result in a mask that has everywhere a0
except at the8th
least significant bit and the9th
least significant bit. We call this maskyyy_msk
. AND
offset
withyyy_msk
. This will leave us with an expression that is0b0000000UV0000000
whereU,V={0,1}
are the8th/9th
least significant bit of the offset.- We shift that expression by
3
to the right. Note that this will result in an expression0b0000000000UV0000
, i.e. everything being equal to0
except the6th/5th
least significant bit. And the6th/5th
least significant bit has the9th/8th
least significant bit of the offset as entry. - We than
XOR
this expression with the original offset. We can therefore conclude that theSwizzle<2,4,3>
will act on theoffset
as follows: Let the offset be
0bxxxxxxxUVxXYxxxx
We will than have a flip in X
if U
is equal to 1
and a flip in Y
if V
is equal to 1
.
Note that this implies that numbers of the form:
0bxxxxxxx00xXYxxxx
won't get affected by the Swizzle at all.
Swizzle<3,4,3>
offset = 0b0000001111111111 = 1023
// Swizzle<3, 4, 3>
-> num_bits = 3
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 3) - 1 = 0b0000000000001000 - 0b0000000000000001 = 0b0000000000000111
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000111 << (4 + 3) = 0b0000000000000111 << 7 = 0b0000001110000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000001110000000 = 0b0000001110000000
// (offset & yyy_mask) >> msk_shift
0b0000001110000000 >> 3 = 0b0000000001110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000001110000 = 0b0000000110001111 = 399
We see that Swizzle<3,4,3>
does the following:
- Calculate a
3
bit mask. Shift that3
bits 7 bits to the left. This will result in a mask that has everywhere a0
except at the8th
least significant bit, the9th
least significant bit and the10th
least significant bit. We call this maskyyy_msk
. AND
offset
withyyy_msk
. This will leave us with an expression that is0b000000PQR0000000
whereP,Q,R={0,1}
are the8th/9th/10th
least significant bit of the offset.- We shift that expression by
3
to the right. Note that this will result in an expression0b000000000PQR0000
, i.e. everything being equal to0
except the7th/6th/5th
least significant bit. And the7th/6th/5th
least significant bit has the10th/9th/8th
least significant bit of the offset as entry. - We than
XOR
this expression with the original offset. We can therefore conclude that theSwizzle<3,4,3>
will act on theoffset
as follows: Let the offset be
0bxxxxxxPQRSTUxxxx
We will than have a flip in U
if R
is equal to 1
, a flip in T
if Q
is equal to 1
and a flip in S
if P
is 1
.
Note that this implies that numbers of the form:
0bxxxxxx000STUxxxx
won't get affected by the Swizzle at all.
Further analysis
Let us recap:
Swizzle<0,4,3>
is the identity.Swizzle<1,4,3>
flips the5th
LSB of a number if the8th
LSB is equal to1
.Swizzle<2,4,3>
flips the5th/6th
LSB of a number if the8th/9th
LSB is equal to1
.Swizzle<3,4,3>
flips the5th/6th/7th
LSB of a number if the8th/9th/10th
LSB is equal to1
.
Note that:
Swizzle<1,4,3>(offset) = Swizzle<0,4,3>(offset)
iff
the8th
LSB is0
.Swizzle<2,4,3>(offset) = Swizzle<1,4,3>(offset)
iff
the9th
LSB is0
Swizzle<3,4,3>(offset) = Swizzle<2,4,3>(offset)
iff
the10th
LSB is0
.
Consider Swizzle<1,4,3>
:
Up to 0b0000000001111111 = 127
Swizzle<1,4,3>
will act as the identity. Than from 0b0000000010000000 = 128
swizzling will take place. That simply means the following:
0b0000000010000000 -> 0b0000000010010000
0b0000000010000001 -> 0b0000000010010001
- ...
0b0000000010000001 -> 0b0000000010011111
That is the the 32
values ranging from 128
to 159
will get get mapped to their counterparts which have the 5th
LSB flipped, i.e. they take value which is 16
larger. Than the value 0b0000000010010000 -> 0b0000000010000000
etc. This is how swizzling avoids bank conflicts.
Note that the swizzling pattern will repeat after certain number of Bytes.
For Swizzle<1,4,3>
that is after 0b0000000011111111
i.e. at the 256 bit / 32B repetition will occur.
For Swizzle<2,4,3>
that is after 0b0000000111111111
i.e. at the 512 bit / 64B repetition will occur.
For Swizzle<3,4,3>
that is after 0b0000001111111111
i.e. at the 1024 bit / 128B repetition will occur.
Conclusion
I hope this blogpost shed some light on the seemingly intimidating Swizzle Notation used in CuTe
. As we saw it is actually pretty straightforward to calculate the Swizzles and interpret them by hand. In future I intend to do a post where I connect the above calculations to the CuTeDSL
. Please feel free to reach out on Linkedin if you want to exchange ideas on GPU programming.