mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
call_is_ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a29986f074 |
6 changed files with 17 additions and 10 deletions
|
|
@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
|
|||
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
||||
vec = src[i, j].contract(j)
|
||||
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
|
||||
def slice_sum_kernel(dest:UOp, src:UOp):
|
||||
G = UOp.range(src.shape[0], 0)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
|
|||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
|
|||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, DTYPES_DICT
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, fetch, OSX
|
||||
from tinygrad.helpers import Timing, fetch, OSX, dedup
|
||||
from test.helpers import slow
|
||||
|
||||
class TempDirTestCase(unittest.TestCase):
|
||||
|
|
@ -172,7 +172,7 @@ class TestSafetensors(TempDirTestCase):
|
|||
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
||||
|
||||
def test_save_all_dtypes(self):
|
||||
for dtype in DTYPES_DICT.values():
|
||||
for dtype in dedup(DTYPES_DICT.values()):
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
if not is_dtype_supported(dtype): continue
|
||||
path = self.tmp(f"ones.{dtype}.safetensors")
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
rctx = IndexingContext()
|
||||
|
||||
# get ops to realize
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
|
||||
|
||||
# get the consumer map
|
||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||
|
|
|
|||
|
|
@ -68,7 +68,10 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
|
|||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
|
@ -515,7 +518,8 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
||||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel_arg = Kernel(ret, metadata)
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
||||
|
|
@ -579,7 +583,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
|
|||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in lst:
|
||||
ret[u] = {}
|
||||
for s in u.src: ret[s][u] = None
|
||||
for s in u.src:
|
||||
if s in ret: ret[s][u] = None
|
||||
return ret
|
||||
|
||||
def pretty_print(x:UOp, cache=None, d=0)->str:
|
||||
|
|
@ -310,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def ended_ranges(self):
|
||||
if self.op in range_start: return self.src[range_start[self.op]:]
|
||||
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
|
||||
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
|
||||
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
|
||||
return ()
|
||||
|
||||
# determine what ranges this is in
|
||||
|
|
@ -819,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue