mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
can ai finish this
This commit is contained in:
parent
ee8ea27637
commit
4faa2b5bf3
16 changed files with 74 additions and 52 deletions
|
|
@ -129,15 +129,16 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
|
||||
def test_sqtt_cli(self):
|
||||
for pkl_path in sorted((EXAMPLES_DIR/self.target).glob("*.pkl")):
|
||||
out = run_cli("--profile-path", str(pkl_path), "--ls")
|
||||
no_rewrites = ("--rewrites-path", "")
|
||||
out = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "--ls")
|
||||
sqtt_traces = [l["value"].strip() for l in out if "SQTT" in l["value"]]
|
||||
for name in sqtt_traces:
|
||||
lines = run_cli("--profile-path", str(pkl_path), "-s", ansistrip(name))
|
||||
lines = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", ansistrip(name))
|
||||
self.assertIn("Clk", lines[0]["value"])
|
||||
waves = [r["clk"] for r in lines[2:] if "WAVE" in r["unit"]]
|
||||
self.assertEqual(waves, sorted(waves), f"wave timestamps not monotonic in {name}")
|
||||
with Context(DEBUG=2):
|
||||
kernels = run_cli("--profile-path", str(pkl_path), "-s", "AMD")
|
||||
kernels = run_cli(*no_rewrites, "--profile-path", str(pkl_path), "-s", "AMD")
|
||||
self.assertEqual(len(kernels), len(self.examples[pkl_path.stem][1]))
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ class TestQuantizeFP8(unittest.TestCase):
|
|||
@needs_second_gpu
|
||||
def test_multi(self):
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
|
||||
try:
|
||||
for dev in devs: Device[dev]
|
||||
except Exception as e: self.skipTest(f"8 devices not available: {e}")
|
||||
x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0)
|
||||
x = Tensor(x, device=devs)
|
||||
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous()
|
||||
|
|
|
|||
|
|
@ -1267,7 +1267,7 @@ class TestBufferView(unittest.TestCase):
|
|||
a = Tensor.arange(10*8).reshape(10, 8).clone().shard(devices, axis=1).realize()
|
||||
run_linear(*check_schedule(a.shrink(((2, 8), None)).shrink(((1, 4), None)).contiguous(), 0))
|
||||
|
||||
# negative tests: these should NOT become BUFFER_VIEW (non-contiguous per shard)
|
||||
# negative tests: these should NOT become SLICE (non-contiguous per shard)
|
||||
def test_expand_multi_not_buffer_view(self):
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(4*2).reshape(4, 1, 2).clone().shard(devices, axis=2).realize()
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
|
||||
def _assert(self, dtype: DType, a: Tensor):
|
||||
uops = self._schedule_render(a)
|
||||
# Assert the dtype of the INDEX value, This will need be updated if UOp spec changes
|
||||
# Assert the dtype of the buffer index value.
|
||||
store = next(uop for uop in uops if uop.op is Ops.STORE)
|
||||
assert store.op is Ops.STORE
|
||||
idx = self._find_op(store, Ops.INDEX)
|
||||
# PTX and NIR turn Ops.INDEX into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
||||
idx = self._find_op(store, Ops.SLICE)
|
||||
# PTX and NIR turn buffer indexing into pointer arithmetic earlier than cstyle, plus it's already cast to int64
|
||||
if not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)):
|
||||
assert idx.op is Ops.INDEX
|
||||
assert idx.op is Ops.SLICE
|
||||
idx_val = idx.src[1]
|
||||
self.assertIs(idx_val.dtype, dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True).replace(op=Ops.SLICE, arg=0), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -552,7 +552,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
self.assertIs(ld0.src[0].op, Ops.SLICE)
|
||||
self.assertEqual(ld0.src[0].src[1], lidx+2)
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(Ops.PARAM, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -564,7 +565,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 6) # +1 for SINK
|
||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
|
||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).replace(op=Ops.SLICE, arg=0).store(val)) # -2 to skip SINK
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from tinygrad.uop.render import pyrender
|
|||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -111,12 +111,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
return sink
|
||||
|
||||
# inject IF/ENDIF. only needed if device doesn't support gated stores
|
||||
def _gate_store(u:UOp, gate:UOp):
|
||||
return (st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]
|
||||
|
||||
pm_linearize_cleanups = PatternMatcher([
|
||||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.SLICE).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))), _gate_store),
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
lambda u, gate: _gate_store(u, gate) if isinstance((u.src[0].src[0] if u.src[0].op is Ops.CAST else u.src[0]).src[0].dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
# requires lst be toposorted. like graph rewrite, but for lines
|
||||
|
|
|
|||
|
|
@ -283,10 +283,11 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
# rewrite INDEX to SLICE
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
|
||||
# rewrite non-image INDEX to SLICE
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: None if isinstance(x.src[0].dtype, ImageDType) else \
|
||||
UOp(Ops.SLICE, dtype=x.dtype, src=x.src, arg=0 if x.dtype.count == 1 else x.dtype.count)),
|
||||
# rewrite CAST on SLICE to SLICE
|
||||
(UPat(Ops.SLICE, name="bv").cast(name="x"), lambda bv,x: bv.replace(dtype=x.dtype, arg=x.dtype.count))
|
||||
(UPat(Ops.SLICE, name="bv").cast(name="x"), lambda bv,x: bv.replace(dtype=x.dtype, arg=0 if x.dtype.count == 1 else x.dtype.count))
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, cast
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, ImageDType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
|
|
@ -29,8 +29,10 @@ class Estimates:
|
|||
def range_gate(x): return x.op is not Ops.RANGE
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
idx = u.src[0].src[0] if u.src[0].op is Ops.CAST else u.src[0]
|
||||
# if u.src[0] is a buffer index/view, include only index arithmetic since the buffer might be an AFTER
|
||||
is_image_index = idx.op is Ops.INDEX and isinstance(idx.src[0].dtype, ImageDType)
|
||||
dont_count = dont_count.union((UOp.sink(*idx.src[1:]) if idx.op is Ops.SLICE or is_image_index else idx).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
elif u.op is Ops.IF:
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
|
|||
|
||||
extra_pm = PatternMatcher([
|
||||
# devectorize any bools
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.SLICE), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
# CAST (from bool) can't be vectorized
|
||||
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
|
||||
# WHERE can't be vectorized
|
||||
|
|
@ -192,14 +192,15 @@ class CStyleLanguage(Renderer):
|
|||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI, Ops.SLICE} or \
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.CUSTOMI, Ops.SLICE} or \
|
||||
(u.op is Ops.INDEX and isinstance(u.src[0].dtype, ImageDType)) or \
|
||||
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
|
|
@ -305,7 +306,8 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
|
||||
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
|
||||
lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
|
||||
|
|
|
|||
|
|
@ -173,15 +173,17 @@ extra_matcher = PatternMatcher([
|
|||
|
||||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def bview(base:UOp, idx:UOp) -> UOp: return UOp(Ops.SLICE, base.dtype, (base, idx), 0)
|
||||
|
||||
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
|
||||
local_idx = bview(local, UOp.const(dtypes.int32, 0))
|
||||
ptr = gate.where(bview(base, idx), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
|
||||
return ptr.cast(cast.dtype).load(dtype=x.dtype)
|
||||
|
||||
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
|
||||
ptr = gate.where(bview(base, idx), bview(local, UOp.const(dtypes.int32, 0)))
|
||||
return ptr.cast(cast.dtype).store(val)
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
|
|
@ -203,8 +205,10 @@ pre_isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.STACK, src=(UPat.var("y"),), allow_any_len=True, name="x"),
|
||||
lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None),
|
||||
# gated load/store become a conditional move on the index, the load/store are unconditional
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").load(UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat.var("base").index(UPat.var("idx")).or_casted(name="cast").store(UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
(UPat(Ops.SLICE, src=(UPat.var("base"), UPat.var("idx"))).or_casted(name="cast").load(
|
||||
UPat.var("alt"), UPat.var("gate"), name="x"), gated_load),
|
||||
(UPat(Ops.SLICE, src=(UPat.var("base"), UPat.var("idx"))).or_casted(name="cast").store(
|
||||
UPat.var("val"), UPat.var("gate")), gated_store),
|
||||
# TODO: remove this once we allow all flag producing ops in cmove
|
||||
# if gate in scalar int cmove is not a comparison need to add one to set the flag
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
|
||||
|
|
@ -322,7 +326,7 @@ def idiv(ctx:IselContext, x:UOp) -> UOp:
|
|||
def fold_address(x:UOp) -> tuple[UOp, UOp, UOp]:
|
||||
def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.int8.max else dtypes.int8, v)
|
||||
def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v
|
||||
if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0))
|
||||
if x.op is not Ops.SLICE: return (x, UOp(Ops.NOOP), _disp(0))
|
||||
base, idx = x.src
|
||||
disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale))
|
||||
|
|
@ -550,7 +554,7 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.ins(X86Ops.VMOVDm)),
|
||||
(UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.ins(X86Ops.VMOVQm)),
|
||||
# index
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
|
||||
(UPat(Ops.SLICE, name="x"), lambda x: x.ins(X86Ops.LEA, src=fold_address(x))),
|
||||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# copy, load, store
|
||||
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
|
||||
|
|
@ -851,13 +855,13 @@ class X86Renderer(ISARenderer):
|
|||
|
||||
def spill(self, disp:UOp, x:UOp) -> UOp:
|
||||
nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype)
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).store(nx))
|
||||
ret = isel_matcher.rewrite(bview(self.stack_pointer(), disp).store(nx))
|
||||
assert ret is not None
|
||||
return ret.replace(src=(s if s is not nx else x for s in ret.src))
|
||||
|
||||
def fill(self, disp:UOp, x:UOp, reg:Register) -> UOp:
|
||||
ndt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype
|
||||
ret = isel_matcher.rewrite(self.stack_pointer().index(disp).load(dtype=ndt, tag=reg))
|
||||
ret = isel_matcher.rewrite(bview(self.stack_pointer(), disp).load(dtype=ndt, tag=reg))
|
||||
assert ret is not None
|
||||
return ret.replace(dtype=x.dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
|
|||
|
||||
base_rewrite = PatternMatcher([
|
||||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
(UPat(Ops.SLICE, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
|
|
|
|||
|
|
@ -136,11 +136,11 @@ class NIRRenderer(Renderer):
|
|||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
|
||||
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
# images need index to be int for nir
|
||||
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
|
||||
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
|
||||
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int)) if isinstance(buf.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
def_rewrite = PatternMatcher([
|
||||
|
|
@ -148,12 +148,12 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.SLICE, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
|
|
@ -189,7 +189,7 @@ class NIRRenderer(Renderer):
|
|||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.SLICE} or (u.op is Ops.INDEX and isinstance(u.src[0].dtype, ImageDType)): pass
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
|
|
|
|||
|
|
@ -51,8 +51,8 @@ ptx_matcher = PatternMatcher([
|
|||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, (x.src[0], x.src[1].cast(dtypes.uint8))+x.src[2:])),
|
||||
# indexing on PTX is in uint64, we do the math while it's still in the graph
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
|
||||
UOp(Ops.INDEX, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
|
||||
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("idx")), name="op"), lambda buf,idx,op:
|
||||
UOp(op.op, dtype=dtypes.int64, src=(buf, buf.cast(dtypes.int64)+idx.cast(dtypes.int64)*buf.dtype.itemsize)+op.src[2:]) \
|
||||
if op.dtype != dtypes.int64 and buf.dtype.addrspace != AddrSpace.REG else None),
|
||||
# ptx shr and shl instructions require y to be uint
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
|
|
@ -100,18 +100,19 @@ string_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.cast_types[x.dtype]}.{ctx.cast_types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
|
||||
# store / gated load / load
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var"))),
|
||||
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("alt"), UPat.var("gate"))),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),
|
||||
UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx, x, loc, alt, gate, buf: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if alt.dtype.count > 1 else [
|
||||
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
||||
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
# simple
|
||||
|
|
@ -204,8 +205,9 @@ class PTXRenderer(Renderer):
|
|||
if u.op is Ops.DEFINE_REG:
|
||||
r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(u.ptrdtype.size)]
|
||||
continue
|
||||
if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
|
||||
if u.op is Ops.INDEX:
|
||||
if u.op in {Ops.SLICE, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and \
|
||||
u.src[0].dtype.addrspace == AddrSpace.REG:
|
||||
if u.op is Ops.SLICE:
|
||||
assert u.src[1].op == Ops.CONST, f"index on REG in ptx only supported on CONST, not {u.src[1].op}"
|
||||
r[u] = r[u.src[0]][u.src[1].arg]
|
||||
else:
|
||||
|
|
@ -214,7 +216,7 @@ class PTXRenderer(Renderer):
|
|||
typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:])
|
||||
kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};")
|
||||
continue
|
||||
if u.op is Ops.INDEX: continue # other index we can skip
|
||||
if u.op is Ops.SLICE: continue
|
||||
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
|
|||
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
|
||||
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
idx = UOp(Ops.SLICE, bidx.dtype, (bidx.src[0], div_idx), bidx.arg)
|
||||
buf = UOp.load(idx, *((UOp.const(dtypes.uint32, 0), gate) if gate is not None else ()), dtype=dtypes.uint32)
|
||||
return UOp.store(idx, (buf & wmask) | new_v, *((gate,) if gate is not None else ()))
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ def packed_store(bidx:UOp, var:UOp, gate:UOp|None=None):
|
|||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|None=None):
|
||||
elems, mask = 4//dtype.itemsize, _mask(dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx))
|
||||
idx = UOp(Ops.SLICE, bidx.dtype, (bidx.src[0], div_idx), bidx.arg)
|
||||
load = UOp.load(idx, *((var, gate) if var is not None and gate is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
|
@ -90,7 +90,7 @@ class WGSLRenderer(CStyleLanguage):
|
|||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
(UPat(Ops.SLICE, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -1076,7 +1076,8 @@ class ProgramInfo:
|
|||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (idx:=u.src[0]).op in {Ops.INDEX, Ops.SLICE} or \
|
||||
(u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op in {Ops.INDEX, Ops.SLICE}):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ spec_program = PatternMatcher([
|
|||
|
||||
# slice in program
|
||||
(UPat(Ops.SLICE), lambda: True),
|
||||
(UPat(Ops.INDEX, name="idx"), lambda idx: isinstance(idx.src[0].dtype, ImageDType)),
|
||||
|
||||
# movement ops are not allowed in programs
|
||||
(UPat(GroupOp.Movement), lambda: False),
|
||||
|
|
@ -219,7 +220,7 @@ spec_program = PatternMatcher([
|
|||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX)))), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.INDEX, Ops.SLICE)).or_casted())), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
])+spec_shared
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue