Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
009b484dc0 bump 2026-06-17 16:22:13 -07:00
George Hotz
0b5796e3c8 fix gemm group + END shape 2026-06-17 16:16:29 -07:00
George Hotz
b2f4f6f6c4 spec for stack 2026-06-17 16:07:50 -07:00
George Hotz
a417b6c144 STACK 0 is dtype void 2026-06-17 16:04:04 -07:00
5 changed files with 12 additions and 12 deletions

View file

@ -217,7 +217,7 @@ class TestStatsOptimized(unittest.TestCase):
raise unittest.SkipTest("no locals") raise unittest.SkipTest("no locals")
SZ = N*N*4 SZ = N*N*4
# NOTE: these are sort of wrong. they aren't honoring the IF statement # NOTE: these are sort of wrong. they aren't honoring the IF statement
self.check_gemm(p, extra_flops=SZ*4) self.check_gemm(p, extra_flops=SZ*5)
self.assertEqual(p.src[0].arg.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) self.assertEqual(p.src[0].arg.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
def test_reduce(self): def test_reduce(self):

View file

@ -568,7 +568,7 @@ class TestFunctionGrad(unittest.TestCase):
GlobalCounters.reset() GlobalCounters.reset()
loss.realize(w1.grad, w2.grad, w3.grad) loss.realize(w1.grad, w2.grad, w3.grad)
print(GlobalCounters.global_ops, GlobalCounters.global_mem) print(GlobalCounters.global_ops, GlobalCounters.global_mem)
self.assertLessEqual(GlobalCounters.global_ops, 4739344) self.assertLessEqual(GlobalCounters.global_ops, 5000000)
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True) def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
def test_function_grad_ops_precompile_backward(self): def test_function_grad_ops_precompile_backward(self):
self.test_function_grad_ops(precompile=True, precompile_backward=True) self.test_function_grad_ops(precompile=True, precompile_backward=True)

View file

@ -28,7 +28,7 @@ class Estimates:
if ignore_indexing: if ignore_indexing:
for u in uops: for u in uops:
if u.op in {Ops.INDEX, Ops.SHRINK}: if u.op in {Ops.INDEX, Ops.SHRINK}:
excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort())) excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort(lambda x: x.op is not Ops.END)))
for u in uops: for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}: if u.op in {Ops.LOAD, Ops.STORE}:
buf = u buf = u

View file

@ -83,7 +83,7 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
return ret return ret
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK, dtypes.weakint.vec(0)) if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg) elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg)) else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
@ -227,7 +227,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
match self.op: match self.op:
# late ops don't have shape # late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \ Ops.SINK | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION: Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None return None
@ -305,7 +305,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# passthrough ops # passthrough ops
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \ case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
Ops.COPY | Ops.ALLREDUCE | Ops.STORE: Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form) # REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0: case Ops.REDUCE if len(self.arg[1]) == 0:

View file

@ -55,6 +55,10 @@ spec_shared = PatternMatcher([
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))), (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
# STACK is everywhere too
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
(UPat(Ops.STACK, src=(UPat(),), allow_any_len=True, name="s"), lambda s: all_same([x.shape for x in s.src])),
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE # ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base), (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
@ -146,12 +150,11 @@ spec_tensor = PatternMatcher([
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)), (UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
# inputs to movement ops # inputs to movement ops
(UPat(Ops.STACK), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True), (UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
# movement ops # movement ops
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True), (UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)), name="x"), (UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"),
lambda x: x.src[1].dtype.count == x.src[2].dtype.count), lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)), (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
@ -214,9 +217,6 @@ spec_program = PatternMatcher([
(UPat(GroupOp.All-{Ops.INS, Ops.NOOP}, name="x"), (UPat(GroupOp.All-{Ops.INS, Ops.NOOP}, name="x"),
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None), lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK), lambda: True),
# if has a <gate, index_for_dedup> # if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True), (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),