Bridging Math and Code: CuTe Layout Algebra in CuTeDSL
Complementation
Introduction
This week the CUTLASS
team released a new version of CUTLASS
which introduces CuTeDSL
a python interface which gives the user access to * core concepts such as layouts, tensors, hardware atoms, and full control over the hardware thread and data hierarchy* as can be read in their dedicated documentation.
In this blogpost we aim to give a basic understanding of some principles of CuTe
layout algebra.
We will explain some basic concepts like the layout function, coalescing and complementation. This can be helpful for a deeper understanding of the CuTeDSL
.
Layout
We define a Layout as
here
is the shape where is the size of the layout.
is the stride.
The pairs
are called modes and can be considered layouts of length 1. In general we define the length to be the number above.
Layout function
Let
which is an isomorphism . It maps a 1d index to a multidimensional index.
The layout function is than given as the mapping which applies the isomorphism to a number (if needed), multiplies each vector entry with the corresponding stride entry and sums these up, i.e.
Let's consider as an example the layout .
Let's calculate .
from here we get
We can calculate this in CuTeDSL
:
import cutlass
import cutlass.cute as cute
@cute.jit
def layout_function_example():
"""
Layout function in cutlass
"""
S = (2, 4)
D = (2, 2)
L = cute.make_layout(shape=S, stride=D)
for i in cutlass.range_constexpr(cute.size(S)):
cute.printf("fL({}) = {}", i, L(i))
layout_function_example()
This will print:
fL(0) = 0
fL(1) = 2
fL(2) = 2
fL(3) = 4
fL(4) = 4
fL(5) = 6
fL(6) = 6
fL(7) = 8
Sorted layouts
We define a sorted layout such that all the strides are increasing, i.e. .
Sorting doesn't leave a layout invariant. Cosider the following representations:
the corresponding layout functions don't coincide, which we can easily verify with CuTeDSL
:
import cutlass
import cutlass.cute as cute
@cute.jit
def sorted_example():
"""
Sorting in cutlass
"""
S1 = (2, 2)
D1 = (3, 1)
L1 = cute.make_layout(shape=S1, stride=D1)
S2 = (2, 2)
D2 = (1, 3)
L2 = cute.make_layout(shape=S2, stride=D2)
for i in cutlass.range_constexpr(cute.size(S1)):
cute.printf("fL1({}) = {}, fL2({}) = {}", i, L1(i), i, L2(i))
sorted_example()
which will print:
fL1(0) = 0, fL2(0) = 0
fL1(1) = 3, fL2(1) = 1
fL1(2) = 1, fL2(2) = 3
fL1(3) = 4, fL2(3) = 4
which we could have of course also verified by hand for this simple example. We may also understand that by looking at the above example and seeing that is row major while is column major.
Coalescing
Coalescing is an operation that
- Preserves the size of the layout
- Preserves the associated layout function.
Let's take a simple example:
By looking at the above definition of we can recognize that will always be , i.e. it won't contribute to the layout function for any value , i.e.
We can verify this:
import cutlass
import cutlass.cute as cute
@cute.jit
def coalesce_example():
"""
Coalesce in cutlass
"""
S = (2, 1)
D = (3, 1)
L = cute.make_layout(shape=S, stride=D)
cL = cute.coalesce(L)
cute.printf("L = {}, cL = {}", L, cL)
coalesce_example()
which will print
L = (2,1):(3,1), cL = 2:3
Complementation
Admissability
Let be a layout and be a positive integer. is admissable for complementation if the following conditions are satisfied:
- divides
- divides
Complement
If is admissable we define the complement operation as follows:
Let's take simple example:
, .
is admissable.
We can calculate the complement:
using same argument as above we can coalesce this to
We can calculate this as follows:
import cutlass
import cutlass.cute as cute
@cute.jit
def complement_example():
"""
Complement in cutlass
"""
S = (2, 4)
D = (1, 2)
L = cute.make_layout(shape=S, stride=D)
K = 16
cL = cute.complement(L, K)
cute.printf("L = {}, cL = {}", L, cL)
complement_example()
which will give the same result as above:
L = (2,4):(1,2), cL = 2:8
We can interpret the complement as follows:
Let be the concatenation of the layout and it's complement. The concatenation can be simply formed by combining all the modes into one layout. For the above example that is:
we can proove that this gives us a bijection . Please see proposition 2.7 of this blogpost.
We can verify this using CuTeDSL
:
import cutlass
import cutlass.cute as cute
@cute.jit
def complement_example2():
"""
Complement in cutlass
"""
S = (2, 4, 2)
D = (1, 2, 8)
L = cute.make_layout(shape=S, stride=D)
for i in cutlass.range_constexpr(cute.size(L)):
cute.printf("{} -> {}", i, L(i))
complement_example2()
0 -> 0
1 -> 1
2 -> 2
3 -> 3
4 -> 4
5 -> 5
6 -> 6
7 -> 7
8 -> 8
9 -> 9
10 -> 10
11 -> 11
12 -> 12
13 -> 13
14 -> 14
15 -> 15
Note that the bijection not necessarily is the identity.
For example take
and .
is admissable because divides 32.
The concatenated layout is
and printing out the values of the layout function will give us a bijection unequal to the identity:
0 -> 0
1 -> 2
2 -> 4
3 -> 6
4 -> 8
5 -> 10
6 -> 12
7 -> 14
8 -> 1
9 -> 3
10 -> 5
11 -> 7
12 -> 9
13 -> 11
14 -> 13
15 -> 15
16 -> 16
17 -> 18
18 -> 20
19 -> 22
20 -> 24
21 -> 26
22 -> 28
23 -> 30
24 -> 17
25 -> 19
26 -> 21
27 -> 23
28 -> 25
29 -> 27
30 -> 29
31 -> 31
Conclusion
I hope this blogpost can serve as an easy intro to the CuTeDSL
by connecting mathematical concepts with programming.
For a deeper mathematical explanation of the concepts see Lei Maos blogposts and Jay Shahs note on CuTe layout algebra.
For further examples of CuTeDSL
see the CUTLASS repo.