mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
load_view_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4689d29f49 | ||
|
|
07f45be2a7 | ||
|
|
064c8e1b8b |
2 changed files with 10 additions and 12 deletions
|
|
@ -119,6 +119,8 @@ def check_load_st(glbl:UOp, view:UOp):
|
|||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = view_left_through_load+PatternMatcher([
|
||||
# add view to LOAD
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
|
|
|
|||
|
|
@ -149,18 +149,15 @@ create_kernels = PatternMatcher([
|
|||
|
||||
# **** fix kernel AST
|
||||
|
||||
early_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()),
|
||||
# no SINK for meta ops
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
])
|
||||
|
||||
replace_globals = PatternMatcher([
|
||||
replace_buffers = PatternMatcher([
|
||||
# replace ASSIGN with the target BUFFER
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
|
||||
# no SINK for meta ops
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
])
|
||||
|
||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
|
|
@ -173,14 +170,13 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s)
|
||||
# replace global memory ops with the BUFFER they write to
|
||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
|
||||
ast = graph_rewrite(k.arg.ast, replace_buffers, bufs, bottom_up=True, name="replace buffers")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
# TODO: move these to codegen
|
||||
ast = graph_rewrite(ast, view_left, name="Main View Left")
|
||||
ast = graph_rewrite(ast, view_right, name="Main View Right")
|
||||
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
|
||||
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="Finalize Kernel")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue