mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
move_simp_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e07dff384 |
2 changed files with 20 additions and 2 deletions
18
extra/tvm_compute_at.py
Normal file
18
extra/tvm_compute_at.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import tvm
|
||||
from tvm import te, tir
|
||||
|
||||
n = te.var("n")
|
||||
A = te.placeholder((n,), name="A")
|
||||
B = te.compute((n,), lambda i: A[i] + 1, name="B") # producer
|
||||
C = te.compute((n,), lambda i: B[i] * 2, name="C") # consumer
|
||||
|
||||
prim_func = te.create_prim_func([A, C])
|
||||
ir_mod = tvm.IRModule({"main": prim_func})
|
||||
sch = tir.Schedule(ir_mod, debug_mask="all")
|
||||
|
||||
blk_B = sch.get_block("B", func_name="main")
|
||||
blk_C = sch.get_block("C", func_name="main")
|
||||
i_loop = sch.get_loops(blk_C)[0]
|
||||
sch.compute_at(blk_B, i_loop)
|
||||
|
||||
print(sch.mod.script())
|
||||
|
|
@ -162,6 +162,8 @@ merge_views = PatternMatcher([
|
|||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
# simplify views
|
||||
(UPat(Ops.VIEW, src=(UPat.var('x'),), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
|
|
@ -411,8 +413,6 @@ finalize_contiguous = PatternMatcher([
|
|||
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
||||
# merge contiguous
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
|
||||
# simplify views
|
||||
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
|
||||
])
|
||||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue