Demystifying numeric conversions in CuTeDSL
CuTeDSL has vector primitives which the user can convert between. On a high level, conversion can be implemented such that we can summarize it in a top-to-bottom fashion as follows: Successively extract packed subsequences of the vector under consideration and then apply corresponding cvt instruction from PTX to convert the packed datatype into the converted packed datatype.
However, currently not all possible conversions are implemented. In this short note I quickly show how to implement such a conversion yourself. This may help code to run faster. On the GEMV task in current GPU mode competition I could leverage the below newly introduced conversion to improve performance by ~10%.
The problem
Let's consider FP8 format f8E4M3FN which is for example used as the data format of scaling factors for the FP4 format Float4E2M1FN. We can convert a RMEM tensor in this format to F32 precision using the following minimal code example:
import cutlass.cute as cute
from cutlass import Float8E4M3FN, Float32
@cute.jit
def cvt():
vec_f8e4m3 = cute.make_rmem_tensor_like(
cute.make_layout(32),
Float8E4M3FN
)
vec_f32 = cute.make_rmem_tensor_like(
cute.make_layout(32),
Float32
)
vec_f32.store(vec_f8e4m3.load().to(Float32))
if __name__ == '__main__':
cvt()
Note that when we switch Float32 with Float16 we will get a compiler error like so:
error: "vec_f16.store(vec_f8e4m3.load().to(Float16))"
'arith.extf' op operand type 'vector<32xf16>' and result type 'vector<32xf16>' are cast incompatible
During current GEMV competition I noticed that the conversion for FP16 worked for the NVFP4 datatype and that it brings small but measurable improvement in memory bound GEMV kernel performance to cast them into this format instead of FP32.
How it's done for FP4
In the tensor class we can find code that is used to convert to FP16 when we deal with NVFP4 data format (the 4 bit factor is called Float4E2M1FN):
if src_dtype == Float4E2M1FN and dtype in (Float16, Float32):
res_vect = cvt_f4e2m1_f16_intrinsic(
src, size(self.shape), loc=loc, ip=ip
)
if dtype == Float32:
res_vect = cutlass_arith.cvtf(
res_vect, dtype.mlir_type, loc=loc, ip=ip
)
We can see that an intrinsic is called which can be found here.
@dsl_user_op
def cvt_f4e2m1_f16_intrinsic(vec_f4e2m1, length, *, loc=None, ip=None):
"""
Convert a vector of float4e2m1 to a vector of float16.
:param vec_f4e2m1: The input vector of float4e2m1.
:type vec_f4e2m1: 1D vector of float4e2m1
:param length: The length of the input vector.
:type length: int
:return: The output 1D vector of float16 with the same length as the input vector.
:rtype: 1D vector of float16
"""
src_pos = 0
vec_src_i4 = builtin.unrealized_conversion_cast(
[ir.VectorType.get([length], Int4.mlir_type, loc=loc)],
[vec_f4e2m1],
loc=loc,
ip=ip,
)
vec_i4x8_type = ir.VectorType.get([8], Int4.mlir_type, loc=loc)
vec_i4x4_type = ir.VectorType.get([4], Int4.mlir_type, loc=loc)
vec_i4x2_type = ir.VectorType.get([2], Int4.mlir_type, loc=loc)
vec_dst_type = ir.VectorType.get([length], Float16.mlir_type, loc=loc)
vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip)
# try to use vectorized version
if length >= 8:
num_vec8 = length // 8
for _ in range(num_vec8):
vec_f4e2m1x8 = vector.extract_strided_slice(
vec_i4x8_type, vec_src_i4, [src_pos], [8], [1], loc=loc, ip=ip
)
vec_f16x8 = cvt_f4e2m1x8_to_f16x8(vec_f4e2m1x8, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x8, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 8
length -= 8
if length >= 4:
vec_f4e2m1x4 = vector.extract_strided_slice(
vec_i4x4_type, vec_src_i4, [src_pos], [4], [1], loc=loc, ip=ip
)
vec_f16x4 = cvt_f4e2m1x4_to_f16x4(vec_f4e2m1x4, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x4, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 4
length -= 4
if length >= 2:
vec_f4e2m1x2 = vector.extract_strided_slice(
vec_i4x2_type, vec_src_i4, [src_pos], [2], [1], loc=loc, ip=ip
)
vec_f16x2 = cvt_f4e2m1x2_to_f16x2(vec_f4e2m1x2, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x2, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 2
length -= 2
if length >= 1:
val_f16 = cvt_f4e2m1_f16(
vector.extractelement(
vec_src_i4,
position=arith.constant(Int32.mlir_type, src_pos),
loc=loc,
ip=ip,
),
loc=loc,
ip=ip,
)
vec_dst = vector.insertelement(
val_f16,
vec_dst,
position=arith.constant(Int32.mlir_type, src_pos),
loc=loc,
ip=ip,
)
return vec_dst
Note that we iteratively extract packed bit sequences in packed format. We then convert such a pack. The implementation is such that we process larger packs of bit sequences if possible (i.e. the shape of the vector allows it). Let's take a look at how these individual conversions can be done. The code can be found here:
@dsl_user_op
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):
# 0 padding for upper 4 bits
zero = arith.constant(src.type, 0, loc=loc, ip=ip)
vec2 = vector.from_elements(
ir.VectorType.get([2], src.type, loc=loc), [src, zero], loc=loc, ip=ip
)
rst_vec2 = cvt_f4e2m1x2_to_f16x2(vec2, loc=loc, ip=ip)
# only the 1st element is valid
rst = vector.extract(
rst_vec2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
)
return rst
# Convert 2 float4e2m1 values to 2 float16 values
@dsl_user_op
def cvt_f4e2m1x2_to_f16x2(src_vec2, *, loc=None, ip=None):
# pack 2 float4e2m1 into 1 int8 value and fill upper bits with 0
src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip)
src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip)
rst_i32 = llvm.inline_asm(
Int32.mlir_type,
[src_i16],
"""{\n\t
.reg .b8 b;\n\t
mov.b16 {b,_}, $1;\n\t
cvt.rn.f16x2.e2m1x2 $0, b;\n\t
}""",
"=r,h",
)
vec_f16x2_type = ir.VectorType.get([2], Float16.mlir_type, loc=loc)
vec_f16x2 = llvm.bitcast(vec_f16x2_type, rst_i32, loc=loc, ip=ip)
return vec_f16x2
# Convert 4 float4e2m1 values to 4 float16 values
@dsl_user_op
def cvt_f4e2m1x4_to_f16x4(src_vec4, *, loc=None, ip=None):
# pack 4 float4e2m1 into 1 int16 value
src_i16 = llvm.bitcast(Int16.mlir_type, src_vec4, loc=loc, ip=ip)
rst_i32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32()]),
[src_i16],
"""{\n\t
.reg .b8 b0, b1;\n\t
mov.b16 {b0, b1}, $2;\n\t
cvt.rn.f16x2.e2m1x2 $0, b0;\n\t
cvt.rn.f16x2.e2m1x2 $1, b1;\n\t
}""",
"=r,=r,h",
)
res0 = llvm.extractvalue(T.i32(), rst_i32x2, [0])
res1 = llvm.extractvalue(T.i32(), rst_i32x2, [1])
vec_f32x2_type = ir.VectorType.get([2], Int32.mlir_type, loc=loc)
vec_f32x2 = vector.from_elements(vec_f32x2_type, [res0, res1], loc=loc, ip=ip)
vec_f16x4_type = ir.VectorType.get([4], Float16.mlir_type, loc=loc)
vec_f16x4 = llvm.bitcast(vec_f16x4_type, vec_f32x2, loc=loc, ip=ip)
return vec_f16x4
# Convert 8 float4e2m1 values to 8 float16 values
@dsl_user_op
def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None):
# pack 8 float4e2m1 into 1 int32 value and fill upper bits with 0
src_i32 = llvm.bitcast(Int32.mlir_type, src_vec8, loc=loc, ip=ip)
rst_i32x4 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
[src_i32],
"""{\n\t
.reg .b8 b0, b1, b2, b3;\n\t
mov.b32 {b0, b1, b2, b3}, $4;\n\t
cvt.rn.f16x2.e2m1x2 $0, b0;\n\t
cvt.rn.f16x2.e2m1x2 $1, b1;\n\t
cvt.rn.f16x2.e2m1x2 $2, b2;\n\t
cvt.rn.f16x2.e2m1x2 $3, b3;\n\t
}""",
"=r,=r,=r,=r,r",
)
res0 = llvm.extractvalue(T.i32(), rst_i32x4, [0])
res1 = llvm.extractvalue(T.i32(), rst_i32x4, [1])
res2 = llvm.extractvalue(T.i32(), rst_i32x4, [2])
res3 = llvm.extractvalue(T.i32(), rst_i32x4, [3])
vec_f32x4_type = ir.VectorType.get([4], Int32.mlir_type, loc=loc)
vec_f32x4 = vector.from_elements(
vec_f32x4_type, [res0, res1, res2, res3], loc=loc, ip=ip
)
vec_f16x8_type = ir.VectorType.get([8], Float16.mlir_type, loc=loc)
vec_f16x8 = llvm.bitcast(vec_f16x8_type, vec_f32x4, loc=loc, ip=ip)
return vec_f16x8
Let's analyze this more closely:
@dsl_user_op
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):
# 0 padding for upper 4 bits
zero = arith.constant(src.type, 0, loc=loc, ip=ip)
vec2 = vector.from_elements(
ir.VectorType.get([2], src.type, loc=loc), [src, zero], loc=loc, ip=ip
)
rst_vec2 = cvt_f4e2m1x2_to_f16x2(vec2, loc=loc, ip=ip)
# only the 1st element is valid
rst = vector.extract(
rst_vec2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
)
return rst
We pack two 4 bit sequences, one of them containing information about the original number, the other one containing zeros. I assume that is because PTX does not allow sub-byte operation. We then call the packed instruction and extract the relevant part from the two packed f16 values.
# Convert 2 float4e2m1 values to 2 float16 values
@dsl_user_op
def cvt_f4e2m1x2_to_f16x2(src_vec2, *, loc=None, ip=None):
# pack 2 float4e2m1 into 1 int8 value and fill upper bits with 0
src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip)
src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip)
rst_i32 = llvm.inline_asm(
Int32.mlir_type,
[src_i16],
"""{\n\t
.reg .b8 b;\n\t
mov.b16 {b,_}, $1;\n\t
cvt.rn.f16x2.e2m1x2 $0, b;\n\t
}""",
"=r,h",
)
vec_f16x2_type = ir.VectorType.get([2], Float16.mlir_type, loc=loc)
vec_f16x2 = llvm.bitcast(vec_f16x2_type, rst_i32, loc=loc, ip=ip)
return vec_f16x2
We pack the two FP4 values into one Int8, then use zext instruction. zext is an LLVM instruction that converts to a different datatype by zero extension.
Below is the assembly code that first allocates an 8 bit register and then moves the packed numbers we convert there. We then call cvt instruction which will make the data conversion from e2m1x2 to f16x2.
Finally we convert the result (which was obtained as an Int32) to the form we want to use it in, i.e. we interpret the 32 bits of the Int32 as 2 F16 values. The other intrinsic functions work similarly.
Implement conversion for FP8
Let's now implement a conversion from f8e4m3 to f16.
On the upper level we have a similar wrapper to above
@dsl_user_op
def cvt_f8e4m3_f16_intrinsic(vec_f8e4m3, length, *, loc=None, ip=None):
"""
Convert a vector of float8e4m3 to a vector of float16.
:param vec_f8e4m3: The input vector of float8e4m3.
:type vec_f8e4m3: 1D vector of float8e4m3
:param length: The length of the input vector.
:type length: int
:return: The output 1D vector of float16 with the same length as the input vector.
:rtype: 1D vector of float16
"""
src_pos = 0
vec_src_i8 = builtin.unrealized_conversion_cast(
[ir.VectorType.get([length], Int8.mlir_type, loc=loc)],
[vec_f8e4m3],
loc=loc,
ip=ip,
)
vec_i8x8_type = ir.VectorType.get([8], Int8.mlir_type, loc=loc)
vec_i8x4_type = ir.VectorType.get([4], Int8.mlir_type, loc=loc)
vec_i8x2_type = ir.VectorType.get([2], Int8.mlir_type, loc=loc)
vec_dst_type = ir.VectorType.get([length], Float16.mlir_type, loc=loc)
vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip)
# try to use vectorized version
if length >= 8:
num_vec8 = length // 8
for _ in range(num_vec8):
vec_f8e4m3x8 = vector.extract_strided_slice(
vec_i8x8_type, vec_src_i8, [src_pos], [8], [1], loc=loc, ip=ip
)
vec_f16x8 = cvt_f8e4m3x8_to_f16x8(vec_f8e4m3x8, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x8, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 8
length -= 8
if length >= 4:
vec_f8e4m3x4 = vector.extract_strided_slice(
vec_i8x4_type, vec_src_i8, [src_pos], [4], [1], loc=loc, ip=ip
)
vec_f16x4 = cvt_f8e4m3x4_to_f16x4(vec_f8e4m3x4, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x4, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 4
length -= 4
if length >= 2:
vec_f8e4m3x2 = vector.extract_strided_slice(
vec_i8x2_type, vec_src_i8, [src_pos], [2], [1], loc=loc, ip=ip
)
vec_f16x2 = cvt_f8e4m3x2_to_f16x2(vec_f8e4m3x2, loc=loc, ip=ip)
vec_dst = vector.insert_strided_slice(
vec_f16x2, vec_dst, [src_pos], [1], loc=loc, ip=ip
)
src_pos += 2
length -= 2
if length >= 1:
val_f16 = cvt_f8e4m3_f16(
vector.extractelement(
vec_src_i8,
position=arith.constant(Int32.mlir_type, src_pos),
loc=loc,
ip=ip,
),
loc=loc,
ip=ip,
)
vec_dst = vector.insertelement(
val_f16,
vec_dst,
position=arith.constant(Int32.mlir_type, src_pos),
loc=loc,
ip=ip,
)
return vec_dst
Note that here we use Int8 types because we deal with 8 bit format now. Otherwise it's very similar to above.
Next we need to wrap our ASM instructions:
@dsl_user_op
def cvt_f8e4m3_f16(src, *, loc=None, ip=None):
# 0 padding for upper 8 bits
zero = arith.constant(src.type, 0, loc=loc, ip=ip)
vec2 = vector.from_elements(
ir.VectorType.get([2], src.type, loc=loc), [src, zero], loc=loc, ip=ip
)
rst_vec2 = cvt_f8e4m3x2_to_f16x2(vec2, loc=loc, ip=ip)
# only the 1st element is valid
rst = vector.extract(
rst_vec2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
)
return rst
# Convert 2 float8e4m3 values to 2 float16 values
@dsl_user_op
def cvt_f8e4m3x2_to_f16x2(src_vec2, *, loc=None, ip=None):
# pack 2 float8e4m3 into 1 int16 value
src_i16 = llvm.bitcast(Int16.mlir_type, src_vec2, loc=loc, ip=ip)
rst_i32 = llvm.inline_asm(
Int32.mlir_type,
[src_i16],
"""{\n\t
cvt.rn.f16x2.e4m3x2 $0, $1;\n\t
}""",
"=r,h",
)
vec_f16x2_type = ir.VectorType.get([2], Float16.mlir_type, loc=loc)
vec_f16x2 = llvm.bitcast(vec_f16x2_type, rst_i32, loc=loc, ip=ip)
return vec_f16x2
# Convert 4 float8e4m3 values to 4 float16 values
@dsl_user_op
def cvt_f8e4m3x4_to_f16x4(src_vec4, *, loc=None, ip=None):
# pack 4 float8e4m3 into 1 int32 value
src_i32 = llvm.bitcast(Int32.mlir_type, src_vec4, loc=loc, ip=ip)
rst_i32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32()]),
[src_i32],
"""{\n\t
.reg .b16 h0, h1;\n\t
mov.b32 {h0, h1}, $2;\n\t
cvt.rn.f16x2.e4m3x2 $0, h0;\n\t
cvt.rn.f16x2.e4m3x2 $1, h1;\n\t
}""",
"=r,=r,r",
)
res0 = llvm.extractvalue(T.i32(), rst_i32x2, [0])
res1 = llvm.extractvalue(T.i32(), rst_i32x2, [1])
vec_i32x2_type = ir.VectorType.get([2], Int32.mlir_type, loc=loc)
vec_i32x2 = vector.from_elements(vec_i32x2_type, [res0, res1], loc=loc, ip=ip)
vec_f16x4_type = ir.VectorType.get([4], Float16.mlir_type, loc=loc)
vec_f16x4 = llvm.bitcast(vec_f16x4_type, vec_i32x2, loc=loc, ip=ip)
return vec_f16x4
# Convert 8 float8e4m3 values to 8 float16 values
@dsl_user_op
def cvt_f8e4m3x8_to_f16x8(src_vec8, *, loc=None, ip=None):
# Split into two i32 values instead of using i64
vec_i32x2_type = ir.VectorType.get([2], Int32.mlir_type, loc=loc)
src_i32x2 = llvm.bitcast(vec_i32x2_type, src_vec8, loc=loc, ip=ip)
src_lo = llvm.extractelement(src_i32x2, arith.constant(Int32.mlir_type, 0), loc=loc, ip=ip)
src_hi = llvm.extractelement(src_i32x2, arith.constant(Int32.mlir_type, 1), loc=loc, ip=ip)
# Process lower 4 bytes (4 fp8 values)
rst_lo_i32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32()]),
[src_lo],
"""{\n\t
.reg .b16 h0, h1;\n\t
mov.b32 {h0, h1}, $2;\n\t
cvt.rn.f16x2.e4m3x2 $0, h0;\n\t
cvt.rn.f16x2.e4m3x2 $1, h1;\n\t
}""",
"=r,=r,r",
)
# Process upper 4 bytes (4 fp8 values)
rst_hi_i32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32()]),
[src_hi],
"""{\n\t
.reg .b16 h0, h1;\n\t
mov.b32 {h0, h1}, $2;\n\t
cvt.rn.f16x2.e4m3x2 $0, h0;\n\t
cvt.rn.f16x2.e4m3x2 $1, h1;\n\t
}""",
"=r,=r,r",
)
res0 = llvm.extractvalue(T.i32(), rst_lo_i32x2, [0])
res1 = llvm.extractvalue(T.i32(), rst_lo_i32x2, [1])
res2 = llvm.extractvalue(T.i32(), rst_hi_i32x2, [0])
res3 = llvm.extractvalue(T.i32(), rst_hi_i32x2, [1])
vec_i32x4_type = ir.VectorType.get([4], Int32.mlir_type, loc=loc)
vec_i32x4 = vector.from_elements(
vec_i32x4_type, [res0, res1, res2, res3], loc=loc, ip=ip
)
vec_f16x8_type = ir.VectorType.get([8], Float16.mlir_type, loc=loc)
vec_f16x8 = llvm.bitcast(vec_f16x8_type, vec_i32x4, loc=loc, ip=ip)
return vec_f16x8
@dsl_user_op
def cvt_f8e4m3_f16(src, *, loc=None, ip=None):
# 0 padding for upper 8 bits
zero = arith.constant(src.type, 0, loc=loc, ip=ip)
vec2 = vector.from_elements(
ir.VectorType.get([2], src.type, loc=loc), [src, zero], loc=loc, ip=ip
)
rst_vec2 = cvt_f8e4m3x2_to_f16x2(vec2, loc=loc, ip=ip)
# only the 1st element is valid
rst = vector.extract(
rst_vec2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
)
return rst
For one element to convert we pad the upper 8 bits. That is because we want to again use packed instruction as above and therefore will need two input values (one of them being set to zero here).
# Convert 2 float8e4m3 values to 2 float16 values
@dsl_user_op
def cvt_f8e4m3x2_to_f16x2(src_vec2, *, loc=None, ip=None):
# pack 2 float8e4m3 into 1 int16 value
src_i16 = llvm.bitcast(Int16.mlir_type, src_vec2, loc=loc, ip=ip)
rst_i32 = llvm.inline_asm(
Int32.mlir_type,
[src_i16],
"""{\n\t
cvt.rn.f16x2.e4m3x2 $0, $1;\n\t
}""",
"=r,h",
)
vec_f16x2_type = ir.VectorType.get([2], Float16.mlir_type, loc=loc)
vec_f16x2 = llvm.bitcast(vec_f16x2_type, rst_i32, loc=loc, ip=ip)
return vec_f16x2
Here we pack two 8 bit values into one 16 bit value. We'll then call the PTX instruction. Afterwards we again convert from Int32 by interpreting each of the 16 bits as fp16.
The other two wrappers can be derived similarly. Note that one peculiarity is that for 8 packed values (i.e. 64 bits) we need to interpret the 8 packed values as 2 * 32 bit sequences. That is because as of today CuTeDSL does not have an Int64 datatype. Everything else is straightforward and as before.
Conclusion
I don't want to give full details until end of competition how exactly I used this. But by using above conversion I could achieve a good performance boost of ~10% in performance for the GEMV task in GPU mode competition. I hope this blogpost demystifies numeric conversions in CuTeDSL and can help others to improve their own pipelines. Feel free to contact me on LinkedIn