can ai finish this

This commit is contained in:
George Hotz 2026-05-27 01:12:59 +00:00
commit 4faa2b5bf3
16 changed files with 74 additions and 52 deletions

View file

@ -129,15 +129,16 @@ class TestSQTTMapBase(unittest.TestCase):
def test_sqtt_cli(self):
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
out = run_cli("--profile-path", str(pkl_path), "--ls")
no_rewrites = ("--rewrites-path", "")
out = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "--ls")
sqtt_traces = [l["value"].strip() for l in out if "SQTT" in l["value"]]
for name in sqtt_traces:
lines = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
lines = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", ansistrip(name))
self.assertIn("Clk", lines[0]["value"])
waves = [r["clk"] for r in lines[2:] if "WAVE" in r["unit"]]
self.assertEqual(waves, sorted(waves), f"wave timestamps not monotonic in {name}")
with Context(DEBUG=2):
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD")
kernels = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", "AMD")
self.assertEqual(len(kernels), len(self.examples[pkl_path.stem][1]))
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"

View file

@ -74,6 +74,9 @@ class TestQuantizeFP8(unittest.TestCase):
@needs_second_gpu
def test_multi(self):
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
try:
for dev in devs: Device[dev]
except Exception as e: self.skipTest(f"8 devices not available: {e}")
x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0)
x = Tensor(x, device=devs)
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous()

View file

@ -1267,7 +1267,7 @@ class TestBufferView(unittest.TestCase):
a = Tensor.arange(10*8).reshape(10, 8).clone().shard(devices, axis=1).realize()
run_linear(*check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0))
# negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard)
# negative tests: these should NOT become SLICE (non-contiguous per shard)
def test_expand_multi_not_buffer_view(self):
devices = ("NULL:1", "NULL:2")
a = Tensor.arange(4*2).reshape(4, 1, 2).clone().shard(devices, axis=2).realize()

View file

@ -70,13 +70,13 @@ class TestIdxUpcast(unittest.TestCase):
def _assert(self, dtype: DType, a: Tensor):
uops = self._schedule_render(a)
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
# Assert the dtype of the buffer index value.
store = next(uop for uop in uops if uop.op is Ops.STORE)
assert store.op is Ops.STORE
idx = self._find_op(store, Ops.INDEX)
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
idx = self._find_op(store, Ops.SLICE)
# PTX and NIR turn buffer indexing into pointer arithmetic earlier than cstyle, plus it's already cast to int64
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
assert idx.op is Ops.INDEX
assert idx.op is Ops.SLICE
idx_val = idx.src[1]
self.assertIs(idx_val.dtype, dtype)

View file

@ -538,7 +538,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True).replace(op=Ops.SLICE, arg=0), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
@ -552,7 +552,8 @@ class TestUOpGraph(unittest.TestCase):
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
self.assertIs(ld0.src[0].op, Ops.SLICE)
self.assertEqual(ld0.src[0].src[1], lidx+2)
def test_fold_gated_store(self):
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
@ -564,7 +565,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 6) # +1 for SINK
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).replace(op=Ops.SLICE, arg=0).store(val)) # -2 to skip SINK
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):

View file

@ -8,7 +8,7 @@ from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
@ -111,12 +111,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
return sink
# inject IF/ENDIF. only needed if device doesn't support gated stores
def _gate_store(u:UOp, gate:UOp):
return (st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.SLICE).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))), _gate_store),
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
lambda u, gate: _gate_store(u, gate) if isinstance((u.src[0].src[0] if u.src[0].op is Ops.CAST else u.src[0]).src[0].dtype, ImageDType) else None),
])
# requires lst be toposorted. like graph rewrite, but for lines

View file

@ -283,10 +283,11 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# rewrite INDEX to SLICE
(UPat(Ops.INDEX, name="x"), lambda x: UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
# rewrite non-image INDEX to SLICE
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
# rewrite CAST on SLICE to SLICE
(UPat(Ops.SLICE, name="bv").cast(name="x"), lambda bv,x: bv.replace(dtype=x.dtype, arg=x.dtype.count))
(UPat(Ops.SLICE, name="bv").cast(name="x"), lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
from tinygrad.dtype import AddrSpace, PtrDType, DType, ImageDType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@ -29,8 +29,10 @@ class Estimates:
def range_gate(x): return x.op is not Ops.RANGE
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}:
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
idx = u.src[0].src[0] if u.src[0].op is Ops.CAST else u.src[0]
# if u.src[0] is a buffer index/view, include only index arithmetic since the buffer might be an AFTER
is_image_index = idx.op is Ops.INDEX and isinstance(idx.src[0].dtype, ImageDType)
dont_count = dont_count.union((UOp.sink(*idx.src[1:]) if idx.op is Ops.SLICE or is_image_index else idx).toposort(range_gate))
# TODO: is this correct? this all needs to be cleaned up
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
elif u.op is Ops.IF:

View file

@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
extra_pm = PatternMatcher([
# devectorize any bools
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.SLICE), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
# CAST (from bool) can't be vectorized
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
# WHERE can't be vectorized
@ -192,14 +192,15 @@ class CStyleLanguage(Renderer):
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI, Ops.SLICE} or \
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.CUSTOMI, Ops.SLICE} or \
(u.op is Ops.INDEX and isinstance(u.src[0].dtype, ImageDType)) or \
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
@ -305,7 +306,8 @@ class OpenCLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),

View file

@ -173,15 +173,17 @@ extra_matcher = PatternMatcher([
# ***** X86 pre instruction selection *****
def bview(base:UOp, idx:UOp) -> UOp: return UOp(Ops.SLICE, base.dtype, (base, idx), 0)
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
local_idx = bview(local, UOp.const(dtypes.int32, 0))
ptr = gate.where(bview(base, idx), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
return ptr.cast(cast.dtype).load(dtype=x.dtype)
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
ptr = gate.where(bview(base, idx), bview(local, UOp.const(dtypes.int32, 0)))
return ptr.cast(cast.dtype).store(val)
# these must be done in a separate matcher because they violate the spec
@ -203,8 +205,10 @@ pre_isel_matcher = PatternMatcher([
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
# gated load/store become a conditional move on the index, the load/store are unconditional
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
(UPat(Ops.SLICE, src=(UPat.var("base"), UPat.var("idx"))).or_casted(name="cast").load(
UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
(UPat(Ops.SLICE, src=(UPat.var("base"), UPat.var("idx"))).or_casted(name="cast").store(
UPat.var("val"), UPat.var("gate")), gated_store),
# TODO: remove this once we allow all flag producing ops in cmove
# if gate in scalar int cmove is not a comparison need to add one to set the flag
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
@ -322,7 +326,7 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
if x.op is not Ops.SLICE: return (x, UOp(Ops.NOOP), _disp(0))
base, idx = x.src
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
@ -550,7 +554,7 @@ isel_matcher = PatternMatcher([
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
# index
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
(UPat(Ops.SLICE, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# copy, load, store
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
@ -851,13 +855,13 @@ class X86Renderer(ISARenderer):
def spill(self, disp:UOp, x:UOp) -> UOp:
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
ret = isel_matcher.rewrite(bview(self.stack_pointer(), disp).store(nx))
assert ret is not None
return ret.replace(src=(s if s is not nx else x for s in ret.src))
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
ret = isel_matcher.rewrite(bview(self.stack_pointer(), disp).load(dtype=ndt, tag=reg))
assert ret is not None
return ret.replace(dtype=x.dtype)

View file

@ -74,7 +74,7 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
(UPat(Ops.SLICE, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask:

View file

@ -136,11 +136,11 @@ class NIRRenderer(Renderer):
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
# images need index to be int for nir
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int)) if isinstance(buf.dtype, ImageDType) else None),
])
def_rewrite = PatternMatcher([
@ -148,12 +148,12 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
(UPat(Ops.STORE, src=(UPat(Ops.SLICE, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
(UPat(Ops.LOAD, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
@ -189,7 +189,7 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, []
for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
if u.op in {Ops.NOOP, Ops.GROUP, Ops.SLICE} or (u.op is Ops.INDEX and isinstance(u.src[0].dtype, ImageDType)): pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]

View file

@ -51,8 +51,8 @@ ptx_matcher = PatternMatcher([
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
# indexing on PTX is in uint64, we do the math while it's still in the graph
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
UOp(op.op, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
# ptx shr and shl instructions require y to be uint
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
@ -100,18 +100,19 @@ string_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
# store / gated load / load
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var"))),
(UPat(Ops.STORE, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var"))),
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("alt"), UPat.var("gate"))),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),
UPat.var("alt"), UPat.var("gate"))),
lambda ctx, x, loc, alt, gate, buf: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if alt.dtype.count > 1 else [
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple
@ -204,8 +205,9 @@ class PTXRenderer(Renderer):
if u.op is Ops.DEFINE_REG:
r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(u.ptrdtype.size)]
continue
if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
if u.op is Ops.INDEX:
if u.op in {Ops.SLICE, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and \
u.src[0].dtype.addrspace == AddrSpace.REG:
if u.op is Ops.SLICE:
assert u.src[1].op == Ops.CONST, f"index on REG in ptx only supported on CONST, not {u.src[1].op}"
r[u] = r[u.src[0]][u.src[1].arg]
else:
@ -214,7 +216,7 @@ class PTXRenderer(Renderer):
typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:])
kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};")
continue
if u.op is Ops.INDEX: continue # other index we can skip
if u.op is Ops.SLICE: continue
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:

View file

@ -14,7 +14,7 @@ def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
idx = UOp(Ops.SLICE, bidx.dtype, (bidx.src[0], div_idx), bidx.arg)
buf = UOp.load(idx, *((UOp.const(dtypes.uint32, 0), gate) if gate is not None else ()), dtype=dtypes.uint32)
return UOp.store(idx, (buf & wmask) | new_v, *((gate,) if gate is not None else ()))
@ -22,7 +22,7 @@ def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|None=None):
elems, mask = 4//dtype.itemsize, _mask(dtype)
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
idx = UOp(Ops.SLICE, bidx.dtype, (bidx.src[0], div_idx), bidx.arg)
load = UOp.load(idx, *((var, gate) if var is not None and gate is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@ -90,7 +90,7 @@ class WGSLRenderer(CStyleLanguage):
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
(UPat(Ops.SLICE, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
]) + base_rewrite

View file

@ -1076,7 +1076,8 @@ class ProgramInfo:
if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.PARAM: _globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (idx:=u.src[0]).op in {Ops.INDEX, Ops.SLICE} or \
(u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op in {Ops.INDEX, Ops.SLICE}):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None

View file

@ -202,6 +202,7 @@ spec_program = PatternMatcher([
# slice in program
(UPat(Ops.SLICE), lambda: True),
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
# movement ops are not allowed in programs
(UPat(GroupOp.Movement), lambda: False),
@ -219,7 +220,7 @@ spec_program = PatternMatcher([
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.INDEX, Ops.SLICE)).or_casted())), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
])+spec_shared