mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
009b484dc0 | ||
|
|
0b5796e3c8 | ||
|
|
b2f4f6f6c4 | ||
|
|
a417b6c144 |
5 changed files with 12 additions and 12 deletions
|
|
@ -217,7 +217,7 @@ class TestStatsOptimized(unittest.TestCase):
|
||||||
raise unittest.SkipTest("no locals")
|
raise unittest.SkipTest("no locals")
|
||||||
SZ = N*N*4
|
SZ = N*N*4
|
||||||
# NOTE: these are sort of wrong. they aren't honoring the IF statement
|
# NOTE: these are sort of wrong. they aren't honoring the IF statement
|
||||||
self.check_gemm(p, extra_flops=SZ*4)
|
self.check_gemm(p, extra_flops=SZ*5)
|
||||||
self.assertEqual(p.src[0].arg.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
self.assertEqual(p.src[0].arg.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
||||||
|
|
||||||
def test_reduce(self):
|
def test_reduce(self):
|
||||||
|
|
|
||||||
|
|
@ -568,7 +568,7 @@ class TestFunctionGrad(unittest.TestCase):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
loss.realize(w1.grad, w2.grad, w3.grad)
|
loss.realize(w1.grad, w2.grad, w3.grad)
|
||||||
print(GlobalCounters.global_ops, GlobalCounters.global_mem)
|
print(GlobalCounters.global_ops, GlobalCounters.global_mem)
|
||||||
self.assertLessEqual(GlobalCounters.global_ops, 4739344)
|
self.assertLessEqual(GlobalCounters.global_ops, 5000000)
|
||||||
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
|
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
|
||||||
def test_function_grad_ops_precompile_backward(self):
|
def test_function_grad_ops_precompile_backward(self):
|
||||||
self.test_function_grad_ops(precompile=True, precompile_backward=True)
|
self.test_function_grad_ops(precompile=True, precompile_backward=True)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class Estimates:
|
||||||
if ignore_indexing:
|
if ignore_indexing:
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op in {Ops.INDEX, Ops.SHRINK}:
|
if u.op in {Ops.INDEX, Ops.SHRINK}:
|
||||||
excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort()))
|
excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort(lambda x: x.op is not Ops.END)))
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||||
buf = u
|
buf = u
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||||
if len(arg) == 0: return UOp(Ops.STACK, dtypes.weakint.vec(0))
|
if len(arg) == 0: return UOp(Ops.STACK)
|
||||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
||||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||||
|
|
||||||
|
|
@ -227,7 +227,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
match self.op:
|
match self.op:
|
||||||
# late ops don't have shape
|
# late ops don't have shape
|
||||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||||
Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
|
Ops.SINK | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
|
||||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -305,7 +305,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
|
|
||||||
# passthrough ops
|
# passthrough ops
|
||||||
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
||||||
Ops.COPY | Ops.ALLREDUCE | Ops.STORE:
|
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
# REDUCE with empty axis is passthrough (lowered form)
|
# REDUCE with empty axis is passthrough (lowered form)
|
||||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
case Ops.REDUCE if len(self.arg[1]) == 0:
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,10 @@ spec_shared = PatternMatcher([
|
||||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
|
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
|
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
|
||||||
|
|
||||||
|
# STACK is everywhere too
|
||||||
|
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
|
||||||
|
(UPat(Ops.STACK, src=(UPat(),), allow_any_len=True, name="s"), lambda s: all_same([x.shape for x in s.src])),
|
||||||
|
|
||||||
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
# ALUs: most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
||||||
|
|
@ -146,12 +150,11 @@ spec_tensor = PatternMatcher([
|
||||||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
(UPat(Ops.GETTUPLE, src=(UPat((Ops.FUNCTION, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||||
|
|
||||||
# inputs to movement ops
|
# inputs to movement ops
|
||||||
(UPat(Ops.STACK), lambda: True),
|
|
||||||
(UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
|
(UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
|
||||||
|
|
||||||
# movement ops
|
# movement ops
|
||||||
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
|
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
|
||||||
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)), name="x"),
|
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"),
|
||||||
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
|
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
|
||||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
|
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
|
||||||
|
|
||||||
|
|
@ -214,9 +217,6 @@ spec_program = PatternMatcher([
|
||||||
(UPat(GroupOp.All-{Ops.INS, Ops.NOOP}, name="x"),
|
(UPat(GroupOp.All-{Ops.INS, Ops.NOOP}, name="x"),
|
||||||
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
|
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
|
||||||
|
|
||||||
# STACK/GEP in program. TODO: this should match Tensor
|
|
||||||
(UPat(Ops.STACK), lambda: True),
|
|
||||||
|
|
||||||
# if has a <gate, index_for_dedup>
|
# if has a <gate, index_for_dedup>
|
||||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
|
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
|
||||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue