mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Refactor hip_bfloat16 cast into uop (#7143)
* refactor hip_bfloat16 cast into uops * hotfix: linter issue * hotfix: comment decorator in test --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
8074c0ec8f
commit
87a1e76745
2 changed files with 21 additions and 15 deletions
|
|
@ -54,6 +54,8 @@ class MathTrait:
|
|||
def __rsub__(self, x): return self.ufix(x).alu(BinaryOps.ADD, -self)
|
||||
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
|
||||
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
|
||||
def __lshift__(self, x): return self.alu(BinaryOps.SHL, self.ufix(x))
|
||||
def __rshift__(self, x): return self.alu(BinaryOps.SHR, self.ufix(x))
|
||||
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
|
||||
def __rfloordiv__(self, x): return self.ufix(x).alu(BinaryOps.IDIV, self)
|
||||
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
|
||||
|
|
|
|||
|
|
@ -363,6 +363,15 @@ code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half
|
|||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})"}
|
||||
|
||||
def cast_float_bf16(x: UOp) -> UOp:
|
||||
x = x.bitcast(dtypes.uint)
|
||||
|
||||
is_not_inf_nan = -x & 0x7f800000
|
||||
has_mantissa = x & 0xffff
|
||||
x = is_not_inf_nan.where(x + ((x >> 16) & 1) + 0x7fff, has_mantissa.where((x | 0x10000), x))
|
||||
|
||||
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
||||
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
|
|
@ -388,12 +397,20 @@ class AMDRenderer(CStyleLanguage):
|
|||
float4 = "make_float4"
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
# cast bfloat16 alus to float
|
||||
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
||||
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg))]) + extra_pm
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
# add float intermediate casting for bfloat16
|
||||
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
# bfloat16 casting
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_bf16)]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
|
|
@ -404,20 +421,7 @@ class AMDRenderer(CStyleLanguage):
|
|||
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
||||
|
||||
# TODO: add BF16 vec dts
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
||||
struct hip_bfloat16 {
|
||||
unsigned short data;
|
||||
inline __attribute__((device)) hip_bfloat16(float val) {
|
||||
union { float fp32; unsigned int u32; } u = {val};
|
||||
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
|
||||
data = (u.u32 >> 16);
|
||||
}
|
||||
inline __attribute__((device)) operator float() const {
|
||||
unsigned int uval = data << 16;
|
||||
return *reinterpret_cast<float*>(&uval);
|
||||
}
|
||||
};
|
||||
""")
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("struct hip_bfloat16 { unsigned short data; };")
|
||||
|
||||
for dtype in dedup(uop.dtype for uop in uops if uop.dtype.count > 1): prefix.append(self.render_vector_prefix(dtype))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue