mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
disable_ru
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8aea58353a |
2 changed files with 3 additions and 3 deletions
|
|
@ -41,7 +41,7 @@ class TestFuse(unittest.TestCase):
|
|||
|
||||
def test_fuse_norm(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a / a.mean(axis=1), a)
|
||||
self._test_fuse(lambda a: a / a.mean(axis=1), a, atol=1e-6)
|
||||
|
||||
def test_fuse_argmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ kernelize_sym = symbolic_simple+PatternMatcher([
|
|||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
|
||||
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
||||
# put UnaryOps before EXPANDs, if it can fuse with the input
|
||||
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
|
||||
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
||||
#(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
|
||||
# lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue