simons blog

Tensors Slicing in CuTe

Introduction

The task of Tensor slicing is fundamental to many applications in machine learning. For example we might access a certain tile in a GEMM kernel by indexing into one dimension and slicing the other. However it is sometimes taken for granted that this operation simply is possible and some work is needed to understand how this is implemented in practice. In this short note I aim to explain how Tensor Slicing is done in CuTe and calculate a few Tensor slices "by hand" for better understanding.

Tensors in CuTe

We can find in the CuTeDSL examples that in the CuTeDSL a Tensor is formally defined as

T(c) = (E o L)(c) = *(E + L(c))

where E is the Engine and L is the Layout.

For now we can think of the Engine the same way we think about a pointer. A pointer stores a memory address as its value, and we can dereference the pointer to obtain the value stored at that address. Furthermore, we can use pointer arithmetic to calculate new addresses - offsetting forward or backward from the original address to access different memory locations. When we perform operations like ptr + n, we create a new address that's offset by n units of the pointed-to type's size.

The Layout is a map from [0, M) -> [0, M1) x ... x [0, Mn) -> N where M is the size of the layout. In the context of a tensor the associated Layout is used such that we can provide a coordinate to the Tensor. The underlying Layout than contains the information by "how far" we need to offset the pointer to the first element to access the element associated with the coordinate.

A simple example

Consider the following simple example:

@cute.jit
def simple_offset(src: cute.Tensor):
    cute.print_tensor(src)
    cute.printf(src[1, 1])

This will output something like the below:

tensor(raw_ptr(0x0000000045abf1c0: f32, generic, align<4>) o (2,2):(2,1), data=
       [[ 1.125501, -0.262254, ],
        [-0.393889, -1.043588, ]])
-1.043588

Note that 0x0000000045abf1c0 is a 64 bit address of type f32 in memory space generic with align<4>. We could specify an alignment when reading the tensor with a method like from_dlpack, the default is just that the alignment is taken to the size of our datatype, in this case 4.

When we provide the coordinate c = (1,1) the Layout will calculate the corresponding offset L(c) = 1 * 2 + 1 = 3. Our datatype here is f32, so the corresponding address can be calculated as follows: 0x0000000045abf1c0 + 3 * sizeof(f32) = 0x0000000045abf1c0 + 0x000000000000000c = 0x0000000045abf1cc. We than dereference this to obtain the corresponding element in the lower right of the matrix -1.043588.

Tensor slicing

In CuTeDSL we can slice a tensor by providing None into the corresponding coordinate, i.e. (None, 1) would give us the second column of the matrix.

Slicing will give us another Tensor that can be calculated as follows:

L' = (2):(1)
L'' = (2):(2)

L'(1) = 1.

T' = (E + L'(1)) o L''

We can verify our reasoning:

@cute.jit
def simple_offset(src: cute.Tensor):
    cute.print_tensor(src)
    cute.print_tensor(src[None, 1])

This will output

tensor(raw_ptr(0x0000000029b5c000: f32, generic, align<4>) o (2,2):(2,1), data=
       [[ 1.418778,  0.503520, ],
        [-0.635310, -0.606532, ]])
tensor(raw_ptr(0x0000000029b5c004: f32, generic, align<4>) o (2):(2), data=
       [ 0.503520, ],
       [-0.606532, ])

Note that 0x0000000029b5c004 = 0x0000000029b5c000 + 0x0000000000000004 = 0x0000000029b5c000 + sizeof(f32).

To understand better once more let us consider a 3d Tensor:

tensor(raw_ptr(0x000000001db97300: f32, generic, align<4>) o (2,2,2):(4,2,1), data=
       [[[ 0.986613,  1.827215, ],
         [-0.073819, -1.699444, ]],

        [[ 1.114080, -2.388388, ],
         [-0.561012, -0.000348, ]]])

What will be obtain when we perform the following slice: (1, None, None)?

L' = (2):(4)
L'' = (2,2):(2,1)

L'(1) = 4

T' = (E + L'(1)) o L''

When performing the slicing we will obtain

tensor(raw_ptr(0x000000001db97310: f32, generic, align<4>) o (2,2):(2,1), data=
       [[-0.073819, -0.561012, ],
        [-1.699444, -0.000348, ]])

We have 0x000000001db97300 + 4 * sizeof(f32) = 0x000000001db97300 + 0x0000000000000010 = 0x000000001db97310. The Layout we compose that with is the same as the one we obtained by hand calculation.

Let us do one more exercise before we finish the blogpost. Let us consider a 3D Tensor with same Layout as above and perform the slice (0, None, 1) on it.

L' = (2,2):(4, 1)
L'' = (2):(2)

L'((0, 1)) = 0 + 1 = 1

T' = (E + L'((0, 1))) o L''

The output will be:

tensor(raw_ptr(0x00000000222f6ec0: f32, generic, align<4>) o (2,2,2):(4,2,1), data=
       [[[ 0.029855, -0.207999, ],
         [ 0.170911, -0.917588, ]],

        [[ 0.904916,  0.288193, ],
         [-1.245411, -2.225127, ]]])
tensor(raw_ptr(0x00000000222f6ec4: f32, generic, align<4>) o (2):(2), data=
       [ 0.904916, ],
       [ 0.288193, ])

Which is exactly what we expected because 0x00000000222f6ec4 - 0x00000000222f6ec0 = 0x0000000000000004 = 1 * sizeof(f32).

Conclusion

I hope this blogpost made Tensor Slicing in CuTeDSL more accessible. Feel free to reach out to me via Linkedin to exchange ideas.