Use PTX instructions in Mojo
Custom PTX instructions are a powerful way of leveraging the newest hardware features of NVIDIA GPUs and often needed to archive peak performance in kernels.
In this blogpost we aim to give a guide on how call custom PTX instruction in Mojo and potentially including them into the Open Source std
of Mojo. We do this by concretely analysing the tanh
function from the Mojo std
library and implementing a missing instruction ourselves.
Introduction
We start by creating a project which we work in:
magic init intro-mojo-ptx --format mojoproject
We will consider the following program:
from math import ceildiv, tanh
from memory import UnsafePointer
from gpu import thread_idx, block_idx, block_dim
from gpu.host import DeviceContext
from random import randn
from testing import assert_almost_equal
alias SIZE = 4096
alias THREADS_PER_BLOCK = 512
alias BLOCKS_PER_GRID = ceildiv(SIZE, THREADS_PER_BLOCK)
alias dtype = DType.bfloat16
fn simple_tanh(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
i = block_dim.x * block_idx.x + thread_idx.x
if i < size:
out[i] = tanh(a[i])
def main():
with DeviceContext() as ctx:
out = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
a = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
with a.map_to_host() as a_host:
for i in range(SIZE):
randn[dtype](a_host.unsafe_ptr(), SIZE, 0, 1)
ctx.enqueue_function[simple_tanh](
out.unsafe_ptr(),
a.unsafe_ptr(),
SIZE,
grid_dim=BLOCKS_PER_GRID,
block_dim=THREADS_PER_BLOCK,
)
expected = ctx.enqueue_create_host_buffer[dtype](SIZE).enqueue_fill(0)
ctx.synchronize()
with a.map_to_host() as a_host:
for i in range(SIZE):
expected[i] = tanh(a_host[i])
with out.map_to_host() as out_host:
print("out:", out_host)
print("expected:", expected)
for i in range(SIZE):
assert_almost_equal(
out_host[i], expected[i], atol=1e-2, rtol=1e-2
)
As we can see we simple calculate the elementwise tanh
in our kernel.
After the kernel execution we assert correctness via assert_almost_equal
. Not that because we have a lower precision here we choose rtol
and atol
accordingly. See in the Mojo manual for a definition of the two quantities.
A simple extension of the std
We imported tanh
from the math
library.
Let's take a look at the source code:
@always_inline
fn tanh[
dtype: DType, width: Int, //
](x: SIMD[dtype, width]) -> SIMD[dtype, width]:
"""Performs elementwise evaluation of the tanh function.
Parameters:
dtype: The `dtype` of the input and output SIMD vector.
width: The width of the input and output SIMD vector.
Args:
x: The vector to perform the elementwise tanh on.
Returns:
The result of the elementwise tanh operation.
"""
constrained[
dtype.is_floating_point(), "the input type must be floating point"
]()
@parameter
if is_nvidia_gpu():
alias instruction = "tanh.approx.f32"
@parameter
if sizeof[dtype]() < sizeof[DType.float32]():
return _call_ptx_intrinsic[
instruction=instruction, constraints="=f,f"
](x.cast[DType.float32]()).cast[dtype]()
elif dtype is DType.float32:
return _call_ptx_intrinsic[
instruction=instruction, constraints="=f,f"
](x)
var xc = x.clamp(-9, 9)
var x_squared = xc * xc
var numerator = xc * polynomial_evaluate[
List[SIMD[dtype, width]](
4.89352455891786e-03,
6.37261928875436e-04,
1.48572235717979e-05,
5.12229709037114e-08,
-8.60467152213735e-11,
2.00018790482477e-13,
-2.76076847742355e-16,
),
](x_squared)
var denominator = polynomial_evaluate[
List[SIMD[dtype, width]](
4.89352518554385e-03,
2.26843463243900e-03,
1.18534705686654e-04,
1.19825839466702e-06,
),
](x_squared)
return numerator / denominator
The relevant part for our kernel is here if we are on NVIDIA
GPU:
@parameter
if is_nvidia_gpu():
alias instruction = "tanh.approx.f32"
@parameter
if sizeof[dtype]() < sizeof[DType.float32]():
return _call_ptx_intrinsic[
instruction=instruction, constraints="=f,f"
](x.cast[DType.float32]()).cast[dtype]()
elif dtype is DType.float32:
return _call_ptx_intrinsic[
instruction=instruction, constraints="=f,f"
](x)
We can see that currently the approach taken in the Mojo std
is for bfloat16
to cast the input to float32
and than use the PTX instruction with reference 9.7.3.22. in the PTX doc.
We might want to prefer to use PTX instruction with reference 9.7.4.9. which offers a dedicated instruction for half precision floating point values
.
The instruction signature reads as follows:
tanh.approx.type d, a;
.type = {.f16, .f16x2, .bf16, .bf16x2}
We can simply write a wrapper around this instruction as follows:
@always_inline
fn tanh_bfloat16[
dtype: DType, width: Int, //
](x: SIMD[dtype, width]) -> SIMD[dtype, width]:
constrained[
dtype.is_floating_point(), "the input type must be floating point"
]()
alias instruction = "tanh.approx.bf16"
return _call_ptx_intrinsic[instruction=instruction, constraints="=h,h"](x)
Note that this is stripped down version and a real implementation should make proper error checks like seen in the above example from std
. Note we use constraints="=h,h"
for u16
register.
We can than simply adjust our kernel as follows:
fn simple_tanh(
out: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
i = block_dim.x * block_idx.x + thread_idx.x
if i < size:
out[i] = tanh_bfloat16(a[i])
And will see that we still obtain the correct result, this time leveraging the dedicated tanh.approx.bf16
PTX
instruction.
Look at compiled code
Mojo offers compilation to LLVM IR
and asm
code via the --emit
flag in the compiler as i described in a previous blogpost.
There are more options to look at compiled versions of our code, many of them can be found in max/kernels/test/gpu/compile part of the modular
repo. I recommend to check these tests out as they offer many hands on examples how to use mojo.
We can use the following code to see how our kernel gets compiled to:
def print_compiled_simple_tanh():
print("== simple_tanh_compiled")
print(_compile_code_asm[simple_tanh, emission_kind="llvm"]())
def main():
print_compiled_simple_tanh()
This uses the following function
@always_inline
fn _compile_code_asm[
func_type: AnyTrivialRegType, //,
func: func_type,
/,
*,
emission_kind: StaticString = "asm",
target: __mlir_type.`!kgen.target` = _get_gpu_target(),
compile_options: StaticString = HardwareInfo.from_target[
target
]().compile_options,
]() -> StaticString:
var asm = compile_info[
func,
emission_kind=emission_kind,
compile_options=compile_options,
target=target,
]().asm
return asm
to obtain the llvm
compile info.
Which will print out a file like below
== simple_tanh_compiled
; ModuleID = 'compile_simple_tanh.mojo'
source_filename = "compile_simple_tanh.mojo"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
; Function Attrs: norecurse
define dso_local ptx_kernel void @compile_simple_tanh_simple_tanh6A6A_01673aab7e47065213f32fd5c44dfa83(ptr noundef %0, ptr noundef %1, i64 noundef %2) #0 {
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = sext i32 %4 to i64
%6 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%7 = sext i32 %6 to i64
%8 = mul i64 %5, %7
%9 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%10 = sext i32 %9 to i64
%11 = add i64 %8, %10
%12 = getelementptr inbounds bfloat, ptr %0, i64 %11
%13 = getelementptr inbounds bfloat, ptr %1, i64 %11
%14 = icmp ult i64 %11, %2
br i1 %14, label %15, label %20
15: ; preds = %3
%16 = load bfloat, ptr %13, align 2
%17 = fpext bfloat %16 to float
%18 = call float asm "tanh.approx.f32 $0, $1;", "=f,f"(float %17)
%19 = fptrunc float %18 to bfloat
store bfloat %19, ptr %12, align 2
br label %21
20: ; preds = %3
br label %21
21: ; preds = %20, %15
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
attributes #0 = { norecurse "target-cpu"="sm_90a" "target-features"="+ptx85,+sm_90a" "tune-cpu"="sm_90a" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
What we want to focus on is here:
%16 = load bfloat, ptr %13, align 2
%17 = fpext bfloat %16 to float
%18 = call float asm "tanh.approx.f32 $0, $1;", "=f,f"(float %17)
%19 = fptrunc float %18 to bfloat
This shows that we read the bfloat16
value, cast it to an ordinary float
and than carry out the tanh.approx.f32
instruction. We'll than cast back to bfloat16
.
Let's compare that with the output we get when we compile the custom kernel we wrote before.
== simple_tanh_compiled
; ModuleID = 'compile_ptx_tanh.mojo'
source_filename = "compile_ptx_tanh.mojo"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
; Function Attrs: norecurse
define dso_local ptx_kernel void @compile_ptx_tanh_simple_tanh6A6AoA6A6A_47582dcf3af05b6ac1bf59dcd20a45f8(ptr noundef %0, ptr noundef %1, i64 noundef %2) #0 {
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = sext i32 %4 to i64
%6 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%7 = sext i32 %6 to i64
%8 = mul i64 %5, %7
%9 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%10 = sext i32 %9 to i64
%11 = add i64 %8, %10
%12 = getelementptr inbounds bfloat, ptr %0, i64 %11
%13 = getelementptr inbounds bfloat, ptr %1, i64 %11
%14 = icmp ult i64 %11, %2
br i1 %14, label %15, label %18
15: ; preds = %3
%16 = load bfloat, ptr %13, align 2
%17 = call bfloat asm "tanh.approx.bf16 $0, $1;", "=h,h"(bfloat %16)
store bfloat %17, ptr %12, align 2
br label %19
18: ; preds = %3
br label %19
19: ; preds = %15, %18
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.x() #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
attributes #0 = { norecurse "target-cpu"="sm_90a" "target-features"="+ptx85,+sm_90a" "tune-cpu"="sm_90a" }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
We get confirmation of what we already expected:
%16 = load bfloat, ptr %13, align 2
%17 = call bfloat asm "tanh.approx.bf16 $0, $1;", "=h,h"(bfloat %16)
store bfloat %17, ptr %12, align 2
In our new version we don't need any casting instructions and immediately input the bfloat16
into tanh.approx.bf16
.
Conclusion
I hope this blogpost was helpful into the process of including simple PTX
instructions into Mojo.
The approach taken here can be used to extend the Mojo std
such that we include dedicated instructions for half precision floating point values.
The code can be found in my repo.