mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
call_is_ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a29986f074 |
6 changed files with 17 additions and 10 deletions
|
|
@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
|
||||||
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
||||||
vec = src[i, j].contract(j)
|
vec = src[i, j].contract(j)
|
||||||
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
||||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||||
|
|
||||||
def slice_sum_kernel(dest:UOp, src:UOp):
|
def slice_sum_kernel(dest:UOp, src:UOp):
|
||||||
G = UOp.range(src.shape[0], 0)
|
G = UOp.range(src.shape[0], 0)
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
|
||||||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
from tinygrad.dtype import DType, ImageDType
|
from tinygrad.dtype import DType, ImageDType
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
|
||||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||||
from tinygrad.schedule.rangeify import Kernel
|
|
||||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||||
|
|
||||||
class KernelCountException(Exception): pass
|
class KernelCountException(Exception): pass
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
from tinygrad.dtype import DType, DTYPES_DICT
|
from tinygrad.dtype import DType, DTYPES_DICT
|
||||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||||
from tinygrad.helpers import Timing, fetch, OSX
|
from tinygrad.helpers import Timing, fetch, OSX, dedup
|
||||||
from test.helpers import slow
|
from test.helpers import slow
|
||||||
|
|
||||||
class TempDirTestCase(unittest.TestCase):
|
class TempDirTestCase(unittest.TestCase):
|
||||||
|
|
@ -172,7 +172,7 @@ class TestSafetensors(TempDirTestCase):
|
||||||
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
||||||
|
|
||||||
def test_save_all_dtypes(self):
|
def test_save_all_dtypes(self):
|
||||||
for dtype in DTYPES_DICT.values():
|
for dtype in dedup(DTYPES_DICT.values()):
|
||||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||||
if not is_dtype_supported(dtype): continue
|
if not is_dtype_supported(dtype): continue
|
||||||
path = self.tmp(f"ones.{dtype}.safetensors")
|
path = self.tmp(f"ones.{dtype}.safetensors")
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
rctx = IndexingContext()
|
rctx = IndexingContext()
|
||||||
|
|
||||||
# get ops to realize
|
# get ops to realize
|
||||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
|
||||||
|
|
||||||
# get the consumer map
|
# get the consumer map
|
||||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,10 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||||
|
|
||||||
def resolve_call(c:UOp) -> UOp:
|
def resolve_call(c:UOp) -> UOp|None:
|
||||||
|
# don't resolve real kernel calls, sink or program
|
||||||
|
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||||
|
if c.src[0].op is Ops.PROGRAM: return None
|
||||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||||
args = c.src[1:]
|
args = c.src[1:]
|
||||||
# TODO: this check belongs in spec, not here
|
# TODO: this check belongs in spec, not here
|
||||||
|
|
@ -515,7 +518,8 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||||
|
|
||||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||||
|
kernel_arg = Kernel(ret, metadata)
|
||||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
||||||
|
|
@ -579,7 +583,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
# bufferize -> store
|
# bufferize -> store
|
||||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||||
|
|
||||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||||
kernel_assign: dict[UOp, UOp] = {}
|
kernel_assign: dict[UOp, UOp] = {}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||||
ret: dict[UOp, dict[UOp, None]] = {}
|
ret: dict[UOp, dict[UOp, None]] = {}
|
||||||
for u in lst:
|
for u in lst:
|
||||||
ret[u] = {}
|
ret[u] = {}
|
||||||
for s in u.src: ret[s][u] = None
|
for s in u.src:
|
||||||
|
if s in ret: ret[s][u] = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def pretty_print(x:UOp, cache=None, d=0)->str:
|
def pretty_print(x:UOp, cache=None, d=0)->str:
|
||||||
|
|
@ -310,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
def ended_ranges(self):
|
def ended_ranges(self):
|
||||||
if self.op in range_start: return self.src[range_start[self.op]:]
|
if self.op in range_start: return self.src[range_start[self.op]:]
|
||||||
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
|
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
|
||||||
|
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
|
||||||
|
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
# determine what ranges this is in
|
# determine what ranges this is in
|
||||||
|
|
@ -819,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||||
|
|
||||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||||
|
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
|
||||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue