mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
index_to_s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c8867c3e9 | ||
|
|
c88941a957 | ||
|
|
760acdf554 | ||
|
|
f54a1d70d6 | ||
|
|
5d3713a6be | ||
|
|
d0f956341c |
12 changed files with 52 additions and 33 deletions
|
|
@ -73,10 +73,10 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
|
||||
store = next(uop for uop in uops if uop.op is Ops.STORE)
|
||||
assert store.op is Ops.STORE
|
||||
idx = self._find_op(store, Ops.INDEX)
|
||||
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
||||
idx = self._find_op(store, Ops.SLICE)
|
||||
# PTX and NIR turn Ops.SLICE into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
||||
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
|
||||
assert idx.op is Ops.INDEX
|
||||
assert idx.op is Ops.SLICE
|
||||
idx_val = idx.src[1]
|
||||
self.assertIs(idx_val.dtype, dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(glbl2.slice(idx), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -552,7 +552,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
new_barrier = ld0.src[0].src[0].src[1]
|
||||
assert new_barrier.op is Ops.BARRIER
|
||||
self.assertEqual(ld0.src[0], smem.after(new_barrier).slice(lidx+2))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -564,7 +566,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 6) # +1 for SINK
|
||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
|
||||
self.assertEqual(uops[-2], glbl.slice(idx1).store(val)) # -2 to skip SINK
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -283,6 +283,12 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# rewrite non-image INDEX to SLICE
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
|
||||
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
|
||||
# rewrite CAST on SLICE to just SLICE
|
||||
(UPat(Ops.SLICE, name="bv").cast(name="x"),
|
||||
lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ class Estimates:
|
|||
def range_gate(x): return x.op is not Ops.RANGE
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
# if u.src[0] is SLICE, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.SLICE else u.src[0]).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
elif u.op is Ops.IF:
|
||||
|
|
@ -40,7 +40,7 @@ class Estimates:
|
|||
buf = u
|
||||
while len(buf.src): buf = buf.src[0]
|
||||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
# u.src[0] is SLICE, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
|
||||
if u.op is Ops.RANGE:
|
||||
|
|
|
|||
|
|
@ -43,8 +43,10 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
||||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# slice is ptr arithmetic
|
||||
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var('idx')), name="x"),
|
||||
lambda ctx,buf,idx,x: ctx.render_cast(x.dtype, f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})")),
|
||||
# new load/store
|
||||
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
|
|
@ -190,14 +192,14 @@ class CStyleLanguage(Renderer):
|
|||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
Ops.SLICE: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SLICE, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
|
|
|
|||
|
|
@ -74,8 +74,9 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
|
|||
|
||||
base_rewrite = PatternMatcher([
|
||||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.SLICE, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]}_o = getelementptr inbounds {ldt(x.src[0].dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}\n"
|
||||
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x]}_o to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class WGSLRenderer(CStyleLanguage):
|
|||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
(UPat(Ops.SLICE, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -101,16 +101,18 @@ class PythonProgram:
|
|||
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[i] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
elif uop is Ops.SLICE:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
values[i] = ret
|
||||
elif uop is Ops.INDEX:
|
||||
assert isinstance(src_dtypes[0], ImageDType), "only image INDEX is supported"
|
||||
ret = []
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.SLICE:
|
||||
if self.arg == 0: return ()
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
return (self.arg,)
|
||||
|
|
@ -459,6 +460,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def vectorize(self, *srcs):
|
||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
||||
def slice(self, offset:UOp|int, size:int=0):
|
||||
return UOp(Ops.SLICE, self.dtype, (self, offset if isinstance(offset, UOp) else UOp.const(dtypes.int, offset)), arg=size)
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
|
|
@ -1074,9 +1077,8 @@ class ProgramInfo:
|
|||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD) and (idx:=u.src[0]).op in (Ops.INDEX, Ops.SLICE) and (buf:=idx.src[0]).op is Ops.PARAM:
|
||||
(outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
|
|
|
|||
|
|
@ -99,11 +99,11 @@ spec_shared = PatternMatcher([
|
|||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val) with gates on the LOAD/STORE
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(), validate_index),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().load(UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
|
||||
lambda uidx,gate,alt,load: validate_index(uidx, gate) if alt.dtype == load.dtype else False),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat(Ops.INDEX, name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat()), validate_index),
|
||||
(UPat((Ops.INDEX, Ops.SLICE), name="uidx").or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
|
||||
|
||||
# STORE in tensor graph: store a value into a target
|
||||
(UPat(Ops.STORE, dtypes.void, (UPat(name="x"), UPat())), lambda x: True),
|
||||
|
|
@ -200,6 +200,10 @@ spec_program = PatternMatcher([
|
|||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
||||
# buffer view in program, Image only for ImageDType
|
||||
(UPat(Ops.SLICE), lambda: True),
|
||||
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
|
||||
|
||||
# movement ops are not allowed in programs
|
||||
(UPat(GroupOp.Movement), lambda: False),
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.SLICE: "#a2c148", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
Ops.LINEAR: "#7DF4FF",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue