mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
fix_slice_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4218cc9257 | ||
|
|
17419edc4a |
2 changed files with 9 additions and 6 deletions
|
|
@ -347,11 +347,12 @@ def late_buffer_view(t:UOp, b:UOp):
|
||||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||||
x = x.src[0]
|
x = x.src[0]
|
||||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||||
|
assert x.op is Ops.INDEX, "must be INDEX"
|
||||||
|
|
||||||
if len(shape) == 0: offset = x.src[1].arg
|
if len(shape) == 0: offset = x.src[1].arg
|
||||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||||
|
|
||||||
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1]))
|
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.src[0], UOp.const(dtypes.weakint, offset)), size),))
|
||||||
|
|
||||||
to_bufferview = PatternMatcher([
|
to_bufferview = PatternMatcher([
|
||||||
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||||
|
|
@ -413,7 +414,11 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||||
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
if x.src[0].op is Ops.SLICE:
|
||||||
|
# no INDEX on SLICE, this could be cleaner
|
||||||
|
do_store = buf.store(x.src[0]).end(*rngs)
|
||||||
|
else:
|
||||||
|
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
||||||
return buf.after(do_store)
|
return buf.after(do_store)
|
||||||
|
|
||||||
if allow_locals:
|
if allow_locals:
|
||||||
|
|
|
||||||
|
|
@ -224,12 +224,10 @@ spec_program = PatternMatcher([
|
||||||
# these are intermediate ops. everything should be deleted from here
|
# these are intermediate ops. everything should be deleted from here
|
||||||
spec_full = PatternMatcher([
|
spec_full = PatternMatcher([
|
||||||
# SLICE on BUFFER is allowed if BUFFER is
|
# SLICE on BUFFER is allowed if BUFFER is
|
||||||
(UPat(Ops.SLICE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
(UPat(Ops.SLICE, src=(UPat(GroupOp.Movement.union({Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.AFTER})),
|
||||||
|
UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||||
lambda bv: isinstance(bv.arg, int)),
|
lambda bv: isinstance(bv.arg, int)),
|
||||||
|
|
||||||
# TODO: SLICE shouldn't go on INDEX. why is this allowed? remove these both
|
|
||||||
(UPat(Ops.SLICE, src=(UPat((Ops.INDEX,)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
|
||||||
lambda bv: isinstance(bv.arg, int)),
|
|
||||||
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
|
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
|
||||||
|
|
||||||
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue