mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
instant means instant
This commit is contained in:
parent
d4b513ab05
commit
e58d9161bf
2 changed files with 12 additions and 1 deletions
|
|
@ -166,6 +166,13 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
self.assertEqual(nout.src[1].op, UOps.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_instant_is_instant(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
self.assertIs(a+0, a)
|
||||
self.assertIs(0+a, a)
|
||||
self.assertIs(a*0, a.const_like(0))
|
||||
self.assertIs(a*1, a)
|
||||
|
||||
def test_commutative_work(self):
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
|
|
|
|||
|
|
@ -193,7 +193,11 @@ class UOp(MathTrait):
|
|||
def __new__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if (ret:=ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
if op is UOps.ALU and arg in COMMUTATIVE and (ret:=ucache.get((op, dtype, src[::-1], arg), None)) is not None: return ret
|
||||
ucache[key] = ret = super().__new__(cls)
|
||||
ret = super().__new__(cls)
|
||||
ret.op, ret.dtype, ret.src, ret.arg = op, dtype, src, arg
|
||||
if (nret:=instant.rewrite(ret)) is not None: ret = nret
|
||||
else: del ret.op # this is a new UOp
|
||||
ucache[key] = ret
|
||||
return ret
|
||||
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue