keep CONST/BUFFER uops in tensor_map [pr] (#9083)

This commit is contained in:
qazal 2025-02-14 14:50:08 +02:00 committed by GitHub
commit 82ad0d2e65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -66,10 +66,6 @@ sym = symbolic_simple+PatternMatcher([
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
remove_movement_ops = merge_views+PatternMatcher([
@ -139,7 +135,7 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.buf_uop, x) for x in s.src if x.base.op not in {Ops.CONST,Ops.BIND,Ops.BUFFER})),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
# realize before expand or unsafe pad ops