Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
ae6c2d2b96 remove dtype.vec(2) from image always 2026-05-14 19:34:04 +00:00
5 changed files with 70 additions and 36 deletions

View file

@ -20,8 +20,10 @@ def get_gated_load_uop(valid:UOp, idx:UOp):
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
# idx is (idx_x, idx_y) — INDEX uses (buf, y, x) order
idx_x, idx_y = idx
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.STACK, dtypes.weakint.vec(2), idx).valid(valid), ptr=True),
UOp(Ops.PARAM, dtypes.imagef(image_shape), arg=0).index(idx_y.valid(valid), idx_x.valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
@ -222,17 +224,18 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = simplify_image_idx(load.sink()).src[0]
off = load.src[0].src[1]
idx = off.get_idx()
self.assertEqual(idx.op, Ops.STACK)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
index = load.src[0]
# INDEX has 3 srcs: (buf, y_with_valid, x_with_valid)
self.assertEqual(len(index.src), 3, f"expected 3-src image INDEX, got {len(index.src)}")
off_y, off_x = index.src[1], index.src[2]
idx0 = off_x.get_idx() # sidx0 is x coordinate
idx1 = off_y.get_idx() # sidx1 is y coordinate
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
if svalid is not None:
check_uop_against_string(self, off.get_valid(), svalid)
check_uop_against_string(self, off_y.get_valid(), svalid)
else:
self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
self.assertEqual(off_y.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

View file

@ -4,14 +4,14 @@ from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate, invalid_pat
from tinygrad.helpers import getenv, flatten, prod
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
@functools.cache
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
def _drop_valid_stmts(valid:UOp, idx_x:UOp, idx_y:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in valid.split_uop(Ops.AND):
@ -20,15 +20,16 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
test_x = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx_x)
test_y = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx_y)
if test_x.vmax < 0 or test_y.vmax < 0:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (width, height)):
for i,b in zip((idx_x, idx_y), (width, height)):
if i.is_increasing():
rw = i.substitute({X:X.const_like(test_value)})
if rw.vmin >= b or rw.vmax < 0:
@ -39,19 +40,27 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
return None # image simplification is handled by simplify_valid_image_load
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
def simplify_valid_image_load(buf:UOp, start_y:UOp, start_x:UOp, valid:UOp) -> UOp|None:
idx_y = uop_given_valid(valid, start_y)
idx_x = uop_given_valid(valid, start_x)
drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1])
drop_stmt = _drop_valid_stmts(valid, idx_x, idx_y, buf.dtype.shape[0], buf.dtype.shape[1])
if not drop_stmt and idx is start_idx: return None
if not drop_stmt and idx_y is start_y and idx_x is start_x: return None
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx.valid(new_valid) if new_valid is not None else idx, ptr=True)
if new_valid is not None:
return buf.index(idx_y.valid(new_valid), idx_x.valid(new_valid), ptr=True)
return buf.index(idx_y, idx_x, ptr=True)
load_store_indexing = PatternMatcher([
# image load valid idx simplification
# image load valid idx simplification (2D image index: buf, y, x each with valid gates)
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("cond_y").where(UPat.var("y"), invalid_pat),
UPat.var("cond_x").where(UPat.var("x"), invalid_pat))),
lambda buf,cond_y,y,cond_x,x,i: simplify_valid_image_load(buf, y, x, cond_y if cond_y is cond_x else cond_y & cond_x)),
# non-image valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
])
@ -64,11 +73,12 @@ def expand_index(ctx, buf:UOp, vec:UOp):
# search for dims that drop the most valid statements
best_drop, cands = -1, []
for ch, cw in ImageDType.valid_dims(dt, ctx.target.arch):
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
best_drop, cands = dropped, [(ch, cw, cidx)]
elif dropped == best_drop: cands.append((ch, cw, cidx))
cidx_x, cidx_y = uop_given_valid(valid, (x//4)%cw), uop_given_valid(valid, x//(4*cw))
if (dropped:=len(_drop_valid_stmts(valid, cidx_x, cidx_y, ch, cw))) > best_drop:
best_drop, cands = dropped, [(ch, cw, cidx_x, cidx_y)]
elif dropped == best_drop: cands.append((ch, cw, cidx_x, cidx_y))
# and tiebreak with indexing complexity (ie. number of nodes)
h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].gep(1).simplify().backward_slice))
h, w, _, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[3].simplify().backward_slice))
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4)))
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
# generate the individual indexes
@ -195,7 +205,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
def get_image_idx(idx:UOp, width:int):
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
idx_x, idx_y = (x // 4) % width, x // (4*width)
return idx.replace(src=(idx.src[0], UOp.vectorize(idx_x, idx_y).valid(valid)))
return idx.replace(src=(idx.src[0], idx_y.valid(valid), idx_x.valid(valid)))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
@ -204,7 +214,8 @@ def image_fixup(ls:UOp):
return ls.replace(src=(get_image_idx(ls.src[0].src[0], dt.shape[1]),)+ls.src[1:])
# this is an unprocessed image without a cast, we should just make it a buffer
if isinstance(dt, ImageDType) and (off:=ls.src[0].src[1]).get_idx().dtype != dtypes.weakint.vec(2):
if isinstance(dt, ImageDType) and len(ls.src[0].src) != 3:
off = ls.src[0].src[1]
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(off)
return ls.replace(src=(idx,), dtype=new_dt).cast(dtypes.float) if ls.op is Ops.LOAD else ls.replace(src=(idx, ls.src[1].cast(new_dt)))
@ -231,7 +242,6 @@ def no_vectorized_wmma(wmma:UOp):
def no_vectorized_alu(alu:UOp):
if alu.dtype.vcount == 1: return None
if alu.op is Ops.WHERE and alu.src[2].arg is Invalid: return None # image load/store has cond.where(idx.vec(2), Invalid) as the index
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.STACK, alu.dtype, alus)

View file

@ -2,20 +2,32 @@
from tinygrad.uop.ops import PatternMatcher, UPat, Ops
from tinygrad.dtype import Invalid, dtypes
_invalid = UPat(arg=Invalid)
pm_move_gates_from_index = PatternMatcher([
# here we create the alt value for load to be 0s and remove the where Invalid
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
# here we create the alt value for load to be 0s and remove the where Invalid (non-image: 2 srcs)
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), _invalid)).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), _invalid)).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
# image: 3 srcs with matching gates on y and x
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("y"), _invalid),
UPat.var("gate").where(UPat.var("x"), _invalid)).or_casted(name="cast").load(name="l"),
lambda buf,gate,y,x,cast,l: buf.index(y, x, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("y"), _invalid),
UPat.var("gate").where(UPat.var("x"), _invalid)).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,y,x,cast,data: buf.index(y, x, ptr=True).cast(cast.dtype).store(data, gate)),
# image: 3 srcs with mismatched gates on y and x — combine with AND
(UPat.var("buf").index(UPat.var("gate_y").where(UPat.var("y"), _invalid),
UPat.var("gate_x").where(UPat.var("x"), _invalid)).or_casted(name="cast").load(name="l"),
lambda buf,gate_y,y,gate_x,x,cast,l: buf.index(y, x, ptr=True).cast(cast.dtype).load(l.const_like(0), gate_y&gate_x, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate_y").where(UPat.var("y"), _invalid),
UPat.var("gate_x").where(UPat.var("x"), _invalid)).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate_y,y,gate_x,x,cast,data: buf.index(y, x, ptr=True).cast(cast.dtype).store(data, gate_y&gate_x)),
# Where after gated load becomes alt value
(UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate", dtype=dtypes.bool), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
(UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
# images use 2D INDEX now (y,x)
(UPat(Ops.INDEX, src=(UPat(), UPat((Ops.CONST, Ops.VCONST, Ops.STACK), name="vec")), name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], vec.gep(1).cast(dtypes.int), vec.gep(0).cast(dtypes.int))) if vec.dtype.count == 2 else None),
])

View file

@ -1559,6 +1559,13 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("gate").where(UPat.var("idx", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx,gate: buf.index(gate.where(idx, idx.const_like(Invalid)), ptr=True)),
# remove hanging casts for 3-src image INDEX (y, x)
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("y", dtypes.ints).cast(), UPat.var("x", dtypes.ints).cast())),
lambda buf,y,x: buf.index(y, x, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"),
UPat.var("gate_y").where(UPat.var("y", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)),
UPat.var("gate_x").where(UPat.var("x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,gate_y,y,gate_x,x: buf.index(gate_y.where(y, y.const_like(Invalid)), gate_x.where(x, x.const_like(Invalid)), ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
])

View file

@ -445,9 +445,11 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
UPat.load(UPat(Ops.INDEX, name="index")))),
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
# fold gated LOAD/STORE
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
# fold gated LOAD/STORE: any INDEX src being Invalid makes the whole access invalid
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="idx"),
lambda idx: idx.const_like(Invalid) if any(s.op is Ops.CONST and s.arg is Invalid for s in idx.src[1:]) else None),
(UPat(Ops.STORE, src=(UPat(Ops.CONST, arg=Invalid).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat(Ops.CONST, arg=Invalid).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
# store of where with invalid -> gated store