minor hip cleanups (#3237)

This commit is contained in:
George Hotz 2024-01-24 15:13:38 -08:00 committed by GitHub
commit a8fbb03438
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 16 deletions

View file

@ -196,12 +196,8 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op is LoadOps.COPY:
op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0]), list(out.srcs)
elif out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT}:
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
elif out.op is LoadOps.EMPTY:
op = LazyOp(LoadOps.EMPTY)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})

View file

@ -25,19 +25,17 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op in {LoadOps.COPY, LoadOps.WAIT}, \
f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
if si.ast.op is LoadOps.EMPTY: return None
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
from tinygrad.runtime.ops_hip import HIPSyncEvent, HIPWaitEvent
if si.ast.op is LoadOps.SYNC: return HIPSyncEvent(si.out)
if si.ast.op is LoadOps.WAIT: return HIPWaitEvent(si.out.device)
if si.ast.op is LoadOps.COPY:
if hasattr(Device[si.out.device].allocator, 'transfer') and type(Device[si.out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
if si.inputs[0].device.startswith("DISK"): return BufferRead()
return BufferCopy()
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
# TODO: this doesn't have to be only HIP, check if it has the event functions
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
from tinygrad.runtime.ops_hip import SyncEvent, WaitEvent
if si.ast.op is LoadOps.SYNC: return SyncEvent(si.out)
if si.ast.op is LoadOps.WAIT: return WaitEvent(si.out.device)
else:
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
if si.ast.op is LoadOps.WAIT: return None
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
if si.ast.op is LoadOps.WAIT: return None
return Device[si.out.device].get_runner(si.ast)
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None

View file

@ -1,6 +1,6 @@
from __future__ import annotations
import ctypes, functools, subprocess, io
from typing import Tuple, TypeVar, List, Any, cast
from typing import Tuple, TypeVar, List, Any, cast, Set
import gpuctypes.hip as hip
from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
from tinygrad.helpers import from_mv, round_up, to_mv, colored
@ -107,6 +107,7 @@ class HIPDevice(Compiled):
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() if not MOCKHIP else "gfx1100" # noqa: E501
self.pending_copyin: List[hip.hipDeviceptr_t] = []
self.track_cross_buffer: List[Any] = []
self.peers: Set[int] = set()
from tinygrad.runtime.graph.hip import HIPGraph
super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
@ -117,8 +118,13 @@ class HIPDevice(Compiled):
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
self.track_cross_buffer.clear()
self.pending_copyin.clear()
def enable_peer(self, dnum):
if self.device == dnum or dnum in self.peers: return
check(hip.hipSetDevice(self.device))
check(hip.hipDeviceEnablePeerAccess(dnum, 0))
self.peers.add(dnum)
class SyncEvent(JITRunner):
class HIPSyncEvent(JITRunner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
super().__init__()
@ -128,7 +134,7 @@ class SyncEvent(JITRunner):
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)
class WaitEvent(JITRunner):
class HIPWaitEvent(JITRunner):
def __init__(self, device):
self.device, self.dname = cast(HIPDevice, Device[device]), device
super().__init__()