Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
278ca0388b scan with assign 2025-11-18 11:02:22 -08:00
3 changed files with 27 additions and 0 deletions

View file

@ -87,6 +87,22 @@ class TestOuterScan(unittest.TestCase):
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_scan_with_assign(self):
vec, mats, ref = self._test_scan()
a = UOp.range(3, -1, AxisType.OUTER)
buf_out = Tensor.empty(3, 1, 10)
phi = Tensor(a.eq(0).where(vec.uop, buf_out[(a-1).maximum(0)].uop))
out = phi @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
out = out.reshape(1, 1, 10).pad(((a,(3-a)-1), None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out = buf_out.assign(out)
out.realize()
# TODO: testing allclose
np.testing.assert_allclose(ref.numpy(), out.numpy())
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()

View file

@ -18,6 +18,11 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's an outer reduce, we *don't* realize it
if a.src[1].op is Ops.REDUCE and any(tr.arg[-1] == AxisType.OUTER for tr in a.src[1].src[1:]):
del ctx[a.src[1]]
# if it's a kernel, we don't realize it
if a.src[1].op is not Ops.KERNEL: ctx[a] = None

View file

@ -540,6 +540,12 @@ replace_contiguous = PatternMatcher([
pm_fix_vmap = PatternMatcher([
# x>=y and x<(y+1) means x==y (this can go in symbolic)
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
# remove the reduce if it's compare reduce w assign (keep the outer range)
(UPat(Ops.BUFFERIZE, name="buf", src=(
UPat(Ops.ASSIGN, name="assign", allow_any_len=True, src=(UPat.var("assign_buf"),
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),)),), allow_any_len=True),
lambda r1,r2,val,buf,assign_buf,assign: buf.replace(src=(UOp(Ops.ASSIGN, val.dtype,
src=(assign_buf,val)+assign.src[2:]),)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
# remove the reduce if it's compare reduce (keep the outer range)
(UPat(Ops.BUFFERIZE, name="buf", src=(
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),