Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
5ddffb382b small cleanups to rangeify 2025-08-21 10:45:13 -07:00
2 changed files with 29 additions and 24 deletions

View file

@ -606,6 +606,8 @@ jobs:
- name: Test CPU=1 RANGEIFY=1
# TODO: add more passing tests here
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
- name: Test CPU=1 RANGEIFY=1 PARTIAL_CONTIG=1
run: PARTIAL_CONTIG=1 CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
testdevectorize:
name: Linux (devectorize)

View file

@ -33,6 +33,8 @@ earliest_rewrites = PatternMatcher([
# assign only to buffer
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x"))),
lambda x,target: x if target.base.op is not Ops.BUFFER else None),
# contiguous/buffer/copy/assign is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
])
# 1. add contiguous where we have to
@ -92,35 +94,29 @@ pm_children = PatternMatcher([
@dataclass
class RangeifyContext:
idx: int = 0
regs: int = 0
# block on parent until all children have been seen
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
progress: int = 0
children: dict[UOp, list[UOp]]|None = None
# create ranges
range_idx: int = 0
def new_range(self, s:sint):
ret = UOp.range(dtypes.int, s, self.idx)
self.idx += 1
ret = UOp.range(dtypes.int, s, self.range_idx)
self.range_idx += 1
return ret
def collapse_to_1(shp:tuple[sint, ...], idxs:tuple[UOp, ...]) -> UOp:
def map_reshape(idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(shp, idxs))[::-1]:
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
return sum(to_sum, start=UOp.const(dtypes.int, 0))
def map_reshape(idx:UOp, r:UOp):
mish = collapse_to_1(idx.shape, idx.src[1:])
mish = sum(to_sum, start=UOp.const(dtypes.int, 0))
ret:list[UOp] = []
for s in r.src[0].shape[::-1]:
if resolve(s!=1):
# this MOD should limit any ranges outside s
ret.append(mish % s)
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
ret.append(mish % s) # NOTE: simplify will turn this to CONST
mish //= s
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
@ -134,6 +130,7 @@ def map_pad(idx:UOp, r:UOp):
if resolve(s > 0): where = where & (ret[i] >= s)
bigwhere = bigwhere & where
# this is safe but dumb
# TODO (S-Lykles): switch to mixed index/valid
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
# PAD is with 0
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
@ -232,8 +229,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
out_rngs = list(idx.src[1:])
idx_ranges, end_ranges = ctx.seen_child[c]
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
if len(idx_ranges) == 0: return c.index(*out_rngs)
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
# index based on the shared ranges
ret = c.index(*out_rngs)
# if all ranges aren't the same between children, we have to bufferize
if len(idx_ranges) > 0: ret = ret.bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
return ret
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
@ -241,6 +241,7 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
def might_end_axis(idx:UOp):
if idx.arg is None: return None
# TODO: write a proper cost function here
if all(x.op not in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.BUFFERIZE} for x in idx.toposort()): return None
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
to_end_axis = []
@ -263,7 +264,7 @@ pm_rangeify = pm_mops+PatternMatcher([
# if we come across this, remove it. it was a CHILD unused in an INDEX
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we idx
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we INDEX it
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(src=())),
# handle arg on any op with weight. old endrange stuff
@ -278,6 +279,7 @@ pm_rangeify = pm_mops+PatternMatcher([
# 3.5 cleanups
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
# TODO: figure out how to reenable this
def cleanup_dead_axes(b:UOp):
parents = b.src[0].toposort()
new_rng = []
@ -361,9 +363,9 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
buf = assign.as_buf()
assert buf not in ctx.map
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
assert buf not in ctx.map
ctx.map[buf] = assign
return buf
@ -381,8 +383,9 @@ to_define_global = PatternMatcher([
(UPat(Ops.STORE, name="store").f(Ops.INDEX, allow_any_len=True, name="idx").f(Ops.LOAD),
lambda store,idx: idx.replace(src=(store.as_buf(),)+idx.src[1:]).load(store)),
# HACK
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
])
def split_store(x:UOp):
@ -396,7 +399,7 @@ def split_store(x:UOp):
# NOTE: the hack for COPY is here
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret, ()))
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,()))
return x.as_buf().assign(kernel)
split_kernels = PatternMatcher([