Understanding CuTe Swizzling - The Math Behind 32B, 64B, and 128B Patterns
Swizzling is a critical technique used in state-of-the-art GEMM kernels to avoid costly shared memory bank conflicts. For a quick refresher on what shared memory bank conflicts are and why they matter for performance, I recommend this excellent explanation.
In this post, we'll dive deep into the canonical swizzling patterns: Swizzle<0,4,3>, Swizzle<1,4,3>, Swizzle<2,4,3>, and Swizzle<3,4,3>. Through step-by-step derivations, we'll uncover why these are known as No-Swizzling, 32B-Swizzling, 64B-Swizzling, and 128B-Swizzling respectively.
Swizzle API Syntax
To understand the API better we can take a look at swizzle.hpp in CUTLASS repo:
// A generic Swizzle functor
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
* ^--^ MBase is the number of least-sig bits to keep constant
* ^-^ ^-^ BBits is the number of bits in the mask
* ^---------^ SShift is the distance to shift the YYY mask
* (pos shifts YYY to the right, neg shifts YYY to the left)
*
* e.g. Given
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
* the result is
* 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
*/
In Code that application of Swizzle to a function looks such:
template <int BBits, int MBase, int SShift = BBits>
struct Swizzle
{
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
static_assert(num_base >= 0, "MBase must be positive.");
static_assert(num_bits >= 0, "BBits must be positive.");
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");
// using 'int' type here to avoid unintentially casting to unsigned... unsure.
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;
static constexpr uint32_t swizzle_code = uint32_t(yyy_msk::value | zzz_msk::value);
template <class Offset>
CUTE_HOST_DEVICE constexpr static
auto
apply(Offset const& offset)
{
return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY
}
template <class Offset>
CUTE_HOST_DEVICE constexpr
auto
operator()(Offset const& offset) const
{
return apply(offset);
}
template <int B, int M, int S>
CUTE_HOST_DEVICE constexpr
auto
operator==(Swizzle<B,M,S> const&) const
{
return B == BBits && M == MBase && S == SShift;
}
};
Let's calculate a the canonical Swizzles by hand to understand better:
Remember:
// Swizzle<BBits, MBase, SShift>
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
...
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;
Swizzle<0,4,3>
offset = 0b0000000001111111 = 127
// Swizzle<0, 4, 3>
-> num_bits = 0
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 0) - 1 = 0b0000000000000001 - 0b0000000000000001 = 0b0000000000000000
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000011111 << (4 + 3) = 0b0000000000000000 << 7 = 0b0000000000000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000001111111 & 0b0000000000000000 = 0b0000000000000000
// (offset & yyy_mask) >> msk_shift
0b0000000000000000 >> 3 = 0b0000000000000000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000001111111 ^ 0b0000000000000000 = 0b0000000001111111 = 127
We see that Swizzle<0,4,3> will always act as the identity. That is because the bit_msk is always a sequence of 0 bits. This will always result in a yyy_msk of zero bits which will than be shifted by a factor of 3 to the right which will result again in a sequence of 0s. XOR with a sequence of 0s is equal to the identity because XOR(0,0) = 0 and XOR(1,0) = 1.
Swizzle<1,4,3>
offset = 0b0000000011111111 = 255
// Swizzle<1, 4, 3>
-> num_bits = 1
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 1) - 1 = 0b0000000000000010 - 0b0000000000000001 = 0b0000000000000001
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000001 << (4 + 3) = 0b0000000000000001 << 7 = 0b0000000010000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000010000000 = 0b0000000010000000
// (offset & yyy_mask) >> msk_shift
0b0000000010000000 >> 3 = 0b0000000000010000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000011111111 ^ 0b0000000000010000 = 0b0000000011101111 = 239
We see that Swizzle<1,4,3> does the following:
- Calculate a
1bit mask. Shift that1bit 7 bits to the left. This will result in a mask that has everywhere a0except at the8thleast significant bit. We call this maskyyy_msk. ANDoffsetwithyyy_msk. This will leave us with an expression that is0b00000000X0000000whereX={0,1}is the8thleast significant bit of the offset.- We shift that expression by
3to the right. Note that this will result in an expression0b00000000000X0000, i.e. everything being equal to0except the8-3=5thleast significant bit. And the5thleast significant bit has the8thleast significant bit of the offset as entry. - We than
XORthis expression with the original offset. We can therefore conclude that theSwizzle<1,4,3>will act on theoffsetas follows: Let the offset be
0bxxxxxxxxXxxYxxxx
Let us called call the "flipped offset"
0bxxxxxxxxXxxZxxxx
where Z = ~Y.
We will than have
Swizzle<1,4,3>(offset) == offset, if X = 0, Y = 0 or X = 0, Y = 1
Swizzle<1,4,3>(offset) == flipped_offset, if X = 1, Y = 0 or X = 1, Y = 1
I.e. we flip the 5th least significant bit if the 8th least significant bit is 1.
Note that this implies that all the numbers of the form
0bxxxxxxxx0xxZxxxx
won't get affected at all by the swizzle.
Swizzle<2,4,3>
offset = 0b0000000111111111 = 511
// Swizzle<2, 4, 3>
-> num_bits = 2
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 2) - 1 = 0b0000000000000100 - 0b0000000000000001 = 0b0000000000000011
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000011 << (4 + 3) = 0b0000000000000011 << 7 = 0b0000000110000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000110000000 = 0b0000000110000000
// (offset & yyy_mask) >> msk_shift
0b0000000110000000 >> 3 = 0b0000000000110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000000110000 = 0b0000000111001111 = 463
We see that Swizzle<2,4,3> does the following:
- Calculate a
2bit mask. Shift that2bits 7 bits to the left. This will result in a mask that has everywhere a0except at the8thleast significant bit and the9thleast significant bit. We call this maskyyy_msk. ANDoffsetwithyyy_msk. This will leave us with an expression that is0b0000000UV0000000whereU,V={0,1}are the8th/9thleast significant bit of the offset.- We shift that expression by
3to the right. Note that this will result in an expression0b0000000000UV0000, i.e. everything being equal to0except the6th/5thleast significant bit. And the6th/5thleast significant bit has the9th/8thleast significant bit of the offset as entry. - We than
XORthis expression with the original offset. We can therefore conclude that theSwizzle<2,4,3>will act on theoffsetas follows: Let the offset be
0bxxxxxxxUVxXYxxxx
We will than have a flip in X if U is equal to 1 and a flip in Y if V is equal to 1.
Note that this implies that numbers of the form:
0bxxxxxxx00xXYxxxx
won't get affected by the Swizzle at all.
Swizzle<3,4,3>
offset = 0b0000001111111111 = 1023
// Swizzle<3, 4, 3>
-> num_bits = 3
-> num_base = 4
-> num_shft = 3
// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 3) - 1 = 0b0000000000001000 - 0b0000000000000001 = 0b0000000000000111
// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000111 << (4 + 3) = 0b0000000000000111 << 7 = 0b0000001110000000
// msk_shift
msk_shift = 3
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000001110000000 = 0b0000001110000000
// (offset & yyy_mask) >> msk_shift
0b0000001110000000 >> 3 = 0b0000000001110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000001110000 = 0b0000000110001111 = 399
We see that Swizzle<3,4,3> does the following:
- Calculate a
3bit mask. Shift that3bits 7 bits to the left. This will result in a mask that has everywhere a0except at the8thleast significant bit, the9thleast significant bit and the10thleast significant bit. We call this maskyyy_msk. ANDoffsetwithyyy_msk. This will leave us with an expression that is0b000000PQR0000000whereP,Q,R={0,1}are the8th/9th/10thleast significant bit of the offset.- We shift that expression by
3to the right. Note that this will result in an expression0b000000000PQR0000, i.e. everything being equal to0except the7th/6th/5thleast significant bit. And the7th/6th/5thleast significant bit has the10th/9th/8thleast significant bit of the offset as entry. - We than
XORthis expression with the original offset. We can therefore conclude that theSwizzle<3,4,3>will act on theoffsetas follows: Let the offset be
0bxxxxxxPQRSTUxxxx
We will than have a flip in U if R is equal to 1, a flip in T if Q is equal to 1 and a flip in S if P is 1.
Note that this implies that numbers of the form:
0bxxxxxx000STUxxxx
won't get affected by the Swizzle at all.
Further analysis
Let us recap:
Swizzle<0,4,3>is the identity.Swizzle<1,4,3>flips the5thLSB of a number if the8thLSB is equal to1.Swizzle<2,4,3>flips the5th/6thLSB of a number if the8th/9thLSB is equal to1.Swizzle<3,4,3>flips the5th/6th/7thLSB of a number if the8th/9th/10thLSB is equal to1.
Note that:
Swizzle<1,4,3>(offset) = Swizzle<0,4,3>(offset)iffthe8thLSB is0.Swizzle<2,4,3>(offset) = Swizzle<1,4,3>(offset)iffthe9thLSB is0Swizzle<3,4,3>(offset) = Swizzle<2,4,3>(offset)iffthe10thLSB is0.
Consider Swizzle<1,4,3>:
Up to 0b0000000001111111 = 127 Swizzle<1,4,3> will act as the identity. Than from 0b0000000010000000 = 128 swizzling will take place. That simply means the following:
0b0000000010000000 -> 0b00000000100100000b0000000010000001 -> 0b0000000010010001- ...
0b0000000010000001 -> 0b0000000010011111
That is the the 32 values ranging from 128 to 159 will get get mapped to their counterparts which have the 5th LSB flipped, i.e. they take value which is 16 larger. Than the value 0b0000000010010000 -> 0b0000000010000000 etc. This is how swizzling avoids bank conflicts.
Note that the swizzling pattern will repeat after certain number of Bytes.
For Swizzle<1,4,3> that is after 0b0000000011111111 i.e. at the 256 bit / 32B repetition will occur.
For Swizzle<2,4,3> that is after 0b0000000111111111 i.e. at the 512 bit / 64B repetition will occur.
For Swizzle<3,4,3> that is after 0b0000001111111111 i.e. at the 1024 bit / 128B repetition will occur.
Conclusion
I hope this blogpost shed some light on the seemingly intimidating Swizzle Notation used in CuTe. As we saw it is actually pretty straightforward to calculate the Swizzles and interpret them by hand. In future I intend to do a post where I connect the above calculations to the CuTeDSL. Please feel free to reach out on Linkedin if you want to exchange ideas on GPU programming.