Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
c0c120bf58 cleanups 2026-05-01 23:52:45 +00:00
George Hotz
9596d13550 move gate to load/store 2026-05-01 23:37:31 +00:00
15 changed files with 185 additions and 128 deletions

View file

@ -375,22 +375,24 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
"""Conditional memory store with sub-word support. Returns list of store UOps."""
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
word_addr = addr >> UOp.const(adt, 2)
idx = mem.index(word_addr.cast(dtypes.int), active)
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
bidx = mem.index(word_addr.cast(dtypes.int), ptr=True)
if data_bits == 32: return [UOp(Ops.STORE, dtypes.void, (bidx, _to_u32(val), active))]
# Sub-word store: read-modify-write with mask
cur = bidx.load(active, _c(0, dtypes.uint32))
byte_pos = addr.cast(dtypes.uint32) & _c(3)
byte_shift = byte_pos * _c(8)
val_u32, size_mask = val.cast(dtypes.uint32), _c(0xFF if data_bits == 8 else 0xFFFF)
mask = size_mask << byte_shift
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift)
if data_bits == 8: return [idx.store(active.where(new_word, idx))]
new_word = (cur & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift)
if data_bits == 8: return [UOp(Ops.STORE, dtypes.void, (bidx, new_word, active))]
# 16-bit cross-word case: byte_pos == 3 means value spans two words
is_cross = byte_pos.eq(_c(3))
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), active & is_cross)
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
cross_word0 = (cur & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
store0 = UOp(Ops.STORE, dtypes.void, (bidx, is_cross.where(cross_word0, new_word), active))
next_bidx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), ptr=True)
next_cur = next_bidx.load(active & is_cross, _c(0, dtypes.uint32))
cross_word1 = (next_cur & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
return [store0, UOp(Ops.STORE, dtypes.void, (next_bidx, cross_word1, active & is_cross))]
def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> list[UOp]:
"""Store to byte-addressable memory (scratch). addr is byte offset, mem is uint8 buffer."""
@ -398,7 +400,8 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
for i in range(data_bits // 8):
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), active).store(byte_val.cast(dtypes.uint8)))
bidx = mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), ptr=True)
stores.append(UOp(Ops.STORE, dtypes.void, (bidx, byte_val.cast(dtypes.uint8), active)))
return stores
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
@ -516,14 +519,15 @@ class _Ctx:
# Dynamic register access (takes UOp index instead of int)
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
"""Read SGPR with dynamic register index."""
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load(valid, _c(0, dtypes.uint32))
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
# RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable.
valid = None if self.wave_size == 64 else reg.ne(_c(124))
return self.sgpr.index(reg.cast(dtypes.int), valid).store(val.cast(dtypes.uint32))
bidx = self.sgpr.index(reg.cast(dtypes.int), ptr=True)
return UOp(Ops.STORE, dtypes.void, (bidx, val.cast(dtypes.uint32))+((valid,) if valid is not None else ()))
def wmask(self, reg: UOp, val: UOp) -> list[UOp]:
"""Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64."""
@ -540,24 +544,26 @@ class _Ctx:
def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
"""Read VGPR with dynamic register index."""
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
return self.vgpr.index(idx, valid, ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load()
if valid is not None: return self.vgpr.index(idx, ptr=True).load(valid, _c(0, dtypes.uint32))
return self.vgpr.index(idx, ptr=True).load()
def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
"""Write VGPR with dynamic register index."""
buf = self.vgpr.after(after) if after is not None else self.vgpr
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
return UOp(Ops.STORE, dtypes.void, (buf.index(offset, ptr=True), val.cast(dtypes.uint32), _lane_active(exec_mask, lane)))
def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
"""Read ACCVGPR with dynamic register index (CDNA only)."""
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
return self.accvgpr.index(idx, valid, ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load()
if valid is not None: return self.accvgpr.index(idx, ptr=True).load(valid, _c(0, dtypes.uint32))
return self.accvgpr.index(idx, ptr=True).load()
def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
"""Write ACCVGPR with dynamic register index (CDNA only)."""
buf = self.accvgpr.after(after) if after is not None else self.accvgpr
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
return UOp(Ops.STORE, dtypes.void, (buf.index(offset, ptr=True), val.cast(dtypes.uint32), _lane_active(exec_mask, lane)))
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
@ -713,7 +719,7 @@ class _Ctx:
old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load()
new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32)
active = _lane_active(exec_mask, lane)
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val)))
raw_stores.append(('vgpr_direct', UOp(Ops.STORE, dtypes.void, (self.vgpr.index(val[0].cast(dtypes.int), ptr=True), new_val, active))))
continue
if 'D0' in dest and '[laneId]' in dest:
old_vcc = self.rmask(_c(VCC_LO.offset))
@ -1847,15 +1853,16 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
if data_bits < 32:
# Sub-dword LDS write: read-modify-write within the uint32 slot
word_addr = (addr >> addr_shift).cast(dtypes.int)
idx = mem.index(word_addr, active)
bidx = mem.index(word_addr, ptr=True)
cur = bidx.load(active, _c(0, dtypes.uint32))
byte_pos = addr.cast(dtypes.uint32) & _c(3)
byte_shift = byte_pos * _c(8)
size_mask = _c(0xFF if data_bits == 8 else 0xFFFF)
mask = size_mask << byte_shift
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
return idx.store(active.where(new_word, idx))
idx = mem.index((addr >> addr_shift).cast(dtypes.int))
return idx.store(active.where(val, idx.load()))
new_word = (cur & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
return UOp(Ops.STORE, dtypes.void, (bidx, new_word, active))
bidx = mem.index((addr >> addr_shift).cast(dtypes.int), ptr=True)
return UOp(Ops.STORE, dtypes.void, (bidx, val, active))
def make_srcs(lane: UOp) -> dict:
addr = make_addr(lane)
@ -2005,17 +2012,18 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
stores.append(ctx.lds.index(lds_idx, active).store(active.where(val, ctx.lds.index(lds_idx, active))))
bidx = ctx.lds.index(lds_idx, ptr=True)
stores.append(UOp(Ops.STORE, dtypes.void, (bidx, val, active)))
elif is_store:
for i in range(n_dwords):
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
idx = mem.index(word_addr.cast(dtypes.int64), in_bounds)
idx = mem.index(word_addr.cast(dtypes.int64), ptr=True)
val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane)
stores.append(idx.store(in_bounds.where(_to_u32(val), idx)))
stores.append(UOp(Ops.STORE, dtypes.void, (idx, _to_u32(val), in_bounds)))
else:
for i in range(n_dwords):
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), in_bounds, ptr=True).load(), _c(0))
val = mem.index(word_addr.cast(dtypes.int64), ptr=True).load(in_bounds, _c(0, dtypes.uint32))
stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask))
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())

View file

@ -828,28 +828,30 @@ class Parser:
assert mem is not None, "memory load requires _vmem or _lds"
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
active = self.vars.get('_active')
gate = (active,) if active is not None else ()
# gate now lives on LOAD; helper to construct gated load with 0 alt
def _gload(bidx, dtype):
return bidx.load(active, _const(dtype.base, 0)) if active is not None else bidx.load()
byte_mem = mem.dtype.base == dtypes.uint8
if byte_mem:
idx = addr.cast(dtypes.int)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
val = _u32(0).cast(dtypes.uint64)
for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8))
for i in range(8): val = val | (_gload(mem.index(idx + _const(dtypes.int, i), ptr=True), mem.dtype).cast(dtypes.uint64) << _u64(i * 8))
elif dt in (dtypes.uint8, dtypes.int8):
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
val = _gload(mem.index(idx, ptr=True), mem.dtype).cast(dt)
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
lo = _gload(mem.index(idx, ptr=True), mem.dtype).cast(dtypes.uint32)
hi = _gload(mem.index(idx + _const(dtypes.int, 1), ptr=True), mem.dtype).cast(dtypes.uint32)
val = (lo | (hi << _u32(8))).cast(dt)
else:
val = _u32(0)
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
for i in range(4): val = val | (_gload(mem.index(idx + _const(dtypes.int, i), ptr=True), mem.dtype).cast(dtypes.uint32) << _u32(i * 8))
else:
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
val = mem.index(idx, *gate)
val = _gload(mem.index(idx, ptr=True), mem.dtype)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
val = val.cast(dtypes.uint64) | (_gload(mem.index(idx2, ptr=True), mem.dtype).cast(dtypes.uint64) << _u64(32))
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
elif dt in (dtypes.uint16, dtypes.int16):
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
@ -862,7 +864,7 @@ class Parser:
idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64)
idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64)
safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native)
hi = mem.index(safe_idx_hi, *gate)
hi = _gload(mem.index(safe_idx_hi, ptr=True), mem.dtype)
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
return _cast_to(val, dt)

View file

@ -49,7 +49,7 @@ class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(UOp.sink(load, *extra)).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
idx, valid = load.src[0].src[1], load.src[1]
check_uop_against_string(self, idx, sidx)
check_uop_against_string(self, valid, svalid)
@ -225,9 +225,11 @@ class TestImageSimplification(unittest.TestCase):
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
if svalid is not None:
check_uop_against_string(self, load.src[0].src[2], svalid)
check_uop_against_string(self, load.src[1], svalid)
else:
self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid")
# gate is at LOAD.src[1] when present; if simplified away, src[1] should not be bool
self.assertFalse(len(load.src) >= 2 and load.src[1].dtype.scalar() == dtypes.bool,
"svalid is None but load still has a valid")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
@ -512,18 +514,34 @@ class TestUnfoldableImage(unittest.TestCase):
self.assertEqual(res.src[0].src[0].dtype, dtypes.float.ptr(400))
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True gate gets simplified to drop the gate
def test_drop_true_gate_on_load(self):
# test that LOAD with a constant True gate gets simplified to drop the gate
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
idx = UOp.const(dtypes.weakint, 0)
true_gate = UOp.const(dtypes.bool, True)
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
bidx = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx))
load = UOp(Ops.LOAD, dtypes.int, (bidx, true_gate))
# apply the optimization
result = graph_rewrite(index_with_gate, load_store_indexing)
# the True gate should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
result = graph_rewrite(load, load_store_indexing)
# the True gate should be dropped (LOAD should only have 1 source)
self.assertEqual(len(result.src), 1, "True gate should be dropped from LOAD")
def test_drop_true_gate_on_store(self):
# test that STORE with a constant True gate gets simplified to drop the gate
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
idx = UOp.const(dtypes.weakint, 0)
val = UOp.const(dtypes.int, 42)
true_gate = UOp.const(dtypes.bool, True)
bidx = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx))
store = UOp(Ops.STORE, dtypes.void, (bidx, val, true_gate))
# apply the optimization
result = graph_rewrite(store, load_store_indexing)
# the True gate should be dropped (STORE should only have 2 sources)
self.assertEqual(len(result.src), 2, "True gate should be dropped from STORE")
class TestRangeShrink(unittest.TestCase):
def get_ranges(self, sink):

View file

@ -428,7 +428,8 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
# alt is at src[2] in new gated LOAD shape (idx, gate, alt)
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg==5
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
@ -438,7 +439,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD: assert u.src[1].arg==5
if u.op is Ops.LOAD: assert u.src[2].arg==5
def test_where_on_gated_load_with_cast(self):
ridx0 = UOp.range(100, 0)
@ -451,7 +452,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg == 5
def test_where_on_casted_gated_load_extra_cond(self):
ridx0 = UOp.range(100, 0)

View file

@ -108,9 +108,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
lambda u, gate: (u.replace(src=u.src[:2]), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u.replace(src=u.src[:2]), UOp(Ops.ENDIF, src=(mif,))]))
])
# requires lst be toposorted. like graph rewrite, but for lines

View file

@ -53,10 +53,15 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
load_store_indexing = PatternMatcher([
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
# simplify away long after index has been lowered
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
# drop true gate from gated LOAD with alt: also drop the now-unused alt
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.const(dtypes.bool, True), UPat()), allow_any_len=True, name="ld"),
lambda ld,idx: ld.replace(src=(idx,)+ld.src[3:])),
# drop true gate from gated LOAD without alt
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.const(dtypes.bool, True)), allow_any_len=True, name="ld"),
lambda ld,idx: ld.replace(src=(idx,)+ld.src[2:])),
# drop true gate from STORE
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val"), UPat.const(dtypes.bool, True))),
lambda idx,val: idx.store(val)),
])
# ***** load/store grouping *****
@ -281,18 +286,18 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# give any gated loads (gate at src[1]) an alt value at src[2]
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool)), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.src[1], x.const_like(0))+x.src[2:])
if len(x.src) == 2 or x.src[2].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# Where after gated load becomes alt value
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat.var("c")), allow_any_len=True, name="l").or_casted(),
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
l.src[3:]).cast(a.dtype)),
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat.var("c", dtype=dtypes.bool).logical_not()),
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.dtype import AddrSpace, PtrDType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@ -31,8 +31,11 @@ class Estimates:
if u.op in {Ops.LOAD, Ops.STORE}:
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
# TODO: is this correct? this all needs to be cleaned up
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
# gate (bool-typed src) is part of indexing/predication, exclude its computation
# LOAD: gate at src[1]; STORE: gate at src[2]
gate_pos = 1 if u.op is Ops.LOAD else 2
if len(u.src) > gate_pos and u.src[gate_pos].dtype.scalar() == dtypes.bool:
dont_count = dont_count.union(u.src[gate_pos].toposort(range_gate))
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort())
for u in uops:

View file

@ -44,9 +44,9 @@ base_rewrite = PatternMatcher([
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("gate", dtype=dtypes.bool), UPat.var("var"))),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
@ -302,11 +302,12 @@ class OpenCLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("gate", dtype=dtypes.bool),
UPat.var("var"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), allow_any_len=True),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
UPat.var("var", dtypes.float.vec(4)))),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
]) + base_rewrite

View file

@ -76,7 +76,7 @@ base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted("idx"), UPat.var("mask", dtype=dtypes.bool), UPat.var("alt")), allow_any_len=True, name="x"),
lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"

View file

@ -125,19 +125,20 @@ class NIRRenderer(Renderer):
(UPat.cvar("x", dtypes.uints), lambda x: UOp.const(x.dtype, x.dtype.max+x.arg+1) if x.arg < 0 else None),
# from ptx
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
# load/store bool -> uint8
# load/store bool -> uint8 (alt at src[2] in new gated shape; preserve gate at src[1])
(UPat(Ops.LOAD, dtypes.bool, name="x"),
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x"),
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8)))),
lambda x: x.replace(dtype=dtypes.uint8, src=tuple(s.cast(dtypes.uint8) if i == 2 and s.dtype.scalar() == dtypes.bool else s
for i,s in enumerate(x.src))).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
])
@ -146,9 +147,10 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt")),
allow_any_len=True, name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
@ -268,9 +270,9 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val"))),
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val"))),
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt"))),
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
]) + NIRRenderer.def_rewrite

View file

@ -45,14 +45,15 @@ ptx_matcher = PatternMatcher([
# upcast to float32 all the ops that don't support half
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
# load/store bool -> uint8
# load/store bool -> uint8 (alt at src[2] in new gated shape; preserve gate at src[1])
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x"),
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8)))),
lambda x: UOp(x.op, dtypes.uint8, tuple(s.cast(dtypes.uint8) if i == 2 and s.dtype.scalar() == dtypes.bool else s
for i,s in enumerate(x.src))).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
# indexing on PTX is in uint64, we do the math while it's still in the graph
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op", allow_any_len=True), lambda buf,idx,op:
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)) \
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
@ -102,18 +103,18 @@ string_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
# store / gated load / load
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("var"))),
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt"))),
lambda ctx, x, loc, alt, gate, buf: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if alt.dtype.count > 1 else [
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),), allow_any_len=True),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),)),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple

View file

@ -10,21 +10,27 @@ def sign_extend(val:UOp, sext_am:int):
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
# store for char: buf[idx/4] <- (var << (idx%4)*8))
def packed_store(bidx:UOp, var:UOp):
def packed_store(bidx:UOp, var:UOp, *extra:UOp):
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
# preserve valid condition (bidx.src[2]) if it exists for gated stores
idx_src = (bidx.src[0], div_idx) if len(bidx.src) == 2 else (bidx.src[0], div_idx, bidx.src[2])
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, idx_src), dtype=dtypes.uint32)
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), (buf & wmask) | new_v)
new_idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
buf = UOp.load(new_idx, dtype=dtypes.uint32)
# preserve trailing srcs (e.g. gate at src[2] for gated stores)
return UOp(Ops.STORE, dtypes.void, (new_idx, (buf & wmask) | new_v) + extra)
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
elems, mask = 4//dtype.itemsize, _mask(dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2]) if var is not None else (bidx.src[0], div_idx))
load = UOp.load(idx, *([var] if var is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
new_idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
# rebuild LOAD srcs preserving gate at src[1] (if bool) and replacing alt with var if provided
other_srcs = list(root.src[1:])
if var is not None:
alt_pos = 1 if (len(other_srcs) >= 1 and other_srcs[0].dtype.scalar() == dtypes.bool) else 0
if alt_pos < len(other_srcs): other_srcs[alt_pos] = var
else: other_srcs.append(var)
load = UOp.load(new_idx, *other_srcs, dtype=dtypes.uint32, arg=root.arg)
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@ -40,12 +46,12 @@ def is_nan(a):
wgsl_matcher = PatternMatcher([
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), allow_any_len=True, name="l"),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
# TODO: load alt value doesnt have to be a const (alt is at src[2] in gated LOAD)
(UPat.load(UPat.var("b"), UPat.var("g", dtype=dtypes.bool), UPat.cvar("c"), name="l"),
lambda l,b,g,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True, name="sto"),
lambda bidx,var,sto: packed_store(bidx,var,*sto.src[2:]) if is_packed(var.dtype, bidx.dtype) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
# fix nan check: 'a != a -> is_nan()'
@ -81,15 +87,15 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, dtype=dtypes.short, name="x"), lambda ctx,x: f"bitcast<i32>(vec2<f16>({ctx[x.src[0]]},0))" \
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),
lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
# TODO: load alt value doesnt have to be a const (gated load: src[1]=gate, src[2]=alt)
(UPat.load(UPat.var("b"), UPat.var("g", dtype=dtypes.bool), UPat.cvar("v")),
lambda ctx,b,g,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
]) + base_rewrite

View file

@ -18,8 +18,10 @@ def _load(m, i, dtype: DType):
return from_storage_scalar(m[i], dtype)
def load(inp, j, dtype: DType):
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
# inp is [index_values, gates, alts] (gated load with alt) or [index_values] (plain load)
if len(inp) == 3: return [_load(m, x+j if x is not None else None, dtype) if g else default
for (m,x),g,default in zip(inp[0], inp[1], inp[2])]
return [_load(m, x+j if x is not None else None, dtype) for m,x in inp[0]]
def _store(m, i, v, dtype: DType):
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
@ -67,8 +69,10 @@ class PythonProgram:
continue
assert dtype is not None, f"{uop} is missing a dtype"
if uop is Ops.STORE:
# gate is at src[2] for gated stores; default to all-True for plain stores
gates = src_values[2] if len(src_values) >= 3 else [True]*len(src_values[0])
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
for (m,o,g),v in zip(src_values[0], val):
for (m,o),g,v in zip(src_values[0], gates, val):
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
@ -98,7 +102,7 @@ class PythonProgram:
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
else:
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
values[i] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0]
elif uop is Ops.RANGE:

View file

@ -1527,17 +1527,22 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# lower Invalid: lift gate from INDEX up to the parent LOAD/STORE
(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"),
UPat(Ops.CONST, arg=Invalid))).or_casted("bidx"),), allow_any_len=True, name="ld"),
lambda ld,buf,cond,idx,bidx: ld.replace(src=((nbidx:=buf.index(idx, ptr=True)) if bidx.op is Ops.INDEX
else bidx.replace(src=(buf.index(idx, ptr=True),)), cond) + ld.src[1:])),
(UPat(Ops.STORE, src=(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"),
UPat(Ops.CONST, arg=Invalid))).or_casted("bidx"), UPat.var("val")), name="st"),
lambda st,buf,cond,idx,bidx,val: st.replace(src=(buf.index(idx, ptr=True) if bidx.op is Ops.INDEX
else bidx.replace(src=(buf.index(idx, ptr=True),)), val, cond))),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
# vectorized indexes (ie. images) must be int
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)))))
])
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View file

@ -5,9 +5,9 @@ from tinygrad.uop.render import print_uops, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
def validate_index(buf:UOp, idx:UOp):
# gate now lives on LOAD/STORE; INDEX is always 2-src (buf, idx)
if idx.op is Ops.CONST and idx.arg is Invalid: return True
if gate is None: gate = UOp.const(dtypes.bool, True)
# TODO: check for overflow
if not CHECK_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
@ -17,12 +17,12 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
# TODO: validate these
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
for x in idx.toposort() | gate.toposort():
for x in idx.toposort():
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
# if all is good and CHECK_OOB=1, validate with z3
from tinygrad.uop.validate import validate_index_with_z3
return validate_index_with_z3(sz, idx, gate)
return validate_index_with_z3(sz, idx, UOp.const(dtypes.bool, True))
# four specs:
# shared_spec -- usable anywhere
@ -174,10 +174,12 @@ shared_codegen_spec = PatternMatcher([
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# LOAD(idx) / STORE(idx, val)
# LOAD(idx) / STORE(idx, val) / LOAD(idx, gate, alt?) gated / STORE(idx, val, gate) gated
(UPat().index(UPat()).or_casted().load(), lambda: True),
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(), lambda: True), # gated load (alt added in program_spec)
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool))), lambda: True), # gated load
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool), UPat())), lambda: True), # gated load with alt
(UPat(Ops.STORE, dtypes.void, src=(UPat(Ops.INDEX).or_casted(), UPat())), lambda: True),
(UPat(Ops.STORE, dtypes.void, src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), # gated store
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
@ -185,9 +187,8 @@ shared_codegen_spec = PatternMatcher([
# assembly instruction
(UPat(Ops.INS), lambda: True),
# INDEX (2-arg and 3-arg with bool gate)
# INDEX (always 2-arg, no gate; gate now lives on LOAD/STORE)
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
# SPECIAL
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
@ -236,8 +237,8 @@ tensor_spec = PatternMatcher([
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
# LOAD (idx, gate, alt_value), LOAD can have an alt value, but only if there's a gate
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool), UPat())), lambda: True),
# END closes ranges
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),