add device name to device, all are constructed (#3221)

This commit is contained in:
George Hotz 2024-01-23 20:34:56 -08:00 committed by GitHub
commit 23b084e70a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 22 additions and 21 deletions

View file

@ -25,9 +25,7 @@ class _Device:
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0] # noqa: E501
if isinstance(ret, type): ret = ret(ix)
return ret
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
@ -189,8 +187,8 @@ class InterpretedASTRunner(JITRunner):
return et
class Interpreted:
def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
self.allocator, self.fxn_for_op = allocator, fxn_for_op
def __init__(self, device:str, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
self.dname, self.allocator, self.fxn_for_op = device, allocator, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@ -305,21 +303,20 @@ class CompiledASTRunner(JITRunner):
if local_size: lra['local_size'] = local_size
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit,
lra=lra, device=rawbufs[0].device, first_run=self.first_run)
lra=lra, device=self.device.dname, first_run=self.first_run)
self.first_run = False
return et
class Compiled:
def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None):
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \
allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey
def __init__(self, device:str, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, compiler_cachekey, runtime, graph=None):
self.dname, self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph, self.compiler_cachekey = \
device, allocator, linearizer_opts, renderer, compiler, runtime, graph, None if getenv("DISABLE_COMPILER_CACHE") else compiler_cachekey
def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledASTRunner:
assert self.compiler is not None, f"compiler is None, can't build {k.ast}"
k.linearize()
src = self.renderer(to_function_name(k.name), k.uops)
return CompiledASTRunner(k.ast, k.name, src, self, k.global_size, k.local_size)
return CompiledASTRunner(k.ast, k.name, self.renderer(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
def get_linearizer(self, ast:LazyOp) -> Linearizer:
if DEBUG >= 3:

View file

@ -22,5 +22,7 @@ class ClangProgram:
def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait)
ClangDevice = Compiled(MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False),
functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram)
class ClangDevice(Compiled):
def __init__(self, device:str):
super().__init__(device, MallocAllocator, LinearizerOptions("CLANG", supports_float4=False, has_local=False),
functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict")), compile_clang, "compile_clang", ClangProgram)

View file

@ -42,4 +42,5 @@ class NumpyAllocator(Allocator):
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)
class CPUDevice(Interpreted):
def __init__(self, device:str): super().__init__(device, NumpyAllocator(), numpy_fxn_for_op)

View file

@ -83,7 +83,7 @@ class CUDADevice(Compiled):
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
from tinygrad.runtime.graph.cuda import CUDAGraph
super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator,
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
LinearizerOptions("CUDA", supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
CUDARenderer, functools.partial(compile_cuda,arch=self.arch), f"compile_cuda_{self.arch}", functools.partial(CUDAProgram, self),
graph=CUDAGraph if not CUDACPU else None)

View file

@ -54,4 +54,4 @@ class DiskAllocator(Allocator):
dest[:] = src._buf()
class DiskDevice(Interpreted):
def __init__(self, device:str): super().__init__(DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), disk_fxn_for_op)

View file

@ -94,7 +94,7 @@ class CLDevice(Compiled):
self.pending_copyin: List[memoryview] = []
compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
super().__init__(CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
super().__init__(device, CLAllocator(self), LinearizerOptions("GPU"), OpenCLRenderer,
functools.partial(compile_cl, self), f"compile_cl_{compile_key}", functools.partial(CLProgram, self))
def synchronize(self):
check(cl.clFinish(self.queue))

View file

@ -104,7 +104,7 @@ class HIPDevice(Compiled):
self.pending_events: List[hip.hipEvent_t] = []
from tinygrad.runtime.graph.hip import HIPGraph
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
super().__init__(device, MallocAllocator if MOCKHIP else HIPAllocator(self), LinearizerOptions("HIP"), HIPRenderer,
functools.partial(compile_hip,arch=self.arch), f"compile_hip_{self.arch}", functools.partial(HIPProgram, self.device), HIPGraph)
def synchronize(self):
check(hip.hipSetDevice(self.device))

View file

@ -38,5 +38,5 @@ class LLVMDevice(Compiled):
backing_mod = llvm.parse_assembly(str())
backing_mod.triple = llvm.get_process_triple()
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
super().__init__(MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
super().__init__(device, MallocAllocator, LinearizerOptions("LLVM", supports_float4=False, has_local=False, has_shared=False),
uops_to_llvm_ir, functools.partial(compile_llvm, self), "compile_llvm", functools.partial(LLVMProgram, self))

View file

@ -78,7 +78,7 @@ class MetalDevice(Compiled):
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
super().__init__(device, MetalAllocator(self), LinearizerOptions("METAL"), MetalRenderer,
compile_metal_xcode if getenv("METAL_XCODE") else functools.partial(compile_metal, self.device), "compile_metal",
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
def synchronize(self):

View file

@ -42,4 +42,5 @@ class TorchAllocator(Allocator):
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op)
class TorchDevice(Interpreted):
def __init__(self, device:str): super().__init__(device, TorchAllocator(), torch_fxn_for_op)