simons blog

Swizzles and their usage in CuTeDSL Kernels

Introduction

Swizzling is used in performant GEMM kernels to design the layout of shared memory and therefore highly important. Essentially swizzling deals with the problem of shared memory bank conflicts we would face when dealing with ordinary, unswizzled Layouts. However it can be hard to understand in the beginning. In this blogpost I aim to give a brief overview of how swizzling is used and implemented in the CuTeDSL.

Composed Layout

Swizzles can be applied via so called ComposedLayout in CuTeDSL.

The composed Layout is defined via the following formula:

R(c) = (inner o offset o outer)(c) = inner(offset + outer(c))

A typical example to create a swizzled Layout would be

sw = cute.make_swizzle(b, m, s) # BBits, MBase, SShift
L_swizzled = cute.make_composed_layout(sw, 0, L) # (sw o 0 o L)

here the composed Layout is than given by R(c) = sw(L(c)).

Let us consider the following concrete example:

@cute.jit
def simple_swizzle(
    S: cute.Shape, D: cute.Stride, bms: cute.IntTuple, coord: cute.IntTuple
):
    L = cute.make_layout(S, stride=D)
    b, m, s = bms[0], bms[1], bms[2]
    sw = cute.make_swizzle(b, m, s)
    L_swizzled = cute.make_composed_layout(sw, 0, L)
    print(coord)
    print(cute.crd2idx(coord, L))
    print(cute.crd2idx(coord, L_swizzled))


if __name__ == "__main__":
    S = (8, 32)
    D = (32, 1)
    bms = (2, 4, 3)
    coord = (7, 25)
    simple_swizzle(S, D, bms, coord)

For c = (7, 25) the ordinary Layout will map as follows:

L((7, 25)) = 7 * 32 + 1 * 25 = 249

We can write in binary: 249 = 0b0000000011111001. In a previous blogpost I have shown that Swizzle<2,4,3> will act on a number of the form 0bxxxxxxxUVxXYxxxx such that X will get flipped if U=1 and Y will get flipped if V=1. Compare these two forms:

0b0000000011111001
0bxxxxxxxUVxXYxxxx

We see that U = 0, V = 1, so we will flip Y, which means

R(c) = sw(L(c)) = sw(0b0000000011111001) = 0b0000000011101001 = 233

You can run the above code to verify for yourself that this is exactly what we get.

Swizzling in CuTe Kernels

In CuTeDSL the Swizzle is used to construct the Shared Memory Layouts, which are than used to construct the TMA atoms.

In the Hopper example the Shared Memory Layouts are constructed within _make_smem_layouts. Here we use the following logic to obtain them:

a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
	sm90_utils.get_smem_layout_atom(
		a_layout,
		a_dtype,
		a_major_mode_size,
	),
	a_dtype,
)
a_smem_layout_staged = cute.tile_to_shape(
	a_smem_layout_atom,
	cute.append(a_smem_shape, ab_stage),
	order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)

This tells us the two stage approach that is used to build them:

  1. We construct a smem_layout_atom.
  2. We use tile_to_shape to cover the whole smem_shape concatenated with the number of stages.

I will now analyse each of these stages and assume for simplicity that stage = 1. This is not important for the understanding the process because it trivially extends from two to three dimensions.

Swizzle Atom

First we get the appropriate swizzling mode. The swizzling mode will be determined by the number of bits of elements in the Tensor we want to create the Shared Memory Layout for and the Major Mode Size.

def get_smem_layout_atom(
    layout: LayoutEnum,
    element_type: Type[Numeric],
    major_mode_size: int,
    *,
    loc=None,
    ip=None,
):
	assert major_mode_size % 8 == 0
	sw128_num_contiguous_bits = 1024
	sw64_num_contiguous_bits = 512
	sw32_num_contiguous_bits = 256
	major_mode_size_bits = major_mode_size * element_type.width
	if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
		if major_mode_size_bits % sw128_num_contiguous_bits == 0:
			return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
		if major_mode_size_bits % sw64_num_contiguous_bits == 0:
			return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
		if major_mode_size_bits % sw32_num_contiguous_bits == 0:
			return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
		return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
	if major_mode_size_bits % sw128_num_contiguous_bits == 0:
		return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
	if major_mode_size_bits % sw64_num_contiguous_bits == 0:
		return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
	if major_mode_size_bits % sw32_num_contiguous_bits == 0:
		return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
	return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER

We see that we calculate the number of bits along the Major Mode.

For example take BFloat16 as datatype. Than the number of bits in one element will be obviously 16. If we than have that the Layout is (8, 32):(32, 1) we have that this is K-Major, and the major_mode_size = 32 accordingly. This will than give us major_mode_size_bits = 32 * 16 = 512. Accordingly we choose cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64 as the SmemLayoutAtomKind.

