GraphRunner (#4375)

* GraphRunner

* new metal graph

* update hsa for graph runner

* put var_vals back

* move that clear after the capture
This commit is contained in:
George Hotz 2024-05-01 10:27:13 -07:00 committed by GitHub
commit 272bea5100
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 76 additions and 85 deletions

View file

@ -42,6 +42,8 @@ Device = _Device()
class Runner:
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
@property
def device(self): return Device[self.dname]
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
return self(rawbufs, {} if var_vals is None else var_vals)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
@ -157,9 +159,6 @@ class CompiledRunner(Runner):
return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
@property
def device(self): return Device[self.dname]
def __reduce__(self):
return self.__class__, (self.display_name, self.prg, self.dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
@ -181,10 +180,6 @@ class CompiledRunner(Runner):
if local_size: lra['local_size'] = local_size
return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait)
class MultiDeviceJITGraph(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
raise NotImplementedError("override this")
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
class Compiled:

View file

@ -1,10 +1,10 @@
from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional
import functools, itertools, operator
import functools, itertools
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored
from tinygrad.device import Buffer, CompiledRunner, BufferXfer, Compiled, Device, Runner
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
@ -12,16 +12,6 @@ from tinygrad.engine.realize import ExecItem, capturing, _internal_memory_planne
from tinygrad.nn.state import get_parameters
from weakref import WeakKeyDictionary
# TODO: these graph functions probably shouldn't exist here
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0)
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and \
((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))]
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars]
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
@ -34,7 +24,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
nonlocal current_batch, current_device, max_batch_size
try:
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
max_batch_size *= 2
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
@ -51,7 +44,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiDeviceJITGraph) and type(ji_graph_dev) == type(current_device))) #type:ignore
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
if can_be_graphed: current_batch.append(ji)
@ -70,6 +63,26 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
class GraphRunner(Runner): # pylint: disable=abstract-method
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.jc_idx_with_updatable_launch_dims = []
self.jc_idx_with_updatable_var_vals = []
op_estimate: sint = 0
mem_estimate: sint = 0
for j,ji in enumerate(jit_cache):
op_estimate += ji.prg.op_estimate
mem_estimate += ji.prg.mem_estimate
if isinstance(ji.prg, CompiledRunner):
if ji.prg.vars: self.jc_idx_with_updatable_var_vals.append(j)
if (ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)):
self.jc_idx_with_updatable_launch_dims.append(j)
self.vars = list(var_vals.keys())
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):

View file

@ -1,23 +1,21 @@
import ctypes, collections
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException, getenv, colored
from tinygrad.device import CompiledRunner, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.helpers import init_c_var, GraphException, getenv
from tinygrad.device import CompiledRunner, Buffer, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.jit import MultiGraphRunner
class CUDAGraph(MultiDeviceJITGraph):
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
@ -37,7 +35,7 @@ class CUDAGraph(MultiDeviceJITGraph):
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
@ -63,14 +61,10 @@ class CUDAGraph(MultiDeviceJITGraph):
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, (cuda.CUgraphNode*1)(node_to), 1,
ctypes.byref(cp_params), dest_dev.context))
if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
@ -80,12 +74,12 @@ class CUDAGraph(MultiDeviceJITGraph):
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
# Update var_vals in the c_args struct.
for j in self.jc_idxs_with_updatable_var_vals:
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
# Update launch dims in the kern_params struct.
for j in self.jc_idxs_with_updatable_launch_dims:
for j in self.jc_idx_with_updatable_launch_dims:
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals))
# Update graph nodes with the updated structs.

View file

@ -1,12 +1,12 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union, Tuple
from tinygrad.helpers import GraphException, init_c_var, round_up, colored
from tinygrad.helpers import GraphException, init_c_var, round_up
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.jit import MultiGraphRunner
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
@ -25,12 +25,9 @@ class VirtAQLQueue(AQLQueue):
self.packets_count += 1
self.available_packet_slots -= 1
class HSAGraph(MultiDeviceJITGraph):
class HSAGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
compiled_devices = set()
@ -111,10 +108,6 @@ class HSAGraph(MultiDeviceJITGraph):
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "HSA", *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
@ -130,12 +123,12 @@ class HSAGraph(MultiDeviceJITGraph):
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
# Update var_vals
for j in self.jc_idxs_with_updatable_var_vals:
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
# Update launch dims
for j in self.jc_idxs_with_updatable_launch_dims:
for j in self.jc_idx_with_updatable_launch_dims:
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.packets[j].workgroup_size_x = lc[0]
self.packets[j].workgroup_size_y = lc[1]

View file

@ -1,79 +1,75 @@
from typing import List, Any, Dict, cast, Optional
import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2, GraphException, colored
from tinygrad.device import Buffer, CompiledRunner, Runner
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledRunner
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.engine.jit import GraphRunner
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
from tinygrad.runtime.ops_metal import wait_check
class MetalGraph(Runner):
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
class MetalGraphRunner(GraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.device: MetalDevice = device
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) # noqa: E501
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
Metal.MTLResourceOptions(0))
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
all_resources = [self.int_buf] if len(var_vals) else []
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf] if len(self.vars) else []
for j,ji in enumerate(self.jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
icb_command.setComputePipelineState_(unwrap2(
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
for i,b in enumerate(ji.bufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
all_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.all_resources = dedup(all_resources)
self.command_buffer: Any = None
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), device.dname, *get_jit_stats(jit_cache))
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
Metal.MTLSize(*local_size))
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer
if wait:
wait_check(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
self.device.mtl_buffers_in_flight.append(command_buffer)
return None
return None

View file

@ -97,9 +97,9 @@ class MetalDevice(Compiled):
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.track_cross_buffer: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
from tinygrad.runtime.graph.metal import MetalGraphRunner
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
functools.partial(MetalProgram, self), MetalGraphRunner)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
self.mv_in_metal.clear()