Compare commits

...

4 commits

Author SHA1 Message Date
wozeparrot
e2e69dfe51
Merge branch 'master' into default_late_allreduce 2026-03-10 01:10:25 +08:00
wozeparrot
d7c915874a
Merge branch 'master' into default_late_allreduce 2026-03-05 17:53:11 +08:00
George Hotz
091443349c fix spec 2026-03-04 18:05:00 +08:00
George Hotz
6a912250c7 make late allreduce the default 2026-03-04 17:49:12 +08:00
3 changed files with 3 additions and 3 deletions

View file

@ -37,7 +37,7 @@ replace_allreduce = PatternMatcher([
_early_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
])
if not getenv("LATE_ALLREDUCE", 0): replace_allreduce = _early_allreduce + replace_allreduce
if not getenv("LATE_ALLREDUCE", 1): replace_allreduce = _early_allreduce + replace_allreduce
# ***** multi functions *****

View file

@ -1519,7 +1519,7 @@ pm_pyrender_extra = PatternMatcher([
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
# TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+''.join([f"{ctx[xx]}, " for xx in x.src[2:]])+
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
# TODO: movement ops simplify stuff, this can break SPEC=2
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),

View file

@ -84,7 +84,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE, Ops.NOOP)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# BUFFER_VIEW on BUFFER is allowed if BUFFER is