This is than passed to make_smem_layout_atom which looks as follows:

@dsl_user_op
def make_smem_layout_atom(
    kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
) -> core.ComposedLayout:
    if not isinstance(element_type, NumericMeta):
        raise TypeError(f"element_type must be a Numeric, but got {element_type}")

    if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
        num_contiguous_bits = 128
        sw = core.make_swizzle(0, 4, 3)
    elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
        num_contiguous_bits = 256
        sw = core.make_swizzle(1, 4, 3)
    elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
        num_contiguous_bits = 512
        sw = core.make_swizzle(2, 4, 3)
    elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
        num_contiguous_bits = 1024
        sw = core.make_swizzle(3, 4, 3)
    else:
        raise ValueError("unrecognized SMEM layout atom kind")
    num_contiguous_elems = num_contiguous_bits // element_type.width

    if kind in (
        SmemLayoutAtomKind.MN_INTER,
        SmemLayoutAtomKind.MN_SW32,
        SmemLayoutAtomKind.MN_SW64,
        SmemLayoutAtomKind.MN_SW128,
    ):
        # M/N-major layout
        return core.make_composed_layout(
            sw,
            0,
            core.make_layout(
                (num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
            ),
            loc=loc,
            ip=ip,
        )
    else:
        # K-major layout
        return core.make_composed_layout(
            sw,
            0,
            core.make_layout(
                (8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
            ),
            loc=loc,
            ip=ip,
        )

This will simply create a ComposedLayout which will behave like the one we have seen above.

Let us analyse in more detail, how the Layout is created. From the SmemLayoutAtomKind we know how many bits are in one slice of the major mode. We choose the according Swizzle Pattern.

As we have learned in past the Swizzle Patterns above have interesting structure and hierarchy:

Note that in the code above we adjust

num_contiguous_elems = num_contiguous_bits // element_type.width

that is because the Layout will be attached to a Tensor.

As we have learned above a Tensor in CuTeDSL is composed of an Engine and a Layout. Indexing (or slicing) into a tensor works than such:

T(c) = (E o L)(c) = *(E + L(c))

which in words describes the operation: If we pass a coordinate to a Tensor the pointer E will be offset by L(c). The result of this operation will than be dereferenced.

The SwizzleAtom will therefore not contain the information about the datatype because this information is contained in the engine E. That is why we divide this information out and only care about the number of contiguous elements, not bits.

We should furthermore note that the Swizzle Atoms follow the mode of the Layout we want to create our Shared Memory Layout for. I.e. a K-Major Layout will get a Swizzled K-Major Shared Memory Layout and similar for M-Major.

Let us look at the Swizzling Atoms if the datatype is BFloat16:

K-Major BFloat16 atoms

Swizzle<0,4,3> o 0 o (8,8):(8,1)

0_4_3_K_major

Swizzle<1,4,3> o 0 o (8, 16):(16,1)

1_4_3_K_major

Swizzle<2,4,3> o 0 o (8, 32):(32,1)

2_4_3_K_major

Swizzle<3,4,3> o 0 o (8, 64):(64,1)

3_4_3_K_major

M-Major BFloat16 atoms

Swizzle<0,4,3> o 0 o (8,8):(8,1)

0_4_3_M_major

Swizzle<1,4,3> o 0 o (16,8):(1,16)

1_4_3_M_major

Swizzle<2,4,3> o 0 o (32,8):(1,32)

2_4_3_M_major

Swizzle<3,4,3> o 0 o (64,8):(1,64)

3_4_3_M_major

Covering a Tile with an Atom

So far we have covered the creation of Swizzle Atoms. However our tiling on a block level is oftentimes larger than the dimensions of the Swizzle atoms. This is where CuTe's tile_to_shape function comes in handy.

Looking again at the code

a_smem_layout_staged = cute.tile_to_shape(
	a_smem_layout_atom,
	cute.append(a_smem_shape, ab_stage),
	order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)

We see that tile_to_shape takes an atom, a shape and an order parameter.

Let us consider the atoms corresponding to Swizzle<1,4,3>. Once in K-Major and once in M-Major. For visualisation it is easier to consider the case where we only tile the SMEM shape with our atom.

K-Major

For K-Major our atom is Swizzle<1,4,3> o 0 o (8, 16):(16,1). Let us assume we want to tile a matrix of shape (32,32) with that atom. The order will be (0,1) because we are in K-Major.

We will now display once more the atom and than below it the matrix tiled with this atom.

1_4_3_K_major

1_4_3_K_major_32_32

We see that this Layout was obtained by first replicating the Layout along the first mode and than along the second mode. This meaning of the order parameter can also be found in the corresponding C++ implementation in layout.hpp

// tile_to_shape -- Perform a product of a layout so that the result matches a target shape.
// This is similar to blocked_product, but specifies the result shape instead of the
//   product shape, which is more convenient in certain circumstances.
// @param block The layout to repeat
// @param trg_shape The target shape of the result
// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with.
//                  Defaults to GenColMajor, so @a layout will repeat
//                    across the first mode first, the second mode second, etc
//                  E.g. Step<_2,_1,_3> will cause @a layout to repeat
//                    across the second mode first, the first mode second, and the third mode last.
// @pre rank(@a block) <= rank(@a trg_shape)
// @post compatible(@a trg_shape, shape(@a result))
template <class Shape, class Stride,
          class TrgShape, class ModeOrder = LayoutLeft>
CUTE_HOST_DEVICE constexpr
auto
tile_to_shape(Layout<Shape,Stride> const& block,
              TrgShape             const& trg_shape,
              ModeOrder            const& ord_shape = {})

Note that this pattern will be repeated 32/8 = 4 times in the down direction and than this pattern will be repeated 32/16 = 2 times in the right direction.

As explained above this pattern will repeat after the index 256, i.e. after we replicated it 256 / (8 * 16) = 2 times to the downwards direction.

We can also print out this Layout: S<1,4,3> o 0 o ((8,4),(16,2)):((16,128),(1,512)). This makes sense because our original atom was of shape (8,16) and we needed 4 tiles to cover the first mode and 2 tiles to cover the second mode. Let us derive visually the strides for the first mode (8,4):(d1,d2). d1 corresponds to the stride within two unswizzled columns in an atom and thats just 16. d2 is the stride between two unswizzled elements in one column and adjacent atoms. It can be directly read off as 128-0. Let us derive visually the strides for the second mode (16,2):(d1,d2). d1 is obviously 1 because the stride between two unswizzled adjacent elements within one row and atom is 1. Similar d2 is the number of steps we need to go from the first element in one atom to first element in the adjacent atom on the right, i.e. 512-0 = 512.

Note that we could also derive them without visuals by hand:

M-Major

For K-Major our atom is Swizzle<1,4,3> o 0 o (16, 1):(1,16). Let us assume we want to tile a matrix of shape (32,32) with that atom. The order will be (1,0) because we are in M-Major.

1_4_3_M_major

1_4_3_M_major_32_32

We see that here we first take the atom and tile it to the right. We than tile this to the down. Similar rule in repetitively as above applies.

We can also print out this Layout: S<1,4,3> o 0 o ((16,2),(8,4)):((1,512),(16,128)). This makes sense because our original atom was of shape (16,8) and we needed 2 tiles to cover the first mode and 4 tiles to cover the second mode. We can derive the strides in a similar manner as above. Of course here we can also derive just by hand calculation as above (do this as an exercise if you are unsure).

Extension to 3 dimensions

If we use the stage as an additional part of the shape the tile to shape functionality will work exactly the same. Only after covering the plane with the Atom we will extend this into the third dimension (i.e. have a bunch of planes all with similar pattern).

Bank Conflicts

In this last section we will analyse bank conflicts. Let's look at the tile (64,64) covered with K-Major BFloat16 Atoms and plot the distribution onto the banks.

Note that for BFloat16 two elements will be in one bank because a bank consists of a 32 bit word.

Note that our swizzle atom has Layout Swizzle<3,4,3> o 0 o (8,64):(64,1) so we will only extend into the first mode (8 times).

3_4_3_K_major_64_64_Bank_Conflict

We see that in the K-Mode we don't have any bank conflicts (note that two consecutive elements with the same number belong to one bank here because we deal with BFloat16). In the M-Mode we have some bank conflicts but not as severe as in the unswizzled Layout. For the unswizzled Layout we would have the following situation:

Unswizzled_64_64_K_major_Bank_Conflict

Here we have severe bank conflicts. Essentially any slice operation on a specific column (i.e. sth like A[None, 0]) will result in as many bank conflicts as possible, and will lead the GPU to serialize these memory accesses and less performance.

Conclusion

I hope this served as a good introduction into Swizzling in general with a focus on how it is used in performant CuTeDSL kernels to construct the SMEM Layout. Please contact me on Linkedin if you are interested into exchange of ideas. The code to reproduce the figures shown in this blog is available in my Github repo.

Shared Memory Bank Conflict is explained in good way by Lei Mao and Axel Feldmann. Lei Mao gives also an introduction to CuTe Swizzle with a more rigorous flavour than this blog. The visualisations in this blog are adapted from code that Cris Cecka provided in a CUTLASS issue on Github.