Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
a29986f074 small changes and test fixes from kernel is call 2026-02-06 16:08:02 +08:00
6 changed files with 17 additions and 10 deletions

View file

@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST) j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
vec = src[i, j].contract(j) vec = src[i, j].contract(j)
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)]) store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=())) return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
def slice_sum_kernel(dest:UOp, src:UOp): def slice_sum_kernel(dest:UOp, src:UOp):
G = UOp.range(src.shape[0], 0) G = UOp.range(src.shape[0], 0)

View file

@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.rangeify import Kernel
from tinygrad.engine.realize import CompiledRunner, run_schedule from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass class KernelCountException(Exception): pass

View file

@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, DTYPES_DICT from tinygrad.dtype import DType, DTYPES_DICT
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, OSX from tinygrad.helpers import Timing, fetch, OSX, dedup
from test.helpers import slow from test.helpers import slow
class TempDirTestCase(unittest.TestCase): class TempDirTestCase(unittest.TestCase):
@ -172,7 +172,7 @@ class TestSafetensors(TempDirTestCase):
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world' assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
def test_save_all_dtypes(self): def test_save_all_dtypes(self):
for dtype in DTYPES_DICT.values(): for dtype in dedup(DTYPES_DICT.values()):
if dtype in [dtypes.bfloat16]: continue # not supported in numpy if dtype in [dtypes.bfloat16]: continue # not supported in numpy
if not is_dtype_supported(dtype): continue if not is_dtype_supported(dtype): continue
path = self.tmp(f"ones.{dtype}.safetensors") path = self.tmp(f"ones.{dtype}.safetensors")

View file

@ -161,7 +161,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx = IndexingContext() rctx = IndexingContext()
# get ops to realize # get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize") graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
# get the consumer map # get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"): with cpu_profile("consumer map in rangeify", "TINY"):

View file

@ -68,7 +68,10 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)] placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders))) return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
def resolve_call(c:UOp) -> UOp: def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:] args = c.src[1:]
# TODO: this check belongs in spec, not here # TODO: this check belongs in spec, not here
@ -515,7 +518,8 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts)) else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel_arg = Kernel(ret, metadata)
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
@ -579,7 +583,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# bufferize -> store # bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1 lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store") tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels") tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {} kernel_assign: dict[UOp, UOp] = {}

View file

@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {} ret: dict[UOp, dict[UOp, None]] = {}
for u in lst: for u in lst:
ret[u] = {} ret[u] = {}
for s in u.src: ret[s][u] = None for s in u.src:
if s in ret: ret[s][u] = None
return ret return ret
def pretty_print(x:UOp, cache=None, d=0)->str: def pretty_print(x:UOp, cache=None, d=0)->str:
@ -310,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def ended_ranges(self): def ended_ranges(self):
if self.op in range_start: return self.src[range_start[self.op]:] if self.op in range_start: return self.src[range_start[self.op]:]
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]])) if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
return () return ()
# determine what ranges this is in # determine what ranges this is in
@ -819,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.PARAM, dtype, src, arg=slot) return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp: def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata)) return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)