simons blog

MMA Atoms in CuTe

Introduction

MMA Atoms are fundamental building blocks of CuTe kernels and therefore have a dedicated section in the CuTe docs as well. In this blogpost we offer a different angle on the same topic taking a different "route of explanation". For better understanding we take one of the examples that is used in the CuTe docs as well.

Understanding MMA Atoms

The following program will print a LaTeX file to the screen that can be in turn used to generate a visualization of the MMA Atom.

#include <cute/tensor.hpp>

int main() {
    using namespace cute;

    MMA_Atom mma = MMA_Atom<SM70_8x8x4_F32F16F16F32_NT>{};
    print_latex(mma);
    return 0;
}

MMA 1

To make sense of this it is useful to first look at the code:

In cutlass/include/cute/arch/mma_sm70.hpp we can find:

struct SM70_8x8x4_F32F16F16F32_NT
{
  using DRegisters = float[8];
  using ARegisters = uint32_t[2];
  using BRegisters = uint32_t[2];
  using CRegisters = float[8];

  // Register asm fma
  CUTE_HOST_DEVICE static void
  fma(float         & d0, float         & d1, float      & d2, float      & d3,
      float         & d4, float         & d5, float      & d6, float      & d7,
      uint32_t const& a0, uint32_t const& a1,
      uint32_t const& b0, uint32_t const& b1,
      float    const& c0, float    const& c1, float const& c2, float const& c3,
      float    const& c4, float    const& c5, float const& c6, float const& c7)
  {
#if defined(CUTE_ARCH_MMA_SM70_ENABLED)
    asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32"
                 "{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7},"
                 "{%8,  %9},"
                 "{%10, %11},"
                 "{%12, %13, %14, %15, %16, %17, %18, %19};"
        : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3),
          "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7)
        :  "r"(a0),  "r"(a1),
           "r"(b0),  "r"(b1),
           "f"(c0),  "f"(c1),  "f"(c2),  "f"(c3),
           "f"(c4),  "f"(c5),  "f"(c6),  "f"(c7));
#else
    CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED");
#endif
  }
};

We see that this essentially a wrapper around a PTX instruction which can be understood as follows:

mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32

- A: 8 x 4 -> Float 16, column major
- B: 4 x 8 -> Float 16, row major (same as 8 x 4 in column major)
- C: 8 x 8 -> Float 32
- D: 8 x 8 -> Float 32

Looking at section 9.7.14.5.1 we see that 8 threads will perform one such MMA and that the values for A and B should be provided as a vector expression containing two .f16x2 registers.

This is exactly what we can see in the above PTX instruction which gets passed 8 * 8 / 8 = 8 values for D/C and 8 * 4 / 8 / 2 = 2 values for A/B.

In PTX docs we can furthermore find Layout for our MMA instruction. Note that PTX docs informs us that Elements of 4 matrices need to be distributed across the threads in a warp. The following table shows distribution of matrices for MMA operations.

MMA computation 1: Threads 0-3, Threads 16-19
MMA computation 2: Threads 4-7, Threads 20-23
MMA computation 3: Threads 8-11, Threads 24-27
MMA computation 4: Threads 12-15, Threads 28-31

A Layout

A_Layout

Note that the above picture for MMA computation 1 is equivalent to the situation we see above in our picture. We can derive the TV Layout:

To derive Thread Layout we note that it consists for one computation of 8 threads which can be split up into two groups. So we have two modes for the thread layout and need to derive the shapes and strides for each mode.

(T0, a0) -> (T1, a0)

takes 8 steps. We have 4 elements (T0, a0), (T1, a0), (T2, a0), (T3, a0) in this group.

(T0, a0) -> (T16, a0)

takes 4 steps. We have 2 elements (T0, a0), (T16, a0) in this group.

The value layout is simple because each thread has just 4 values which have a stride of 1 in between. This gives us ((4, 2), 4):((8,4),1) for the TV Layout of A.

B Layout

From the PTX docs:

Layout B

Unsurprisingly this again agrees with what we see above. Note that equivalently we may interpret B has a 8 x 4 matrix in column major. To see why that is true consider:

1) M: N x K in column major
2) M^T K x N in row major

Coordinate mapping:
for (i, j) in {0, ..., N-1} x {0, ..., K-1}
(i, j) -> j * N + i
(j, i) -> j * N + i

We than can use same TV Layout as above for A.

C Layout

Layout C

Once more we see that MMA computation 1 matches the above depiction.

Where do the Layouts come from?

In cutlass/include/cute/atom/mma_atom.hpp we find:

template <class MMAOperation>
struct MMA_Atom<MMAOperation> : MMA_Atom<MMA_Traits<MMAOperation>>
{};

in cutlass/include/cute/atom/mma_traits_sm70.hpp we find

namespace {
// Logical thread id to thread idx (quadpair)
using SM70_QuadPair = Layout<Shape <_4, _2>,
                             Stride<_1,_16>>;
// (T8,V4) -> (M8,K4)
using SM70_8x4_Row  = Layout<Shape <_8,_4>,
                             Stride<_1,_8>>;
// (T8,V4) -> (M8,K4)
using SM70_8x4_Col  = Layout<Shape <Shape <_4,_2>,_4>,
                             Stride<Stride<_8,_4>,_1>>;
// (T8,V8) -> (M8,N8)
using SM70_8x8_16b  = Layout<Shape <_8,_8>,
                             Stride<_1,_8>>;
// (T8,V8) -> (M8,N8)
using SM70_8x8_32b  = Layout<Shape <Shape <_2, _2,_2>,Shape <_2,_2, _2>>,
                             Stride<Stride<_1,_16,_4>,Stride<_8,_2,_32>>>;
} 
...
template <>
struct MMA_Traits<SM70_8x8x4_F32F16F16F32_NT>
{
  using ValTypeD = float;
  using ValTypeA = half_t;
  using ValTypeB = half_t;
  using ValTypeC = float;

  using Shape_MNK = Shape<_8,_8,_4>;
  using ThrID   = SM70_QuadPair;
  using ALayout = SM70_8x4_Col;
  using BLayout = SM70_8x4_Col;
  using CLayout = SM70_8x8_32b;
};

we see that this is where the Layout information is hardcoded into the trait for the operation. Note that it although B has same TV Layout as A we display it's transpose above in the picture. We explained above why the two representations (i.e. 4 x 8 row major and 8 x 4 column major) are equivalent. Note that the trait also contains information about the value types of the matrices as well as the general problem shape. The ThrId is needed for mapping within one MMA computation group. Within one group we have a total of 8 threads. They are spitted into two groups and within each group we have a stride of 1 (from 0 to 1 etc) whereas to go from one group to the other we have a stride of 16 (from 0 to 16 etc).

Visualization

Code for visualization can be found in cutlass/include/cute/util/print_latex.hpp.

template <class LayoutC, class LayoutA, class LayoutB, class Tile_MNK,
          class TikzColorFn = TikzColor_TV>
CUTE_HOST_DEVICE
void
print_latex_mma(LayoutC const& C,         // (tid,vid) -> (m,n) coord
                LayoutA const& A,         // (tid,vid) -> (m,k) coord
                LayoutB const& B,         // (tid,vid) -> (n,k) coord
                Tile_MNK const& tile_mnk, // (M,N,K)
                TikzColorFn color = {})   // lambda(tid,vid) -> tikz color string
{
  CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{});
  CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{});
  CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{});

  // Commented prints
  printf("%% LayoutC: "); print(C); printf("\n");
  printf("%% LayoutA: "); print(A); printf("\n");
  printf("%% LayoutB: "); print(B); printf("\n");
  // Header
  printf("\\documentclass[convert]{standalone}\n"
         "\\usepackage{tikz}\n\n"
         "\\begin{document}\n"
         "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n");

  auto [M, N, K] = product_each(shape(tile_mnk));
  Tensor filled = make_tensor<bool>(make_shape(M, N, K));
  clear(filled);

  // C starting at 0,0
  for (int tid = 0; tid < size<0>(C); ++tid) {
    for (int vid = 0; vid < size<1>(C); ++vid) {
      auto [m, n] = C(tid, vid);
      if (not filled(m, n, 0)) {
        filled(m, n, 0) = true;
        printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
               color(tid, vid),
               int(m), int(n),
               tid, vid);
      }
    }
  }
  // Grid
  printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
          0, 0, int(M), int(N));

  clear(filled);

  // A starting at 0,-K-1
  for (int tid = 0; tid < size<0>(A); ++tid) {
    for (int vid = 0; vid < size<1>(A); ++vid) {
      auto [m, k] = A(tid, vid);
      if (not filled(m, 0, k)) {
        filled(m, 0, k) = true;
        printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
               color(tid, vid),
               int(m), int(k-K-1),
               tid, vid);
      }
    }
  }
  // Grid
  printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
         0, -int(K)-1, int(M), -1);
  // A labels
  for (int m =  0, k = -1; m < M; ++m) {
    printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), m);
  }
  for (int m = -1, k =  0; k < K; ++k) {
    printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, int(k-K-1), k);
  }

  clear(filled);

  // B starting at -K-1,0
  for (int tid = 0; tid < size<0>(B); ++tid) {
    for (int vid = 0; vid < size<1>(B); ++vid) {
      auto [n, k] = B(tid, vid);
      if (not filled(0, n, k)) {
        filled(0, n, k) = true;
        printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
               color(tid, vid),
               int(k)-int(K)-1, int(n),
               tid, vid);
      }
    }
  }
  // Grid
  printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n",
         -int(K)-1, 0, -1, int(N));
  // B labels
  for (int n =  0, k = -1; n < N; ++n) {
    printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, n);
  }
  for (int n = -1, k =  0; k < K; ++k) {
    printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", int(k-K-1), n, k);
  }

  // Footer
  printf("\\end{tikzpicture}\n"
         "\\end{document}\n");
}

We see that

auto [n, k] = B(tid, vid);
...
printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n",
               color(tid, vid),
               int(k)-int(K)-1, int(n),
               tid, vid);

so the Layout gives us (n, k) coordinate, but we display it in the equivalent view as (k, n) row major.

Conclusion

I hope this can serve as a helpful complementary guide to the CUTLASS explanation of MMA Atoms. The same principles of analysis can of course be applied to other MMA Atoms. If you like to exchange ideas on similar topics feel free to connect to me on Linkedin.