mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
bring hip graph back (#2385)
* bring hip graph back * share with metal * fix linter * remove hasattrs * Update ops_hip.py * hip wrapper does not use _buf --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
46b05daf7c
commit
e68aebfff9
4 changed files with 104 additions and 63 deletions
|
|
@ -177,37 +177,16 @@ try:
|
|||
c_struct: Any
|
||||
context: Any = None
|
||||
|
||||
# Better to cache struct_types since they reused often and take a lot of time to create.
|
||||
struct_type_cache: Dict[str, Any] = {}
|
||||
def __get_struct(name, field_list):
|
||||
global struct_type_cache
|
||||
if name in struct_type_cache:
|
||||
return struct_type_cache[name]
|
||||
def getCStructForType(argtypes):
|
||||
fields = []
|
||||
for j,typ in enumerate(argtypes):
|
||||
fields.append((f'field{j}', typ))
|
||||
|
||||
class CStructure(ctypes.Structure):
|
||||
_fields_ = field_list
|
||||
struct_type_cache[name] = CStructure
|
||||
return struct_type_cache[name]
|
||||
_fields_ = fields
|
||||
return CStructure
|
||||
|
||||
def getStructTypeForArgs(*args):
|
||||
types = ""
|
||||
fields: List[Tuple[str, Any]] = []
|
||||
for idx in range(len(args)):
|
||||
if args[idx].__class__ is int:
|
||||
types += 'i'
|
||||
fields.append((f'field{idx}', ctypes.c_int))
|
||||
else:
|
||||
types += 'P'
|
||||
fields.append((f'field{idx}', ctypes.c_void_p))
|
||||
return __get_struct(types, fields)
|
||||
|
||||
def updateKernelNodeParams(npwrapper:kernelNodeParamsWrapper, *args, grid=(1,1,1), block=(1,1,1), updated_args=None):
|
||||
_, struct, _ = npwrapper.context
|
||||
if updated_args is not None:
|
||||
for i in updated_args:
|
||||
setattr(struct, f'field{i}', (args[i] if args[i].__class__ is int else args[i]._buf))
|
||||
else:
|
||||
for i,d in enumerate(args):
|
||||
setattr(struct, f'field{i}', (d if d.__class__ is int else d._buf))
|
||||
def setKernelNodeLaunchDims(npwrapper:kernelNodeParamsWrapper, grid, block):
|
||||
npwrapper.c_struct.blockDimX = block[0]
|
||||
npwrapper.c_struct.blockDimY = block[1]
|
||||
npwrapper.c_struct.blockDimZ = block[2]
|
||||
|
|
@ -215,10 +194,13 @@ try:
|
|||
npwrapper.c_struct.gridDimY = grid[1]
|
||||
npwrapper.c_struct.gridDimZ = grid[2]
|
||||
|
||||
def buildKernelNodeParams(*args, func=None, grid=(1,1,1), block=(1,1,1), sharedMemBytes=0, argsStructType=None):
|
||||
data = [d if d.__class__ is int else d._buf for d in args]
|
||||
if argsStructType is None: argsStructType = getStructTypeForArgs(*args)
|
||||
struct = argsStructType(*data)
|
||||
def setKernelNodeParams(npwrapper:kernelNodeParamsWrapper, args, ids):
|
||||
for j,i in enumerate(ids):
|
||||
setattr(npwrapper.context[1], f'field{i}', args[j])
|
||||
|
||||
def buildKernelNodeParams(args, argtypes, func, grid, block, sharedMemBytes=0):
|
||||
c_struct_t = getCStructForType(argtypes)
|
||||
struct = c_struct_t(*args)
|
||||
size = ctypes.c_size_t(ctypes.sizeof(struct))
|
||||
p_size = ctypes.c_void_p(ctypes.addressof(size))
|
||||
p_struct = ctypes.c_void_p(ctypes.addressof(struct))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv
|
||||
from tinygrad.ops import RawBuffer, Device, JITRunner
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
|
||||
from tinygrad.ops import RawBuffer, Device, JITRunner, CompiledASTRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node
|
||||
|
|
@ -24,6 +24,10 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]
|
|||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
return input_replace
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))]
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
||||
|
||||
class GraphException(Exception): pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
import numpy as np
|
||||
import ctypes
|
||||
import extra.hip_wrapper as hip
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List, Any, Dict, cast, Optional, Callable
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
|
||||
from tinygrad.renderer.hip import HIPRenderer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer, RawBuffer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException
|
||||
|
||||
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
|
||||
if DEBUG >= 6:
|
||||
|
|
@ -48,9 +50,22 @@ def compile_hip(prg) -> bytes:
|
|||
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
|
||||
return hip.hiprtcGetCode(prog)
|
||||
|
||||
def time_execution(cb, enable=False):
|
||||
if enable:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
cb()
|
||||
if enable:
|
||||
hip.hipEventRecord(end)
|
||||
hip.hipEventSynchronize(end)
|
||||
ret = hip.hipEventElapsedTime(start, end)*1e-3
|
||||
hip.hipEventDestroy(start)
|
||||
hip.hipEventDestroy(end)
|
||||
return ret
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, name:str, prg:bytes):
|
||||
self.modules, self.prgs = [], []
|
||||
self.modules, self.prgs, self.c_struct_t = [], [], None
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
|
||||
|
|
@ -63,20 +78,63 @@ class HIPProgram:
|
|||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
|
||||
hip.hipSetDevice(args[0]._device)
|
||||
if wait:
|
||||
start, end = hip.hipEventCreate(), hip.hipEventCreate()
|
||||
hip.hipEventRecord(start)
|
||||
struct = hip.getStructTypeForArgs(*args)(*[data._buf if not isinstance(data, int) else np.int32(data) for data in args])
|
||||
hip.hipModuleLaunchKernel(self.prgs[args[0]._device], global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
|
||||
if wait:
|
||||
hip.hipEventRecord(end)
|
||||
hip.hipEventSynchronize(end)
|
||||
ret = hip.hipEventElapsedTime(start, end)*1e-3
|
||||
hip.hipEventDestroy(start)
|
||||
hip.hipEventDestroy(end)
|
||||
return ret
|
||||
if self.c_struct_t is None: self.c_struct_t = hip.getCStructForType([(ctypes.c_void_p if not isinstance(x, int) else ctypes.c_int) for x in args])
|
||||
c_params = cast(Callable, self.c_struct_t)(*[x._buf if not isinstance(x, int) else x for x in args])
|
||||
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prgs[args[0]._device], *global_size, *local_size, 0, 0, c_params), enable=wait)
|
||||
|
||||
def __del__(self):
|
||||
for module in self.modules: hip.hipModuleUnload(module)
|
||||
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize)
|
||||
class HIPGraph:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
|
||||
# TODO: Only HIPProgram can be captured for now.
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) and isinstance(ji.prg.clprg, HIPProgram) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
|
||||
self.graph, graph_node = hip.hipGraphCreate(), None
|
||||
self.updatable_nodes: Dict[int, Tuple[Any, hip.kernelNodeParamsWrapper]] = {} # Dict[jc index] = tuple(graph_node, node_params)
|
||||
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
assert all(x is not None for x in ji.rawbufs) and ji.rawbufs[0] is not None, "buffers could not be None" # for linters
|
||||
|
||||
args = [cast(RawBuffer, x)._buf for x in ji.rawbufs] + [var_vals[x] for x in prg.vars]
|
||||
types = [ctypes.c_void_p] * len(ji.rawbufs) + [ctypes.c_int] * len(prg.vars)
|
||||
c_params = hip.buildKernelNodeParams(args, types, prg.clprg.prgs[ji.rawbufs[0]._device], *prg.launch_dims(var_vals))
|
||||
graph_node = hip.hipGraphAddKernelNode(self.graph, [graph_node] if graph_node else [], c_params)
|
||||
|
||||
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
|
||||
self.updatable_nodes[j] = (graph_node, c_params)
|
||||
|
||||
self.instance = hip.hipGraphInstantiate(self.graph)
|
||||
|
||||
def __del__(self):
|
||||
hip.hipGraphExecDestroy(self.instance)
|
||||
hip.hipGraphDestroy(self.graph)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# Update cached params structs with the new values.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hip.setKernelNodeParams(self.updatable_nodes[j][1], [input_rawbuffers[input_idx]._buf], [i])
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
hip.setKernelNodeLaunchDims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, self.jit_cache[j].prg)
|
||||
hip.setKernelNodeParams(self.updatable_nodes[j][1], [var_vals[x] for x in prg.vars], list(range(len(self.jit_cache[j].rawbufs), len(self.jit_cache[j].rawbufs) + len(prg.vars))))
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
for node, params in self.updatable_nodes.values():
|
||||
hip.hipGraphExecKernelNodeSetParams(self.instance, node, params)
|
||||
|
||||
et = time_execution(lambda: hip.hipGraphLaunch(self.instance), enable=wait)
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
|
||||
return et
|
||||
|
||||
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, graph=HIPGraph)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, libdispatch
|
||||
from typing import List, Any, Tuple, Dict, Set, cast, Optional
|
||||
from typing import List, Any, Tuple, Dict, cast, Optional
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
|
||||
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, GraphException
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs):
|
||||
|
|
@ -88,6 +88,7 @@ class MetalGraph:
|
|||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
|
|
@ -99,7 +100,6 @@ class MetalGraph:
|
|||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
|
||||
self.input_has_variable_dims: Set[int] = set()
|
||||
read_resources, write_resources = [], []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
|
|
@ -117,11 +117,8 @@ class MetalGraph:
|
|||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(prg.vars):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
assert prg.global_size and prg.local_size, "need global and local size to JIT"
|
||||
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
|
||||
self.input_has_variable_dims.add(j)
|
||||
else:
|
||||
if j not in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
|
||||
|
|
@ -134,7 +131,7 @@ class MetalGraph:
|
|||
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.input_has_variable_dims:
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
self.int_buf_view[:] = list(var_vals.values())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue