Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
392a21b82b don't bufferize 0s 2025-08-20 21:02:09 -07:00

View file

@ -121,7 +121,7 @@ def map_reshape(idx:UOp, r:UOp):
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
tret = ret[0].sink(*ret[1:]).simplify(tracked=True).src[::-1] if len(ret) else ()
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
@ -184,7 +184,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
passthrough_idx.append(idx.src[1+i])
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
new_ranges.append(ranges[-1])
ret = x.src[0].index(*ranges).bufferize(*new_ranges, arg=x.device)
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device)
return ret.index(*passthrough_idx)
def map_contiguous(ctx:RangeifyContext, x:UOp):
@ -192,7 +192,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
ranges = []
for s in x.shape:
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
ret = x.src[0].index(*ranges).bufferize(*ranges, arg=x.device)
ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
return ret.forced_reshape(x.shape)
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):