simons blog

Use PTX instructions in Mojo

Custom PTX instructions are a powerful way of leveraging the newest hardware features of NVIDIA GPUs and often needed to archive peak performance in kernels. In this blogpost we aim to give a guide on how call custom PTX instruction in Mojo and potentially including them into the Open Source std of Mojo. We do this by concretely analysing the tanh function from the Mojo std library and implementing a missing instruction ourselves.

Introduction

We start by creating a project which we work in:

magic init intro-mojo-ptx --format mojoproject

We will consider the following program:

from math import ceildiv, tanh
from memory import UnsafePointer
from gpu import thread_idx, block_idx, block_dim
from gpu.host import DeviceContext
from random import randn
from testing import assert_almost_equal

alias SIZE = 4096
alias THREADS_PER_BLOCK = 512
alias BLOCKS_PER_GRID = ceildiv(SIZE, THREADS_PER_BLOCK)
alias dtype = DType.bfloat16


fn simple_tanh(
    out: UnsafePointer[Scalar[dtype]],
    a: UnsafePointer[Scalar[dtype]],
    size: Int,
):
    i = block_dim.x * block_idx.x + thread_idx.x
    if i < size:
        out[i] = tanh(a[i])


def main():
    with DeviceContext() as ctx:
        out = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
        a = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
        with a.map_to_host() as a_host:
            for i in range(SIZE):
                randn[dtype](a_host.unsafe_ptr(), SIZE, 0, 1)

        ctx.enqueue_function[simple_tanh](
            out.unsafe_ptr(),
            a.unsafe_ptr(),
            SIZE,
            grid_dim=BLOCKS_PER_GRID,
            block_dim=THREADS_PER_BLOCK,
        )

        expected = ctx.enqueue_create_host_buffer[dtype](SIZE).enqueue_fill(0)

        ctx.synchronize()

        with a.map_to_host() as a_host:
            for i in range(SIZE):
                expected[i] = tanh(a_host[i])

        with out.map_to_host() as out_host:
            print("out:", out_host)
            print("expected:", expected)
            for i in range(SIZE):
                assert_almost_equal(
                    out_host[i], expected[i], atol=1e-2, rtol=1e-2
                )

As we can see we simple calculate the elementwise tanh in our kernel. After the kernel execution we assert correctness via assert_almost_equal. Not that because we have a lower precision here we choose rtol and atol accordingly. See in the Mojo manual for a definition of the two quantities.

A simple extension of the std

We imported tanh from the math library. Let's take a look at the source code:

@always_inline
fn tanh[
    dtype: DType, width: Int, //
](x: SIMD[dtype, width]) -> SIMD[dtype, width]:
    """Performs elementwise evaluation of the tanh function.

    Parameters:
        dtype: The `dtype` of the input and output SIMD vector.
        width: The width of the input and output SIMD vector.

    Args:
        x: The vector to perform the elementwise tanh on.

    Returns:
        The result of the elementwise tanh operation.
    """

    constrained[
        dtype.is_floating_point(), "the input type must be floating point"
    ]()

    @parameter
    if is_nvidia_gpu():
        alias instruction = "tanh.approx.f32"

        @parameter
        if sizeof[dtype]() < sizeof[DType.float32]():
            return _call_ptx_intrinsic[
                instruction=instruction, constraints="=f,f"
            ](x.cast[DType.float32]()).cast[dtype]()
        elif dtype is DType.float32:
            return _call_ptx_intrinsic[
                instruction=instruction, constraints="=f,f"
            ](x)

    var xc = x.clamp(-9, 9)
    var x_squared = xc * xc

    var numerator = xc * polynomial_evaluate[
        List[SIMD[dtype, width]](
            4.89352455891786e-03,
            6.37261928875436e-04,
            1.48572235717979e-05,
            5.12229709037114e-08,
            -8.60467152213735e-11,
            2.00018790482477e-13,
            -2.76076847742355e-16,
        ),
    ](x_squared)

    var denominator = polynomial_evaluate[
        List[SIMD[dtype, width]](
            4.89352518554385e-03,
            2.26843463243900e-03,
            1.18534705686654e-04,
            1.19825839466702e-06,
        ),
    ](x_squared)

    return numerator / denominator

The relevant part for our kernel is here if we are on NVIDIA GPU:

    @parameter
    if is_nvidia_gpu():
        alias instruction = "tanh.approx.f32"

        @parameter
        if sizeof[dtype]() < sizeof[DType.float32]():
            return _call_ptx_intrinsic[
                instruction=instruction, constraints="=f,f"
            ](x.cast[DType.float32]()).cast[dtype]()
        elif dtype is DType.float32:
            return _call_ptx_intrinsic[
                instruction=instruction, constraints="=f,f"
            ](x)

We can see that currently the approach taken in the Mojo std is for bfloat16 to cast the input to float32 and than use the PTX instruction with reference 9.7.3.22. in the PTX doc. We might want to prefer to use PTX instruction with reference 9.7.4.9. which offers a dedicated instruction for half precision floating point values.

The instruction signature reads as follows:

tanh.approx.type d, a;

.type = {.f16, .f16x2, .bf16, .bf16x2}

We can simply write a wrapper around this instruction as follows:

@always_inline
fn tanh_bfloat16[
    dtype: DType, width: Int, //
](x: SIMD[dtype, width]) -> SIMD[dtype, width]:
    constrained[
        dtype.is_floating_point(), "the input type must be floating point"
    ]()

    alias instruction = "tanh.approx.bf16"

    return _call_ptx_intrinsic[instruction=instruction, constraints="=h,h"](x)

Note that this is stripped down version and a real implementation should make proper error checks like seen in the above example from std. Note we use constraints="=h,h" for u16 register.

We can than simply adjust our kernel as follows:

fn simple_tanh(
    out: UnsafePointer[Scalar[dtype]],
    a: UnsafePointer[Scalar[dtype]],
    size: Int,
):
    i = block_dim.x * block_idx.x + thread_idx.x
    if i < size:
        out[i] = tanh_bfloat16(a[i])

And will see that we still obtain the correct result, this time leveraging the dedicated tanh.approx.bf16 PTX instruction.

Look at compiled code

Mojo offers compilation to LLVM IR and asm code via the --emit flag in the compiler as i described in a previous blogpost. There are more options to look at compiled versions of our code, many of them can be found in max/kernels/test/gpu/compile part of the modular repo. I recommend to check these tests out as they offer many hands on examples how to use mojo. We can use the following code to see how our kernel gets compiled to:

def print_compiled_simple_tanh():
    print("== simple_tanh_compiled")

    print(_compile_code_asm[simple_tanh, emission_kind="llvm"]())


def main():
    print_compiled_simple_tanh()

This uses the following function

@always_inline
fn _compile_code_asm[
    func_type: AnyTrivialRegType, //,
    func: func_type,
    /,
    *,
    emission_kind: StaticString = "asm",
    target: __mlir_type.`!kgen.target` = _get_gpu_target(),
    compile_options: StaticString = HardwareInfo.from_target[
        target
    ]().compile_options,
]() -> StaticString:
    var asm = compile_info[
        func,
        emission_kind=emission_kind,
        compile_options=compile_options,
        target=target,
    ]().asm
    return asm

to obtain the llvm compile info.

Which will print out a file like below

== simple_tanh_compiled
; ModuleID = 'compile_simple_tanh.mojo'
source_filename = "compile_simple_tanh.mojo"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; Function Attrs: norecurse
define dso_local ptx_kernel void @compile_simple_tanh_simple_tanh6A6A_01673aab7e47065213f32fd5c44dfa83(ptr noundef %0, ptr noundef %1, i64 noundef %2) #0 {
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = sext i32 %4 to i64
  %6 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %7 = sext i32 %6 to i64
  %8 = mul i64 %5, %7
  %9 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %10 = sext i32 %9 to i64
  %11 = add i64 %8, %10
  %12 = getelementptr inbounds bfloat, ptr %0, i64 %11
  %13 = getelementptr inbounds bfloat, ptr %1, i64 %11
  %14 = icmp ult i64 %11, %2
  br i1 %14, label %15, label %20

15:                                               ; preds = %3
  %16 = load bfloat, ptr %13, align 2
  %17 = fpext bfloat %16 to float
  %18 = call float asm "tanh.approx.f32 $0, $1;", "=f,f"(float %17)
  %19 = fptrunc float %18 to bfloat
  store bfloat %19, ptr %12, align 2
  br label %21

20:                                               ; preds = %3
  br label %21

21:                                               ; preds = %20, %15
  ret void
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1

attributes #0 = { norecurse "target-cpu"="sm_90a" "target-features"="+ptx85,+sm_90a" "tune-cpu"="sm_90a" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!llvm.module.flags = !{!0}

!0 = !{i32 2, !"Debug Info Version", i32 3}

What we want to focus on is here:

%16 = load bfloat, ptr %13, align 2
%17 = fpext bfloat %16 to float
%18 = call float asm "tanh.approx.f32 $0, $1;", "=f,f"(float %17)
%19 = fptrunc float %18 to bfloat

This shows that we read the bfloat16 value, cast it to an ordinary float and than carry out the tanh.approx.f32 instruction. We'll than cast back to bfloat16.

Let's compare that with the output we get when we compile the custom kernel we wrote before.

== simple_tanh_compiled
; ModuleID = 'compile_ptx_tanh.mojo'
source_filename = "compile_ptx_tanh.mojo"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; Function Attrs: norecurse
define dso_local ptx_kernel void @compile_ptx_tanh_simple_tanh6A6AoA6A6A_47582dcf3af05b6ac1bf59dcd20a45f8(ptr noundef %0, ptr noundef %1, i64 noundef %2) #0 {
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = sext i32 %4 to i64
  %6 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %7 = sext i32 %6 to i64
  %8 = mul i64 %5, %7
  %9 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %10 = sext i32 %9 to i64
  %11 = add i64 %8, %10
  %12 = getelementptr inbounds bfloat, ptr %0, i64 %11
  %13 = getelementptr inbounds bfloat, ptr %1, i64 %11
  %14 = icmp ult i64 %11, %2
  br i1 %14, label %15, label %18

15:                                               ; preds = %3
  %16 = load bfloat, ptr %13, align 2
  %17 = call bfloat asm "tanh.approx.bf16 $0, $1;", "=h,h"(bfloat %16)
  store bfloat %17, ptr %12, align 2
  br label %19

18:                                               ; preds = %3
  br label %19

19:                                               ; preds = %15, %18
  ret void
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1

attributes #0 = { norecurse "target-cpu"="sm_90a" "target-features"="+ptx85,+sm_90a" "tune-cpu"="sm_90a" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!llvm.module.flags = !{!0}

!0 = !{i32 2, !"Debug Info Version", i32 3}

We get confirmation of what we already expected:

%16 = load bfloat, ptr %13, align 2
%17 = call bfloat asm "tanh.approx.bf16 $0, $1;", "=h,h"(bfloat %16)
store bfloat %17, ptr %12, align 2

In our new version we don't need any casting instructions and immediately input the bfloat16 into tanh.approx.bf16.

Conclusion

I hope this blogpost was helpful into the process of including simple PTX instructions into Mojo. The approach taken here can be used to extend the Mojo std such that we include dedicated instructions for half precision floating point values. The code can be found in my repo.