mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
gate-on-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0c120bf58 | ||
|
|
9596d13550 |
15 changed files with 185 additions and 128 deletions
|
|
@ -375,22 +375,24 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
|
|||
"""Conditional memory store with sub-word support. Returns list of store UOps."""
|
||||
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
|
||||
word_addr = addr >> UOp.const(adt, 2)
|
||||
idx = mem.index(word_addr.cast(dtypes.int), active)
|
||||
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
|
||||
bidx = mem.index(word_addr.cast(dtypes.int), ptr=True)
|
||||
if data_bits == 32: return [UOp(Ops.STORE, dtypes.void, (bidx, _to_u32(val), active))]
|
||||
# Sub-word store: read-modify-write with mask
|
||||
cur = bidx.load(active, _c(0, dtypes.uint32))
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
byte_shift = byte_pos * _c(8)
|
||||
val_u32, size_mask = val.cast(dtypes.uint32), _c(0xFF if data_bits == 8 else 0xFFFF)
|
||||
mask = size_mask << byte_shift
|
||||
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift)
|
||||
if data_bits == 8: return [idx.store(active.where(new_word, idx))]
|
||||
new_word = (cur & (mask ^ _c(0xFFFFFFFF))) | ((val_u32 & size_mask) << byte_shift)
|
||||
if data_bits == 8: return [UOp(Ops.STORE, dtypes.void, (bidx, new_word, active))]
|
||||
# 16-bit cross-word case: byte_pos == 3 means value spans two words
|
||||
is_cross = byte_pos.eq(_c(3))
|
||||
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
|
||||
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
|
||||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), active & is_cross)
|
||||
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
|
||||
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
|
||||
cross_word0 = (cur & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
|
||||
store0 = UOp(Ops.STORE, dtypes.void, (bidx, is_cross.where(cross_word0, new_word), active))
|
||||
next_bidx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int), ptr=True)
|
||||
next_cur = next_bidx.load(active & is_cross, _c(0, dtypes.uint32))
|
||||
cross_word1 = (next_cur & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
|
||||
return [store0, UOp(Ops.STORE, dtypes.void, (next_bidx, cross_word1, active & is_cross))]
|
||||
|
||||
def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> list[UOp]:
|
||||
"""Store to byte-addressable memory (scratch). addr is byte offset, mem is uint8 buffer."""
|
||||
|
|
@ -398,7 +400,8 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int
|
|||
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
for i in range(data_bits // 8):
|
||||
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
|
||||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), active).store(byte_val.cast(dtypes.uint8)))
|
||||
bidx = mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int), ptr=True)
|
||||
stores.append(UOp(Ops.STORE, dtypes.void, (bidx, byte_val.cast(dtypes.uint8), active)))
|
||||
return stores
|
||||
|
||||
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
|
||||
|
|
@ -516,14 +519,15 @@ class _Ctx:
|
|||
# Dynamic register access (takes UOp index instead of int)
|
||||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read SGPR with dynamic register index."""
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load()
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load(valid, _c(0, dtypes.uint32))
|
||||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
|
||||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||||
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
|
||||
# RDNA: NULL (124) discards writes. CDNA: M0 (124) is writable.
|
||||
valid = None if self.wave_size == 64 else reg.ne(_c(124))
|
||||
return self.sgpr.index(reg.cast(dtypes.int), valid).store(val.cast(dtypes.uint32))
|
||||
bidx = self.sgpr.index(reg.cast(dtypes.int), ptr=True)
|
||||
return UOp(Ops.STORE, dtypes.void, (bidx, val.cast(dtypes.uint32))+((valid,) if valid is not None else ()))
|
||||
|
||||
def wmask(self, reg: UOp, val: UOp) -> list[UOp]:
|
||||
"""Write a lane mask (VCC/EXEC). Splits into lo/hi for wave64."""
|
||||
|
|
@ -540,24 +544,26 @@ class _Ctx:
|
|||
def rvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read VGPR with dynamic register index."""
|
||||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return self.vgpr.index(idx, valid, ptr=True).load() if valid is not None else self.vgpr.index(idx, ptr=True).load()
|
||||
if valid is not None: return self.vgpr.index(idx, ptr=True).load(valid, _c(0, dtypes.uint32))
|
||||
return self.vgpr.index(idx, ptr=True).load()
|
||||
|
||||
def wvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||||
"""Write VGPR with dynamic register index."""
|
||||
buf = self.vgpr.after(after) if after is not None else self.vgpr
|
||||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(offset, ptr=True), val.cast(dtypes.uint32), _lane_active(exec_mask, lane)))
|
||||
|
||||
def raccvgpr_dyn(self, reg: UOp, lane: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read ACCVGPR with dynamic register index (CDNA only)."""
|
||||
idx = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return self.accvgpr.index(idx, valid, ptr=True).load() if valid is not None else self.accvgpr.index(idx, ptr=True).load()
|
||||
if valid is not None: return self.accvgpr.index(idx, ptr=True).load(valid, _c(0, dtypes.uint32))
|
||||
return self.accvgpr.index(idx, ptr=True).load()
|
||||
|
||||
def waccvgpr_dyn(self, reg: UOp, lane: UOp, val: UOp, exec_mask: UOp, after: UOp | None = None) -> UOp:
|
||||
"""Write ACCVGPR with dynamic register index (CDNA only)."""
|
||||
buf = self.accvgpr.after(after) if after is not None else self.accvgpr
|
||||
offset = reg.cast(dtypes.int) * _c(self.wave_size, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(offset, ptr=True), val.cast(dtypes.uint32), _lane_active(exec_mask, lane)))
|
||||
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
|
||||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
|
||||
|
|
@ -713,7 +719,7 @@ class _Ctx:
|
|||
old = self.vgpr.index(val[0].cast(dtypes.int), ptr=True).load()
|
||||
new_val = _set_bits(old, _val_to_bits(val[1]), width, lo_bit).cast(dtypes.uint32)
|
||||
active = _lane_active(exec_mask, lane)
|
||||
raw_stores.append(('vgpr_direct', self.vgpr.index(val[0].cast(dtypes.int), active).store(new_val)))
|
||||
raw_stores.append(('vgpr_direct', UOp(Ops.STORE, dtypes.void, (self.vgpr.index(val[0].cast(dtypes.int), ptr=True), new_val, active))))
|
||||
continue
|
||||
if 'D0' in dest and '[laneId]' in dest:
|
||||
old_vcc = self.rmask(_c(VCC_LO.offset))
|
||||
|
|
@ -1847,15 +1853,16 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
|||
if data_bits < 32:
|
||||
# Sub-dword LDS write: read-modify-write within the uint32 slot
|
||||
word_addr = (addr >> addr_shift).cast(dtypes.int)
|
||||
idx = mem.index(word_addr, active)
|
||||
bidx = mem.index(word_addr, ptr=True)
|
||||
cur = bidx.load(active, _c(0, dtypes.uint32))
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
byte_shift = byte_pos * _c(8)
|
||||
size_mask = _c(0xFF if data_bits == 8 else 0xFFFF)
|
||||
mask = size_mask << byte_shift
|
||||
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
|
||||
return idx.store(active.where(new_word, idx))
|
||||
idx = mem.index((addr >> addr_shift).cast(dtypes.int))
|
||||
return idx.store(active.where(val, idx.load()))
|
||||
new_word = (cur & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
|
||||
return UOp(Ops.STORE, dtypes.void, (bidx, new_word, active))
|
||||
bidx = mem.index((addr >> addr_shift).cast(dtypes.int), ptr=True)
|
||||
return UOp(Ops.STORE, dtypes.void, (bidx, val, active))
|
||||
|
||||
def make_srcs(lane: UOp) -> dict:
|
||||
addr = make_addr(lane)
|
||||
|
|
@ -2005,17 +2012,18 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
|
|||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
|
||||
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
|
||||
stores.append(ctx.lds.index(lds_idx, active).store(active.where(val, ctx.lds.index(lds_idx, active))))
|
||||
bidx = ctx.lds.index(lds_idx, ptr=True)
|
||||
stores.append(UOp(Ops.STORE, dtypes.void, (bidx, val, active)))
|
||||
elif is_store:
|
||||
for i in range(n_dwords):
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
idx = mem.index(word_addr.cast(dtypes.int64), in_bounds)
|
||||
idx = mem.index(word_addr.cast(dtypes.int64), ptr=True)
|
||||
val = (ctx.raccvgpr_dyn if use_acc else ctx.rvgpr_dyn)(vdata + _c(i), lane)
|
||||
stores.append(idx.store(in_bounds.where(_to_u32(val), idx)))
|
||||
stores.append(UOp(Ops.STORE, dtypes.void, (idx, _to_u32(val), in_bounds)))
|
||||
else:
|
||||
for i in range(n_dwords):
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), in_bounds, ptr=True).load(), _c(0))
|
||||
val = mem.index(word_addr.cast(dtypes.int64), ptr=True).load(in_bounds, _c(0, dtypes.uint32))
|
||||
stores.append((ctx.waccvgpr_dyn if use_acc else ctx.wvgpr_dyn)(vdata + _c(i), lane, val, exec_mask))
|
||||
return UOp.sink(UOp.group(*stores).end(lane), *ctx.inc_pc())
|
||||
|
||||
|
|
|
|||
|
|
@ -828,28 +828,30 @@ class Parser:
|
|||
assert mem is not None, "memory load requires _vmem or _lds"
|
||||
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
|
||||
active = self.vars.get('_active')
|
||||
gate = (active,) if active is not None else ()
|
||||
# gate now lives on LOAD; helper to construct gated load with 0 alt
|
||||
def _gload(bidx, dtype):
|
||||
return bidx.load(active, _const(dtype.base, 0)) if active is not None else bidx.load()
|
||||
byte_mem = mem.dtype.base == dtypes.uint8
|
||||
if byte_mem:
|
||||
idx = addr.cast(dtypes.int)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
val = _u32(0).cast(dtypes.uint64)
|
||||
for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8))
|
||||
for i in range(8): val = val | (_gload(mem.index(idx + _const(dtypes.int, i), ptr=True), mem.dtype).cast(dtypes.uint64) << _u64(i * 8))
|
||||
elif dt in (dtypes.uint8, dtypes.int8):
|
||||
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
|
||||
val = _gload(mem.index(idx, ptr=True), mem.dtype).cast(dt)
|
||||
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
|
||||
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
|
||||
lo = _gload(mem.index(idx, ptr=True), mem.dtype).cast(dtypes.uint32)
|
||||
hi = _gload(mem.index(idx + _const(dtypes.int, 1), ptr=True), mem.dtype).cast(dtypes.uint32)
|
||||
val = (lo | (hi << _u32(8))).cast(dt)
|
||||
else:
|
||||
val = _u32(0)
|
||||
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
|
||||
for i in range(4): val = val | (_gload(mem.index(idx + _const(dtypes.int, i), ptr=True), mem.dtype).cast(dtypes.uint32) << _u32(i * 8))
|
||||
else:
|
||||
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
|
||||
val = mem.index(idx, *gate)
|
||||
val = _gload(mem.index(idx, ptr=True), mem.dtype)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
|
||||
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
|
||||
val = val.cast(dtypes.uint64) | (_gload(mem.index(idx2, ptr=True), mem.dtype).cast(dtypes.uint64) << _u64(32))
|
||||
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
|
||||
elif dt in (dtypes.uint16, dtypes.int16):
|
||||
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
|
||||
|
|
@ -862,7 +864,7 @@ class Parser:
|
|||
idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64)
|
||||
idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64)
|
||||
safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native)
|
||||
hi = mem.index(safe_idx_hi, *gate)
|
||||
hi = _gload(mem.index(safe_idx_hi, ptr=True), mem.dtype)
|
||||
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
|
||||
return _cast_to(val, dt)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
|||
def check(self, load, sidx, svalid, extra=()):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(UOp.sink(load, *extra)).src[0]
|
||||
idx, valid = load.src[0].src[1], load.src[0].src[2]
|
||||
idx, valid = load.src[0].src[1], load.src[1]
|
||||
check_uop_against_string(self, idx, sidx)
|
||||
check_uop_against_string(self, valid, svalid)
|
||||
|
||||
|
|
@ -225,9 +225,11 @@ class TestImageSimplification(unittest.TestCase):
|
|||
check_uop_against_string(self, idx0, sidx0)
|
||||
check_uop_against_string(self, idx1, sidx1)
|
||||
if svalid is not None:
|
||||
check_uop_against_string(self, load.src[0].src[2], svalid)
|
||||
check_uop_against_string(self, load.src[1], svalid)
|
||||
else:
|
||||
self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid")
|
||||
# gate is at LOAD.src[1] when present; if simplified away, src[1] should not be bool
|
||||
self.assertFalse(len(load.src) >= 2 and load.src[1].dtype.scalar() == dtypes.bool,
|
||||
"svalid is None but load still has a valid")
|
||||
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
|
|
@ -512,18 +514,34 @@ class TestUnfoldableImage(unittest.TestCase):
|
|||
self.assertEqual(res.src[0].src[0].dtype, dtypes.float.ptr(400))
|
||||
|
||||
class TestDropTrueGate(unittest.TestCase):
|
||||
def test_drop_true_gate_on_index(self):
|
||||
# test that INDEX with a constant True gate gets simplified to drop the gate
|
||||
def test_drop_true_gate_on_load(self):
|
||||
# test that LOAD with a constant True gate gets simplified to drop the gate
|
||||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.weakint, 0)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
|
||||
bidx = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx))
|
||||
load = UOp(Ops.LOAD, dtypes.int, (bidx, true_gate))
|
||||
# apply the optimization
|
||||
result = graph_rewrite(index_with_gate, load_store_indexing)
|
||||
# the True gate should be dropped (INDEX should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
|
||||
result = graph_rewrite(load, load_store_indexing)
|
||||
# the True gate should be dropped (LOAD should only have 1 source)
|
||||
self.assertEqual(len(result.src), 1, "True gate should be dropped from LOAD")
|
||||
|
||||
def test_drop_true_gate_on_store(self):
|
||||
# test that STORE with a constant True gate gets simplified to drop the gate
|
||||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
buf = UOp(Ops.PARAM, dtypes.int.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.weakint, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
bidx = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx))
|
||||
store = UOp(Ops.STORE, dtypes.void, (bidx, val, true_gate))
|
||||
# apply the optimization
|
||||
result = graph_rewrite(store, load_store_indexing)
|
||||
# the True gate should be dropped (STORE should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True gate should be dropped from STORE")
|
||||
|
||||
class TestRangeShrink(unittest.TestCase):
|
||||
def get_ranges(self, sink):
|
||||
|
|
|
|||
|
|
@ -428,7 +428,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([w, red])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
|
||||
# alt is at src[2] in new gated LOAD shape (idx, gate, alt)
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg==5
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
|
@ -438,7 +439,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD: assert u.src[1].arg==5
|
||||
if u.op is Ops.LOAD: assert u.src[2].arg==5
|
||||
|
||||
def test_where_on_gated_load_with_cast(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
|
@ -451,7 +452,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([w, red])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[2].arg == 5
|
||||
|
||||
def test_where_on_casted_gated_load_extra_cond(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
|
|
|||
|
|
@ -108,9 +108,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
|
|||
pm_linearize_cleanups = PatternMatcher([
|
||||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: (u.replace(src=u.src[:2]), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u.replace(src=u.src[:2]), UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
# requires lst be toposorted. like graph rewrite, but for lines
|
||||
|
|
|
|||
|
|
@ -53,10 +53,15 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
|||
load_store_indexing = PatternMatcher([
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate)), lambda buf,x,i,cond: simplify_valid_load(buf, x, cond)),
|
||||
# simplify away long after index has been lowered
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x", dtypes.long), UPat.var("c", dtypes.bool))), lambda buf,x,c: simplify_valid_load(buf, x, c)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x, ptr=True)),
|
||||
# drop true gate from gated LOAD with alt: also drop the now-unused alt
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.const(dtypes.bool, True), UPat()), allow_any_len=True, name="ld"),
|
||||
lambda ld,idx: ld.replace(src=(idx,)+ld.src[3:])),
|
||||
# drop true gate from gated LOAD without alt
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.const(dtypes.bool, True)), allow_any_len=True, name="ld"),
|
||||
lambda ld,idx: ld.replace(src=(idx,)+ld.src[2:])),
|
||||
# drop true gate from STORE
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val"), UPat.const(dtypes.bool, True))),
|
||||
lambda idx,val: idx.store(val)),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
|
|
@ -281,18 +286,18 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# give any gated loads (gate at src[1]) an alt value at src[2]
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool)), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.src[1], x.const_like(0))+x.src[2:])
|
||||
if len(x.src) == 2 or x.src[2].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# Where after gated load becomes alt value
|
||||
# NOTE: if a is CAST and a.src[0].dtype == l.dtype, use a.src[0] to avoid roundtrip cast (e.g. uint->float->uint)
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c")).or_casted(),), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat().index(UPat(), UPat.var("c", dtype=dtypes.bool).logical_not()).or_casted(),),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[2:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat.var("c")), allow_any_len=True, name="l").or_casted(),
|
||||
UPat.var("a")), lambda c,l,a: l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype))+
|
||||
l.src[3:]).cast(a.dtype)),
|
||||
(UPat.var("c").where(UPat.var("a"), UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat.var("c", dtype=dtypes.bool).logical_not()),
|
||||
allow_any_len=True, name="l").or_casted()), lambda c,l,a: l.replace(src=(l.src[0], l.src[1], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype
|
||||
else a.cast(l.dtype))+l.src[3:]).cast(a.dtype)),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, cast
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
|
|
@ -31,8 +31,11 @@ class Estimates:
|
|||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
# gate (bool-typed src) is part of indexing/predication, exclude its computation
|
||||
# LOAD: gate at src[1]; STORE: gate at src[2]
|
||||
gate_pos = 1 if u.op is Ops.LOAD else 2
|
||||
if len(u.src) > gate_pos and u.src[gate_pos].dtype.scalar() == dtypes.bool:
|
||||
dont_count = dont_count.union(u.src[gate_pos].toposort(range_gate))
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
for u in uops:
|
||||
|
|
|
|||
|
|
@ -44,9 +44,9 @@ base_rewrite = PatternMatcher([
|
|||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("gate", dtype=dtypes.bool), UPat.var("var"))),
|
||||
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
|
|
@ -302,11 +302,12 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("gate", dtype=dtypes.bool),
|
||||
UPat.var("var"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), allow_any_len=True),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),
|
||||
UPat.var("var", dtypes.float.vec(4)))),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ base_rewrite = PatternMatcher([
|
|||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted("idx"), UPat.var("mask", dtype=dtypes.bool), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
|
|
|
|||
|
|
@ -125,19 +125,20 @@ class NIRRenderer(Renderer):
|
|||
(UPat.cvar("x", dtypes.uints), lambda x: UOp.const(x.dtype, x.dtype.max+x.arg+1) if x.arg < 0 else None),
|
||||
# from ptx
|
||||
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
||||
# load/store bool -> uint8
|
||||
# load/store bool -> uint8 (alt at src[2] in new gated shape; preserve gate at src[1])
|
||||
(UPat(Ops.LOAD, dtypes.bool, name="x"),
|
||||
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8)))),
|
||||
lambda x: x.replace(dtype=dtypes.uint8, src=tuple(s.cast(dtypes.uint8) if i == 2 and s.dtype.scalar() == dtypes.bool else s
|
||||
for i,s in enumerate(x.src))).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
|
||||
# NIR requires shift amount to be 32 bit: https://docs.mesa3d.org/nir/alu.html#nir-alu-op-ishl
|
||||
(UPat((Ops.SHL, Ops.SHR), name="x"), lambda x: x.replace(src=(x.src[0], x.src[1].cast(dtypes.uint))) if x.src[1].dtype.bitsize != 32 else None),
|
||||
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
|
||||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
|
||||
])
|
||||
|
||||
|
|
@ -146,9 +147,10 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt")),
|
||||
allow_any_len=True, name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
|
||||
|
|
@ -268,9 +270,9 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
return _nload_img(ctx.b, ctx.r[img], ctx.r[coord], img.dtype)
|
||||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), allow_any_len=True), UPat.var("val"))),
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("val"))),
|
||||
lambda ctx,img,coord,val: nstore_img(ctx.b, ctx.r[img], ctx.r[coord], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("alt"))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt"))),
|
||||
lambda ctx,img,coord,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, coord), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('coord', dtypes.int.vec(2))),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
|
|
|||
|
|
@ -45,14 +45,15 @@ ptx_matcher = PatternMatcher([
|
|||
# upcast to float32 all the ops that don't support half
|
||||
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
|
||||
# load/store bool -> uint8
|
||||
# load/store bool -> uint8 (alt at src[2] in new gated shape; preserve gate at src[1])
|
||||
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x"),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8)))),
|
||||
lambda x: UOp(x.op, dtypes.uint8, tuple(s.cast(dtypes.uint8) if i == 2 and s.dtype.scalar() == dtypes.bool else s
|
||||
for i,s in enumerate(x.src))).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
|
||||
# indexing on PTX is in uint64, we do the math while it's still in the graph
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op", allow_any_len=True), lambda buf,idx,op:
|
||||
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
|
||||
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)) \
|
||||
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
|
||||
|
|
@ -102,18 +103,18 @@ string_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
# store / gated load / load
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("var"))),
|
||||
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))), UPat.var("gate", dtype=dtypes.bool), UPat.var("alt"))),
|
||||
lambda ctx, x, loc, alt, gate, buf: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if alt.dtype.count > 1 else [
|
||||
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
||||
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),), allow_any_len=True),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),)),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
# simple
|
||||
|
|
|
|||
|
|
@ -10,21 +10,27 @@ def sign_extend(val:UOp, sext_am:int):
|
|||
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
|
||||
|
||||
# store for char: buf[idx/4] <- (var << (idx%4)*8))
|
||||
def packed_store(bidx:UOp, var:UOp):
|
||||
def packed_store(bidx:UOp, var:UOp, *extra:UOp):
|
||||
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
|
||||
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
||||
# preserve valid condition (bidx.src[2]) if it exists for gated stores
|
||||
idx_src = (bidx.src[0], div_idx) if len(bidx.src) == 2 else (bidx.src[0], div_idx, bidx.src[2])
|
||||
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, idx_src), dtype=dtypes.uint32)
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), (buf & wmask) | new_v)
|
||||
new_idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
buf = UOp.load(new_idx, dtype=dtypes.uint32)
|
||||
# preserve trailing srcs (e.g. gate at src[2] for gated stores)
|
||||
return UOp(Ops.STORE, dtypes.void, (new_idx, (buf & wmask) | new_v) + extra)
|
||||
|
||||
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
||||
elems, mask = 4//dtype.itemsize, _mask(dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2]) if var is not None else (bidx.src[0], div_idx))
|
||||
load = UOp.load(idx, *([var] if var is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
|
||||
new_idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
# rebuild LOAD srcs preserving gate at src[1] (if bool) and replacing alt with var if provided
|
||||
other_srcs = list(root.src[1:])
|
||||
if var is not None:
|
||||
alt_pos = 1 if (len(other_srcs) >= 1 and other_srcs[0].dtype.scalar() == dtypes.bool) else 0
|
||||
if alt_pos < len(other_srcs): other_srcs[alt_pos] = var
|
||||
else: other_srcs.append(var)
|
||||
load = UOp.load(new_idx, *other_srcs, dtype=dtypes.uint32, arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
|
|
@ -40,12 +46,12 @@ def is_nan(a):
|
|||
wgsl_matcher = PatternMatcher([
|
||||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
||||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), allow_any_len=True, name="l"),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
|
||||
# TODO: load alt value doesnt have to be a const (alt is at src[2] in gated LOAD)
|
||||
(UPat.load(UPat.var("b"), UPat.var("g", dtype=dtypes.bool), UPat.cvar("c"), name="l"),
|
||||
lambda l,b,g,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True, name="sto"),
|
||||
lambda bidx,var,sto: packed_store(bidx,var,*sto.src[2:]) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
|
|
@ -81,15 +87,15 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.BITCAST, dtype=dtypes.short, name="x"), lambda ctx,x: f"bitcast<i32>(vec2<f16>({ctx[x.src[0]]},0))" \
|
||||
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),
|
||||
lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
|
||||
# TODO: load alt value doesnt have to be a const (gated load: src[1]=gate, src[2]=alt)
|
||||
(UPat.load(UPat.var("b"), UPat.var("g", dtype=dtypes.bool), UPat.cvar("v")),
|
||||
lambda ctx,b,g,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
|
||||
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ def _load(m, i, dtype: DType):
|
|||
return from_storage_scalar(m[i], dtype)
|
||||
|
||||
def load(inp, j, dtype: DType):
|
||||
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
|
||||
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
|
||||
# inp is [index_values, gates, alts] (gated load with alt) or [index_values] (plain load)
|
||||
if len(inp) == 3: return [_load(m, x+j if x is not None else None, dtype) if g else default
|
||||
for (m,x),g,default in zip(inp[0], inp[1], inp[2])]
|
||||
return [_load(m, x+j if x is not None else None, dtype) for m,x in inp[0]]
|
||||
|
||||
def _store(m, i, v, dtype: DType):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
|
|
@ -67,8 +69,10 @@ class PythonProgram:
|
|||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
if uop is Ops.STORE:
|
||||
# gate is at src[2] for gated stores; default to all-True for plain stores
|
||||
gates = src_values[2] if len(src_values) >= 3 else [True]*len(src_values[0])
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
for (m,o,g),v in zip(src_values[0], val):
|
||||
for (m,o),g,v in zip(src_values[0], gates, val):
|
||||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
|
|
@ -98,7 +102,7 @@ class PythonProgram:
|
|||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
else:
|
||||
for m,o in zip(src_values[0], src_values[1]): ret.append((m,o))
|
||||
values[i] = [(m,o,g) for (m,o),g in zip(ret, src_values[2] if len(src_values) == 3 else [True]*len(ret))] # set the gate last
|
||||
values[i] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
values[i] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
|
|
|
|||
|
|
@ -1527,17 +1527,22 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# lower Invalid: lift gate from INDEX up to the parent LOAD/STORE
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"),
|
||||
UPat(Ops.CONST, arg=Invalid))).or_casted("bidx"),), allow_any_len=True, name="ld"),
|
||||
lambda ld,buf,cond,idx,bidx: ld.replace(src=((nbidx:=buf.index(idx, ptr=True)) if bidx.op is Ops.INDEX
|
||||
else bidx.replace(src=(buf.index(idx, ptr=True),)), cond) + ld.src[1:])),
|
||||
(UPat(Ops.STORE, src=(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"),
|
||||
UPat(Ops.CONST, arg=Invalid))).or_casted("bidx"), UPat.var("val")), name="st"),
|
||||
lambda st,buf,cond,idx,bidx,val: st.replace(src=(buf.index(idx, ptr=True) if bidx.op is Ops.INDEX
|
||||
else bidx.replace(src=(buf.index(idx, ptr=True),)), val, cond))),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
# vectorized indexes (ie. images) must be int
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), name="idx"),
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)))))
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from tinygrad.uop.render import print_uops, pyrender
|
|||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
|
||||
def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
||||
def validate_index(buf:UOp, idx:UOp):
|
||||
# gate now lives on LOAD/STORE; INDEX is always 2-src (buf, idx)
|
||||
if idx.op is Ops.CONST and idx.arg is Invalid: return True
|
||||
if gate is None: gate = UOp.const(dtypes.bool, True)
|
||||
# TODO: check for overflow
|
||||
if not CHECK_OOB or isinstance(buf.dtype, ImageDType) or (sz := buf.ptrdtype.size) == -1: return True
|
||||
|
||||
|
|
@ -17,12 +17,12 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
|||
# TODO: validate these
|
||||
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
|
||||
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
|
||||
for x in idx.toposort() | gate.toposort():
|
||||
for x in idx.toposort():
|
||||
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
|
||||
# if all is good and CHECK_OOB=1, validate with z3
|
||||
from tinygrad.uop.validate import validate_index_with_z3
|
||||
return validate_index_with_z3(sz, idx, gate)
|
||||
return validate_index_with_z3(sz, idx, UOp.const(dtypes.bool, True))
|
||||
|
||||
# four specs:
|
||||
# shared_spec -- usable anywhere
|
||||
|
|
@ -174,10 +174,12 @@ shared_codegen_spec = PatternMatcher([
|
|||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# LOAD(idx) / STORE(idx, val)
|
||||
# LOAD(idx) / STORE(idx, val) / LOAD(idx, gate, alt?) gated / STORE(idx, val, gate) gated
|
||||
(UPat().index(UPat()).or_casted().load(), lambda: True),
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(), lambda: True), # gated load (alt added in program_spec)
|
||||
(UPat(Ops.INDEX).or_casted().store(UPat()), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool))), lambda: True), # gated load
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool), UPat())), lambda: True), # gated load with alt
|
||||
(UPat(Ops.STORE, dtypes.void, src=(UPat(Ops.INDEX).or_casted(), UPat())), lambda: True),
|
||||
(UPat(Ops.STORE, dtypes.void, src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), # gated store
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
|
@ -185,9 +187,8 @@ shared_codegen_spec = PatternMatcher([
|
|||
# assembly instruction
|
||||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# INDEX (2-arg and 3-arg with bool gate)
|
||||
# INDEX (always 2-arg, no gate; gate now lives on LOAD/STORE)
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
|
||||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.weakint, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
|
@ -236,8 +237,8 @@ tensor_spec = PatternMatcher([
|
|||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
# LOAD (idx, alt_value), LOAD can have an alt value, but only if the index has a gate
|
||||
(UPat().index(UPat(), UPat(dtype=dtypes.bool)).or_casted().load(UPat()), lambda: True),
|
||||
# LOAD (idx, gate, alt_value), LOAD can have an alt value, but only if there's a gate
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX).or_casted(), UPat(dtype=dtypes.bool), UPat())), lambda: True),
|
||||
|
||||
# END closes ranges
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue