rangeify: buf limit (#12336)

* limit bufs

* g

* fix buffer limit

* um?

* fix

* only these?

* typo

* f

* cleaner
This commit is contained in:
nimlgen 2025-09-30 14:59:47 +03:00 committed by GitHub
commit 2c397eb2a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 10 deletions

View file

@ -146,7 +146,6 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
@unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting")
def test_add_chain_buffers(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
@ -1959,7 +1958,6 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(swizzle_cnt(new_uop), 0)
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
@unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting")
def test_limit_bufs_with_var(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):

View file

@ -1,5 +1,5 @@
from typing import Any, cast
import functools, operator
from typing import Any, cast, Iterator
import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
@ -134,11 +134,8 @@ class RangeifyContext:
progress: int = 0
# create ranges
range_idx: int = 0
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
ret = UOp.range(s, self.range_idx, axistype)
self.range_idx += 1
return ret
range_idx: Iterator[int] = field(default_factory=itertools.count)
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): return UOp.range(s, next(self.range_idx), axistype)
def map_reshape(idx:UOp, r:UOp):
acc = 1
@ -467,6 +464,30 @@ to_bufferview = PatternMatcher([
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
])
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
def limit_bufs(ctx:RangeifyContext, root:UOp):
if (device:=root._device) is None: return None # no device, index related calculations
device = device if isinstance(device, str) else device[0].split(":")[0]
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
bufs: set[UOp] = set()
def gate_input(u:UOp):
# TODO: add cache to fix n^2
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.DEFINE_VAR}): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
if len(bufs) > MAX_BUFS - 1: # NOTE: this -1 is for the output buffer
srcs = []
for s in root.src:
if s.op in GroupOp.Elementwise:
# Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP
orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges]
s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=device)).index(*orig_ranges)
srcs.append(s)
return root.replace(src=tuple(srcs))
pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs)])
# *****************
# 4. put in buffers for bufferize
# TODO: should BUFFERIZE look a lot more like STORE
@ -662,10 +683,11 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children")
# rangeify
tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify")
tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify")
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor