mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add device name to device, all are constructed (#3221)
This commit is contained in:
parent
91a1b2bd7a
commit
23b084e70a
10 changed files with 22 additions and 21 deletions
|
|
@ -25,9 +25,7 @@ class _Device:
|
|||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
|
||||
x = ix.split(":")[0].upper()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0] # noqa: E501
|
||||
if isinstance(ret, type): ret = ret(ix)
|
||||
return ret
|
||||
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
|
||||
|
|
@ -189,8 +187,8 @@ class InterpretedASTRunner(JITRunner):
|
|||
return et
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
|
||||
self.allocator, self.fxn_for_op = allocator, fxn_for_op
|
||||
def __init__(self, device:str, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
|
||||
self.dname, self.allocator, self.fxn_for_op = device, allocator, fxn_for_op
|
||||
self.synchronize, self.codegen, self.graph = lambda: None, None, None
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
|
|
@ -305,21 +303,20 @@ class CompiledASTRunner(JITRunner):
|
|||
if local_size: lra['local_size'] = local_size
|
||||
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
|
||||
if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit,
|
||||
lra=lra, device=rawbufs[0].device, first_run=self.first_run)
|
||||
lra=lra, device=self.device.dname, first_run=self.first_run)
|
||||
self.first_run = False
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None):
|
||||
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \
|
||||
allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey
|
||||
def __init__(self, device:str, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None):
|
||||
self.dname, self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \
|
||||
device, allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey
|
||||
def synchronize(self): pass # override this in your device
|
||||
|
||||
def to_program(self, k:Linearizer) -> CompiledASTRunner:
|
||||
assert self.compiler is not None, f"compiler is None, can't build {k.ast}"
|
||||
k.linearize()
|
||||
src = self.renderer(to_function_name(k.name), k.uops)
|
||||
return CompiledASTRunner(k.ast, k.name, src, self, k.global_size, k.local_size)
|
||||
return CompiledASTRunner(k.ast, k.name, self.renderer(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
|
||||
|
||||
def get_linearizer(self, ast:LazyOp) -> Linearizer:
|
||||
if DEBUG >= 3:
|
||||
|
|
|
|||
|
|
@ -22,5 +22,7 @@ class ClangProgram:
|
|||
|
||||
def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
|
||||
|
||||
ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False),
|
||||
functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram)
|
||||
class ClangDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
super().__init__(device, MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False),
|
||||
functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram)
|
||||
|
|
|
|||
|
|
@ -42,4 +42,5 @@ class NumpyAllocator(Allocator):
|
|||
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
|
||||
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
|
||||
|
||||
CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)
|
||||
class CPUDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(device, NumpyAllocator(), numpy_fxn_for_op)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class CUDADevice(Compiled):
|
|||
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
|
||||
CUDARenderer, functools.partial(compile_cuda,arch=self.arch), f"compile_cuda_{self.arch}", functools.partial(CUDAProgram, self),
|
||||
graph=CUDAGraph if not CUDACPU else None)
|
||||
|
|
|
|||
|
|
@ -54,4 +54,4 @@ class DiskAllocator(Allocator):
|
|||
dest[:] = src._buf()
|
||||
|
||||
class DiskDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)
|
||||
|
|
@ -94,7 +94,7 @@ class CLDevice(Compiled):
|
|||
self.pending_copyin: List[memoryview] = []
|
||||
|
||||
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
|
||||
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
|
||||
super().__init__(device, CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
|
||||
functools.partial(compile_cl, self), f"compile_cl_{compile_key}", functools.partial(CLProgram, self))
|
||||
def synchronize(self):
|
||||
check(cl.clFinish(self.queue))
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class HIPDevice(Compiled):
|
|||
self.pending_events: List[hip.hipEvent_t] = []
|
||||
|
||||
from tinygrad.runtime.graph.hip import HIPGraph
|
||||
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
|
||||
functools.partial(compile_hip,arch=self.arch), f"compile_hip_{self.arch}", functools.partial(HIPProgram, self.device), HIPGraph)
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device))
|
||||
|
|
|
|||
|
|
@ -38,5 +38,5 @@ class LLVMDevice(Compiled):
|
|||
backing_mod = llvm.parse_assembly(str())
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
|
||||
super().__init__(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
|
||||
super().__init__(device, MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
|
||||
uops_to_llvm_ir, functools.partial(compile_llvm, self), "compile_llvm", functools.partial(LLVMProgram, self))
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class MetalDevice(Compiled):
|
|||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
|
||||
super().__init__(device, MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
|
||||
compile_metal_xcode if getenv("METAL_XCODE") else functools.partial(compile_metal, self.device), "compile_metal",
|
||||
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
def synchronize(self):
|
||||
|
|
|
|||
|
|
@ -42,4 +42,5 @@ class TorchAllocator(Allocator):
|
|||
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
|
||||
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
|
||||
|
||||
TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op)
|
||||
class TorchDevice(Interpreted):
|
||||
def __init__(self, device:str): super().__init__(device, TorchAllocator(), torch_fxn_for_op)
|
||||
Loading…
Add table
Add a link
Reference in a new issue