Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
b910f1d5c0 something 2026-05-08 17:32:30 -07:00
George Hotz
e14b2b41c6 move image index 2026-05-08 17:27:34 -07:00
George Hotz
bf05a2762e
Merge branch 'master' into image_no_vec 2026-05-08 16:32:08 -07:00
George Hotz
08747264cf fixes 2026-05-08 11:07:09 -07:00
George Hotz
f68c224b71 don't use vec(2) for image index 2026-05-08 10:52:24 -07:00
10 changed files with 159 additions and 65 deletions

View file

@ -21,7 +21,7 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.STACK, dtypes.weakint.vec(2), idx).valid(valid), ptr=True),
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[0].valid(valid), idx[1].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
@ -222,17 +222,15 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = simplify_image_idx(load.sink()).src[0]
off = load.src[0].src[1]
idx = off.get_idx()
self.assertEqual(idx.op, Ops.STACK)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
off = load.src[0]
idx0, idx1 = off.src[1].get_idx(), off.src[2].get_idx()
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
self.assertEqual(off.src[1].get_valid(), off.src[2].get_valid())
if svalid is not None:
check_uop_against_string(self, off.get_valid(), svalid)
check_uop_against_string(self, off.src[1].get_valid(), svalid)
else:
self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
self.assertEqual(off.src[1].get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

View file

@ -17,7 +17,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.late.gater import pm_image_index, pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
@ -77,6 +77,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
# convert image linear offsets to image coordinates before symbolic/index dtype cleanup
sink = graph_rewrite(sink, pm_image_index, name="image indexing")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")

View file

@ -38,19 +38,24 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
return None if isinstance(buf.dtype, ImageDType) or idx is start_idx else buf.index(idx.valid(valid), ptr=True)
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1])
if not drop_stmt and idx is start_idx: return None
def simplify_valid_image_load(buf:UOp, start_x:UOp, start_y:UOp, valid:UOp) -> UOp|None:
if not isinstance(buf.dtype, ImageDType) or start_x.dtype.scalar() is not dtypes.weakint or \
start_y.dtype.scalar() is not dtypes.weakint: return None
x, y = uop_given_valid(valid, start_x), uop_given_valid(valid, start_y)
drop_stmt = _drop_valid_stmts(valid, UOp.vectorize(x, y), buf.dtype.shape[0], buf.dtype.shape[1])
if not drop_stmt and x is start_x and y is start_y: return None
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
return buf.index(x.valid(new_valid) if new_valid is not None else x, y.valid(new_valid) if new_valid is not None else y, ptr=True)
image_invalid_gate_x = UPat.var("cond").where(UPat.var("x"), UPat(Ops.CONST, arg=Invalid))
image_invalid_gate_y = UPat.var("cond").where(UPat.var("y"), UPat(Ops.CONST, arg=Invalid))
load_store_indexing = PatternMatcher([
# image load valid idx simplification with scalar x/y coordinates
(UPat(Ops.INDEX, src=(UPat.var("buf"), image_invalid_gate_x, image_invalid_gate_y)),
lambda buf,x,y,cond: simplify_valid_image_load(buf, x, y, cond)),
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
])
@ -192,27 +197,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if len(ret) <= 1: return None
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
def get_image_idx(idx:UOp, width:int):
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
idx_x, idx_y = (x // 4) % width, x // (4*width)
return idx.replace(src=(idx.src[0], UOp.vectorize(idx_x, idx_y).valid(valid)))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) and ls.src[0].op is Ops.CAST:
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
return ls.replace(src=(get_image_idx(ls.src[0].src[0], dt.shape[1]),)+ls.src[1:])
# this is an unprocessed image without a cast, we should just make it a buffer
if isinstance(dt, ImageDType) and (off:=ls.src[0].src[1]).get_idx().dtype != dtypes.weakint.vec(2):
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(off)
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
correct_load_store = PatternMatcher([
# split LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
# image indexing, including unfoldable images
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
# *** uop expander ***
@ -231,7 +218,7 @@ def no_vectorized_wmma(wmma:UOp):
def no_vectorized_alu(alu:UOp):
if alu.dtype.vcount == 1: return None
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # gated indexes use cond.where(idx, Invalid)
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.STACK, alu.dtype, alus)

View file

@ -1,6 +1,48 @@
# this is a temporary intermediate step while we remove this index style
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import Invalid, dtypes
from tinygrad.dtype import Invalid, dtypes, ImageDType
def move_image_load_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, l:UOp):
if not isinstance(buf.dtype, ImageDType): return None
return buf.index(x, y, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)
def move_image_store_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, data:UOp):
if not isinstance(buf.dtype, ImageDType): return None
return buf.index(x, y, ptr=True).cast(cast.dtype).store(data, gate)
def image_coords_to_int(idx:UOp, buf:UOp, x:UOp, y:UOp):
if not isinstance(buf.dtype, ImageDType) or (x.dtype != dtypes.long and y.dtype != dtypes.long): return None
return idx.replace(src=(buf, x.cast(dtypes.int) if x.dtype == dtypes.long else x, y.cast(dtypes.int) if y.dtype == dtypes.long else y))
def index_and_valid(idx:UOp) -> tuple[UOp, UOp]:
if idx.dtype.scalar() is dtypes.weakint: return idx.get_idx(), idx.get_valid()
if idx.op is Ops.WHERE and idx.src[2].arg is Invalid: return idx.src[1], idx.src[0]
return idx, UOp.const(dtypes.bool, idx.arg is not Invalid)
def valid_idx(idx:UOp, valid:UOp) -> UOp:
return idx if valid.op is Ops.CONST and valid.arg is True else valid.where(idx, idx.const_like(Invalid))
def get_image_idx(idx:UOp, height:int, width:int) -> UOp:
x, valid = index_and_valid(idx.src[1])
px = x // 4
idx_x, idx_y = (px, px.const_like(0)) if height == 1 else (px % width, px // width)
return idx.replace(src=(idx.src[0], valid_idx(idx_x, valid), valid_idx(idx_y, valid)))
def image_fixup(ls:UOp):
# normal image load/store from split_load_store: casted linear offset -> image x/y coordinates
if ls.src[0].op is Ops.CAST and (cast_idx:=ls.src[0].src[0]).op is Ops.INDEX and isinstance(dt:=cast_idx.src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[0], dt.shape[1]),)+ls.src[1:])
if ls.src[0].op is not Ops.INDEX or not isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) or len(ls.src[0].src) == 3: return None
# this is an unprocessed image without a cast, we should just make it a buffer
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(ls.src[0].src[1])
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
pm_image_index = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
pm_move_gates_from_index = PatternMatcher([
# here we create the alt value for load to be 0s and remove the where Invalid
@ -8,6 +50,12 @@ pm_move_gates_from_index = PatternMatcher([
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
move_image_load_gate),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
move_image_store_gate),
# Where after gated load becomes alt value
(UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate"), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
@ -15,7 +63,8 @@ pm_move_gates_from_index = PatternMatcher([
(UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
# vectorized indexes (ie. images) must be int
# vectorized indexes must be int
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.var("y")), name="idx"), image_coords_to_int),
])

View file

@ -97,6 +97,14 @@ pm_manual_bf16_cast = PatternMatcher([
])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
def image_coord(ctx, x:UOp, y:UOp) -> str: return f"(int2)({ctx[x]}, {ctx[y]})"
def render_image_load(ctx, buf:UOp, x:UOp, y:UOp, var:UOp|None=None, gate:UOp|None=None) -> str|None:
if not isinstance(buf.dtype, ImageDType): return None
load = f"read_imagef({ctx[buf]}, smp, {image_coord(ctx, x, y)})"
return f"({ctx[gate]}?{load}:{ctx[var]})" if gate is not None and var is not None else load
def render_image_store(ctx, buf:UOp, x:UOp, y:UOp, var:UOp) -> str|None:
if not isinstance(buf.dtype, ImageDType): return None
return f"write_imagef({ctx[buf]}, {image_coord(ctx, x, y)}, {ctx[var]});"
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]):
@ -301,13 +309,15 @@ class OpenCLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
(UPat(Ops.INDEX, src=(UPat.var('buf'), UPat.var('x'), UPat.var('y')), name="idx"),
lambda ctx,buf,x,y,idx: image_coord(ctx, x, y) if isinstance(buf.dtype, ImageDType) else None),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,x,y,var,gate: render_image_load(ctx, buf, x, y, var, gate)),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),)),
lambda ctx,buf,x,y: render_image_load(ctx, buf, x, y)),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,x,y,var: render_image_store(ctx, buf, x, y, var)),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:

View file

@ -114,6 +114,11 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
def cast_global_index(x:UOp, buf:UOp, off:UOp):
if isinstance(buf.dtype, ImageDType) or not isinstance(buf.dtype, PtrDType) or buf.dtype.addrspace == AddrSpace.REG or \
off.op in (Ops.CAST, Ops.STACK): return None
return x.replace(src=(buf, off.cast(dtypes.long))+x.src[2:])
class NIRRenderer(Renderer):
suffix = "NIR"
nir_options: bytes
@ -136,8 +141,7 @@ class NIRRenderer(Renderer):
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), cast_global_index),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
])
@ -249,30 +253,39 @@ class LVPRenderer(NIRRenderer):
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
# FIXME: this should be a rewrite rule
def tovec(b, coord): return nalu(b, "vec4", nchannel(b, coord, 0), nchannel(b, coord, 1), nundef(b, dtypes.int), nundef(b, dtypes.int))
def tovec(b, x, y): return nalu(b, "vec4", x, y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
srcs=lambda b,img,coord,val:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
srcs=lambda b,img,x,y,val:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
lambda b,img,x,y,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
nc=4, bs=32, num_components=4, srcs=lambda b,img,x,y:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
lambda b,img,x,y,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
def nstore_img_checked(ctx, img:UOp, x:UOp, y:UOp, val:UOp):
if not isinstance(img.dtype, ImageDType): return None
return nstore_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], ctx.r[val], val.dtype)
def nload_img_gated(ctx, img:UOp, x:UOp, y:UOp, alt:UOp, gate:UOp):
if not isinstance(img.dtype, ImageDType): return None
return if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, x, y), lambda: ctx.r[alt])
class IR3Renderer(NIRRenderer, OpenCLRenderer):
has_aux = True
def nload_img(ctx,img,coord):
def nload_img(ctx,img,x,y):
if not isinstance(img.dtype, ImageDType): return None
ctx.texs.add(img)
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
return _nload_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], img.dtype)
def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val")), allow_any_len=True),
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))),
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("val")), allow_any_len=True),
nstore_img_checked),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("alt"), UPat.var("gate"))),
nload_img_gated),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')),)), nload_img),
]) + NIRRenderer.def_rewrite
_param = LVPRenderer.param

View file

@ -5,7 +5,7 @@
from typing import Any, TYPE_CHECKING
import pickle, base64, itertools, time, sys, functools
from dataclasses import replace
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, Invalid
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
@ -92,11 +92,13 @@ class PythonProgram:
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX:
if len(src_values) != 2: raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
if len(src_values) != 2 and not isinstance(src_dtypes[0], ImageDType): raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]):
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
xs, ys = (src_values[1][0], src_values[1][1]) if len(src_values) == 2 else (src_values[1], src_values[2])
for m,ox,oy in zip(src_values[0], xs, ys):
invalid = ox is Invalid or oy is Invalid
if invalid or ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else:
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))

View file

@ -1531,6 +1531,19 @@ def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x)
def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
def lower_index_casts(idx:UOp) -> UOp|None:
new_src, changed = [idx.src[0]], False
for s in idx.src[1:]:
ns = None
if s.op is Ops.CAST and s.dtype == dtypes.weakint and s.src[0].dtype.scalar() in dtypes.ints:
ns = s.src[0]
elif s.op is Ops.WHERE and s.dtype.scalar() is dtypes.weakint and s.src[2].op is Ops.CONST and s.src[2].arg is Invalid:
val = s.src[1]
if val.op is Ops.CAST and val.dtype == dtypes.weakint and val.src[0].dtype.scalar() in dtypes.ints:
ns = s.src[0].where(val.src[0], val.src[0].const_like(Invalid))
new_src.append(ns if ns is not None else s)
changed = changed or ns is not None
return idx.replace(src=tuple(new_src)) if changed else None
pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
@ -1552,6 +1565,7 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
(UPat(Ops.INDEX, name="idx"), lower_index_casts),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
])

View file

@ -24,6 +24,10 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
from tinygrad.uop.validate import validate_index_with_z3
return validate_index_with_z3(sz, idx, gate)
def validate_image_index(buf:UOp, idx0:UOp, idx1:UOp, gate:UOp|None=None):
if not isinstance(buf.dtype, ImageDType): return None
return validate_index(buf, idx0, gate) and validate_index(buf, idx1, gate)
# four specs:
# shared_spec -- usable anywhere
# tensor_spec -- usable in tensor graph
@ -180,6 +184,13 @@ shared_codegen_spec = PatternMatcher([
lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(), validate_image_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(
UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
lambda buf,idx0,idx1,gate,alt,load: validate_image_index(buf, idx0, idx1, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(UPat()), validate_image_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(
UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_image_index),
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
@ -189,6 +200,8 @@ shared_codegen_spec = PatternMatcher([
# INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available.
(UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat(), UPat())),
lambda buf: True if isinstance(buf.dtype, ImageDType) else None),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),

View file

@ -447,8 +447,13 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
# fold gated LOAD/STORE
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid), UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())),
lambda: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid),
UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)),
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
# store of where with invalid -> gated store
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))),