Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
a29986f074 small changes and test fixes from kernel is call 2026-02-06 16:08:02 +08:00
6 changed files with 17 additions and 10 deletions

View file

@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
vec = src[i, j].contract(j)
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
def slice_sum_kernel(dest:UOp, src:UOp):
G = UOp.range(src.shape[0], 0)

View file

@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.rangeify import Kernel
from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass

View file

@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, DTYPES_DICT
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, fetch, OSX
from tinygrad.helpers import Timing, fetch, OSX, dedup
from test.helpers import slow
class TempDirTestCase(unittest.TestCase):
@ -172,7 +172,7 @@ class TestSafetensors(TempDirTestCase):
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
def test_save_all_dtypes(self):
for dtype in DTYPES_DICT.values():
for dtype in dedup(DTYPES_DICT.values()):
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
if not is_dtype_supported(dtype): continue
path = self.tmp(f"ones.{dtype}.safetensors")

View file

@ -161,7 +161,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx = IndexingContext()
# get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
# get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"):

View file

@ -68,7 +68,10 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
def resolve_call(c:UOp) -> UOp:
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
@ -515,7 +518,8 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel_arg = Kernel(ret, metadata)
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
@ -579,7 +583,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}

View file

@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
for u in lst:
ret[u] = {}
for s in u.src: ret[s][u] = None
for s in u.src:
if s in ret: ret[s][u] = None
return ret
def pretty_print(x:UOp, cache=None, d=0)->str:
@ -310,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def ended_ranges(self):
if self.op in range_start: return self.src[range_start[self.op]:]
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
return ()
# determine what ranges this is in
@ -819,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)