mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
index_to_s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c8867c3e9 | ||
|
|
c88941a957 | ||
|
|
760acdf554 | ||
|
|
f54a1d70d6 | ||
|
|
5d3713a6be | ||
|
|
d0f956341c |
12 changed files with 52 additions and 33 deletions
|
|
@ -73,10 +73,10 @@ class TestIdxUpcast(unittest.TestCase):
|
||||||
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
|
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
|
||||||
store = next(uop for uop in uops if uop.op is Ops.STORE)
|
store = next(uop for uop in uops if uop.op is Ops.STORE)
|
||||||
assert store.op is Ops.STORE
|
assert store.op is Ops.STORE
|
||||||
idx = self._find_op(store, Ops.INDEX)
|
idx = self._find_op(store, Ops.SLICE)
|
||||||
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
# PTX and NIR turn Ops.SLICE into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
||||||
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
|
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
|
||||||
assert idx.op is Ops.INDEX
|
assert idx.op is Ops.SLICE
|
||||||
idx_val = idx.src[1]
|
idx_val = idx.src[1]
|
||||||
self.assertIs(idx_val.dtype, dtype)
|
self.assertIs(idx_val.dtype, dtype)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -538,7 +538,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||||
# the gate and invalid value are deleted from ld1
|
# the gate and invalid value are deleted from ld1
|
||||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
self.assertEqual(ld0, UOp.load(glbl2.slice(idx), dtype=dtypes.int))
|
||||||
|
|
||||||
def test_fold_gated_load_local(self):
|
def test_fold_gated_load_local(self):
|
||||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||||
|
|
@ -552,7 +552,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
|
|
||||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||||
# the gate and invalid value are deleted from ld1
|
# the gate and invalid value are deleted from ld1
|
||||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
new_barrier = ld0.src[0].src[0].src[1]
|
||||||
|
assert new_barrier.op is Ops.BARRIER
|
||||||
|
self.assertEqual(ld0.src[0], smem.after(new_barrier).slice(lidx+2))
|
||||||
|
|
||||||
def test_fold_gated_store(self):
|
def test_fold_gated_store(self):
|
||||||
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||||
|
|
@ -564,7 +566,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
uops = to_uops_list([st0, st1])
|
uops = to_uops_list([st0, st1])
|
||||||
# only the second store happens
|
# only the second store happens
|
||||||
self.assertEqual(len(uops), 6) # +1 for SINK
|
self.assertEqual(len(uops), 6) # +1 for SINK
|
||||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
|
self.assertEqual(uops[-2], glbl.slice(idx1).store(val)) # -2 to skip SINK
|
||||||
|
|
||||||
@unittest.skip("this is a uop type error")
|
@unittest.skip("this is a uop type error")
|
||||||
def test_asserts_bad_gate(self):
|
def test_asserts_bad_gate(self):
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
|
||||||
# if statements are not allowed in the graph
|
# if statements are not allowed in the graph
|
||||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
(UPat(Ops.STORE, name="u", src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||||
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -283,6 +283,12 @@ pm_render = PatternMatcher([
|
||||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||||
|
# rewrite non-image INDEX to SLICE
|
||||||
|
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
|
||||||
|
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
|
||||||
|
# rewrite CAST on SLICE to just SLICE
|
||||||
|
(UPat(Ops.SLICE, name="bv").cast(name="x"),
|
||||||
|
lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
|
||||||
])
|
])
|
||||||
|
|
||||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,8 @@ class Estimates:
|
||||||
def range_gate(x): return x.op is not Ops.RANGE
|
def range_gate(x): return x.op is not Ops.RANGE
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
# if u.src[0] is SLICE, we have to include the buffer since it might be an AFTER
|
||||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.SLICE else u.src[0]).toposort(range_gate))
|
||||||
# TODO: is this correct? this all needs to be cleaned up
|
# TODO: is this correct? this all needs to be cleaned up
|
||||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||||
elif u.op is Ops.IF:
|
elif u.op is Ops.IF:
|
||||||
|
|
@ -40,7 +40,7 @@ class Estimates:
|
||||||
buf = u
|
buf = u
|
||||||
while len(buf.src): buf = buf.src[0]
|
while len(buf.src): buf = buf.src[0]
|
||||||
if buf.op is Ops.PARAM:
|
if buf.op is Ops.PARAM:
|
||||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
# u.src[0] is SLICE, cap at buffer size for re-reads (e.g. matmul)
|
||||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||||
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
|
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
|
||||||
if u.op is Ops.RANGE:
|
if u.op is Ops.RANGE:
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,10 @@ base_rewrite = PatternMatcher([
|
||||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
||||||
# default const render
|
# default const render
|
||||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||||
|
# slice is ptr arithmetic
|
||||||
|
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var('idx')), name="x"),
|
||||||
|
lambda ctx,buf,idx,x: ctx.render_cast(x.dtype, f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})")),
|
||||||
# new load/store
|
# new load/store
|
||||||
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
|
||||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||||
|
|
@ -190,14 +192,14 @@ class CStyleLanguage(Renderer):
|
||||||
else:
|
else:
|
||||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
||||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
Ops.SLICE: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||||
r[u] = f"{prefix}{c[prefix]}"
|
r[u] = f"{prefix}{c[prefix]}"
|
||||||
|
|
||||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||||
|
|
||||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SLICE, Ops.CUSTOMI} or \
|
||||||
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
||||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||||
|
|
|
||||||
|
|
@ -74,8 +74,9 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
|
||||||
|
|
||||||
base_rewrite = PatternMatcher([
|
base_rewrite = PatternMatcher([
|
||||||
# memory load/store
|
# memory load/store
|
||||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
(UPat(Ops.SLICE, name="x"), lambda ctx,x:
|
||||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
f" {ctx[x]}_o = getelementptr inbounds {ldt(x.src[0].dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}\n"
|
||||||
|
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x]}_o to {ldt(x.dtype)}"),
|
||||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||||
lambda ctx,x,idx,alt,mask:
|
lambda ctx,x,idx,alt,mask:
|
||||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||||
else f"{ctx[b]} = {ctx[v]};"),
|
else f"{ctx[b]} = {ctx[v]};"),
|
||||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
(UPat(Ops.SLICE, src=(UPat.var("b"), UPat.var("idx"))),
|
||||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||||
]) + base_rewrite
|
]) + base_rewrite
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,16 +101,18 @@ class PythonProgram:
|
||||||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||||
elif uop is Ops.INDEX:
|
elif uop is Ops.SLICE:
|
||||||
|
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||||
ret:list = []
|
ret:list = []
|
||||||
if isinstance(src_dtypes[0], ImageDType):
|
for m,o in zip(*src_values): ret.append((m,o))
|
||||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
values[i] = ret
|
||||||
for m,oy,ox in zip(*src_values):
|
elif uop is Ops.INDEX:
|
||||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
|
||||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
ret = []
|
||||||
else:
|
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
for m,oy,ox in zip(*src_values):
|
||||||
for m,o in zip(*src_values): ret.append((m,o))
|
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||||
|
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||||
values[i] = ret
|
values[i] = ret
|
||||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||||
values[i] = src_values[0]
|
values[i] = src_values[0]
|
||||||
|
|
|
||||||
|
|
@ -267,6 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.BINARY: return (len(self.arg),)
|
case Ops.BINARY: return (len(self.arg),)
|
||||||
case Ops.BUFFER: return (self.arg,)
|
case Ops.BUFFER: return (self.arg,)
|
||||||
case Ops.SLICE:
|
case Ops.SLICE:
|
||||||
|
if self.arg == 0: return ()
|
||||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||||
if self.src[0].op is Ops.INDEX: return ()
|
if self.src[0].op is Ops.INDEX: return ()
|
||||||
return (self.arg,)
|
return (self.arg,)
|
||||||
|
|
@ -459,6 +460,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||||
def vectorize(self, *srcs):
|
def vectorize(self, *srcs):
|
||||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
||||||
|
def slice(self, offset:UOp|int, size:int=0):
|
||||||
|
return UOp(Ops.SLICE, self.dtype, (self, offset if isinstance(offset, UOp) else UOp.const(dtypes.int, offset)), arg=size)
|
||||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
|
@ -1074,9 +1077,8 @@ class ProgramInfo:
|
||||||
for u in sink.toposort():
|
for u in sink.toposort():
|
||||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||||
if u.op is Ops.PARAM: _globals.append(u.arg)
|
if u.op is Ops.PARAM: _globals.append(u.arg)
|
||||||
if u.op in (Ops.STORE, Ops.LOAD):
|
if u.op in (Ops.STORE, Ops.LOAD) and (idx:=u.src[0]).op in (Ops.INDEX, Ops.SLICE) and (buf:=idx.src[0]).op is Ops.PARAM:
|
||||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
(outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
|
|
||||||
if u.op is Ops.SPECIAL:
|
if u.op is Ops.SPECIAL:
|
||||||
if u.arg[0] == 'i': local_size = None
|
if u.arg[0] == 'i': local_size = None
|
||||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||||
|
|
|
||||||
|
|
@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
|
||||||
(UPat(Ops.INS), lambda: True),
|
(UPat(Ops.INS), lambda: True),
|
||||||
|
|
||||||
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
||||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(), validate_index),
|
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
|
||||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||||
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
|
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
|
||||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat()), validate_index),
|
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
|
||||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||||
|
|
||||||
# STORE in tensor graph: store a value into a target
|
# STORE in tensor graph: store a value into a target
|
||||||
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
|
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
|
||||||
|
|
@ -200,6 +200,10 @@ spec_program = PatternMatcher([
|
||||||
# weakint is not allowed in programs
|
# weakint is not allowed in programs
|
||||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||||
|
|
||||||
|
# buffer view in program, Image only for ImageDType
|
||||||
|
(UPat(Ops.SLICE), lambda: True),
|
||||||
|
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
|
||||||
|
|
||||||
# movement ops are not allowed in programs
|
# movement ops are not allowed in programs
|
||||||
(UPat(GroupOp.Movement), lambda: False),
|
(UPat(GroupOp.Movement), lambda: False),
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||||
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
Ops.SLICE: "#a2c148", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||||
Ops.LINEAR: "#7DF4FF",
|
Ops.LINEAR: "#7DF4FF",
|
||||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue