Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
5bc73e644d noop continue 2025-07-24 16:20:53 -07:00
George Hotz
984a0edbc9 identity store for DEFINE_REG 2025-07-24 16:13:43 -07:00
George Hotz
0b19e2dddd identity store for DEFINE_REG 2025-07-24 16:04:58 -07:00
7 changed files with 17 additions and 13 deletions

View file

@ -284,9 +284,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.load()] + lst # put acc as the first element
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), (identity,), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret

View file

@ -25,7 +25,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
# const
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
@ -60,9 +60,9 @@ base_rewrite = PatternMatcher([
])
extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
# insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# devectorize any bools
@ -135,6 +135,7 @@ class CStyleLanguage(Renderer):
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue
@ -154,7 +155,7 @@ class CStyleLanguage(Renderer):
elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg}"
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"

View file

@ -160,6 +160,7 @@ class LLVMRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View file

@ -176,6 +176,7 @@ class PTXRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View file

@ -40,7 +40,7 @@ class PythonProgram:
loop_ends: dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK}
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP}
if uop is Ops.DEFINE_REG: idp = [idp[0]]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
@ -49,7 +49,7 @@ class PythonProgram:
loop_ends[idp[0]] = i
i = idp[0]
continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK):
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
# in the python emulator, the warp is always in sync
i += 1
continue

View file

@ -9,7 +9,7 @@ class FastEnum(IntEnum):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto() # noqa: E702
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702

View file

@ -159,7 +159,7 @@ spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
# LOAD on STORE
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
@ -199,7 +199,7 @@ spec = PatternMatcher([
# NOTE: for testing, we let sinks be anything
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOMI, Ops.CUSTOM, Ops.PRECAST)), lambda: True),
# PTX LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),