Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
b910f1d5c0 something 2026-05-08 17:32:30 -07:00
George Hotz
e14b2b41c6 move image index 2026-05-08 17:27:34 -07:00
George Hotz
bf05a2762e
Merge branch 'master' into image_no_vec 2026-05-08 16:32:08 -07:00
George Hotz
08747264cf fixes 2026-05-08 11:07:09 -07:00
George Hotz
f68c224b71 don't use vec(2) for image index 2026-05-08 10:52:24 -07:00
10 changed files with 159 additions and 65 deletions

View file

@ -21,7 +21,7 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), ( return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.STACK, dtypes.weakint.vec(2), idx).valid(valid), ptr=True), UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx[0].valid(valid), idx[1].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4) UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
@ -222,17 +222,15 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase): class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1): def check(self, load, svalid, sidx0, sidx1):
load = simplify_image_idx(load.sink()).src[0] load = simplify_image_idx(load.sink()).src[0]
off = load.src[0].src[1] off = load.src[0]
idx = off.get_idx() idx0, idx1 = off.src[1].get_idx(), off.src[2].get_idx()
self.assertEqual(idx.op, Ops.STACK)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
check_uop_against_string(self, idx0, sidx0) check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1) check_uop_against_string(self, idx1, sidx1)
self.assertEqual(off.src[1].get_valid(), off.src[2].get_valid())
if svalid is not None: if svalid is not None:
check_uop_against_string(self, off.get_valid(), svalid) check_uop_against_string(self, off.src[1].get_valid(), svalid)
else: else:
self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True") self.assertEqual(off.src[1].get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
def test_idx_gt_c(self): def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid # (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

View file

@ -17,7 +17,7 @@ from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_f
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index from tinygrad.codegen.late.gater import pm_image_index, pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
@ -77,6 +77,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
# convert image linear offsets to image coordinates before symbolic/index dtype cleanup
sink = graph_rewrite(sink, pm_image_index, name="image indexing")
# lower the index dtype to a concrete int # lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic") sink = graph_rewrite(sink, symbolic, name="post index symbolic")

View file

@ -38,19 +38,24 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
idx = uop_given_valid(valid, start_idx) idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True) return None if isinstance(buf.dtype, ImageDType) or idx is start_idx else buf.index(idx.valid(valid), ptr=True)
# wait for it to be image indexed before running simplification def simplify_valid_image_load(buf:UOp, start_x:UOp, start_y:UOp, valid:UOp) -> UOp|None:
if start_idx.dtype.count != 2: return None if not isinstance(buf.dtype, ImageDType) or start_x.dtype.scalar() is not dtypes.weakint or \
start_y.dtype.scalar() is not dtypes.weakint: return None
drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1]) x, y = uop_given_valid(valid, start_x), uop_given_valid(valid, start_y)
drop_stmt = _drop_valid_stmts(valid, UOp.vectorize(x, y), buf.dtype.shape[0], buf.dtype.shape[1])
if not drop_stmt and idx is start_idx: return None if not drop_stmt and x is start_x and y is start_y: return None
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True) return buf.index(x.valid(new_valid) if new_valid is not None else x, y.valid(new_valid) if new_valid is not None else y, ptr=True)
image_invalid_gate_x = UPat.var("cond").where(UPat.var("x"), UPat(Ops.CONST, arg=Invalid))
image_invalid_gate_y = UPat.var("cond").where(UPat.var("y"), UPat(Ops.CONST, arg=Invalid))
load_store_indexing = PatternMatcher([ load_store_indexing = PatternMatcher([
# image load valid idx simplification with scalar x/y coordinates
(UPat(Ops.INDEX, src=(UPat.var("buf"), image_invalid_gate_x, image_invalid_gate_y)),
lambda buf,x,y,cond: simplify_valid_image_load(buf, x, y, cond)),
# image load valid idx simplification # image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)), (UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
]) ])
@ -192,27 +197,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if len(ret) <= 1: return None if len(ret) <= 1: return None
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
def get_image_idx(idx:UOp, width:int):
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
idx_x, idx_y = (x // 4) % width, x // (4*width)
return idx.replace(src=(idx.src[0], UOp.vectorize(idx_x, idx_y).valid(valid)))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) and ls.src[0].op is Ops.CAST:
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
return ls.replace(src=(get_image_idx(ls.src[0].src[0], dt.shape[1]),)+ls.src[1:])
# this is an unprocessed image without a cast, we should just make it a buffer
if isinstance(dt, ImageDType) and (off:=ls.src[0].src[1]).get_idx().dtype != dtypes.weakint.vec(2):
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(off)
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
correct_load_store = PatternMatcher([ correct_load_store = PatternMatcher([
# split LOAD/STORE # split LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store), (UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ls", allow_any_len=True), split_load_store),
# image indexing, including unfoldable images
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
]) ])
# *** uop expander *** # *** uop expander ***
@ -231,7 +218,7 @@ def no_vectorized_wmma(wmma:UOp):
def no_vectorized_alu(alu:UOp): def no_vectorized_alu(alu:UOp):
if alu.dtype.vcount == 1: return None if alu.dtype.vcount == 1: return None
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # gated indexes use cond.where(idx, Invalid)
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.STACK, alu.dtype, alus) return UOp(Ops.STACK, alu.dtype, alus)

View file

@ -1,6 +1,48 @@
# this is a temporary intermediate step while we remove this index style # this is a temporary intermediate step while we remove this index style
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import Invalid, dtypes from tinygrad.dtype import Invalid, dtypes, ImageDType
def move_image_load_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, l:UOp):
if not isinstance(buf.dtype, ImageDType): return None
return buf.index(x, y, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)
def move_image_store_gate(buf:UOp, gate:UOp, x:UOp, y:UOp, cast:UOp, data:UOp):
if not isinstance(buf.dtype, ImageDType): return None
return buf.index(x, y, ptr=True).cast(cast.dtype).store(data, gate)
def image_coords_to_int(idx:UOp, buf:UOp, x:UOp, y:UOp):
if not isinstance(buf.dtype, ImageDType) or (x.dtype != dtypes.long and y.dtype != dtypes.long): return None
return idx.replace(src=(buf, x.cast(dtypes.int) if x.dtype == dtypes.long else x, y.cast(dtypes.int) if y.dtype == dtypes.long else y))
def index_and_valid(idx:UOp) -> tuple[UOp, UOp]:
if idx.dtype.scalar() is dtypes.weakint: return idx.get_idx(), idx.get_valid()
if idx.op is Ops.WHERE and idx.src[2].arg is Invalid: return idx.src[1], idx.src[0]
return idx, UOp.const(dtypes.bool, idx.arg is not Invalid)
def valid_idx(idx:UOp, valid:UOp) -> UOp:
return idx if valid.op is Ops.CONST and valid.arg is True else valid.where(idx, idx.const_like(Invalid))
def get_image_idx(idx:UOp, height:int, width:int) -> UOp:
x, valid = index_and_valid(idx.src[1])
px = x // 4
idx_x, idx_y = (px, px.const_like(0)) if height == 1 else (px % width, px // width)
return idx.replace(src=(idx.src[0], valid_idx(idx_x, valid), valid_idx(idx_y, valid)))
def image_fixup(ls:UOp):
# normal image load/store from split_load_store: casted linear offset -> image x/y coordinates
if ls.src[0].op is Ops.CAST and (cast_idx:=ls.src[0].src[0]).op is Ops.INDEX and isinstance(dt:=cast_idx.src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[0], dt.shape[1]),)+ls.src[1:])
if ls.src[0].op is not Ops.INDEX or not isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) or len(ls.src[0].src) == 3: return None
# this is an unprocessed image without a cast, we should just make it a buffer
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(ls.src[0].src[1])
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
pm_image_index = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), name="ls"), image_fixup),
])
pm_move_gates_from_index = PatternMatcher([ pm_move_gates_from_index = PatternMatcher([
# here we create the alt value for load to be 0s and remove the where Invalid # here we create the alt value for load to be 0s and remove the where Invalid
@ -8,6 +50,12 @@ pm_move_gates_from_index = PatternMatcher([
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)), lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")), (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)), lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
move_image_load_gate),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("x"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("y"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
move_image_store_gate),
# Where after gated load becomes alt value # Where after gated load becomes alt value
(UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate"), name="l").or_casted(), UPat.var("a")), lambda gate,l,a: (UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate"), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
@ -15,7 +63,8 @@ pm_move_gates_from_index = PatternMatcher([
(UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a: (UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)), l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
# vectorized indexes (ie. images) must be int # vectorized indexes must be int
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"), (UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))) lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.var("y")), name="idx"), image_coords_to_int),
]) ])

View file

@ -97,6 +97,14 @@ pm_manual_bf16_cast = PatternMatcher([
]) ])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
def image_coord(ctx, x:UOp, y:UOp) -> str: return f"(int2)({ctx[x]}, {ctx[y]})"
def render_image_load(ctx, buf:UOp, x:UOp, y:UOp, var:UOp|None=None, gate:UOp|None=None) -> str|None:
if not isinstance(buf.dtype, ImageDType): return None
load = f"read_imagef({ctx[buf]}, smp, {image_coord(ctx, x, y)})"
return f"({ctx[gate]}?{load}:{ctx[var]})" if gate is not None and var is not None else load
def render_image_store(ctx, buf:UOp, x:UOp, y:UOp, var:UOp) -> str|None:
if not isinstance(buf.dtype, ImageDType): return None
return f"write_imagef({ctx[buf]}, {image_coord(ctx, x, y)}, {ctx[var]});"
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes) # (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]): def wmma_args(uops:list[UOp]):
@ -301,13 +309,15 @@ class OpenCLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtypes.bfloat16, name="x"), (UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"), lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL) # load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))), (UPat(Ops.INDEX, src=(UPat.var('buf'), UPat.var('x'), UPat.var('y')), name="idx"),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"), lambda ctx,buf,x,y,idx: image_coord(ctx, x, y) if isinstance(buf.dtype, ImageDType) else None),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)), (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"), lambda ctx,buf,x,y,var,gate: render_image_load(ctx, buf, x, y, var, gate)),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),)),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), lambda ctx,buf,x,y: render_image_load(ctx, buf, x, y)),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"), (UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('x'), UPat.var('y')),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,x,y,var: render_image_store(ctx, buf, x, y, var)),
]) + base_rewrite ]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:

View file

@ -114,6 +114,11 @@ def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long)))) lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f() return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
def cast_global_index(x:UOp, buf:UOp, off:UOp):
if isinstance(buf.dtype, ImageDType) or not isinstance(buf.dtype, PtrDType) or buf.dtype.addrspace == AddrSpace.REG or \
off.op in (Ops.CAST, Ops.STACK): return None
return x.replace(src=(buf, off.cast(dtypes.long))+x.src[2:])
class NIRRenderer(Renderer): class NIRRenderer(Renderer):
suffix = "NIR" suffix = "NIR"
nir_options: bytes nir_options: bytes
@ -136,8 +141,7 @@ class NIRRenderer(Renderer):
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU # ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)), (UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing # load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace( (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), cast_global_index),
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None), (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
]) ])
@ -249,30 +253,39 @@ class LVPRenderer(NIRRenderer):
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)]) self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
# FIXME: this should be a rewrite rule # FIXME: this should be a rewrite rule
def tovec(b, coord): return nalu(b, "vec4", nchannel(b, coord, 0), nchannel(b, coord, 1), nundef(b, dtypes.int), nundef(b, dtypes.int)) def tovec(b, x, y): return nalu(b, "vec4", x, y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32 def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components, nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)}, intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
srcs=lambda b,img,coord,val:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])( srcs=lambda b,img,x,y,val:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store"))) lambda b,img,x,y,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)}, _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
nc=4, bs=32, num_components=4, srcs=lambda b,img,coord:[nsrc(x) for x in [img, tovec(b, coord), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])( nc=4, bs=32, num_components=4, srcs=lambda b,img,x,y:[nsrc(z) for z in [img, tovec(b, x, y), nundef(b, dtypes.int), nimm(b, 0, dtypes.int)]])(
lambda b,img,coord,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load"))) lambda b,img,x,y,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
def nstore_img_checked(ctx, img:UOp, x:UOp, y:UOp, val:UOp):
if not isinstance(img.dtype, ImageDType): return None
return nstore_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], ctx.r[val], val.dtype)
def nload_img_gated(ctx, img:UOp, x:UOp, y:UOp, alt:UOp, gate:UOp):
if not isinstance(img.dtype, ImageDType): return None
return if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, x, y), lambda: ctx.r[alt])
class IR3Renderer(NIRRenderer, OpenCLRenderer): class IR3Renderer(NIRRenderer, OpenCLRenderer):
has_aux = True has_aux = True
def nload_img(ctx,img,coord): def nload_img(ctx,img,x,y):
if not isinstance(img.dtype, ImageDType): return None
ctx.texs.add(img) ctx.texs.add(img)
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype) return _nload_img(ctx.b, ctx.r[img], ctx.r[x], ctx.r[y], img.dtype)
def_rewrite = PatternMatcher([ def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val")), allow_any_len=True), (UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("val")), allow_any_len=True),
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)), nstore_img_checked),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("alt"), UPat.var("gate"))), (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')), UPat.var("alt"), UPat.var("gate"))),
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])), nload_img_gated),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img), (UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('x'), UPat.var('y')),)), nload_img),
]) + NIRRenderer.def_rewrite ]) + NIRRenderer.def_rewrite
_param = LVPRenderer.param _param = LVPRenderer.param

View file

@ -5,7 +5,7 @@
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import pickle, base64, itertools, time, sys, functools import pickle, base64, itertools, time, sys, functools
from dataclasses import replace from dataclasses import replace
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, Invalid
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target
from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc from tinygrad.codegen.opt import tc
@ -92,11 +92,13 @@ class PythonProgram:
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp] elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size elif uop is Ops.CONST: values[i] = [arg] * warp_size
elif uop is Ops.INDEX: elif uop is Ops.INDEX:
if len(src_values) != 2: raise RuntimeError("gates must be on LOAD/STORE, not INDEX") if len(src_values) != 2 and not isinstance(src_dtypes[0], ImageDType): raise RuntimeError("gates must be on LOAD/STORE, not INDEX")
ret:list = [] ret:list = []
if isinstance(src_dtypes[0], ImageDType): if isinstance(src_dtypes[0], ImageDType):
for m,ox,oy in zip(src_values[0], src_values[1][0], src_values[1][1]): xs, ys = (src_values[1][0], src_values[1][1]) if len(src_values) == 2 else (src_values[1], src_values[2])
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None)) for m,ox,oy in zip(src_values[0], xs, ys):
invalid = ox is Invalid or oy is Invalid
if invalid or ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4)) else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else: else:
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o)) for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))

View file

@ -1531,6 +1531,19 @@ def sint_to_uop(x:sint, dtype=dtypes.weakint) -> UOp: return UOp.const(dtype, x)
def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape) def to_max_shape(shape:tuple[sint, ...]) -> tuple[int, ...]: return tuple(int(x.vmax) if isinstance(x, UOp) else x for x in shape)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count) def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
def lower_index_casts(idx:UOp) -> UOp|None:
new_src, changed = [idx.src[0]], False
for s in idx.src[1:]:
ns = None
if s.op is Ops.CAST and s.dtype == dtypes.weakint and s.src[0].dtype.scalar() in dtypes.ints:
ns = s.src[0]
elif s.op is Ops.WHERE and s.dtype.scalar() is dtypes.weakint and s.src[2].op is Ops.CONST and s.src[2].arg is Invalid:
val = s.src[1]
if val.op is Ops.CAST and val.dtype == dtypes.weakint and val.src[0].dtype.scalar() in dtypes.ints:
ns = s.src[0].where(val.src[0], val.src[0].const_like(Invalid))
new_src.append(ns if ns is not None else s)
changed = changed or ns is not None
return idx.replace(src=tuple(new_src)) if changed else None
pm_lower_index_dtype = PatternMatcher([ pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later # There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y: (UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
@ -1552,6 +1565,7 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)), lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
(UPat(Ops.INDEX, name="idx"), lower_index_casts),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), (UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
]) ])

View file

@ -24,6 +24,10 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
from tinygrad.uop.validate import validate_index_with_z3 from tinygrad.uop.validate import validate_index_with_z3
return validate_index_with_z3(sz, idx, gate) return validate_index_with_z3(sz, idx, gate)
def validate_image_index(buf:UOp, idx0:UOp, idx1:UOp, gate:UOp|None=None):
if not isinstance(buf.dtype, ImageDType): return None
return validate_index(buf, idx0, gate) and validate_index(buf, idx1, gate)
# four specs: # four specs:
# shared_spec -- usable anywhere # shared_spec -- usable anywhere
# tensor_spec -- usable in tensor graph # tensor_spec -- usable in tensor graph
@ -180,6 +184,13 @@ shared_codegen_spec = PatternMatcher([
lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False), lambda buf,idx,gate,alt,load: validate_index(buf, idx, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat()), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).or_casted().store(UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(), validate_image_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().load(
UPat.var("alt"), UPat.var("gate", dtype=dtypes.bool), name="load"),
lambda buf,idx0,idx1,gate,alt,load: validate_image_index(buf, idx0, idx1, gate) if alt.dtype == load.dtype else False),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(UPat()), validate_image_index),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx0"), UPat.var("idx1"))).or_casted().store(
UPat(), UPat.var("gate", dtype=dtypes.bool)), validate_image_index),
# CUSTOM (inline and non inline) # CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True), (UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
@ -189,6 +200,8 @@ shared_codegen_spec = PatternMatcher([
# INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available. # INDEX is just address calculation. OOB validation is on LOAD/STORE where the gate is available.
(UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True), (UPat(GroupOp.Defines|{Ops.AFTER}).index(UPat()), lambda: True),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat(), UPat())),
lambda buf: True if isinstance(buf.dtype, ImageDType) else None),
# SPECIAL # SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)), (UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),

View file

@ -447,8 +447,13 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)), lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
# fold gated LOAD/STORE # fold gated LOAD/STORE
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)), (UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid), UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())),
lambda: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"), (UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid),
UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)),
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)), (UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
# store of where with invalid -> gated store # store of where with invalid -> gated store
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))), (UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))),