mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fixedvars in all graphs (#10365)
* cuda fixedvars * metal: fixevars * f * ups * count fixedvars
This commit is contained in:
parent
efa8dfe7fb
commit
90c4bb10c0
3 changed files with 14 additions and 11 deletions
|
|
@ -75,7 +75,7 @@ class GraphRunner(Runner):
|
|||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.var_vals_replace:dict[int, list[int]] = {}
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class GraphRunner(Runner):
|
|||
for j,ji in enumerate(jit_cache):
|
||||
estimates += ji.prg.estimates
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars if v not in ji.fixedvars]
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v)) for i, v in enumerate(ji.prg.p.vars) if v not in ji.fixedvars]
|
||||
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
|
|
@ -107,7 +107,7 @@ class GraphRunner(Runner):
|
|||
def updated_vars(self, var_vals: dict[Variable, int]):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
for j, vidxs in self.var_vals_replace.items():
|
||||
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
||||
for i, v in vidxs: yield j, i, vals[v]
|
||||
|
||||
def updated_launch_dims(self, var_vals: dict[Variable, int]):
|
||||
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class CUDAGraph(MultiGraphRunner):
|
|||
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, cast
|
||||
import ctypes, re
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.helpers import dedup, getenv, merge_dicts
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
|
|
@ -34,9 +34,11 @@ class MetalGraph(GraphRunner):
|
|||
icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
|
||||
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
|
||||
|
||||
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
||||
all_pipelines = []
|
||||
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
|
||||
self.varlist = self.vars + list(self.fixedvars.keys())
|
||||
if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
|
||||
|
||||
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
|
||||
for j,ji in enumerate(jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
|
||||
|
|
@ -46,7 +48,7 @@ class MetalGraph(GraphRunner):
|
|||
if b is not None and b not in input_rawbuffers:
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
||||
|
|
@ -55,7 +57,8 @@ class MetalGraph(GraphRunner):
|
|||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
|
|
@ -69,7 +72,7 @@ class MetalGraph(GraphRunner):
|
|||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
|
||||
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
||||
for var in self.vars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
|
||||
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
||||
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue