mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
image_idx_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03daefc625 | ||
|
|
fc68fcafa5 |
9 changed files with 109 additions and 49 deletions
|
|
@ -15,7 +15,7 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
|
|||
|
||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.STACK, dtypes.weakint.vec(2), idx).valid(valid), ptr=True),
|
||||
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[0], idx[1], valid, ptr=True),
|
||||
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||
))
|
||||
|
||||
|
|
@ -218,16 +218,13 @@ class TestImageSimplification(unittest.TestCase):
|
|||
def check(self, load, svalid, sidx0, sidx1):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx = load.src[0].src[1]
|
||||
self.assertEqual(idx.op, Ops.STACK)
|
||||
self.assertEqual(len(idx.src), 2)
|
||||
idx0, idx1 = idx.src[0], idx.src[1]
|
||||
idx0, idx1 = load.src[0].src[1], load.src[0].src[2]
|
||||
check_uop_against_string(self, idx0, sidx0)
|
||||
check_uop_against_string(self, idx1, sidx1)
|
||||
if svalid is not None:
|
||||
check_uop_against_string(self, load.src[0].src[2], svalid)
|
||||
check_uop_against_string(self, load.src[0].src[3], svalid)
|
||||
else:
|
||||
self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid")
|
||||
self.assertEqual(len(load.src[0].src), 3, "svalid is None but load still has a valid")
|
||||
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
|
|
@ -353,7 +350,7 @@ class TestImageSimplification(unittest.TestCase):
|
|||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.check(load,
|
||||
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
|
||||
"(((idx2*2)+r0)<11)",
|
||||
"(idx0+(idx1*512+r1*64)+-192)",
|
||||
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
|
||||
|
||||
|
|
@ -481,7 +478,7 @@ class TestImageSimplification(unittest.TestCase):
|
|||
self.check(load, None, "(gidx0+lidx0*1024+r0*1024+lidx1*128+-3168)", "0")
|
||||
except AssertionError:
|
||||
# TODO: fold valid
|
||||
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<19))",
|
||||
self.check(load, "(((lidx1<1)!=True)&((lidx0+r0)<19))",
|
||||
"(gidx0+lidx1*128+(lidx0*1024+r0*1024)+-3168)", "0")
|
||||
|
||||
def test_simplify10(self):
|
||||
|
|
@ -500,7 +497,7 @@ class TestImageSimplification(unittest.TestCase):
|
|||
self.check(load, None, "(lidx2+gidx0*4+lidx0*1024+r0*1024+lidx1*256+-3264)", "0")
|
||||
except AssertionError:
|
||||
# TODO: fold valid
|
||||
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
|
||||
self.check(load, "(((lidx1<1)!=True)&((lidx0+r0)<11))",
|
||||
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
|
||||
|
||||
class TestUnfoldableImage(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -109,6 +109,8 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))])),
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
|||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
|
||||
return buf.index(idx.gep(0), idx.gep(1), new_valid, ptr=True) if new_valid is not None else buf.index(idx.gep(0), idx.gep(1), ptr=True)
|
||||
|
||||
def simplify_image_valid(buf:UOp, idx_x:UOp, idx_y:UOp, c:UOp) -> UOp|None:
|
||||
return simplify_valid_load(buf, UOp(Ops.STACK, idx_x.dtype.vec(2), (idx_x, idx_y)), c) if isinstance(buf.dtype, ImageDType) else None
|
||||
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
|
|
@ -55,8 +58,11 @@ load_store_indexing = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
|
||||
# simplify away long after index has been lowered
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx_x"), UPat.var("idx_y"), UPat.var("c", dtypes.bool))), simplify_image_valid),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx_x"), UPat.var("idx_y"), UPat.const(dtypes.bool, True)),),
|
||||
lambda buf,idx_x,idx_y: buf.index(idx_x, idx_y, ptr=True) if isinstance(buf.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
|
|
@ -197,8 +203,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def get_image_idx(idx:UOp, width:int):
|
||||
oidx = UOp(Ops.STACK, dtypes.weakint.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
|
||||
return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid())))
|
||||
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
idx_x, idx_y = (x // 4) % width, x // (4*width)
|
||||
return idx.replace(src=(idx.src[0], idx_x, idx_y) + (() if valid.op is Ops.CONST and valid.arg is True else (valid,)))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
|
|
@ -207,7 +214,7 @@ def image_fixup(ls:UOp):
|
|||
return ls.replace(src=(get_image_idx(ls.src[0].src[0], dt.shape[1]),)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, we should just make it a buffer
|
||||
if isinstance(dt, ImageDType) and (off:=ls.src[0].src[1]).get_idx().dtype != dtypes.weakint.vec(2):
|
||||
if isinstance(dt, ImageDType) and len(ls.src[0].src) == 2 and (off:=ls.src[0].src[1]).get_idx().dtype.count != 2:
|
||||
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(off)
|
||||
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
|
||||
|
||||
|
|
@ -234,13 +241,26 @@ def no_vectorized_wmma(wmma:UOp):
|
|||
|
||||
def no_vectorized_alu(alu:UOp):
|
||||
if alu.dtype.vcount == 1: return None
|
||||
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index
|
||||
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # masked indexes use cond.where(idx, Invalid)
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.STACK, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def drop_load_alt(x:UOp):
|
||||
idx = x.src[0].src[0] if x.src[0].op is Ops.CAST else x.src[0]
|
||||
if idx.op is not Ops.INDEX: return None
|
||||
if len(idx.src) == 2 or (isinstance(idx.src[0].dtype, ImageDType) and len(idx.src) == 3): return x.replace(src=(x.src[0],)+x.src[2:])
|
||||
return None
|
||||
|
||||
def add_load_alt(x:UOp):
|
||||
idx = x.src[0].src[0] if x.src[0].op is Ops.CAST else x.src[0]
|
||||
if idx.op is not Ops.INDEX: return None
|
||||
if not (len(idx.src) == 3 and idx.src[2].dtype == dtypes.bool) and \
|
||||
not (isinstance(idx.src[0].dtype, ImageDType) and len(idx.src) == 4 and idx.src[3].dtype == dtypes.bool): return None
|
||||
return x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
cnt = cast.dtype.count
|
||||
if bcast is not None and bcast.op is Ops.GEP:
|
||||
|
|
@ -282,8 +302,8 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
add_load_alt),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat()), allow_any_len=True, name="x"), drop_load_alt),
|
||||
# Where after gated load becomes alt value
|
||||
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
lambda ctx,buf,idx: None if isinstance(buf.dtype, ImageDType) else f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
|
||||
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
|
|
@ -296,19 +296,27 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
dtypes.bfloat16: "ushort" }
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast + extra_pm
|
||||
|
||||
@staticmethod
|
||||
def image_coord(ctx, idx_x, idx_y): return f"({ctx.render_dtype(dtypes.int.vec(2))})({ctx[idx_x]},{ctx[idx_y]})"
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
# bfloat16 constants need to be rendered as their bit pattern since bf16 is stored as ushort
|
||||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), allow_any_len=True),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var('buf'), UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int)), allow_any_len=True),
|
||||
lambda ctx,buf,idx_x,idx_y: ctx.image_coord(ctx, idx_x, idx_y) if isinstance(buf.dtype, ImageDType) else None),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(
|
||||
UPat.var('buf').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int), UPat.var("gate")), UPat.var("var"))),
|
||||
lambda ctx,buf,idx_x,idx_y,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx.image_coord(ctx, idx_x, idx_y)}):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(
|
||||
UPat.var('buf').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int)),)),
|
||||
lambda ctx,buf,idx_x,idx_y: f"read_imagef({ctx[buf]}, smp, {ctx.image_coord(ctx, idx_x, idx_y)})"),
|
||||
(UPat(Ops.STORE, src=(
|
||||
UPat.var('buf').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int), allow_any_len=True),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx_x,idx_y,var: f"write_imagef({ctx[buf]}, {ctx.image_coord(ctx, idx_x, idx_y)}, {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
|
|
|
|||
|
|
@ -136,8 +136,9 @@ class NIRRenderer(Renderer):
|
|||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off:
|
||||
x.replace(src=(buf,off.cast(dtypes.long))+x.src[2:]) if not isinstance(buf.dtype, ImageDType) and \
|
||||
buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
|
||||
])
|
||||
|
||||
|
|
@ -249,30 +250,31 @@ class LVPRenderer(NIRRenderer):
|
|||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
||||
|
||||
# FIXME: this should be a rewrite rule
|
||||
def tovec(b, coord): return nalu(b, "vec4", nchannel(b, coord, 0), nchannel(b, coord, 1), nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def tovec(b, coord_x, coord_y): return nalu(b, "vec4", coord_x, coord_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
|
||||
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
|
||||
srcs=lambda b,img,coord,val:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
srcs=lambda b,img,coord_x,coord_y,val:[nsrc(x) for x in [img, tovec(b, coord_x, coord_y), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord_x,coord_y,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
|
||||
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
|
||||
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
nc=4, bs=32, num_components=4,
|
||||
srcs=lambda b,img,coord_x,coord_y:[nsrc(x) for x in [img, tovec(b, coord_x, coord_y), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,coord_x,coord_y,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
|
||||
class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||
has_aux = True
|
||||
|
||||
def nload_img(ctx,img,coord):
|
||||
def nload_img(ctx,img,idx_x,idx_y):
|
||||
ctx.texs.add(img)
|
||||
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
|
||||
return _nload_img(ctx.b, ctx.r[img], ctx.r[idx_x], ctx.r[idx_y], img.dtype)
|
||||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val")),
|
||||
allow_any_len=True), lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
|
||||
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int), allow_any_len=True), UPat.var("val")),
|
||||
allow_any_len=True), lambda ctx,img,idx_x,idx_y,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_x], ctx.r[idx_y], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int), UPat.var("gate")), UPat.var("alt"))),
|
||||
lambda ctx,img,idx_x,idx_y,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_x, idx_y), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_x', dtypes.int), UPat.var('idx_y', dtypes.int)),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
||||
_param = LVPRenderer.param
|
||||
|
|
|
|||
|
|
@ -93,12 +93,14 @@ class PythonProgram:
|
|||
elif uop is Ops.INDEX:
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
|
||||
for m,ox,oy in zip(src_values[0], src_values[1], src_values[2]):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
gate = src_values[3] if len(src_values) == 4 else [True]*len(ret)
|
||||
else:
|
||||
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
|
||||
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
|
||||
gate = src_values[2] if len(src_values) == 3 else [True]*len(ret)
|
||||
values[i] = [(m,o,g) for (m,o),g in zip(ret, gate)] # set the gate last
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
|
|
|
|||
|
|
@ -1136,8 +1136,8 @@ class UPat(OpMixin):
|
|||
|
||||
# copied from UOp
|
||||
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, idx:UPat, valid:UPat|None=None, **kwargs):
|
||||
return UPat(Ops.INDEX, self.match_dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
|
||||
def index(self, *srcs:UPat|None, **kwargs):
|
||||
return UPat(Ops.INDEX, self.match_dtype, (self,)+tuple(x for x in srcs if x is not None), **kwargs)
|
||||
def cast(self, dtype=None, **kwargs):
|
||||
if dtype is not None and self.match_dtype == (dtype,): return self
|
||||
return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
|
|
@ -1506,6 +1506,14 @@ def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x)
|
|||
def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
def strip_index_casts(idx:UOp):
|
||||
new_src = (idx.src[0],) + tuple(s.src[0] if s.op is Ops.CAST and s.dtype.scalar() is dtypes.weakint and
|
||||
s.src[0].dtype.scalar() in dtypes.ints else s for s in idx.src[1:])
|
||||
return new_src[0].index(*new_src[1:], ptr=True) if new_src != idx.src else None
|
||||
def lower_image_index(idx:UOp):
|
||||
if not isinstance(idx.src[0].dtype, ImageDType) or len(idx.src) not in (3, 4): return None
|
||||
new_src = (idx.src[0],) + tuple(s if s.dtype == dtypes.int else s.cast(dtypes.int) for s in idx.src[1:3]) + idx.src[3:]
|
||||
return idx.replace(src=new_src) if new_src != idx.src else None
|
||||
pm_lower_index_dtype = PatternMatcher([
|
||||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
|
||||
|
|
@ -1526,9 +1534,9 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="idx"), strip_index_casts),
|
||||
# images are indexed with separate int x/y coordinates
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat()), allow_any_len=True, name="idx"), lower_image_index),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
# vectorized indexes (ie. images) must be int
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
|||
from tinygrad.uop.validate import validate_index_with_z3
|
||||
return validate_index_with_z3(sz, idx, gate)
|
||||
|
||||
def validate_image_index(buf:UOp, idx_x:UOp, idx_y:UOp, gate:UOp|None=None):
|
||||
if not isinstance(buf.dtype, ImageDType): return None
|
||||
return idx_x.dtype in dtypes.ints+(dtypes.weakint,) and idx_y.dtype in dtypes.ints+(dtypes.weakint,) and \
|
||||
(gate is None or gate.dtype == dtypes.bool)
|
||||
|
||||
# four specs:
|
||||
# shared_spec -- usable anywhere
|
||||
# tensor_spec -- usable in tensor graph
|
||||
|
|
@ -175,6 +180,7 @@ shared_codegen_spec = PatternMatcher([
|
|||
|
||||
# LOAD(idx) / STORE(idx, val)
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted().load(), lambda: True),
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(), lambda: True), # gated load (alt added in program_spec)
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
|
||||
|
|
@ -187,6 +193,9 @@ shared_codegen_spec = PatternMatcher([
|
|||
# INDEX (2-arg and 3-arg with bool gate)
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx_x"), UPat.var("idx_y"))), validate_image_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx_x"), UPat.var("idx_y"),
|
||||
UPat.var("gate", dtype=dtypes.bool))), validate_image_index),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
|
@ -237,6 +246,7 @@ tensor_spec = PatternMatcher([
|
|||
program_spec = PatternMatcher([
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(), UPat(dtype=dtypes.bool))).or_casted().load(UPat()), lambda: True),
|
||||
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import math, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, ImageDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import threefry2x32, xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
|
@ -398,6 +398,13 @@ def where_on_load(cond:UOp, buf:UOp, idx:UOp, or_cast:UOp) -> UOp|None:
|
|||
ret_idx = idx.cast(or_cast.dtype) if or_cast.op is Ops.CAST else idx
|
||||
return UOp.const(dtypes.bool, True).uprod(*keep).where(ret_idx, ret_idx.const_like(0))
|
||||
|
||||
def index_with_gate(index:UOp, gate:UOp) -> UOp:
|
||||
if isinstance(index.src[0].dtype, ImageDType) and len(index.src) in (3, 4):
|
||||
return index.replace(src=index.src[:3] + (gate if len(index.src) == 3 else index.src[3] & gate,))
|
||||
if len(index.src) == 3 and index.src[2].dtype == dtypes.bool:
|
||||
return index.replace(src=(index.src[0], index.src[1], index.src[2] & gate))
|
||||
return index.replace(src=(index.src[0], gate.where(index.src[1], UOp.invalid())) + index.src[2:])
|
||||
|
||||
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("cond").where(UPat.var("buf").index(UPat.var("idx")).or_casted("or_cast"), 0), where_on_load),
|
||||
|
|
@ -443,15 +450,19 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
|||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
||||
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
|
||||
lambda index, gate, alt, store: UOp.store(index_with_gate(index, gate), alt, *store.src[2:])),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, False)), name="index"),), allow_any_len=True, name="x"),
|
||||
lambda index,x: UOp(Ops.NOOP) if isinstance(index.src[0].dtype, ImageDType) else None),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(), UPat.const(dtypes.bool, False)), name="index"),), allow_any_len=True, name="x"),
|
||||
lambda index,x: (x.src[1] if len(x.src) > 1 else x.const_like(0)) if isinstance(index.src[0].dtype, ImageDType) else None),
|
||||
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP)),
|
||||
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
|
||||
(UPat(Ops.STORE, src=(UPat(), invalid_pat), allow_any_len=True), lambda i: UOp(Ops.NOOP)),
|
||||
# store of where with invalid -> gated store
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat)), allow_any_len=True, name="store"),
|
||||
lambda index, cond, val, store, i: UOp.store(index.src[0].index(cond.where(index.src[1], UOp.invalid())), val, *store.src[2:])),
|
||||
lambda index, cond, val, store, i: UOp.store(index_with_gate(index, cond), val, *store.src[2:])),
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue