mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
308eb13eae | ||
|
|
1b1d81e3d5 | ||
|
|
50b7b283dc | ||
|
|
39da624581 |
5 changed files with 26 additions and 13 deletions
|
|
@ -2,7 +2,7 @@ import time
|
|||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec, kernel_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
|
@ -144,7 +144,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
if SPEC: type_verify(big_sink, tensor_spec+kernel_spec)
|
||||
|
||||
# hack to preserve metadata
|
||||
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}), name="preserve metadata")
|
||||
|
|
|
|||
|
|
@ -68,7 +68,15 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
|
|||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
param_to_ptr = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda x:
|
||||
None if isinstance(x.dtype, PtrDType) else x.replace(src=(), dtype=x.dtype.ptr(size=x.size)).reshape(x.shape)),
|
||||
])
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
if c.src[0].op in {Ops.SINK, Ops.PROGRAM}:
|
||||
# CALL is KERNEL...sort of
|
||||
return UOp(Ops.KERNEL, src=c.src[1:], arg=Kernel(graph_rewrite(c.src[0], param_to_ptr)))
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
|
|
|||
|
|
@ -248,7 +248,10 @@ class Tensor(OpMixin):
|
|||
|
||||
This API is alpha and may change.
|
||||
"""
|
||||
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
contig_srcs = tuple(x.contiguous() if x.uop.op is not Ops.AFTER else x for x in ((self,)+lst))
|
||||
params = [x.as_param(i) for i,x in enumerate(contig_srcs)]
|
||||
kernel = UOp.call(*[x.uop for x in contig_srcs], fxn=fxn(*[x.uop for x in params]), arg=grad_fxn)
|
||||
return [Tensor(s.uop.after(kernel), device=s.device) for s in contig_srcs]
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -58,6 +58,17 @@ shared_spec = PatternMatcher([
|
|||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
|
||||
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
|
|
@ -266,16 +277,6 @@ full_spec = PatternMatcher([
|
|||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u.op is Ops.KERNEL:
|
||||
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
|
||||
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
|
||||
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue