mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
2adedf5ccb
commit
4dc51aff6e
3 changed files with 78 additions and 176 deletions
|
|
@ -1,145 +0,0 @@
|
|||
from __future__ import annotations
|
||||
import time
|
||||
from typing import cast
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.jit import GraphRunner
|
||||
from tinygrad.engine.realize import get_call_outs_ins, get_runtime
|
||||
from tinygrad.helpers import round_up, ceildiv
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, graph_rewrite
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQ2DeviceCtx, HCQ2LowerCtx, pm_prep_runtime, pm_lower_ops
|
||||
from extra.hcq2.hcq2 import pm_split_into_queues, pm_add_barriers, pm_add_signals
|
||||
from extra.hcq2.hcq2 import pm_bufferize, pm_lift_patches_to_cmdbuf, pm_resolve_patches, pm_parametrize_host_buffers
|
||||
from extra.hcq2.hcq2 import pm_add_timeline_inc, pm_callify, pm_calc_kernargs_sizes
|
||||
|
||||
# **************** insert deps ****************
|
||||
|
||||
def insert_deps(ctx:HCQ2Graph, linear:UOp) -> UOp:
|
||||
src = []
|
||||
for j, call in enumerate(linear.src):
|
||||
call = call.replace(tag=j)
|
||||
_, _, bufs, _ = ctx.calls[j]
|
||||
outs, ins = get_call_outs_ins(call)
|
||||
deps = ctx._access_resources([bufs[i] for i in outs + ins], list(range(len(outs))), call)
|
||||
src.append(UOp(Ops.AFTER, call.dtype, (call, *deps), tag=call.tag))
|
||||
return linear.replace(src=tuple(src))
|
||||
pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)])
|
||||
|
||||
pm_replace_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
||||
])
|
||||
|
||||
# **************** graph-only passes ****************
|
||||
|
||||
def alloc_queue_sig(ctx:HCQ2Graph, q:UOp) -> None:
|
||||
if q.arg in ctx.queue_sigs: return None
|
||||
dev = q.arg[0][0] # TODO: multi device
|
||||
buf = Buffer(dev, 0x100, dtypes.uint8, options=BufferSpec(host=True, uncached=True, cpu_access=True), preallocate=True)
|
||||
ctx.queue_sig_bufs.append(buf)
|
||||
ctx.queue_sigs[q.arg] = UOp.from_buffer(buf, dev)
|
||||
return None
|
||||
pm_alloc_queue_sigs = PatternMatcher([(UPat(Ops.LINEAR, src=UPat({Ops.PROGRAM, Ops.COPY}), name="q"), alloc_queue_sig)])
|
||||
|
||||
def lower_queue_deps(ctx:HCQ2Graph, after:UOp) -> UOp:
|
||||
wrapper, deps, call_idx = after.src[0], after.src[1:], after.tag
|
||||
def store(q_arg, v): return ctx.queue_sigs[q_arg].store(UOp.const(dtypes.uint32, v))
|
||||
waits = tuple(UOp(Ops.WAIT, dtypes.void, (ctx.queue_sigs[dep.src[0].arg], UOp.const(dtypes.uint32, dep.tag),
|
||||
store(dep.src[0].arg, dep.tag))) for dep in deps)
|
||||
return wrapper.replace(src=tuple(q.replace(src=(*waits, *q.src, store(q.arg, call_idx))) for q in wrapper.src))
|
||||
pm_lower_queue_deps = PatternMatcher([(UPat(Ops.AFTER, src=UPat(Ops.LINEAR), name="after"), lower_queue_deps)])
|
||||
|
||||
def optimize_queue_deps(ctx:HCQ2Graph, queue:UOp) -> UOp|None:
|
||||
src, seen, pending, queue_sig = [], {}, {}, ctx.queue_sigs[queue.arg]
|
||||
for x in queue.src:
|
||||
if x.op is Ops.WAIT:
|
||||
sig, val = x.src[0], x.src[1]
|
||||
if sig is queue_sig or seen.get(sig, -1) >= val.arg: continue
|
||||
if (old:=pending.get(sig)) is None or old.src[1].arg < val.arg: pending[sig] = x
|
||||
continue
|
||||
for wait in pending.values():
|
||||
src.append(wait)
|
||||
seen[wait.src[0]] = wait.src[1].arg
|
||||
pending.clear()
|
||||
src.append(x)
|
||||
src += pending.values()
|
||||
return queue.replace(src=tuple(src)) if tuple(src) != queue.src else None
|
||||
pm_optimize_queue_deps = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, src=UPat({Ops.BARRIER, Ops.WAIT, Ops.STORE, Ops.PROGRAM, Ops.COPY}), name="queue"), optimize_queue_deps),
|
||||
])
|
||||
|
||||
def drop_dead_stores(ctx:HCQ2Graph, outer:UOp) -> UOp:
|
||||
live = {u.src[2] for u in outer.toposort() if u.op is Ops.WAIT}
|
||||
return outer.replace(src=tuple(q.replace(src=tuple(x for x in q.src if x.op is not Ops.STORE or x in live)) for q in outer.src))
|
||||
pm_drop_dead_stores = PatternMatcher([(UPat(Ops.LINEAR, src=UPat(Ops.LINEAR), name="outer"), drop_dead_stores)])
|
||||
|
||||
def add_queue_sig_resets(ctx:HCQ2Graph, x:UOp, cmdbuf:UOp) -> UOp|None:
|
||||
if not ctx.queue_sig_bufs or cmdbuf.tag not in ("compute", "copy"): return None
|
||||
resets = tuple((b:=UOp.from_buffer(sig)).index(UOp.const(dtypes.int, 0), dtype=b.dtype.ptr())
|
||||
.cast(dtypes.uint64.ptr()).store(UOp.const(dtypes.uint64, 0)) for sig in ctx.queue_sig_bufs)
|
||||
return x.replace(src=x.src + resets)
|
||||
pm_add_queue_sig_resets = PatternMatcher([(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, name="cmdbuf"),), allow_any_len=True, name="x"),
|
||||
add_queue_sig_resets)])
|
||||
|
||||
# **************** Graph ****************
|
||||
|
||||
class HCQ2Graph(GraphRunner):
|
||||
def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
|
||||
super().__init__(linear, input_uops)
|
||||
self.dev = cast(HCQ2Compiled, Device[self.device])
|
||||
self.hcq_ctx = HCQ2LowerCtx(name="hcq_graph")
|
||||
|
||||
self.input_addrs = Buffer("CPU", max(len(input_uops), 1), dtypes.uint64, preallocate=True)
|
||||
self.input_addrs_uop = UOp.from_buffer(self.input_addrs, "CPU")
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_insert_deps, ctx=self, name="hcq: insert deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_replace_params, ctx=self, name="hcq: replace params", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_prep_runtime, ctx=self.hcq_ctx, name="hcq: prepare runtime")
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_ops, ctx=self.hcq_ctx, name="hcq: lower ops")
|
||||
|
||||
# per-queue signal state — populated as a side-effect by pm_alloc_queue_sigs walking the lowered linear.
|
||||
self.queue_sig_bufs:list[Buffer] = []
|
||||
self.queue_sigs:dict[tuple[str, str], UOp] = {}
|
||||
graph_rewrite(self.linear, pm_alloc_queue_sigs, ctx=self, name="hcq: alloc queue sigs", walk=True)
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_lower_queue_deps, ctx=self, name="hcq: lower queue deps")
|
||||
self.linear = graph_rewrite(self.linear, pm_split_into_queues, ctx=self.hcq_ctx, name="hcq: split into queues")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_barriers, ctx=self.hcq_ctx, name="hcq: add barriers", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_optimize_queue_deps, ctx=self, name="hcq: optimize queue deps", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_drop_dead_stores, ctx=self, name="hcq: drop dead stores")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_signals, ctx=self.hcq_ctx, name="hcq: add signals", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_add_timeline_inc, ctx=self.hcq_ctx, name="hcq: add submit", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, self.dev.pm_lower, ctx=self.hcq_ctx, name=f"hcq: encode cmdbuf {self.dev.device}", walk=True)
|
||||
|
||||
graph_rewrite(self.linear, pm_calc_kernargs_sizes, ctx=(sizes:={}), name=None)
|
||||
for dev_name, sz in sizes.items():
|
||||
buf = Buffer(dev_name, sz, dtypes.uint8, options=BufferSpec(cpu_access=True), preallocate=True)
|
||||
self.hcq_ctx.dev_ctx[dev_name] = HCQ2DeviceCtx(dev_name, UOp.from_buffer(buf, dev_name), UOp.const(dtypes.uint64, buf._buf.va_addr))
|
||||
|
||||
self.linear = graph_rewrite(self.linear, pm_bufferize, ctx=self.hcq_ctx, bottom_up=True, name="realize binaries")
|
||||
self.linear = graph_rewrite(self.linear, pm_lift_patches_to_cmdbuf, ctx=self.hcq_ctx, bottom_up=False, name="lift patches to cmdbuf")
|
||||
self.linear = graph_rewrite(self.linear, pm_resolve_patches, ctx=self.hcq_ctx, bottom_up=False, name="simplify patches")
|
||||
self.linear = graph_rewrite(self.linear, pm_add_queue_sig_resets, ctx=self, name="hcq: add queue sig resets", walk=True)
|
||||
self.linear = graph_rewrite(self.linear, pm_parametrize_host_buffers, ctx=self.hcq_ctx, bottom_up=True, name="parametrize host buffers")
|
||||
self.host_call = graph_rewrite(self.linear, pm_callify, ctx=self.hcq_ctx, name="hcq: callify")
|
||||
|
||||
self.host_rt, self.host_globals = get_runtime("CPU", self.host_call.src[0]), self.host_call.src[0].arg.globals
|
||||
|
||||
def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
addrs = self.input_addrs.as_memoryview(force_zero_copy=True).cast('Q')
|
||||
for i, u in enumerate(input_uops):
|
||||
buf = next(b for b in u.buffer.bufs if b.device == self.dev.device) if isinstance(u.buffer, MultiBuffer) else u.buffer
|
||||
addrs[i] = buf._buf.va_addr
|
||||
self.host_rt(*[self.hcq_ctx.inputs[i].get_buf("CPU") for i in self.host_globals], vals=self.host_call.src[0].arg.vals(var_vals), wait=True)
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
self.dev.synchronize()
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
all_devs = GraphRunner._all_devs(batch_devs, new_call)
|
||||
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(all_devs) == 1 and isinstance(all_devs[0], HCQ2Compiled)
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Callable, TypeVar, Generic, Any
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
|
||||
from dataclasses import replace, dataclass, field
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, flatten, pluralize
|
||||
from tinygrad.helpers import to_tuple
|
||||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
|
|
@ -137,9 +136,9 @@ def unwrap_after(uop):
|
|||
while uop.op is Ops.AFTER: uop = uop.src[0]
|
||||
return uop
|
||||
|
||||
def make_getaddr(u, dev=None):
|
||||
def make_getaddr(u, device=None):
|
||||
if unwrap_after(u).op not in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT): return u
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=dev or to_tuple(u.device)[0])))
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(u, UOp(Ops.DEVICE, arg=device or to_tuple(u.device)[0])))
|
||||
|
||||
def make_ins(op, *srcs):
|
||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||
|
|
@ -149,7 +148,7 @@ def make_cmdbuf(lin, devs, tag):
|
|||
for s in (s for ins in lin.src for s in ins.src):
|
||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||
buf = UOp.new_buffer(devs if len(devs) > 1 else devs[0], len(blob), dtypes.uint8).rtag(tag)
|
||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||
|
||||
|
|
@ -172,6 +171,10 @@ class HCQInfo:
|
|||
name:str = ""
|
||||
estimates:Estimates = Estimates()
|
||||
outs:tuple[int, ...] = ()
|
||||
devs:tuple[str, ...] = ()
|
||||
|
||||
params:tuple[int, ...] = ()
|
||||
inputs:int|None = None
|
||||
|
||||
@staticmethod
|
||||
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
|
||||
|
|
@ -269,6 +272,14 @@ class DepsCtx:
|
|||
deps:DepsTracker = field(default_factory=DepsTracker)
|
||||
opid:itertools.count = field(default_factory=lambda: itertools.count(0))
|
||||
last_per_queue:weakref.WeakValueDictionary[tuple[Any, str], UOp] = field(default_factory=weakref.WeakValueDictionary)
|
||||
params:dict[tuple[int, int], Buffer] = field(default_factory=dict)
|
||||
|
||||
def get_dep_buf(ctx:DepsCtx, u:UOp, lane:int) -> Buffer:
|
||||
# TODO: should this be a part of DepsTracker?
|
||||
if u.op is Ops.PARAM: return ctx.params.setdefault((u.arg.slot, lane), Buffer("NULL", u.max_numel(), u.dtype.base))
|
||||
if u.op is Ops.MSTACK: return get_dep_buf(ctx, u.src[lane], 0)
|
||||
if u.op in (Ops.SLICE, Ops.MSELECT): return get_dep_buf(ctx, u.src[0], u.arg if u.op is Ops.MSELECT else lane)
|
||||
return b.bufs[lane] if isinstance(b:=u.buffer, MultiBuffer) else b
|
||||
|
||||
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
new_src = []
|
||||
|
|
@ -280,12 +291,10 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
q = get_submit(call.src[0]).src[0]
|
||||
new_q = ctx.last_per_queue[q.arg] = q.rtag(next(ctx.opid))
|
||||
|
||||
deps = []
|
||||
refs = [b.buffer for b in get_call_arg_uops(call)]
|
||||
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
|
||||
refs = get_call_arg_uops(call)
|
||||
deps = dedup(flatten(ctx.deps.access_resources([get_dep_buf(ctx, b, l) for b in refs], call.arg.aux.outs, new_q) for l in range(len(q.arg[0]))))
|
||||
|
||||
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
|
||||
new_q = new_q.after(*deps).rtag("deps") if deps else new_q
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
|
@ -298,11 +307,13 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
zero = UOp.const(dtypes.int, 0)
|
||||
tl = make_signal_value(devs)
|
||||
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
|
||||
# queue is inc with deps
|
||||
submit = make_submit(make_signal(devs).store(tl.index(zero)), devs=devs, queue="COMPUTE:0")
|
||||
submit = submit.replace(src=(submit.src[0].after(*queues).rtag("deps"),))
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
|
||||
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
patches = [s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]
|
||||
return UOp.barrier(*patches).sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
|
||||
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
parts:dict[str, list[UOp]] = collections.defaultdict(list)
|
||||
|
|
@ -384,6 +395,29 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
|||
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
|
||||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 4.3. annotate exec devs
|
||||
|
||||
pm_annotate_devs = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
||||
lambda call: call.replace(arg=replace(call.arg, aux=replace(call.arg.aux, devs=get_submit(call.src[0]).src[0].arg[0]))))])
|
||||
|
||||
# *****************
|
||||
# 4.4. replace params with per-submit input address loads
|
||||
|
||||
def replace_params(call:UOp) -> UOp|None:
|
||||
if not (params:={u:u.arg.slot for u in call.src[0].toposort() if u.op is Ops.PARAM and u.addrspace is not None}): return None
|
||||
|
||||
# fill new info
|
||||
hcqinfo = replace(call.arg.aux, params=tuple(sorted(set(params.values()))), inputs=len(get_call_arg_uops(call)))
|
||||
|
||||
inputs = UOp.new_buffer(get_submit(call.src[0]).src[0].arg[0], len(hcqinfo.params), dtypes.uint64).rtag("inputs")
|
||||
|
||||
slot2idx = {s:i for i,s in enumerate(hcqinfo.params)}
|
||||
body = call.src[0].substitute({u:inputs.index(UOp.const(dtypes.int, slot2idx[s])).load() for u,s in params.items()})
|
||||
|
||||
return call.replace(src=(body, *call.src[1:], inputs), arg=replace(call.arg, aux=hcqinfo))
|
||||
pm_replace_params = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), replace_params)])
|
||||
|
||||
# *****************
|
||||
# 5.1. encode cmdbufs
|
||||
|
||||
|
|
@ -423,7 +457,7 @@ pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"),
|
|||
|
||||
def bufferize_buf(buf:UOp) -> UOp|None:
|
||||
if buf.tag is None: return None
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), dev) for dev in to_tuple(buf.src[1].arg))
|
||||
uops = tuple(UOp.from_buffer((dv:=Device[dev]).pm_bufferize.rewrite(buf, ctx=dv), "CPU") for dev in to_tuple(buf.src[1].arg))
|
||||
return make_mstack(uops)
|
||||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
|
||||
|
|
@ -443,18 +477,22 @@ def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
|||
return UOp(Ops.NOOP)
|
||||
|
||||
def resolve_getaddr(buf:UOp, g:UOp) -> UOp:
|
||||
if buf.op not in (Ops.BUFFER, Ops.MSTACK, Ops.MSELECT): return buf
|
||||
if isinstance(b:=buf.buffer, Buffer): return UOp.const(dtypes.uint64, b.get_buf(g.src[1].arg).va_addr)
|
||||
return UOp(Ops.STACK, dtypes.uint64.vec(len(b.bufs)), tuple(UOp.const(dtypes.uint64, x.ensure_allocated()._buf.va_addr) for x in b.bufs))
|
||||
|
||||
def resolve_getaddr_slice(bv:UOp, dev:UOp) -> UOp:
|
||||
itemsize = bv.src[0].dtype.itemsize if unwrap_after(bv.src[0]).op in (Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT) else bv.dtype.itemsize
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * itemsize)
|
||||
|
||||
pm_resolve_patches = PatternMatcher([
|
||||
# multi
|
||||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||
|
||||
# getaddr
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
||||
(UPat(Ops.GETADDR, src=(UPat({Ops.BUFFER, Ops.MSTACK, Ops.MSELECT}, name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||
(UPat(Ops.GETADDR, src=(UPat(name="buf"), UPat(Ops.DEVICE)), name="g"), resolve_getaddr),
|
||||
|
||||
# folders
|
||||
(UPat({Ops.BUFFER, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
|
|
@ -466,24 +504,25 @@ pm_resolve_patches = PatternMatcher([
|
|||
# 8. callify hcq programs
|
||||
|
||||
def to_param(bufs:list[UOp], ref:UOp) -> UOp:
|
||||
bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, len(bufs)-1)
|
||||
if ref not in bufs: bufs.append(ref)
|
||||
return UOp.placeholder((ref.buffer.size,), ref.dtype, bufs.index(ref))
|
||||
pm_to_param = PatternMatcher([(UPat({Ops.MSELECT, Ops.MSTACK, Ops.BUFFER}, name="r"), lambda ctx, r: to_param(ctx, r))])
|
||||
|
||||
def parametrize_host_buffers(call:UOp) -> UOp:
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=[]), bottom_up=True, name="parametrize host buffers")
|
||||
# preserve original order of args
|
||||
body = graph_rewrite(call.src[0], pm_to_param, ctx=(bufs:=list(get_call_arg_uops(call))), bottom_up=True, name="parametrize host buffers")
|
||||
|
||||
# move vars to new slots
|
||||
var_slots = {nm:len(bufs)+i for i,nm in enumerate(sorted({v.expr for v in body.variables() if v.op is Ops.PARAM}))}
|
||||
body = body.substitute({v:v.replace(arg=replace(v.arg, slot=var_slots[v.expr])) for v in body.variables() if v.op is Ops.PARAM})
|
||||
|
||||
return call.replace(src=(body, *bufs) + call.src[1:], tag="hcq_param")
|
||||
return call.replace(src=(body, *bufs) + tuple(x for x in call.src[1:] if x.op is Ops.BIND))
|
||||
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
||||
|
||||
def callify_hcq(call:UOp) -> UOp:
|
||||
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
|
||||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), callify_hcq)])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
||||
def hcq_schedule(linear:UOp) -> UOp:
|
||||
|
|
@ -497,6 +536,8 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_annotate_devs, name="annotate devs")
|
||||
linear = graph_rewrite(linear, pm_replace_params, name="replace params")
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
|
|
@ -504,7 +545,7 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
# realize starts from here
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, walk=True, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
||||
|
||||
return linear
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Iterator, Any, Sequence
|
||||
import time, random, itertools, math, contextlib, weakref
|
||||
import time, random, itertools, math, contextlib, weakref, array
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv, dedup, to_tuple
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv, to_tuple
|
||||
from tinygrad.helpers import BEAM, size_to_str, time_to_str, VALIDATE_WITH_CPU, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite, ProgramInfo
|
||||
|
|
@ -175,10 +175,10 @@ def exec_copy(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
|||
|
||||
def exec_kernel(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
||||
et = None
|
||||
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
|
||||
for device, (bufs, device_vars) in zip(to_tuple(call.src[1].device), unwrap_multi(call, resolve_params(call, ctx.input_uops))):
|
||||
var_vals = {**ctx.var_vals, **device_vars}
|
||||
prg_bufs = [bufs[i].ensure_allocated() for i in ast.arg.globals]
|
||||
rt = get_runtime(device:=bufs[0].device, ast, cache=ctx.cache)
|
||||
rt = get_runtime(device, ast, cache=ctx.cache)
|
||||
global_size, local_size = ast.arg.launch_dims(var_vals)
|
||||
with track_stats(ctx, call, device, prg_bufs, var_vals) as tm:
|
||||
et = tm[0] = rt(*[b.get_buf(device) for b in prg_bufs], global_size=global_size, local_size=local_size, vals=ast.arg.vals(var_vals),
|
||||
|
|
@ -209,8 +209,14 @@ def exec_graph(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
|||
return t[0]
|
||||
|
||||
def exec_hcq(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
||||
pm_exec.rewrite(call.replace(src=(ast,) + call.src[1:]), replace(ctx, update_stats=False))
|
||||
for d in dedup(flatten([to_tuple(u.device) for u in resolve_params(call, ctx.input_uops)])):
|
||||
if call.arg.aux.inputs is not None:
|
||||
for j,dev in enumerate(call.arg.aux.devs):
|
||||
addrs = [(b.bufs[j] if isinstance(b:=ctx.input_uops[i].buffer, MultiBuffer) else b).get_buf(dev).va_addr for i in call.arg.aux.params]
|
||||
call.src[1+call.arg.aux.inputs].buffer.ensure_allocated()._buf.cpu_view().view(fmt='Q')[:len(addrs)] = array.array('Q', addrs)
|
||||
|
||||
pm_exec.rewrite(call.replace(src=(ast,) + call.src[1:]), replace(ctx, update_stats=False, wait=True))
|
||||
|
||||
for d in call.arg.aux.devs:
|
||||
with track_stats(ctx, call, d, [], ctx.var_vals):
|
||||
if ctx.wait: Device[d].synchronize()
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue