Compare commits

...

6 commits

Author SHA1 Message Date
George Hotz
2c8867c3e9 all cpu tests pass local 2026-05-26 19:28:46 -07:00
George Hotz
c88941a957 bugfix 2026-05-26 19:19:44 -07:00
George Hotz
760acdf554 really fix llvm 2026-05-26 19:16:48 -07:00
George Hotz
f54a1d70d6 unneeded 2026-05-26 19:05:04 -07:00
George Hotz
5d3713a6be fix llvm 2026-05-26 18:47:34 -07:00
George Hotz
d0f956341c replace INDEX with SLICE 2026-05-26 18:37:03 -07:00
12 changed files with 52 additions and 33 deletions

View file

@ -73,10 +73,10 @@ class TestIdxUpcast(unittest.TestCase):
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes # Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
store = next(uop for uop in uops if uop.op is Ops.STORE) store = next(uop for uop in uops if uop.op is Ops.STORE)
assert store.op is Ops.STORE assert store.op is Ops.STORE
idx = self._find_op(store, Ops.INDEX) idx = self._find_op(store, Ops.SLICE)
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64 # PTX and NIR turn Ops.SLICE into pointer arithmetic earlier than cstyle, plus it's already cast to int64
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)): if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
assert idx.op is Ops.INDEX assert idx.op is Ops.SLICE
idx_val = idx.src[1] idx_val = idx.src[1]
self.assertIs(idx_val.dtype, dtype) self.assertIs(idx_val.dtype, dtype)

View file

@ -538,7 +538,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))]) uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1 # the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int)) self.assertEqual(ld0, UOp.load(glbl2.slice(idx), dtype=dtypes.int))
def test_fold_gated_load_local(self): def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0) glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
@ -552,7 +552,9 @@ class TestUOpGraph(unittest.TestCase):
ld0 = uops[-2].src[-1] # -2 to skip SINK ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1 # the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True)) new_barrier = ld0.src[0].src[0].src[1]
assert new_barrier.op is Ops.BARRIER
self.assertEqual(ld0.src[0], smem.after(new_barrier).slice(lidx+2))
def test_fold_gated_store(self): def test_fold_gated_store(self):
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0) glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
@ -564,7 +566,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([st0, st1]) uops = to_uops_list([st0, st1])
# only the second store happens # only the second store happens
self.assertEqual(len(uops), 6) # +1 for SINK self.assertEqual(len(uops), 6) # +1 for SINK
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK self.assertEqual(uops[-2], glbl.slice(idx1).store(val)) # -2 to skip SINK
@unittest.skip("this is a uop type error") @unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self): def test_asserts_bad_gate(self):

View file

@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph # if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")), (UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF # gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))), (UPat(Ops.STORE, name="u", src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))])) lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
]) ])

View file

@ -283,6 +283,12 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None), (UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x), (UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# rewrite non-image INDEX to SLICE
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
# rewrite CAST on SLICE to just SLICE
(UPat(Ops.SLICE, name="bv").cast(name="x"),
lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
]) ])
# *** Ops.REDUCE -> Ops.DEFINE_ACC *** # *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -29,8 +29,8 @@ class Estimates:
def range_gate(x): return x.op is not Ops.RANGE def range_gate(x): return x.op is not Ops.RANGE
for u in uops: for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}: if u.op in {Ops.LOAD, Ops.STORE}:
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER # if u.src[0] is SLICE, we have to include the buffer since it might be an AFTER
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.SLICE else u.src[0]).toposort(range_gate))
# TODO: is this correct? this all needs to be cleaned up # TODO: is this correct? this all needs to be cleaned up
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
elif u.op is Ops.IF: elif u.op is Ops.IF:
@ -40,7 +40,7 @@ class Estimates:
buf = u buf = u
while len(buf.src): buf = buf.src[0] while len(buf.src): buf = buf.src[0]
if buf.op is Ops.PARAM: if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul) # u.src[0] is SLICE, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
if u.op is Ops.RANGE: if u.op is Ops.RANGE:

View file

@ -43,8 +43,10 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"), (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
# default const render # default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# slice is ptr arithmetic
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var('idx')), name="x"),
lambda ctx,buf,idx,x: ctx.render_cast(x.dtype, f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})")),
# new load/store # new load/store
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), (UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
@ -190,14 +192,14 @@ class CStyleLanguage(Renderer):
else: else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu") Ops.SLICE: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}" r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SLICE, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): (u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):

View file

@ -74,8 +74,9 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
base_rewrite = PatternMatcher([ base_rewrite = PatternMatcher([
# memory load/store # memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x: (UPat(Ops.SLICE, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"), f" {ctx[x]}_o = getelementptr inbounds {ldt(x.src[0].dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}\n"
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x]}_o to {ldt(x.dtype)}"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"), (UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask: lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n" f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"

View file

@ -90,7 +90,7 @@ class WGSLRenderer(CStyleLanguage):
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \ f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
else f"{ctx[b]} = {ctx[v]};"), else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))), (UPat(Ops.SLICE, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
]) + base_rewrite ]) + base_rewrite

View file

@ -101,16 +101,18 @@ class PythonProgram:
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp] elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX: elif uop is Ops.SLICE:
assert len(src_values) == 2, "non-image index must be 2 srcs"
ret:list = [] ret:list = []
if isinstance(src_dtypes[0], ImageDType): for m,o in zip(*src_values): ret.append((m,o))
assert len(src_values) == 3, "image index must be 3 srcs" values[i] = ret
for m,oy,ox in zip(*src_values): elif uop is Ops.INDEX:
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None)) assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4)) ret = []
else: assert len(src_values) == 3, "image index must be 3 srcs"
assert len(src_values) == 2, "non-image index must be 2 srcs" for m,oy,ox in zip(*src_values):
for m,o in zip(*src_values): ret.append((m,o)) if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
values[i] = ret values[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType): elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0] values[i] = src_values[0]

View file

@ -267,6 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.BINARY: return (len(self.arg),) case Ops.BINARY: return (len(self.arg),)
case Ops.BUFFER: return (self.arg,) case Ops.BUFFER: return (self.arg,)
case Ops.SLICE: case Ops.SLICE:
if self.arg == 0: return ()
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX # HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
if self.src[0].op is Ops.INDEX: return () if self.src[0].op is Ops.INDEX: return ()
return (self.arg,) return (self.arg,)
@ -459,6 +460,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs): def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs) return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
def slice(self, offset:UOp|int, size:int=0):
return UOp(Ops.SLICE, self.dtype, (self, offset if isinstance(offset, UOp) else UOp.const(dtypes.int, offset)), arg=size)
def index(self, *srcs:UOp|None, ptr=False, **kwargs): def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): def __getitem__(self, idx):
@ -1074,9 +1077,8 @@ class ProgramInfo:
for u in sink.toposort(): for u in sink.toposort():
if u.op is Ops.DEFINE_VAR: _vars.append(u) if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.PARAM: _globals.append(u.arg) if u.op is Ops.PARAM: _globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD): if u.op in (Ops.STORE, Ops.LOAD) and (idx:=u.src[0]).op in (Ops.INDEX, Ops.SLICE) and (buf:=idx.src[0]).op is Ops.PARAM:
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX): (outs if u.op is Ops.STORE else ins).append(buf.arg)
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
if u.op is Ops.SPECIAL: if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size special_size = local_size if u.arg[0] == 'l' else global_size

View file

@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
(UPat(Ops.INS), lambda: True), (UPat(Ops.INS), lambda: True),
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE # LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
(UPat(Ops.INDEX, name="uidx").or_casted().load(), validate_index), (UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
(UPat(Ops.INDEX, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"), (UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False), lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat()), validate_index), (UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index), (UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
# STORE in tensor graph: store a value into a target # STORE in tensor graph: store a value into a target
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True), (UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
@ -200,6 +200,10 @@ spec_program = PatternMatcher([
# weakint is not allowed in programs # weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False), (UPat(GroupOp.All, dtypes.weakint), lambda: False),
# buffer view in program, Image only for ImageDType
(UPat(Ops.SLICE), lambda: True),
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
# movement ops are not allowed in programs # movement ops are not allowed in programs
(UPat(GroupOp.Movement), lambda: False), (UPat(GroupOp.Movement), lambda: False),

View file

@ -50,7 +50,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6", Ops.SLICE: "#a2c148", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040", Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
Ops.LINEAR: "#7DF4FF", Ops.LINEAR: "#7DF4FF",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",