bring hip graph back (#2385)

* bring hip graph back

* share with metal

* fix linter

* remove hasattrs

* Update ops_hip.py

* hip wrapper does not use _buf

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
nimlgen 2023-11-24 18:53:44 +03:00 committed by GitHub
commit e68aebfff9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 63 deletions

View file

@ -177,37 +177,16 @@ try:
c_struct: Any
context: Any = None
# Better to cache struct_types since they reused often and take a lot of time to create.
struct_type_cache: Dict[str, Any] = {}
def __get_struct(name, field_list):
global struct_type_cache
if name in struct_type_cache:
return struct_type_cache[name]
def getCStructForType(argtypes):
fields = []
for j,typ in enumerate(argtypes):
fields.append((f'field{j}', typ))
class CStructure(ctypes.Structure):
_fields_ = field_list
struct_type_cache[name] = CStructure
return struct_type_cache[name]
_fields_ = fields
return CStructure
def getStructTypeForArgs(*args):
types = ""
fields: List[Tuple[str, Any]] = []
for idx in range(len(args)):
if args[idx].__class__ is int:
types += 'i'
fields.append((f'field{idx}', ctypes.c_int))
else:
types += 'P'
fields.append((f'field{idx}', ctypes.c_void_p))
return __get_struct(types, fields)
def updateKernelNodeParams(npwrapper:kernelNodeParamsWrapper, *args, grid=(1,1,1), block=(1,1,1), updated_args=None):
_, struct, _ = npwrapper.context
if updated_args is not None:
for i in updated_args:
setattr(struct, f'field{i}', (args[i] if args[i].__class__ is int else args[i]._buf))
else:
for i,d in enumerate(args):
setattr(struct, f'field{i}', (d if d.__class__ is int else d._buf))
def setKernelNodeLaunchDims(npwrapper:kernelNodeParamsWrapper, grid, block):
npwrapper.c_struct.blockDimX = block[0]
npwrapper.c_struct.blockDimY = block[1]
npwrapper.c_struct.blockDimZ = block[2]
@ -215,10 +194,13 @@ try:
npwrapper.c_struct.gridDimY = grid[1]
npwrapper.c_struct.gridDimZ = grid[2]
def buildKernelNodeParams(*args, func=None, grid=(1,1,1), block=(1,1,1), sharedMemBytes=0, argsStructType=None):
data = [d if d.__class__ is int else d._buf for d in args]
if argsStructType is None: argsStructType = getStructTypeForArgs(*args)
struct = argsStructType(*data)
def setKernelNodeParams(npwrapper:kernelNodeParamsWrapper, args, ids):
for j,i in enumerate(ids):
setattr(npwrapper.context[1], f'field{i}', args[j])
def buildKernelNodeParams(args, argtypes, func, grid, block, sharedMemBytes=0):
c_struct_t = getCStructForType(argtypes)
struct = c_struct_t(*args)
size = ctypes.c_size_t(ctypes.sizeof(struct))
p_size = ctypes.c_void_p(ctypes.addressof(size))
p_struct = ctypes.c_void_p(ctypes.addressof(struct))

View file

@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv
from tinygrad.ops import RawBuffer, Device, JITRunner
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
from tinygrad.ops import RawBuffer, Device, JITRunner, CompiledASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node
@ -24,6 +24,10 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]
input_replace[(j,i)] = input_rawbuffers.index(a)
assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found"
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))]
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
class GraphException(Exception): pass

View file

@ -1,12 +1,14 @@
import numpy as np
import ctypes
import extra.hip_wrapper as hip
from typing import Tuple
from typing import Tuple, List, Any, Dict, cast, Optional, Callable
from tinygrad.helpers import DEBUG, getenv, diskcache
from tinygrad.ops import Compiled
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
from tinygrad.renderer.hip import HIPRenderer
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer, RawBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
if DEBUG >= 6:
@ -48,9 +50,22 @@ def compile_hip(prg) -> bytes:
hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}'])
return hip.hiprtcGetCode(prog)
def time_execution(cb, enable=False):
if enable:
start, end = hip.hipEventCreate(), hip.hipEventCreate()
hip.hipEventRecord(start)
cb()
if enable:
hip.hipEventRecord(end)
hip.hipEventSynchronize(end)
ret = hip.hipEventElapsedTime(start, end)*1e-3
hip.hipEventDestroy(start)
hip.hipEventDestroy(end)
return ret
class HIPProgram:
def __init__(self, name:str, prg:bytes):
self.modules, self.prgs = [], []
self.modules, self.prgs, self.c_struct_t = [], [], None
if DEBUG >= 6:
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
@ -63,20 +78,63 @@ class HIPProgram:
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
hip.hipSetDevice(args[0]._device)
if wait:
start, end = hip.hipEventCreate(), hip.hipEventCreate()
hip.hipEventRecord(start)
struct = hip.getStructTypeForArgs(*args)(*[data._buf if not isinstance(data, int) else np.int32(data) for data in args])
hip.hipModuleLaunchKernel(self.prgs[args[0]._device], global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct)
if wait:
hip.hipEventRecord(end)
hip.hipEventSynchronize(end)
ret = hip.hipEventElapsedTime(start, end)*1e-3
hip.hipEventDestroy(start)
hip.hipEventDestroy(end)
return ret
if self.c_struct_t is None: self.c_struct_t = hip.getCStructForType([(ctypes.c_void_p if not isinstance(x, int) else ctypes.c_int) for x in args])
c_params = cast(Callable, self.c_struct_t)(*[x._buf if not isinstance(x, int) else x for x in args])
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prgs[args[0]._device], *global_size, *local_size, 0, 0, c_params), enable=wait)
def __del__(self):
for module in self.modules: hip.hipModuleUnload(module)
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize)
class HIPGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
# TODO: Only HIPProgram can be captured for now.
if not all(isinstance(ji.prg, CompiledASTRunner) and isinstance(ji.prg.clprg, HIPProgram) for ji in jit_cache): raise GraphException
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
self.graph, graph_node = hip.hipGraphCreate(), None
self.updatable_nodes: Dict[int, Tuple[Any, hip.kernelNodeParamsWrapper]] = {} # Dict[jc index] = tuple(graph_node, node_params)
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
assert all(x is not None for x in ji.rawbufs) and ji.rawbufs[0] is not None, "buffers could not be None" # for linters
args = [cast(RawBuffer, x)._buf for x in ji.rawbufs] + [var_vals[x] for x in prg.vars]
types = [ctypes.c_void_p] * len(ji.rawbufs) + [ctypes.c_int] * len(prg.vars)
c_params = hip.buildKernelNodeParams(args, types, prg.clprg.prgs[ji.rawbufs[0]._device], *prg.launch_dims(var_vals))
graph_node = hip.hipGraphAddKernelNode(self.graph, [graph_node] if graph_node else [], c_params)
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
self.updatable_nodes[j] = (graph_node, c_params)
self.instance = hip.hipGraphInstantiate(self.graph)
def __del__(self):
hip.hipGraphExecDestroy(self.instance)
hip.hipGraphDestroy(self.graph)
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Update cached params structs with the new values.
for (j,i),input_idx in self.input_replace.items():
hip.setKernelNodeParams(self.updatable_nodes[j][1], [input_rawbuffers[input_idx]._buf], [i])
for j in self.jc_idxs_with_updatable_launch_dims:
hip.setKernelNodeLaunchDims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
for j in self.jc_idxs_with_updatable_var_vals:
prg: CompiledASTRunner = cast(CompiledASTRunner, self.jit_cache[j].prg)
hip.setKernelNodeParams(self.updatable_nodes[j][1], [var_vals[x] for x in prg.vars], list(range(len(self.jit_cache[j].rawbufs), len(self.jit_cache[j].rawbufs) + len(prg.vars))))
# Update graph nodes with the updated structs.
for node, params in self.updatable_nodes.values():
hip.hipGraphExecKernelNodeSetParams(self.instance, node, params)
et = time_execution(lambda: hip.hipGraphLaunch(self.instance), enable=wait)
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, graph=HIPGraph)

View file

@ -1,13 +1,13 @@
import os, subprocess, pathlib, ctypes, tempfile
import Metal, libdispatch
from typing import List, Any, Tuple, Dict, Set, cast, Optional
from typing import List, Any, Tuple, Dict, cast, Optional
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.ops import Compiled, CompiledASTRunner, update_stats
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, GraphException
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
@ -88,6 +88,7 @@ class MetalGraph:
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
@ -99,7 +100,6 @@ class MetalGraph:
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
self.input_has_variable_dims: Set[int] = set()
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
@ -117,11 +117,8 @@ class MetalGraph:
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = prg.launch_dims(var_vals)
assert prg.global_size and prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
self.input_has_variable_dims.add(j)
else:
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
@ -134,7 +131,7 @@ class MetalGraph:
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.input_has_variable_dims:
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())