mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
392a21b82b |
1 changed files with 3 additions and 3 deletions
|
|
@ -121,7 +121,7 @@ def map_reshape(idx:UOp, r:UOp):
|
||||||
mish //= s
|
mish //= s
|
||||||
else:
|
else:
|
||||||
ret.append(UOp.const(dtypes.int, 0))
|
ret.append(UOp.const(dtypes.int, 0))
|
||||||
tret = ret[0].sink(*ret[1:]).simplify(tracked=True).src[::-1] if len(ret) else ()
|
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
|
||||||
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
|
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
|
||||||
|
|
||||||
def map_pad(idx:UOp, r:UOp):
|
def map_pad(idx:UOp, r:UOp):
|
||||||
|
|
@ -184,7 +184,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
|
||||||
passthrough_idx.append(idx.src[1+i])
|
passthrough_idx.append(idx.src[1+i])
|
||||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||||
new_ranges.append(ranges[-1])
|
new_ranges.append(ranges[-1])
|
||||||
ret = x.src[0].index(*ranges).bufferize(*new_ranges, arg=x.device)
|
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device)
|
||||||
return ret.index(*passthrough_idx)
|
return ret.index(*passthrough_idx)
|
||||||
|
|
||||||
def map_contiguous(ctx:RangeifyContext, x:UOp):
|
def map_contiguous(ctx:RangeifyContext, x:UOp):
|
||||||
|
|
@ -192,7 +192,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
|
||||||
ranges = []
|
ranges = []
|
||||||
for s in x.shape:
|
for s in x.shape:
|
||||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||||
ret = x.src[0].index(*ranges).bufferize(*ranges, arg=x.device)
|
ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
|
||||||
return ret.forced_reshape(x.shape)
|
return ret.forced_reshape(x.shape)
|
||||||
|
|
||||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue