This commit is contained in:
George Hotz 2025-07-31 19:57:15 -07:00
commit 58e79f3fc8

View file

@ -419,10 +419,15 @@ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(
def contiguous_create_ranges(ctx:list[int], x:UOp):
if len(x.src) != 1: return None
ranges = [UOp.range(dtypes.int, s, ctx[0]+i) if resolve(s!=1) else UOp.const(dtypes.int, 0) for i,s in enumerate(x.shape)]
ctx[0] += len(ranges)
ranges = []
for s in x.shape:
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, ctx[0]))
ctx[0] += 1
else:
ranges.append(UOp.const(dtypes.int, 0))
mm = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype)
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype).reshape(x.shape)
mm2 = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(buf,)+tuple(ranges))
return UOp(Ops.STORE, src=(mm2, mm)+tuple(ranges))
#return x.replace(src=(mm,)+tuple(ranges))
@ -516,7 +521,8 @@ index_pushing = PatternMatcher([
fix_buffers = PatternMatcher([
(UPat(Ops.BUFFER, name="x"), lambda x: UOp(Ops.DEFINE_GLOBAL, dtype=x.dtype.ptr(x.arg), arg=x.src[0].arg)),
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype)),
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype).load()),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), name="x", allow_any_len=True), lambda x: x.replace(src=(x.src[0].src[0],)+x.src[1:])),
(UPat((Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE), name="x"), lambda x: x.src[0]),
(UPat(Ops.EXPAND, name="x"), lambda x: x.replace(arg=None, op=Ops.NOOP)),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: x.replace(op=Ops.REDUCE, arg=x.arg[0])),