mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
d5d59a2be6
commit
5644605d92
1 changed files with 37 additions and 12 deletions
|
|
@ -2,12 +2,12 @@ from __future__ import annotations
|
|||
from typing import cast, Callable, TypeVar, Generic, Any
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace, dataclass, field
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, flatten, pluralize
|
||||
from tinygrad.helpers import to_tuple
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, DEBUG, dedup, flatten, pluralize
|
||||
from tinygrad.helpers import to_tuple, round_up
|
||||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -345,11 +345,10 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
return cur.replace(src=tuple(waits) + cur.src)
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
|
||||
src = q.src
|
||||
if q.tag in ctx:
|
||||
devs, queue = q.arg
|
||||
src += (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp|None:
|
||||
if q.tag not in ctx: return None
|
||||
devs, queue = q.arg
|
||||
src = q.src + (make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + q.tag),)
|
||||
return submit.replace(src=(q.replace(src=src, tag=None),))
|
||||
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
||||
|
||||
|
|
@ -451,7 +450,28 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|||
])
|
||||
|
||||
# *****************
|
||||
# 5.3. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
# 5.3. pack placeholders buffers
|
||||
|
||||
def pack_hcq_placeholders(call:UOp) -> UOp|None:
|
||||
bufs = [b for b in call.src[0].toposort() if b.op is Ops.BUFFER and b.tag in (maxtags:={"scratch"}) | (sumtags:={"program", "kernargs"})]
|
||||
|
||||
off_per_buf:dict[UOp, int] = {}
|
||||
size_per_tag:dict[str, int] = {}
|
||||
for b in bufs:
|
||||
if b.tag in maxtags: size_per_tag[b.tag] = max(size_per_tag.get(b.tag, 0), b.arg)
|
||||
elif b.tag in sumtags:
|
||||
off_per_buf[b] = round_up(size_per_tag.get(b.tag, 0), {"program": 0x1000}.get(b.tag, 128))
|
||||
size_per_tag[b.tag] = off_per_buf[b] + b.arg
|
||||
|
||||
count_per_tag = collections.Counter(b.tag for b in bufs)
|
||||
ref_bufs = {b.tag:b for b in bufs if count_per_tag[b.tag] > 1}
|
||||
bases = {tag:UOp.new_buffer(b.src[1].arg, size_per_tag[tag], b.dtype).rtag(tag) for tag,b in ref_bufs.items()}
|
||||
subs = {b:UOp(Ops.SLICE, b.dtype, (bases[b.tag], UOp.const(dtypes.weakint, off_per_buf.get(b, 0))), b.arg) for b in bufs if b.tag in bases}
|
||||
return call.replace(src=(call.src[0].substitute(subs, walk=True), *call.src[1:])) if subs else None
|
||||
pm_pack_placeholders = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), pack_hcq_placeholders)])
|
||||
|
||||
# *****************
|
||||
# 5.4. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
|
||||
def hold_call_buffers(call:UOp) -> UOp|None:
|
||||
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
||||
|
|
@ -496,13 +516,17 @@ pm_resolve_patches = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||
|
||||
# index on slice is index
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
||||
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
||||
|
||||
# getaddr
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
|
||||
# folders
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||
fold_const_store),
|
||||
]) + symbolic_simple
|
||||
|
||||
|
|
@ -546,10 +570,11 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_pack_placeholders, walk=True, name="pack placeholders")
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
|
||||
# realize starts from here
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, walk=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue