Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
1e07dff384 move simplify views to merge views 2025-07-28 13:55:23 -07:00
2 changed files with 20 additions and 2 deletions

18
extra/tvm_compute_at.py Normal file
View file

@ -0,0 +1,18 @@
import tvm
from tvm import te, tir
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute((n,), lambda i: A[i] + 1, name="B") # producer
C = te.compute((n,), lambda i: B[i] * 2, name="C") # consumer
prim_func = te.create_prim_func([A, C])
ir_mod = tvm.IRModule({"main": prim_func})
sch = tir.Schedule(ir_mod, debug_mask="all")
blk_B = sch.get_block("B", func_name="main")
blk_C = sch.get_block("C", func_name="main")
i_loop = sch.get_loops(blk_C)[0]
sch.compute_at(blk_B, i_loop)
print(sch.mod.script())

View file

@ -162,6 +162,8 @@ merge_views = PatternMatcher([
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
# simplify views
(UPat(Ops.VIEW, src=(UPat.var('x'),), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
])
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
@ -411,8 +413,6 @@ finalize_contiguous = PatternMatcher([
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
# merge contiguous
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
# simplify views
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
])
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])