mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
minor hip cleanups (#3237)
This commit is contained in:
parent
3205fd8481
commit
a8fbb03438
3 changed files with 16 additions and 16 deletions
|
|
@ -196,12 +196,8 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
|||
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op is LoadOps.COPY:
|
||||
op, inputs = LazyOp(LoadOps.COPY, (), out.srcs[0]), list(out.srcs)
|
||||
elif out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT}:
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
elif out.op is LoadOps.EMPTY:
|
||||
op = LazyOp(LoadOps.EMPTY)
|
||||
else:
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
||||
|
|
|
|||
|
|
@ -25,19 +25,17 @@ def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
|||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op in {LoadOps.COPY, LoadOps.WAIT}, \
|
||||
f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
if si.ast.op is LoadOps.EMPTY: return None
|
||||
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIPSyncEvent, HIPWaitEvent
|
||||
if si.ast.op is LoadOps.SYNC: return HIPSyncEvent(si.out)
|
||||
if si.ast.op is LoadOps.WAIT: return HIPWaitEvent(si.out.device)
|
||||
if si.ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[si.out.device].allocator, 'transfer') and type(Device[si.out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
|
||||
if si.inputs[0].device.startswith("DISK"): return BufferRead()
|
||||
return BufferCopy()
|
||||
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
|
||||
# TODO: this doesn't have to be only HIP, check if it has the event functions
|
||||
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
|
||||
from tinygrad.runtime.ops_hip import SyncEvent, WaitEvent
|
||||
if si.ast.op is LoadOps.SYNC: return SyncEvent(si.out)
|
||||
if si.ast.op is LoadOps.WAIT: return WaitEvent(si.out.device)
|
||||
else:
|
||||
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
|
||||
if si.ast.op is LoadOps.WAIT: return None
|
||||
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
|
||||
if si.ast.op is LoadOps.WAIT: return None
|
||||
return Device[si.out.device].get_runner(si.ast)
|
||||
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io
|
||||
from typing import Tuple, TypeVar, List, Any, cast
|
||||
from typing import Tuple, TypeVar, List, Any, cast, Set
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored
|
||||
|
|
@ -107,6 +107,7 @@ class HIPDevice(Compiled):
|
|||
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() if not MOCKHIP else "gfx1100" # noqa: E501
|
||||
self.pending_copyin: List[hip.hipDeviceptr_t] = []
|
||||
self.track_cross_buffer: List[Any] = []
|
||||
self.peers: Set[int] = set()
|
||||
|
||||
from tinygrad.runtime.graph.hip import HIPGraph
|
||||
super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
|
|
@ -117,8 +118,13 @@ class HIPDevice(Compiled):
|
|||
for opaque in self.pending_copyin: check(hip.hipFree(opaque))
|
||||
self.track_cross_buffer.clear()
|
||||
self.pending_copyin.clear()
|
||||
def enable_peer(self, dnum):
|
||||
if self.device == dnum or dnum in self.peers: return
|
||||
check(hip.hipSetDevice(self.device))
|
||||
check(hip.hipDeviceEnablePeerAccess(dnum, 0))
|
||||
self.peers.add(dnum)
|
||||
|
||||
class SyncEvent(JITRunner):
|
||||
class HIPSyncEvent(JITRunner):
|
||||
def __init__(self, lb):
|
||||
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
|
||||
super().__init__()
|
||||
|
|
@ -128,7 +134,7 @@ class SyncEvent(JITRunner):
|
|||
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, device=self.dname)
|
||||
|
||||
class WaitEvent(JITRunner):
|
||||
class HIPWaitEvent(JITRunner):
|
||||
def __init__(self, device):
|
||||
self.device, self.dname = cast(HIPDevice, Device[device]), device
|
||||
super().__init__()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue