mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bc73e644d | ||
|
|
984a0edbc9 | ||
|
|
0b19e2dddd |
7 changed files with 17 additions and 13 deletions
|
|
@ -284,9 +284,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
|
||||
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.load()] + lst # put acc as the first element
|
||||
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), (identity,), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
# const
|
||||
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
|
||||
|
|
@ -60,9 +60,9 @@ base_rewrite = PatternMatcher([
|
|||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
|
||||
# insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
|
||||
if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# devectorize any bools
|
||||
|
|
@ -135,6 +135,7 @@ class CStyleLanguage(Renderer):
|
|||
c: defaultdict[str, int] = defaultdict(int)
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
|
@ -154,7 +155,7 @@ class CStyleLanguage(Renderer):
|
|||
elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg}"
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ class LLVMRenderer(Renderer):
|
|||
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ class PTXRenderer(Renderer):
|
|||
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class PythonProgram:
|
|||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK}
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP}
|
||||
if uop is Ops.DEFINE_REG: idp = [idp[0]]
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
|
|
@ -49,7 +49,7 @@ class PythonProgram:
|
|||
loop_ends[idp[0]] = i
|
||||
i = idp[0]
|
||||
continue
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK):
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class FastEnum(IntEnum):
|
|||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto() # noqa: E702
|
||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
|
||||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD on STORE
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
|
||||
|
|
@ -199,7 +199,7 @@ spec = PatternMatcher([
|
|||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue