Make dev a property of Allocator (#10286)

* Make `dev` a property of `Allocator`

(this is a prereq refactor for #10285)

At least `BufferXfer.copy` accesses it assuming it's always present,
currently most devices just add this property on their own repeating
the same code over and over again.

This is also a bit footguny, see `RemoteAllocator` that named this
property `device` instead of `dev`, i could obviously just change that
in one place but doing it globally seems like a better solution (and it
reduces code duplication too).

`MallocAllocator` is a bit special, but passing `None` works just fine.

* typing

* ignore type instead of cast
This commit is contained in:
uuuvn 2025-05-14 05:01:01 +05:00 committed by GitHub
commit 7bc4864bc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 52 additions and 69 deletions

View file

@ -4,7 +4,7 @@ from examples.llama import Transformer, MODEL_PARAMS
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.device import Allocator
from tinygrad.device import Allocator, Compiled
from tinygrad.engine.realize import method_cache
from tinygrad.helpers import Profiling
@ -12,7 +12,7 @@ class FakeProgram:
def __init__(self, name:str, prg:bytes): pass
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass
class FakeAllocator(Allocator):
class FakeAllocator(Allocator[Compiled]):
def _alloc(self, sz, options): return None
def _copyin(self, dest, src:memoryview): pass
@ -22,7 +22,7 @@ class TestLLaMASpeed(unittest.TestCase):
backup_allocator = Device[Device.DEFAULT].allocator
backup_compiler = Device[Device.DEFAULT].compiler
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].allocator = FakeAllocator()
Device[Device.DEFAULT].allocator = FakeAllocator(Device.default)
print("testing llama python run time")
model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"])

View file

@ -10,7 +10,7 @@ from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, unwrap, CI
from tinygrad.codegen import full_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
def derandomize_model(model):
for p in get_parameters(model):
@ -52,7 +52,7 @@ def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]:
return ret, (time.perf_counter_ns()-st)*1e-6
def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
allocator = PythonAllocator()
allocator = Device['PYTHON'].allocator
bufs = []
for buf_dt, data in inputs or []:
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Any, Iterator, Generator
from typing import Optional, Any, Generic, TypeVar, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE
@ -184,8 +184,11 @@ class Buffer:
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
return Buffer(self.device, size, dtype, base=self, offset=offset)
DeviceType = TypeVar('DeviceType', bound='Compiled')
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
class Allocator(Generic[DeviceType]):
def __init__(self, dev:DeviceType): self.dev: DeviceType = dev
# overridden in LRUAllocator
def alloc(self, size:int, options:Optional[BufferSpec]=None):
assert size > 0, f"alloc size must be positive, getting {size}"
@ -201,12 +204,14 @@ class Allocator:
# def _offset(self, buf, size:int, offset:int):
# def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
class LRUAllocator(Allocator):
class LRUAllocator(Allocator, Generic[DeviceType]):
"""
The LRU Allocator is responsible for caching buffers.
It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
"""
def __init__(self): self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
def __init__(self, dev:DeviceType):
self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
super().__init__(dev)
def alloc(self, size:int, options:Optional[BufferSpec]=None):
if len(c := self.cache[(size, options)]): return c.pop()
try: return super().alloc(size, options)
@ -221,7 +226,7 @@ class LRUAllocator(Allocator):
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
class _MallocAllocator(LRUAllocator['Compiled']):
def _alloc(self, size:int, options:BufferSpec):
# must be aligned to 0x20 for 256-bit ymm registers
# TODO: investigate if this is the cause of nondeterminism in speed
@ -236,7 +241,7 @@ class _MallocAllocator(LRUAllocator):
def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
MallocAllocator = _MallocAllocator()
MallocAllocator = _MallocAllocator(None) # type: ignore
# NOTE: MAP_JIT is added to mmap module in python 3.13
MAP_JIT = 0x0800

View file

@ -62,10 +62,7 @@ class CUDAProgram:
for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, self.smem, None, None, self.vargs)), enable=wait)
class CUDAAllocator(LRUAllocator):
def __init__(self, dev:CUDADevice):
self.dev = dev
super().__init__()
class CUDAAllocator(LRUAllocator['CUDADevice']):
def _alloc(self, size, options:BufferSpec):
check(cuda.cuCtxSetCurrent(self.dev.context))
if options.external_ptr: return cuda.CUdeviceptr_v2(options.external_ptr)

View file

@ -94,11 +94,7 @@ class DSPBuffer:
def __init__(self, va_addr:int, size:int, share_info, offset:int=0):
self.va_addr, self.size, self.share_info, self.offset = va_addr, size, share_info, offset
class DSPAllocator(Allocator):
def __init__(self, dev:DSPDevice):
self.dev = dev
super().__init__()
class DSPAllocator(Allocator['DSPDevice']):
def _alloc(self, size:int, options:BufferSpec):
b = qcom_dsp.ION_IOC_ALLOC(self.dev.ion_fd, len=size, align=0x200, heap_id_mask=1<<qcom_dsp.ION_SYSTEM_HEAP_ID, flags=qcom_dsp.ION_FLAG_CACHED)
share_info = qcom_dsp.ION_IOC_SHARE(self.dev.ion_fd, handle=b.handle)

View file

@ -58,10 +58,7 @@ class CLProgram:
return float(end.value-start.value) * OSX_TIMING_RATIO * 1e-9
return None
class CLAllocator(LRUAllocator):
def __init__(self, dev:CLDevice):
self.dev = dev
super().__init__()
class CLAllocator(LRUAllocator['CLDevice']):
def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]:
if options.image is not None:
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,

View file

@ -50,10 +50,7 @@ class HIPProgram:
check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), self.dev.time_event_st, self.dev.time_event_en))
return ret.value * 1e-3
class HIPAllocator(LRUAllocator):
def __init__(self, dev:HIPDevice):
self.dev = dev
super().__init__()
class HIPAllocator(LRUAllocator[HIPDevice]):
def _alloc(self, size:int, options:BufferSpec):
check(hip.hipSetDevice(self.dev.device_id))
return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))

View file

@ -189,10 +189,7 @@ class MetalProgram:
class MetalBuffer:
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
class MetalAllocator(LRUAllocator):
def __init__(self, dev:MetalDevice):
self.dev:MetalDevice = dev
super().__init__()
class MetalAllocator(LRUAllocator[MetalDevice]):
def _alloc(self, size:int, options) -> MetalBuffer:
if options.external_ptr: return MetalBuffer(objc_id(options.external_ptr), size)

View file

@ -2,9 +2,9 @@ import numpy as np
from tinygrad.helpers import flat_mv
from tinygrad.device import Compiled, Allocator
class NpyAllocator(Allocator):
class NpyAllocator(Allocator['NpyDevice']):
def _as_buffer(self, src:np.ndarray) -> memoryview: return flat_mv(np.require(src, requirements='C').data)
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = self._as_buffer(src)
class NpyDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)
def __init__(self, device:str): super().__init__(device, NpyAllocator(self), None, None, None)

View file

@ -13,8 +13,7 @@ class NullProgram:
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
return 1e-4
class NullAllocator(Allocator):
dev = None
class NullAllocator(Allocator['NullDevice']):
def _alloc(self, size, options): pass
def _copyin(self, dest, src:memoryview): pass
def _copyout(self, dest:memoryview, src): pass
@ -24,4 +23,4 @@ class NullGraph(MultiGraphRunner):
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3
class NullDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NullAllocator(), NullRenderer(), Compiler(), NullProgram, NullGraph)
def __init__(self, device:str): super().__init__(device, NullAllocator(self), NullRenderer(), Compiler(), NullProgram, NullGraph)

View file

@ -216,10 +216,10 @@ class PythonRenderer(Renderer):
class PythonCompiler(Compiler):
def compile(self, src:str) -> bytes: return base64.b64decode(src)
class PythonAllocator(Allocator):
class PythonAllocator(Allocator['PythonDevice']):
def _alloc(self, size, options): return memoryview(bytearray(size))
def _copyin(self, dest, src:memoryview): dest[:] = src
def _copyout(self, dest:memoryview, src): dest[:] = src
class PythonDevice(Compiled):
def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)

View file

@ -196,21 +196,18 @@ def remote_server(port:int):
# ***** frontend *****
class RemoteAllocator(Allocator):
def __init__(self, dev:RemoteDevice):
self.device = dev
super().__init__()
class RemoteAllocator(Allocator['RemoteDevice']):
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferSpec) -> int:
self.device.buffer_num += 1
self.device.q(BufferAlloc(self.device.buffer_num, size, options))
return self.device.buffer_num
self.dev.buffer_num += 1
self.dev.q(BufferAlloc(self.dev.buffer_num, size, options))
return self.dev.buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.device.q(BufferFree(opaque))
def _copyin(self, dest:int, src:memoryview): self.device.q(CopyIn(dest, self.device.conn.req.h(bytes(src))))
def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque))
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src))))
def _copyout(self, dest:memoryview, src:int):
self.device.q(CopyOut(src))
resp = self.device.conn.batch_submit()
self.dev.q(CopyOut(src))
resp = self.dev.conn.batch_submit()
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp

View file

@ -175,8 +175,7 @@ class WebGPUProgram:
return time
return None
class WebGpuAllocator(Allocator):
def __init__(self, dev:WGPUDevPtr): self.dev = dev
class WebGpuAllocator(Allocator['WGPUDevPtr']):
def _alloc(self, size:int, options:BufferSpec) -> WGPUBufPtr:
# WebGPU buffers have to be 4-byte aligned
return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),

View file

@ -51,7 +51,7 @@ if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockFileIOInterf
# **************** for HCQ Compatible Devices ****************
SignalType = TypeVar('SignalType', bound='HCQSignal')
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQCompiled')
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
QueueType = TypeVar('QueueType', bound='HWQueue')
@ -65,14 +65,14 @@ class BumpAllocator:
self.ptr = (res:=round_up(self.ptr, alignment)) + size
return res + self.base
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
"""
def __init__(self):
self._q:Any = []
self.binded_device:DeviceType|None = None
self.binded_device:HCQDeviceType|None = None
self.q_sints:list[tuple[int, int]] = []
self.mv_sints:list[tuple[MMIOInterface, int, int, int|None]] = []
self.syms:list[sint] = []
@ -158,7 +158,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
# *** submit and bind commands ***
def bind(self, dev:DeviceType):
def bind(self, dev:HCQDeviceType):
"""
Associates the queue with a specific device for optimized execution.
@ -197,7 +197,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
"""
Submits the command queue to a specific device for execution.
@ -208,15 +208,15 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
if var_vals is not None: self._apply_var_vals(var_vals)
self._submit(dev)
return self
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
class HCQSignal(Generic[DeviceType]):
def __init__(self, base_buf:HCQBuffer|None=None, value:int=0, dev_t:Type[DeviceType]|None=None, timeline_for_device:DeviceType|None=None,
class HCQSignal(Generic[HCQDeviceType]):
def __init__(self, base_buf:HCQBuffer|None=None, value:int=0, dev_t:Type[HCQDeviceType]|None=None, timeline_for_device:HCQDeviceType|None=None,
timestamp_divider=1, value_off=0, timestamp_off=8):
self.base_buf = cast(HCQBuffer, dev_t._alloc_signal() if dev_t is not None and base_buf is None else base_buf)
self.value_addr, self.timestamp_addr, self.dev_t = self.base_buf.va_addr+value_off, self.base_buf.va_addr+timestamp_off, dev_t
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
self.timeline_for_device:DeviceType|None = timeline_for_device
self.timeline_for_device:HCQDeviceType|None = timeline_for_device
if isinstance(self.base_buf.va_addr, int):
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(value_off, 8, 'Q'), self.base_buf.cpu_view().view(timestamp_off, 8, 'Q')
@ -296,8 +296,8 @@ class CLikeArgsState(HCQArgsState[ProgramType]):
self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4)
self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
class HCQProgram(Generic[DeviceType]):
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
class HCQProgram(Generic[HCQDeviceType]):
def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
if PROFILE: Compiled.profile_events += [ProfileProgramEvent(dev.device, name, lib, base)]
@ -429,23 +429,22 @@ class HCQBuffer:
assert self.view is not None, "buffer has no cpu_view"
return self.view
class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
class HCQAllocatorBase(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
"""
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
"""
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32, copy_bufs=None, max_copyout_size:int|None=None):
self.dev:DeviceType = dev
def __init__(self, dev:HCQDeviceType, batch_size:int=(2 << 20), batch_cnt:int=32, copy_bufs=None, max_copyout_size:int|None=None):
super().__init__(dev)
self.b = copy_bufs or [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
self.b_timeline, self.b_next, self.max_copyout_size = [0] * len(self.b), 0, max_copyout_size
super().__init__()
def map(self, buf:HCQBuffer): pass
def _offset(self, buf, size:int, offset:int) -> HCQBuffer: return buf.offset(offset=offset, size=size)
class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):
def _copyin(self, dest:HCQBuffer, src:memoryview):
assert self.dev.hw_copy_queue_t is not None
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
@ -488,7 +487,7 @@ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
self.dev.timeline_signal.wait(self.dev.timeline_value - 1)
dest[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQDeviceType, dest_dev:HCQDeviceType):
cast(HCQAllocator, src_dev.allocator).map(dest)
assert src_dev.hw_copy_queue_t is not None