mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
wgpu-f16-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc9ae53a05 | ||
|
|
4669e467ef |
3 changed files with 20 additions and 13 deletions
5
examples/test_f16.py
Normal file
5
examples/test_f16.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from tinygrad import Tensor, dtypes
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Tensor([1.0,2.0,3.0], dtype=dtypes.half)
|
||||
print((a*2.0).numpy())
|
||||
|
|
@ -215,7 +215,7 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
|||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tinygrad.helpers import strip_parens
|
|||
import math
|
||||
|
||||
# utility functions for handling packed load/store of < 4-byte data types: bool, char/uchar, short/ushort
|
||||
unpack_map = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32}
|
||||
unpack_map = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32, dtypes.half: dtypes.uint32}
|
||||
|
||||
def sign_extend(val:UOp, sext_am:int):
|
||||
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
|
||||
|
|
@ -16,26 +16,26 @@ def sign_extend(val:UOp, sext_am:int):
|
|||
def packed_store(bidx:UOp, var:UOp):
|
||||
unpacked_type = unpack_map[var.dtype]
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
|
||||
new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
|
||||
new_v = ((var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am) if var.dtype != dtypes.half else var
|
||||
mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(unpacked_type)
|
||||
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=unpacked_type)
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(unpacked_type)))
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize) if var.dtype != dtypes.half else bidx.src[1])), ((buf & mask) | new_v.cast(unpacked_type)))
|
||||
|
||||
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
|
||||
div_idx = bidx.src[1]//(4//dtype.itemsize)
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
||||
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=unpack_map[dtype], arg=root.arg)
|
||||
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=unpack_map[dtype], arg=root.arg)
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
wgsl_matcher = PatternMatcher([
|
||||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
|
||||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 and l.dtype != dtypes.half else None),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 and l.dtype != dtypes.half else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
|
||||
# TODO: why is this needed, and only for this MUL order
|
||||
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
|
||||
|
|
@ -43,8 +43,8 @@ wgsl_matcher = PatternMatcher([
|
|||
]) + extra_pm
|
||||
|
||||
type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
|
||||
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool" }
|
||||
buffer_map = { **type_map, dtypes.bool: "i32" }
|
||||
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f32" }
|
||||
buffer_map = { **type_map, dtypes.bool: "i32", dtypes.half: "u32" }
|
||||
|
||||
class WGSLRenderer(CStyleLanguage):
|
||||
device = "WEBGPU"
|
||||
|
|
@ -67,18 +67,20 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), lambda ctx,b,v,g: f"select({ctx[v]}, {ctx[b]}, {ctx[g]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.cond_unpack_f16(ctx[b], ctx[b.src[1]], b.src[0].dtype.base)}, {ctx[g]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.cond_unpack_f16(ctx[b], ctx[b.src[1]], b.src[0].dtype.base)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[({strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})" + ("/2]" if buf.dtype.base == dtypes.half else "]")),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx.cond_pack_to_f16(ctx[v.src[1].src[0]] if b.src[0].dtype.base == dtypes.half else ctx[v.src[1]], ctx[b.src[1]], b.src[0].dtype.base)});" if b.src[0].dtype.itemsize < 4 \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
def cond_unpack_f16(self, v, idx, dt): return f"unpack2x16float(bitcast<u32>({v}))[({idx})%2]" if dt == dtypes.half else v
|
||||
def cond_pack_to_f16(self, v, idx, dt): return f"pack2x16float(select(vec2(f32({v}), 0.0), vec2(0.0, f32({v})), ({idx}%2) == 1))" if dt == dtypes.half else v
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue