instant means instant

This commit is contained in:
George Hotz 2024-10-24 19:01:16 +08:00
commit e58d9161bf
2 changed files with 12 additions and 1 deletions

View file

@ -166,6 +166,13 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].op, UOps.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
def test_instant_is_instant(self):
a = UOp.variable('a', 0, 1)
self.assertIs(a+0, a)
self.assertIs(0+a, a)
self.assertIs(a*0, a.const_like(0))
self.assertIs(a*1, a)
def test_commutative_work(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)

View file

@ -193,7 +193,11 @@ class UOp(MathTrait):
def __new__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if (ret:=ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
if op is UOps.ALU and arg in COMMUTATIVE and (ret:=ucache.get((op, dtype, src[::-1], arg), None)) is not None: return ret
ucache[key] = ret = super().__new__(cls)
ret = super().__new__(cls)
ret.op, ret.dtype, ret.src, ret.arg = op, dtype, src, arg
if (nret:=instant.rewrite(ret)) is not None: ret = nret
else: del ret.op # this is a new UOp
ucache[key] = ret
return ret
__slots__ = ["op", "dtype", "src", "arg"]