mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rangeify: buf limit (#12336)
* limit bufs * g * fix buffer limit * um? * fix * only these? * typo * f * cleaner
This commit is contained in:
parent
a83f219253
commit
2c397eb2a2
2 changed files with 30 additions and 10 deletions
|
|
@ -146,7 +146,6 @@ class TestSchedule(unittest.TestCase):
|
|||
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
@unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting")
|
||||
def test_add_chain_buffers(self):
|
||||
N = 31
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
|
|
@ -1959,7 +1958,6 @@ class TestSchedule(unittest.TestCase):
|
|||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
@unittest.skipIf(RANGEIFY, "rangeify doesn't implement input buffer limiting")
|
||||
def test_limit_bufs_with_var(self):
|
||||
N = 31
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, cast
|
||||
import functools, operator
|
||||
from typing import Any, cast, Iterator
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
|
||||
|
|
@ -134,11 +134,8 @@ class RangeifyContext:
|
|||
progress: int = 0
|
||||
|
||||
# create ranges
|
||||
range_idx: int = 0
|
||||
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP):
|
||||
ret = UOp.range(s, self.range_idx, axistype)
|
||||
self.range_idx += 1
|
||||
return ret
|
||||
range_idx: Iterator[int] = field(default_factory=itertools.count)
|
||||
def new_range(self, s:sint, axistype:AxisType=AxisType.LOOP): return UOp.range(s, next(self.range_idx), axistype)
|
||||
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
|
|
@ -467,6 +464,30 @@ to_bufferview = PatternMatcher([
|
|||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
|
||||
])
|
||||
|
||||
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
|
||||
def limit_bufs(ctx:RangeifyContext, root:UOp):
|
||||
if (device:=root._device) is None: return None # no device, index related calculations
|
||||
device = device if isinstance(device, str) else device[0].split(":")[0]
|
||||
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
|
||||
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
# TODO: add cache to fix n^2
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
|
||||
if len(bufs) > MAX_BUFS - 1: # NOTE: this -1 is for the output buffer
|
||||
srcs = []
|
||||
for s in root.src:
|
||||
if s.op in GroupOp.Elementwise:
|
||||
# Insert bufferize: all AxisType.REDUCE before bufferize are AxisType.LOOP
|
||||
orig_ranges, end_ranges = s.ranges, [x.replace(arg=(next(ctx.range_idx), AxisType.LOOP)) if x.op is Ops.RANGE else x for x in s.ranges]
|
||||
s = s.substitute(dict(zip(orig_ranges, end_ranges))).bufferize(*end_ranges, arg=BufferizeOpts(device=device)).index(*orig_ranges)
|
||||
srcs.append(s)
|
||||
return root.replace(src=tuple(srcs))
|
||||
pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs)])
|
||||
|
||||
# *****************
|
||||
# 4. put in buffers for bufferize
|
||||
# TODO: should BUFFERIZE look a lot more like STORE
|
||||
|
|
@ -662,10 +683,11 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children")
|
||||
|
||||
# rangeify
|
||||
tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify")
|
||||
tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify")
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue