add float4 support to LLVM (#8920)

* add float4 support to LLVM

* is_bool
This commit is contained in:
George Hotz 2025-02-06 12:15:50 +08:00 committed by GitHub
commit 3e082d4a9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 18 additions and 3 deletions

View file

@ -80,6 +80,8 @@ class dtypes:
@functools.lru_cache(None)
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
@staticmethod
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
@staticmethod
def from_py(x) -> DType:
if x.__class__ is float: return dtypes.default_float
if x.__class__ is int: return dtypes.default_int

View file

@ -5,6 +5,7 @@ from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
@ -20,7 +21,7 @@ def lcast(input_type:DType, output_type:DType):
if dtypes.is_float(input_type):
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type):
if dtypes.is_float(output_type): return 'uitofp'
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
if dtypes.is_int(input_type):
@ -49,12 +50,24 @@ llvm_rewrite = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
# GEP/VECTORIZE/CAST for float4 support
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y:
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
(UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
# unary/binary/ternary ops
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
@ -79,7 +92,7 @@ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
class LLVMRenderer(Renderer):
device = "LLVM"
supports_float4 = False
supports_float4 = True
has_local = False
has_shared = False
global_max = None