mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
GraphRunner (#4375)
* GraphRunner * new metal graph * update hsa for graph runner * put var_vals back * move that clear after the capture
This commit is contained in:
parent
077ea6926c
commit
272bea5100
6 changed files with 76 additions and 85 deletions
|
|
@ -42,6 +42,8 @@ Device = _Device()
|
|||
class Runner:
|
||||
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
|
||||
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
|
||||
@property
|
||||
def device(self): return Device[self.dname]
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
return self(rawbufs, {} if var_vals is None else var_vals)
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
|
@ -157,9 +159,6 @@ class CompiledRunner(Runner):
|
|||
return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
|
||||
|
||||
@property
|
||||
def device(self): return Device[self.dname]
|
||||
|
||||
def __reduce__(self):
|
||||
return self.__class__, (self.display_name, self.prg, self.dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
|
||||
|
|
@ -181,10 +180,6 @@ class CompiledRunner(Runner):
|
|||
if local_size: lra['local_size'] = local_size
|
||||
return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait)
|
||||
|
||||
class MultiDeviceJITGraph(Runner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
class Compiled:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
|
||||
import functools, itertools, operator
|
||||
import functools, itertools
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
|
||||
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored
|
||||
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, Device, Runner
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
|
@ -12,16 +12,6 @@ from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planne
|
|||
from tinygrad.nn.state import get_parameters
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
# TODO: these graph functions probably shouldn't exist here
|
||||
|
||||
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
|
||||
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \
|
||||
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0)
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and \
|
||||
((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))]
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars]
|
||||
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
|
|
@ -34,7 +24,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|||
nonlocal current_batch, current_device, max_batch_size
|
||||
try:
|
||||
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
|
||||
graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
|
||||
graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
||||
except GraphException as e:
|
||||
|
|
@ -51,7 +44,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|||
|
||||
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
||||
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
|
||||
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiDeviceJITGraph) and type(ji_graph_dev) == type(current_device))) #type:ignore
|
||||
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
|
||||
if can_be_graphed: current_batch.append(ji)
|
||||
|
|
@ -70,6 +63,26 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
|
|||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.jc_idx_with_updatable_launch_dims = []
|
||||
self.jc_idx_with_updatable_var_vals = []
|
||||
op_estimate: sint = 0
|
||||
mem_estimate: sint = 0
|
||||
for j,ji in enumerate(jit_cache):
|
||||
op_estimate += ji.prg.op_estimate
|
||||
mem_estimate += ji.prg.mem_estimate
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.vars: self.jc_idx_with_updatable_var_vals.append(j)
|
||||
if (ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)):
|
||||
self.jc_idx_with_updatable_launch_dims.append(j)
|
||||
self.vars = list(var_vals.keys())
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
|
||||
|
||||
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
|
|
|
|||
|
|
@ -1,23 +1,21 @@
|
|||
import ctypes, collections
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv, colored
|
||||
from tinygrad.device import CompiledRunner, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class CUDAGraph(MultiDeviceJITGraph):
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
|
|
@ -37,7 +35,7 @@ class CUDAGraph(MultiDeviceJITGraph):
|
|||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
|
||||
if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
|
||||
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
|
|
@ -63,14 +61,10 @@ class CUDAGraph(MultiDeviceJITGraph):
|
|||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1,
|
||||
ctypes.byref(cp_params), dest_dev.context))
|
||||
if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
|
|
@ -80,12 +74,12 @@ class CUDAGraph(MultiDeviceJITGraph):
|
|||
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
||||
|
||||
# Update var_vals in the c_args struct.
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for j in self.jc_idx_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
|
||||
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims in the kern_params struct.
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals))
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up, colored
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, Device
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
||||
|
||||
|
|
@ -25,12 +25,9 @@ class VirtAQLQueue(AQLQueue):
|
|||
self.packets_count += 1
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
class HSAGraph(MultiDeviceJITGraph):
|
||||
class HSAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
compiled_devices = set()
|
||||
|
|
@ -111,10 +108,6 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "HSA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
|
|
@ -130,12 +123,12 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for j in self.jc_idx_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.packets[j].workgroup_size_x = lc[0]
|
||||
self.packets[j].workgroup_size_y = lc[1]
|
||||
|
|
|
|||
|
|
@ -1,79 +1,75 @@
|
|||
from typing import List, Any, Dict, cast, Optional
|
||||
import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException, colored
|
||||
from tinygrad.device import Buffer, CompiledRunner, Runner
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledRunner
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.engine.jit import GraphRunner
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
from tinygrad.runtime.ops_metal import wait_check
|
||||
|
||||
class MetalGraph(Runner):
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
class MetalGraphRunner(GraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.device: MetalDevice = device
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
|
||||
icb_descriptor.setInheritBuffers_(False)
|
||||
icb_descriptor.setInheritPipelineState_(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
||||
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) # noqa: E501
|
||||
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
||||
Metal.MTLResourceOptions(0))
|
||||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf] if len(var_vals) else []
|
||||
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf] if len(self.vars) else []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(prg.clprg.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
icb_command.setComputePipelineState_(pipeline_state)
|
||||
icb_command.setComputePipelineState_(unwrap2(
|
||||
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
||||
all_resources.append(b._buf)
|
||||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(prg.vars):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
if j not in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.command_buffer: Any = None
|
||||
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), device.dname, *get_jit_stats(jit_cache))
|
||||
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
||||
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
|
||||
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
|
||||
for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
||||
Metal.MTLSize(*local_size))
|
||||
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
||||
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
self.command_buffer = command_buffer
|
||||
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
return None
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -97,9 +97,9 @@ class MetalDevice(Compiled):
|
|||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
self.track_cross_buffer: List[Any] = []
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
from tinygrad.runtime.graph.metal import MetalGraphRunner
|
||||
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
|
||||
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
functools.partial(MetalProgram, self), MetalGraphRunner)
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||
self.mv_in_metal.clear()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue