This commit is contained in:
George Hotz 2026-02-03 18:09:15 +08:00
commit 308eb13eae
3 changed files with 13 additions and 11 deletions

View file

@ -74,7 +74,7 @@ param_to_ptr = PatternMatcher([
])
def resolve_call(c:UOp) -> UOp:
if c.src[0].op is Ops.SINK:
if c.src[0].op in {Ops.SINK, Ops.PROGRAM}:
# CALL is KERNEL...sort of
return UOp(Ops.KERNEL, src=c.src[1:], arg=Kernel(graph_rewrite(c.src[0], param_to_ptr)))
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)

View file

@ -58,6 +58,17 @@ shared_spec = PatternMatcher([
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
])
# ***** UOp spec in the Tensor graph *****
@ -266,16 +277,6 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),

View file

@ -109,6 +109,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.op is Ops.KERNEL:
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):