Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
4689d29f49 rewrite name 2025-08-05 20:28:15 -07:00
George Hotz
07f45be2a7 simpler replace buffers 2025-08-05 20:14:07 -07:00
George Hotz
064c8e1b8b add the load view later 2025-08-05 20:09:09 -07:00
2 changed files with 10 additions and 12 deletions

View file

@ -119,6 +119,8 @@ def check_load_st(glbl:UOp, view:UOp):
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = view_left_through_load+PatternMatcher([
# add view to LOAD
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),

View file

@ -149,18 +149,15 @@ create_kernels = PatternMatcher([
# **** fix kernel AST
early_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
])
replace_globals = PatternMatcher([
replace_buffers = PatternMatcher([
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
])
def fix_kernel_ast(k:UOp) -> UOp|None:
@ -173,14 +170,13 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
# replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
ast = graph_rewrite(k.arg.ast, replace_buffers, bufs, bottom_up=True, name="replace buffers")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
# TODO: move these to codegen
ast = graph_rewrite(ast, view_left, name="Main View Left")
ast = graph_rewrite(ast, view_right, name="Main View Right")
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])