Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
4218cc9257 fix spec 2026-05-27 17:35:46 -07:00
George Hotz
17419edc4a fix slice store to remove the index 2026-05-27 17:21:49 -07:00
2 changed files with 9 additions and 6 deletions

View file

@ -347,11 +347,12 @@ def late_buffer_view(t:UOp, b:UOp):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise" assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
x = x.src[0] x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX) x = next(u for u in x.src if u.op is Ops.INDEX)
assert x.op is Ops.INDEX, "must be INDEX"
if len(shape) == 0: offset = x.src[1].arg if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0) else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1])) return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.src[0], UOp.const(dtypes.weakint, offset)), size),))
to_bufferview = PatternMatcher([ to_bufferview = PatternMatcher([
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view), (UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
@ -413,7 +414,11 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
# NOTE: the DEFINE_LOCAL needs to be disambiguated here # NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL: if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size) buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs) if x.src[0].op is Ops.SLICE:
# no INDEX on SLICE, this could be cleaner
do_store = buf.store(x.src[0]).end(*rngs)
else:
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store) return buf.after(do_store)
if allow_locals: if allow_locals:

View file

@ -224,12 +224,10 @@ spec_program = PatternMatcher([
# these are intermediate ops. everything should be deleted from here # these are intermediate ops. everything should be deleted from here
spec_full = PatternMatcher([ spec_full = PatternMatcher([
# SLICE on BUFFER is allowed if BUFFER is # SLICE on BUFFER is allowed if BUFFER is
(UPat(Ops.SLICE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"), (UPat(Ops.SLICE, src=(UPat(GroupOp.Movement.union({Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.AFTER})),
UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
lambda bv: isinstance(bv.arg, int)), lambda bv: isinstance(bv.arg, int)),
# TODO: SLICE shouldn't go on INDEX. why is this allowed? remove these both
(UPat(Ops.SLICE, src=(UPat((Ops.INDEX,)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
lambda bv: isinstance(bv.arg, int)),
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True), (UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL. # codegen may end ranges after gpudims has replaced RANGE with SPECIAL.