Refactor hip_bfloat16 cast into uop (#7143)

* refactor hip_bfloat16 cast into uops

* hotfix: linter issue

* hotfix: comment decorator in test

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
ignaciosica 2024-10-21 04:17:14 -03:00 committed by GitHub
commit 87a1e76745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 15 deletions

View file

@ -54,6 +54,8 @@ class MathTrait:
def __rsub__(self, x): return self.ufix(x).alu(BinaryOps.ADD, -self)
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
def __lshift__(self, x): return self.alu(BinaryOps.SHL, self.ufix(x))
def __rshift__(self, x): return self.alu(BinaryOps.SHR, self.ufix(x))
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
def __rfloordiv__(self, x): return self.ufix(x).alu(BinaryOps.IDIV, self)
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))

View file

@ -363,6 +363,15 @@ code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
def cast_float_bf16(x: UOp) -> UOp:
x = x.bitcast(dtypes.uint)
is_not_inf_nan = -x & 0x7f800000
has_mantissa = x & 0xffff
x = is_not_inf_nan.where(x + ((x >> 16) & 1) + 0x7fff, has_mantissa.where((x | 0x10000), x))
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
class AMDRenderer(CStyleLanguage):
device = "AMD"
shared_max = 65536
@ -388,12 +397,20 @@ class AMDRenderer(CStyleLanguage):
float4 = "make_float4"
type_map = {dtypes.bfloat16: "hip_bfloat16"}
extra_matcher = PatternMatcher([
# cast bfloat16 alus to float
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg))]) + extra_pm
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
# add float intermediate casting for bfloat16
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
# bfloat16 casting
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_bf16)]) + extra_pm
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
@ -404,20 +421,7 @@ class AMDRenderer(CStyleLanguage):
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
# TODO: add BF16 vec dts
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
struct hip_bfloat16 {
unsigned short data;
inline __attribute__((device)) hip_bfloat16(float val) {
union { float fp32; unsigned int u32; } u = {val};
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
data = (u.u32 >> 16);
}
inline __attribute__((device)) operator float() const {
unsigned int uval = data << 16;
return *reinterpret_cast<float*>(&uval);
}
};
""")
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("struct hip_bfloat16 { unsigned short data; };")
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))