Swizzles and their usage in CuTeDSL Kernels
Introduction
Swizzling is used in performant GEMM kernels to design the layout of shared memory and therefore highly important. Essentially swizzling deals with the problem of shared memory bank conflicts we would face when dealing with ordinary, unswizzled Layouts. However it can be hard to understand in the beginning. In this blogpost I aim to give a brief overview of how swizzling is used and implemented in the CuTeDSL
.
Composed Layout
Swizzles can be applied via so called ComposedLayout
in CuTeDSL
.
The composed Layout is defined via the following formula:
R(c) = (inner o offset o outer)(c) = inner(offset + outer(c))
A typical example to create a swizzled Layout would be
sw = cute.make_swizzle(b, m, s) # BBits, MBase, SShift
L_swizzled = cute.make_composed_layout(sw, 0, L) # (sw o 0 o L)
here the composed Layout is than given by R(c) = sw(L(c))
.
Let us consider the following concrete example:
@cute.jit
def simple_swizzle(
S: cute.Shape, D: cute.Stride, bms: cute.IntTuple, coord: cute.IntTuple
):
L = cute.make_layout(S, stride=D)
b, m, s = bms[0], bms[1], bms[2]
sw = cute.make_swizzle(b, m, s)
L_swizzled = cute.make_composed_layout(sw, 0, L)
print(coord)
print(cute.crd2idx(coord, L))
print(cute.crd2idx(coord, L_swizzled))
if __name__ == "__main__":
S = (8, 32)
D = (32, 1)
bms = (2, 4, 3)
coord = (7, 25)
simple_swizzle(S, D, bms, coord)
For c = (7, 25)
the ordinary Layout will map as follows:
L((7, 25)) = 7 * 32 + 1 * 25 = 249
We can write in binary: 249 = 0b0000000011111001
. In a previous blogpost I have shown that Swizzle<2,4,3>
will act on a number of the form 0bxxxxxxxUVxXYxxxx
such that X
will get flipped if U=1
and Y
will get flipped if V=1
.
Compare these two forms:
0b0000000011111001
0bxxxxxxxUVxXYxxxx
We see that U = 0, V = 1
, so we will flip Y
, which means
R(c) = sw(L(c)) = sw(0b0000000011111001) = 0b0000000011101001 = 233
You can run the above code to verify for yourself that this is exactly what we get.
Swizzling in CuTe Kernels
In CuTeDSL
the Swizzle
is used to construct the Shared Memory Layouts, which are than used to construct the TMA atoms.
In the Hopper example the Shared Memory Layouts are constructed within _make_smem_layouts
. Here we use the following logic to obtain them:
a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0]
a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
sm90_utils.get_smem_layout_atom(
a_layout,
a_dtype,
a_major_mode_size,
),
a_dtype,
)
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, ab_stage),
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)
This tells us the two stage approach that is used to build them:
- We construct a
smem_layout_atom
. - We use
tile_to_shape
to cover the wholesmem_shape
concatenated with the number of stages.
I will now analyse each of these stages and assume for simplicity that stage = 1
. This is not important for the understanding the process because it trivially extends from two to three dimensions.
Swizzle Atom
First we get the appropriate swizzling mode. The swizzling mode will be determined by the number of bits of elements in the Tensor
we want to create the Shared Memory Layout for and the Major Mode Size
.
def get_smem_layout_atom(
layout: LayoutEnum,
element_type: Type[Numeric],
major_mode_size: int,
*,
loc=None,
ip=None,
):
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
sw32_num_contiguous_bits = 256
major_mode_size_bits = major_mode_size * element_type.width
if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER
We see that we calculate the number of bits along the Major Mode
.
For example take BFloat16
as datatype. Than the number of bits in one element will be obviously 16
. If we than have that the Layout is (8, 32):(32, 1)
we have that this is K-Major
, and the major_mode_size = 32
accordingly. This will than give us major_mode_size_bits = 32 * 16 = 512
. Accordingly we choose cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
as the SmemLayoutAtomKind
.
This is than passed to make_smem_layout_atom
which looks as follows:
@dsl_user_op
def make_smem_layout_atom(
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
) -> core.ComposedLayout:
if not isinstance(element_type, NumericMeta):
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
num_contiguous_bits = 128
sw = core.make_swizzle(0, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
num_contiguous_bits = 256
sw = core.make_swizzle(1, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
num_contiguous_bits = 512
sw = core.make_swizzle(2, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
num_contiguous_bits = 1024
sw = core.make_swizzle(3, 4, 3)
else:
raise ValueError("unrecognized SMEM layout atom kind")
num_contiguous_elems = num_contiguous_bits // element_type.width
if kind in (
SmemLayoutAtomKind.MN_INTER,
SmemLayoutAtomKind.MN_SW32,
SmemLayoutAtomKind.MN_SW64,
SmemLayoutAtomKind.MN_SW128,
):
# M/N-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
),
loc=loc,
ip=ip,
)
else:
# K-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
),
loc=loc,
ip=ip,
)
This will simply create a ComposedLayout
which will behave like the one we have seen above.
Let us analyse in more detail, how the Layout is created. From the SmemLayoutAtomKind
we know how many bits are in one slice of the major mode. We choose the according Swizzle Pattern.
As we have learned in past the Swizzle Patterns above have interesting structure and hierarchy:
Swizzle<0,4,3>
is the identity. It won't affect the Layout it acts on inComposedLayout
at all.Swizzle<1,4,3>
will act on numbers of the form0bxxxxxxxxXxxZxxxx
(where we assume16bit
numbers here, it will work the same way for32bit
numbers) thatZ
gets flipped whenX = 1
. This pattern will obviously repeat after we reach0b0000000100000000 = 256bit = (256/8)B = 32B
because only the 8 rightmost bits determine the swizzling pattern and the other bits effectively just add their part to the swizzled first8 bit
.Swizzle<2,4,3>
will act on numbers of the form0bxxxxxxxUVxYZxxxx
such thatY
will flip iffU = 1
andZ
will flip iffV = 1
. This pattern will repeat after0b0000001000000000 = 512bit = (512/8)B = 64B
with similar reasoning as above.Swizzle<3,4,3>
will act on numbers of the form0bxxxxxxPQRSTUxxxx
such thatS
will flip iffP = 1
,T
will flip iffQ = 1
andU
will flip iffR = 1
. This pattern will repeat after0b0000010000000000 = 1024bit = (1024/8)B = 128B
with similar reasoning as above.
Note that in the code above we adjust
num_contiguous_elems = num_contiguous_bits // element_type.width
that is because the Layout will be attached to a Tensor
.
As we have learned above a Tensor
in CuTeDSL
is composed of an Engine
and a Layout
. Indexing (or slicing) into a tensor works than such:
T(c) = (E o L)(c) = *(E + L(c))
which in words describes the operation: If we pass a coordinate to a Tensor the pointer E
will be offset by L(c)
. The result of this operation will than be dereferenced.
The SwizzleAtom
will therefore not contain the information about the datatype
because this information is contained in the engine E
. That is why we divide this information out and only care about the number of contiguous elements, not bits.
We should furthermore note that the Swizzle Atoms follow the mode of the Layout we want to create our Shared Memory Layout for. I.e. a K-Major
Layout will get a Swizzled K-Major Shared Memory Layout
and similar for M-Major
.
Let us look at the Swizzling Atoms if the datatype is BFloat16
:
K-Major BFloat16
atoms
Swizzle<0,4,3> o 0 o (8,8):(8,1)
Swizzle<1,4,3> o 0 o (8, 16):(16,1)
Swizzle<2,4,3> o 0 o (8, 32):(32,1)
Swizzle<3,4,3> o 0 o (8, 64):(64,1)
M-Major BFloat16
atoms
Swizzle<0,4,3> o 0 o (8,8):(8,1)
Swizzle<1,4,3> o 0 o (16,8):(1,16)
Swizzle<2,4,3> o 0 o (32,8):(1,32)
Swizzle<3,4,3> o 0 o (64,8):(1,64)
Covering a Tile with an Atom
So far we have covered the creation of Swizzle Atoms
. However our tiling on a block level is oftentimes larger than the dimensions of the Swizzle atoms. This is where CuTe
's tile_to_shape
function comes in handy.
Looking again at the code
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, ab_stage),
order=(0, 1, 2) if a_is_k_major else (1, 0, 2),
)
We see that tile_to_shape
takes an atom, a shape and an order parameter.
- The atom is the Swizzle Atom
- The shape is the shape for the
SMEM
concatenated with the number of stages - The order parameter will be explained visually below.
Let us consider the atoms corresponding to Swizzle<1,4,3>
. Once in K-Major
and once in M-Major
. For visualisation it is easier to consider the case where we only tile the SMEM
shape with our atom.
K-Major
For K-Major our atom is Swizzle<1,4,3> o 0 o (8, 16):(16,1)
.
Let us assume we want to tile a matrix of shape (32,32)
with that atom. The order will be (0,1)
because we are in K-Major
.
We will now display once more the atom and than below it the matrix tiled with this atom.
We see that this Layout was obtained by first replicating the Layout along the first mode and than along the second mode. This meaning of the order parameter can also be found in the corresponding C++
implementation in layout.hpp
// tile_to_shape -- Perform a product of a layout so that the result matches a target shape.
// This is similar to blocked_product, but specifies the result shape instead of the
// product shape, which is more convenient in certain circumstances.
// @param block The layout to repeat
// @param trg_shape The target shape of the result
// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with.
// Defaults to GenColMajor, so @a layout will repeat
// across the first mode first, the second mode second, etc
// E.g. Step<_2,_1,_3> will cause @a layout to repeat
// across the second mode first, the first mode second, and the third mode last.
// @pre rank(@a block) <= rank(@a trg_shape)
// @post compatible(@a trg_shape, shape(@a result))
template <class Shape, class Stride,
class TrgShape, class ModeOrder = LayoutLeft>
CUTE_HOST_DEVICE constexpr
auto
tile_to_shape(Layout<Shape,Stride> const& block,
TrgShape const& trg_shape,
ModeOrder const& ord_shape = {})
Note that this pattern will be repeated 32/8 = 4
times in the down direction and than this pattern will be repeated 32/16 = 2
times in the right direction.
As explained above this pattern will repeat after the index 256
, i.e. after we replicated it 256 / (8 * 16) = 2
times to the downwards direction.
We can also print out this Layout: S<1,4,3> o 0 o ((8,4),(16,2)):((16,128),(1,512))
. This makes sense because our original atom was of shape (8,16)
and we needed 4
tiles to cover the first mode and 2
tiles to cover the second mode.
Let us derive visually the strides for the first mode (8,4):(d1,d2)
. d1
corresponds to the stride within two unswizzled columns in an atom and thats just 16
. d2
is the stride between two unswizzled elements in one column and adjacent atoms. It can be directly read off as 128-0
.
Let us derive visually the strides for the second mode (16,2):(d1,d2)
. d1
is obviously 1
because the stride between two unswizzled adjacent elements within one row and atom is 1
. Similar d2
is the number of steps we need to go from the first element in one atom to first element in the adjacent atom on the right, i.e. 512-0 = 512
.
Note that we could also derive them without visuals by hand:
- Our atom Layout is
(8,16):(16,1)
. - We first fill along the first mode. We will have four atoms (because
32/8 = 4
) to cover that mode. The stride between each of these atoms along the first mode is8 * 16 = 128
for the unswizzled version. Thus we have a Layout along that mode of(8,4):(16,128)
- We than fill the second mode with these atoms. We have two atoms (because
32/16 = 2
) and the stride between them will be8 * 16 * 4 = 512
because we layed out4
atoms before we lay out the first one in the directon of the second mode. Thus we have a Layout of(16,2):(1, 16)
into this direction. - We'll apply swizzle to the whole Layout we just obtained: Thus
Swizzle<1,4,3> o 0 o ((8,16),(16,2)):((16,128),(1,512))
.
M-Major
For K-Major our atom is Swizzle<1,4,3> o 0 o (16, 1):(1,16)
.
Let us assume we want to tile a matrix of shape (32,32)
with that atom. The order will be (1,0)
because we are in M-Major
.
We see that here we first take the atom and tile it to the right. We than tile this to the down. Similar rule in repetitively as above applies.
We can also print out this Layout: S<1,4,3> o 0 o ((16,2),(8,4)):((1,512),(16,128))
. This makes sense because our original atom was of shape (16,8)
and we needed 2
tiles to cover the first mode and 4
tiles to cover the second mode.
We can derive the strides in a similar manner as above. Of course here we can also derive just by hand calculation as above (do this as an exercise if you are unsure).
Extension to 3 dimensions
If we use the stage as an additional part of the shape the tile to shape functionality will work exactly the same. Only after covering the plane with the Atom we will extend this into the third dimension (i.e. have a bunch of planes all with similar pattern).
Bank Conflicts
In this last section we will analyse bank conflicts.
Let's look at the tile (64,64)
covered with K-Major BFloat16
Atoms and plot the distribution onto the banks.
Note that for BFloat16
two elements will be in one bank because a bank consists of a 32
bit word.
Note that our swizzle atom has Layout Swizzle<3,4,3> o 0 o (8,64):(64,1)
so we will only extend into the first mode (8 times
).
We see that in the K-Mode
we don't have any bank conflicts (note that two consecutive elements with the same number belong to one bank here because we deal with BFloat16
). In the M-Mode
we have some bank conflicts but not as severe as in the unswizzled Layout. For the unswizzled Layout we would have the following situation:
Here we have severe bank conflicts. Essentially any slice operation on a specific column (i.e. sth like A[None, 0]
) will result in as many bank conflicts as possible, and will lead the GPU to serialize these memory accesses and less performance.
Conclusion
I hope this served as a good introduction into Swizzling in general with a focus on how it is used in performant CuTeDSL
kernels to construct the SMEM Layout
. Please contact me on Linkedin if you are interested into exchange of ideas. The code to reproduce the figures shown in this blog is available in my Github repo.
Shared Memory Bank Conflict is explained in good way by Lei Mao and Axel Feldmann. Lei Mao gives also an introduction to CuTe Swizzle
with a more rigorous flavour than this blog. The visualisations in this blog are adapted from code that Cris Cecka provided in a CUTLASS
issue on Github.