mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
scan_assig
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278ca0388b |
3 changed files with 27 additions and 0 deletions
|
|
@ -87,6 +87,22 @@ class TestOuterScan(unittest.TestCase):
|
||||||
# TODO: testing allclose
|
# TODO: testing allclose
|
||||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||||
|
|
||||||
|
def test_scan_with_assign(self):
|
||||||
|
vec, mats, ref = self._test_scan()
|
||||||
|
|
||||||
|
a = UOp.range(3, -1, AxisType.OUTER)
|
||||||
|
buf_out = Tensor.empty(3, 1, 10)
|
||||||
|
phi = Tensor(a.eq(0).where(vec.uop, buf_out[(a-1).maximum(0)].uop))
|
||||||
|
out = phi @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
|
||||||
|
out = out.reshape(1, 1, 10).pad(((a,(3-a)-1), None, None))
|
||||||
|
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||||
|
out = buf_out.assign(out)
|
||||||
|
|
||||||
|
out.realize()
|
||||||
|
|
||||||
|
# TODO: testing allclose
|
||||||
|
np.testing.assert_allclose(ref.numpy(), out.numpy())
|
||||||
|
|
||||||
class TestOuterworld(unittest.TestCase):
|
class TestOuterworld(unittest.TestCase):
|
||||||
def test_range_plus_1(self):
|
def test_range_plus_1(self):
|
||||||
t = Tensor.arange(100).reshape(10,10).realize()
|
t = Tensor.arange(100).reshape(10,10).realize()
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,11 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||||
|
|
||||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||||
|
|
||||||
|
# if it's an outer reduce, we *don't* realize it
|
||||||
|
if a.src[1].op is Ops.REDUCE and any(tr.arg[-1] == AxisType.OUTER for tr in a.src[1].src[1:]):
|
||||||
|
del ctx[a.src[1]]
|
||||||
|
|
||||||
# if it's a kernel, we don't realize it
|
# if it's a kernel, we don't realize it
|
||||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -540,6 +540,12 @@ replace_contiguous = PatternMatcher([
|
||||||
pm_fix_vmap = PatternMatcher([
|
pm_fix_vmap = PatternMatcher([
|
||||||
# x>=y and x<(y+1) means x==y (this can go in symbolic)
|
# x>=y and x<(y+1) means x==y (this can go in symbolic)
|
||||||
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
|
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
|
||||||
|
# remove the reduce if it's compare reduce w assign (keep the outer range)
|
||||||
|
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
||||||
|
UPat(Ops.ASSIGN, name="assign", allow_any_len=True, src=(UPat.var("assign_buf"),
|
||||||
|
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),)),), allow_any_len=True),
|
||||||
|
lambda r1,r2,val,buf,assign_buf,assign: buf.replace(src=(UOp(Ops.ASSIGN, val.dtype,
|
||||||
|
src=(assign_buf,val)+assign.src[2:]),)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
|
||||||
# remove the reduce if it's compare reduce (keep the outer range)
|
# remove the reduce if it's compare reduce (keep the outer range)
|
||||||
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
||||||
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
|
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue