mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
default_la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2e69dfe51 |
||
|
|
d7c915874a |
||
|
|
091443349c | ||
|
|
6a912250c7 |
3 changed files with 3 additions and 3 deletions
|
|
@ -37,7 +37,7 @@ replace_allreduce = PatternMatcher([
|
|||
_early_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
])
|
||||
if not getenv("LATE_ALLREDUCE", 0): replace_allreduce = _early_allreduce + replace_allreduce
|
||||
if not getenv("LATE_ALLREDUCE", 1): replace_allreduce = _early_allreduce + replace_allreduce
|
||||
|
||||
# ***** multi functions *****
|
||||
|
||||
|
|
|
|||
|
|
@ -1519,7 +1519,7 @@ pm_pyrender_extra = PatternMatcher([
|
|||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
||||
# TODO: index shouldn't mismatch dtype
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+''.join([f"{ctx[xx]}, " for xx in x.src[2:]])+
|
||||
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
|
||||
# TODO: movement ops simplify stuff, this can break SPEC=2
|
||||
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ _tensor_spec = PatternMatcher([
|
|||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
||||
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE, Ops.NOOP)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
|
||||
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue