simons blog

Understanding CuTe Swizzling - The Math Behind 32B, 64B, and 128B Patterns

Swizzling is a critical technique used in state-of-the-art GEMM kernels to avoid costly shared memory bank conflicts. For a quick refresher on what shared memory bank conflicts are and why they matter for performance, I recommend this excellent explanation.

In this post, we'll dive deep into the canonical swizzling patterns: Swizzle<0,4,3>, Swizzle<1,4,3>, Swizzle<2,4,3>, and Swizzle<3,4,3>. Through step-by-step derivations, we'll uncover why these are known as No-Swizzling, 32B-Swizzling, 64B-Swizzling, and 128B-Swizzling respectively.

Note: Throughout this analysis, we'll use the convention that the 0th least significant bit is the rightmost bit, the 1st LSB is second from right, etc.

Swizzle API Syntax

To understand the API better we can take a look at swizzle.hpp in CUTLASS repo:

// A generic Swizzle functor
/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
 *                               ^--^ MBase is the number of least-sig bits to keep constant
 *                  ^-^       ^-^     BBits is the number of bits in the mask
 *                    ^---------^     SShift is the distance to shift the YYY mask
 *                                       (pos shifts YYY to the right, neg shifts YYY to the left)
 *
 * e.g. Given
 * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
 * the result is
 * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
 */

In Code that application of Swizzle to a function looks such:

template <int BBits, int MBase, int SShift = BBits>
struct Swizzle
{
  static constexpr int num_bits = BBits;
  static constexpr int num_base = MBase;
  static constexpr int num_shft = SShift;

  static_assert(num_base >= 0,             "MBase must be positive.");
  static_assert(num_bits >= 0,             "BBits must be positive.");
  static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");

  // using 'int' type here to avoid unintentially casting to unsigned... unsure.
  using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
  using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
  using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
  using msk_sft = cute::constant<int, num_shft>;

  static constexpr uint32_t swizzle_code = uint32_t(yyy_msk::value | zzz_msk::value);

  template <class Offset>
  CUTE_HOST_DEVICE constexpr static
  auto
  apply(Offset const& offset)
  {
	    return offset ^ shiftr(offset & yyy_msk{}, msk_sft{});   // ZZZ ^= YYY
  }

  template <class Offset>
  CUTE_HOST_DEVICE constexpr
  auto
  operator()(Offset const& offset) const
  {
    return apply(offset);
  }

  template <int B, int M, int S>
  CUTE_HOST_DEVICE constexpr
  auto
  operator==(Swizzle<B,M,S> const&) const
  {
    return B == BBits && M == MBase && S == SShift;
  }
};

Let's calculate a the canonical Swizzles by hand to understand better:

Remember:

// Swizzle<BBits, MBase, SShift>
static constexpr int num_bits = BBits;
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
...
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using msk_sft = cute::constant<int, num_shft>;

Swizzle<0,4,3>

offset = 0b0000000001111111 = 127
// Swizzle<0, 4, 3>
-> num_bits = 0
-> num_base = 4
-> num_shft = 3

// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 0) - 1 = 0b0000000000000001 - 0b0000000000000001 = 0b0000000000000000

// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000011111 << (4 + 3) = 0b0000000000000000 << 7 = 0b0000000000000000

// msk_shift 
msk_shift = 3

// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000001111111 & 0b0000000000000000 = 0b0000000000000000
// (offset & yyy_mask) >> msk_shift
0b0000000000000000 >> 3 = 0b0000000000000000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000001111111 ^ 0b0000000000000000 = 0b0000000001111111 = 127

We see that Swizzle<0,4,3> will always act as the identity. That is because the bit_msk is always a sequence of 0 bits. This will always result in a yyy_msk of zero bits which will than be shifted by a factor of 3 to the right which will result again in a sequence of 0s. XOR with a sequence of 0s is equal to the identity because XOR(0,0) = 0 and XOR(1,0) = 1.

Swizzle<1,4,3>

offset = 0b0000000011111111 = 255
// Swizzle<1, 4, 3>
-> num_bits = 1
-> num_base = 4
-> num_shft = 3

// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 1) - 1 = 0b0000000000000010 - 0b0000000000000001 = 0b0000000000000001

// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000001 << (4 + 3) = 0b0000000000000001 << 7 = 0b0000000010000000

// msk_shift 
msk_shift = 3

// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000010000000 = 0b0000000010000000
// (offset & yyy_mask) >> msk_shift
0b0000000010000000 >> 3 = 0b0000000000010000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000011111111 ^ 0b0000000000010000 = 0b0000000011101111 = 239

We see that Swizzle<1,4,3> does the following:

0bxxxxxxxxXxxYxxxx

Let us called call the "flipped offset"

0bxxxxxxxxXxxZxxxx

where Z = ~Y.

We will than have

Swizzle<1,4,3>(offset) == offset, if X = 0, Y = 0 or X = 0, Y = 1
Swizzle<1,4,3>(offset) == flipped_offset, if X = 1, Y = 0 or X = 1, Y = 1

I.e. we flip the 5th least significant bit if the 8th least significant bit is 1.

Note that this implies that all the numbers of the form

0bxxxxxxxx0xxZxxxx

won't get affected at all by the swizzle.

Swizzle<2,4,3>

offset = 0b0000000111111111 = 511
// Swizzle<2, 4, 3>
-> num_bits = 2
-> num_base = 4
-> num_shft = 3

// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 2) - 1 = 0b0000000000000100 - 0b0000000000000001 = 0b0000000000000011

// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000011 << (4 + 3) = 0b0000000000000011 << 7 = 0b0000000110000000

// msk_shift 
msk_shift = 3

// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000000110000000 = 0b0000000110000000
// (offset & yyy_mask) >> msk_shift
0b0000000110000000 >> 3 = 0b0000000000110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000000110000 = 0b0000000111001111 = 463

We see that Swizzle<2,4,3> does the following:

0bxxxxxxxUVxXYxxxx

We will than have a flip in X if U is equal to 1 and a flip in Y if V is equal to 1. Note that this implies that numbers of the form:

0bxxxxxxx00xXYxxxx

won't get affected by the Swizzle at all.

Swizzle<3,4,3>

offset = 0b0000001111111111 = 1023
// Swizzle<3, 4, 3>
-> num_bits = 3
-> num_base = 4
-> num_shft = 3

// bit_msk = (1 << num_bits) - 1
bit_msk = (0b0000000000000001 << 3) - 1 = 0b0000000000001000 - 0b0000000000000001 = 0b0000000000000111

// yyy_msk = bit_msk << (num_base + max(0,num_shft))
yyy_msk = 0b0000000000000111 << (4 + 3) = 0b0000000000000111 << 7 = 0b0000001110000000

// msk_shift 
msk_shift = 3

// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
// offset & yyy_mask
0b0000000011111111 & 0b0000001110000000 = 0b0000001110000000
// (offset & yyy_mask) >> msk_shift
0b0000001110000000 >> 3 = 0b0000000001110000
// offset ^ shiftr(offset & yyy_msk{}, msk_sft{});
0b0000000111111111 ^ 0b0000000001110000 = 0b0000000110001111 = 399

We see that Swizzle<3,4,3> does the following:

0bxxxxxxPQRSTUxxxx

We will than have a flip in U if R is equal to 1, a flip in T if Q is equal to 1 and a flip in S if P is 1. Note that this implies that numbers of the form:

0bxxxxxx000STUxxxx

won't get affected by the Swizzle at all.

Further analysis

Let us recap:

Note that:

Consider Swizzle<1,4,3>:

Up to 0b0000000001111111 = 127 Swizzle<1,4,3> will act as the identity. Than from 0b0000000010000000 = 128 swizzling will take place. That simply means the following:

That is the the 32 values ranging from 128 to 159 will get get mapped to their counterparts which have the 5th LSB flipped, i.e. they take value which is 16 larger. Than the value 0b0000000010010000 -> 0b0000000010000000 etc. This is how swizzling avoids bank conflicts.

Note that the swizzling pattern will repeat after certain number of Bytes.

For Swizzle<1,4,3> that is after 0b0000000011111111 i.e. at the 256 bit / 32B repetition will occur.

For Swizzle<2,4,3> that is after 0b0000000111111111 i.e. at the 512 bit / 64B repetition will occur.

For Swizzle<3,4,3> that is after 0b0000001111111111 i.e. at the 1024 bit / 128B repetition will occur.

Conclusion

I hope this blogpost shed some light on the seemingly intimidating Swizzle Notation used in CuTe. As we saw it is actually pretty straightforward to calculate the Swizzles and interpret them by hand. In future I intend to do a post where I connect the above calculations to the CuTeDSL. Please feel free to reach out on Linkedin if you want to exchange ideas on GPU programming.