Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
8aea58353a start disabling rules there's no tests for 2025-08-19 23:36:49 -07:00
2 changed files with 3 additions and 3 deletions

View file

@ -41,7 +41,7 @@ class TestFuse(unittest.TestCase):
def test_fuse_norm(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a / a.mean(axis=1), a)
self._test_fuse(lambda a: a / a.mean(axis=1), a, atol=1e-6)
def test_fuse_argmax(self):
a = Tensor.rand(50,50).realize()

View file

@ -91,8 +91,8 @@ kernelize_sym = symbolic_simple+PatternMatcher([
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# put UnaryOps before EXPANDs, if it can fuse with the input
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
#(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
# lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists