mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
r_cleanups
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ddffb382b |
2 changed files with 29 additions and 24 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -606,6 +606,8 @@ jobs:
|
|||
- name: Test CPU=1 RANGEIFY=1
|
||||
# TODO: add more passing tests here
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
- name: Test CPU=1 RANGEIFY=1 PARTIAL_CONTIG=1
|
||||
run: PARTIAL_CONTIG=1 CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
|
||||
testdevectorize:
|
||||
name: Linux (devectorize)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ earliest_rewrites = PatternMatcher([
|
|||
# assign only to buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x"))),
|
||||
lambda x,target: x if target.base.op is not Ops.BUFFER else None),
|
||||
# contiguous/buffer/copy/assign is already contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
||||
])
|
||||
|
||||
# 1. add contiguous where we have to
|
||||
|
|
@ -92,35 +94,29 @@ pm_children = PatternMatcher([
|
|||
|
||||
@dataclass
|
||||
class RangeifyContext:
|
||||
idx: int = 0
|
||||
regs: int = 0
|
||||
# block on parent until all children have been seen
|
||||
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
|
||||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
progress: int = 0
|
||||
children: dict[UOp, list[UOp]]|None = None
|
||||
|
||||
# create ranges
|
||||
range_idx: int = 0
|
||||
def new_range(self, s:sint):
|
||||
ret = UOp.range(dtypes.int, s, self.idx)
|
||||
self.idx += 1
|
||||
ret = UOp.range(dtypes.int, s, self.range_idx)
|
||||
self.range_idx += 1
|
||||
return ret
|
||||
|
||||
def collapse_to_1(shp:tuple[sint, ...], idxs:tuple[UOp, ...]) -> UOp:
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(shp, idxs))[::-1]:
|
||||
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
return sum(to_sum, start=UOp.const(dtypes.int, 0))
|
||||
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
mish = collapse_to_1(idx.shape, idx.src[1:])
|
||||
mish = sum(to_sum, start=UOp.const(dtypes.int, 0))
|
||||
ret:list[UOp] = []
|
||||
for s in r.src[0].shape[::-1]:
|
||||
if resolve(s!=1):
|
||||
# this MOD should limit any ranges outside s
|
||||
ret.append(mish % s)
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret.append(mish % s) # NOTE: simplify will turn this to CONST
|
||||
mish //= s
|
||||
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
|
||||
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
|
||||
|
||||
|
|
@ -134,6 +130,7 @@ def map_pad(idx:UOp, r:UOp):
|
|||
if resolve(s > 0): where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
# TODO (S-Lykles): switch to mixed index/valid
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
|
||||
|
|
@ -232,8 +229,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
|||
out_rngs = list(idx.src[1:])
|
||||
idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
|
||||
if len(idx_ranges) == 0: return c.index(*out_rngs)
|
||||
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
# index based on the shared ranges
|
||||
ret = c.index(*out_rngs)
|
||||
# if all ranges aren't the same between children, we have to bufferize
|
||||
if len(idx_ranges) > 0: ret = ret.bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
return ret
|
||||
|
||||
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
||||
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
|
||||
|
|
@ -241,6 +241,7 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
|||
|
||||
def might_end_axis(idx:UOp):
|
||||
if idx.arg is None: return None
|
||||
# TODO: write a proper cost function here
|
||||
if all(x.op not in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.BUFFERIZE} for x in idx.toposort()): return None
|
||||
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
|
||||
to_end_axis = []
|
||||
|
|
@ -263,7 +264,7 @@ pm_rangeify = pm_mops+PatternMatcher([
|
|||
# if we come across this, remove it. it was a CHILD unused in an INDEX
|
||||
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
|
||||
|
||||
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we idx
|
||||
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we INDEX it
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(src=())),
|
||||
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
|
|
@ -278,6 +279,7 @@ pm_rangeify = pm_mops+PatternMatcher([
|
|||
# 3.5 cleanups
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
# TODO: figure out how to reenable this
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
parents = b.src[0].toposort()
|
||||
new_rng = []
|
||||
|
|
@ -361,9 +363,9 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
|
|||
|
||||
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
|
||||
buf = assign.as_buf()
|
||||
assert buf not in ctx.map
|
||||
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
|
||||
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
|
||||
assert buf not in ctx.map
|
||||
ctx.map[buf] = assign
|
||||
return buf
|
||||
|
||||
|
|
@ -381,8 +383,9 @@ to_define_global = PatternMatcher([
|
|||
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
|
||||
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store)),
|
||||
|
||||
# HACK
|
||||
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
# HACK in case any CONSTs were replaced
|
||||
# this is only needed if you are using symbolic
|
||||
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
])
|
||||
|
||||
def split_store(x:UOp):
|
||||
|
|
@ -396,7 +399,7 @@ def split_store(x:UOp):
|
|||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret, ()))
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,()))
|
||||
return x.as_buf().assign(kernel)
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue