This commit is contained in:
George Hotz 2026-03-04 13:01:35 +08:00
commit ccb5dcf3b8

View file

@ -136,7 +136,7 @@ def soft_allreduce(c:UOp, a:UOp):
to = c.src[1].param_like(0)
src = c.src[2].param_like(1)
red = UOp(Ops.ALLREDUCE, dtype=a.arg, src=(src, a.src[1]), arg=a.arg)
return handle_allreduce(src, red).assign(to).sink().call(*c.src[1:])
return to.assign(handle_allreduce(src, red)).sink().call(*c.src[1:])
pm_schedule = PatternMatcher([
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),