mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
scan_assig
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278ca0388b |
3 changed files with 27 additions and 0 deletions
|
|
@ -87,6 +87,22 @@ class TestOuterScan(unittest.TestCase):
|
|||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
def test_scan_with_assign(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
buf_out = Tensor.empty(3, 1, 10)
|
||||
phi = Tensor(a.eq(0).where(vec.uop, buf_out[(a-1).maximum(0)].uop))
|
||||
out = phi @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
|
||||
out = out.reshape(1, 1, 10).pad(((a,(3-a)-1), None, None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out = buf_out.assign(out)
|
||||
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
np.testing.assert_allclose(ref.numpy(), out.numpy())
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
|
||||
# if it's an outer reduce, we *don't* realize it
|
||||
if a.src[1].op is Ops.REDUCE and any(tr.arg[-1] == AxisType.OUTER for tr in a.src[1].src[1:]):
|
||||
del ctx[a.src[1]]
|
||||
|
||||
# if it's a kernel, we don't realize it
|
||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -540,6 +540,12 @@ replace_contiguous = PatternMatcher([
|
|||
pm_fix_vmap = PatternMatcher([
|
||||
# x>=y and x<(y+1) means x==y (this can go in symbolic)
|
||||
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
|
||||
# remove the reduce if it's compare reduce w assign (keep the outer range)
|
||||
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
||||
UPat(Ops.ASSIGN, name="assign", allow_any_len=True, src=(UPat.var("assign_buf"),
|
||||
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),)),), allow_any_len=True),
|
||||
lambda r1,r2,val,buf,assign_buf,assign: buf.replace(src=(UOp(Ops.ASSIGN, val.dtype,
|
||||
src=(assign_buf,val)+assign.src[2:]),)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
|
||||
# remove the reduce if it's compare reduce (keep the outer range)
|
||||
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
||||
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue