mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
codegen
This commit is contained in:
parent
1b1d81e3d5
commit
308eb13eae
3 changed files with 13 additions and 11 deletions
|
|
@ -74,7 +74,7 @@ param_to_ptr = PatternMatcher([
|
|||
])
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
if c.src[0].op is Ops.SINK:
|
||||
if c.src[0].op in {Ops.SINK, Ops.PROGRAM}:
|
||||
# CALL is KERNEL...sort of
|
||||
return UOp(Ops.KERNEL, src=c.src[1:], arg=Kernel(graph_rewrite(c.src[0], param_to_ptr)))
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,17 @@ shared_spec = PatternMatcher([
|
|||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
|
||||
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
|
|
@ -266,16 +277,6 @@ full_spec = PatternMatcher([
|
|||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u.op is Ops.KERNEL:
|
||||
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
|
||||
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
|
||||
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue