remove DEFINE_LOCAL and DEFINE_REG (gpt) (#16673)

* remove define_local and define_reg (gpt)

* fix precommit

* cleanups

* regalloc fix

* cleanups 2
This commit is contained in:
George Hotz 2026-06-19 10:07:50 -07:00 committed by GitHub
commit 649971f02a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 53 additions and 61 deletions

View file

@ -458,7 +458,8 @@ def test_matmul():
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
P_lds = QP_lds[:, :BLOCK_N]
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
rw1 = UOp.range(TM, 296, AxisType.LOOP)
rw2 = UOp.range(TN, 297, AxisType.LOOP)
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)

View file

@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),

View file

@ -219,7 +219,8 @@ def test_matmul():
def asm_kernel(A, B, C):
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(THREADS, "lidx0")]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))

View file

@ -70,7 +70,7 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
num_threads = A.shape[0]
threads = UOp.special(num_threads, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
isa = r4 if arch == "rdna4" else r3
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
@ -103,7 +103,7 @@ def custom_handwritten(A:UOp) -> UOp:
A = A.flatten()
threads = UOp.special(128, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
k = Kernel()
# wrap in loop to filter out icache misses

View file

@ -76,7 +76,7 @@ class TestLinearizer(unittest.TestCase):
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
@ -161,7 +161,7 @@ class TestLinearizer(unittest.TestCase):
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0],
[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).src[2].src)
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
accs = [u for u in uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
@ -210,14 +210,14 @@ class TestLinearizer(unittest.TestCase):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
self.assertEqual(local[0].dtype.base, expected_dtype)
tests = (
@ -243,7 +243,7 @@ class TestLinearizer(unittest.TestCase):
r = (x@y).relu()
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
ast = helper_linearizer_opt(r, [opt])
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
# the uops graph is reg BUFFER -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
@ -361,7 +361,8 @@ class TestLinearizer(unittest.TestCase):
ast = helper_linearizer_opt(out, opts=[opt])
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
local_stores = [u for u in uops if u.op is Ops.STORE and any(
x.op is Ops.BUFFER and x.addrspace is AddrSpace.LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.PARAM for x in get_recursive(u.src[0]))]
barrier = [u for u in uops if u.op is Ops.BARRIER]
assert len(barrier) == 1

View file

@ -3,6 +3,7 @@ import numpy as np
import tempfile, unittest
from tinygrad import Tensor, Context, Device, dtypes, UOp
from tinygrad.uop.ops import Ops
from tinygrad.dtype import AddrSpace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
@ -80,7 +81,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
with Context(QUANTIZE=1):
linear = run_onnx({"input":inp})["output"].schedule_linear()
prg = to_program(linear.src[-2].src[0], renderer=Device[Device.DEFAULT].renderer)
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.DEFINE_REG]
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")

View file

@ -1425,7 +1425,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
mant = h & UOp.const(dtypes.uint32, 0x3FF)
# Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain
# bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats
# Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization
# Instead: construct f32 bits properly, use a local uint32 array to force materialization
f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \
((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \
(mant << UOp.const(dtypes.uint32, 13))

View file

@ -35,11 +35,11 @@ pm_index_is_shrink = PatternMatcher([
pm_remove_vec_dtypes = PatternMatcher([
# rewrite PARAM to non pointer
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
(UPat((Ops.PARAM, Ops.BUFFER), name="buf"), lambda buf:
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
# remove all vec dtypes
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER}, name="x"),
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink

View file

@ -268,12 +268,10 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
(UPat(Ops.BUFFER, name="buf"), no_vectorized_buf),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), no_vectorized_index),
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), no_vectorized_index),
])
devectorize_alu = PatternMatcher([

View file

@ -2,6 +2,7 @@ import heapq
from typing import Any
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
def linearize(sink:UOp) -> list[UOp]:
@ -23,7 +24,7 @@ def linearize(sink:UOp) -> list[UOp]:
match u.op:
# the order and placement of these defines is important
case Ops.PARAM: priority, extra = -20, u.arg.slot
case Ops.BUFFER: priority = -18
case Ops.BUFFER: priority = -17 if u.addrspace == AddrSpace.LOCAL else -18
case Ops.LOAD: priority = -1 # place loads early
case Ops.STORE: priority = 1 # place stores late
case Ops.RANGE: priority = 5 # placing RANGE is good

View file

@ -83,10 +83,10 @@ class LinearScanRegallocContext:
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
self.reals.setdefault(i, {})[v] = live[v]
# allocate stack array, BUFFER size is in src[0]
# allocate stack array
if u.op is Ops.BUFFER:
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
self.stack_size += u.src[0].arg * u.dtype.itemsize
self.stack_size += u.max_numel() * u.dtype.itemsize
# loop prologue, avoid loading inside the loop
if u.op is Ops.RANGE:

View file

@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.BUFFER}: continue
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")

View file

@ -39,7 +39,7 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
for u in sink.toposort():
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
elif u.op is Ops.PARAM: n_bufs += 1
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.BUFFER and u.addrspace is AddrSpace.LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
code_bytes = b"".join(inst.to_bytes() for inst in insts)
arch = next(v for k, v in _arch_map.items() if arch.startswith(k))

View file

@ -4,7 +4,7 @@ import sys, struct, functools
from typing import cast
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, ParamArg
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
@ -174,7 +174,7 @@ extra_matcher = PatternMatcher([
# ***** X86 pre instruction selection *****
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
return UOp(Ops.BUFFER, elem_dt, src=(UOp.const(dtypes.int, count),), arg=ParamArg(slot, addrspace=AddrSpace.LOCAL))
return UOp.placeholder((count,), elem_dt, slot, AddrSpace.LOCAL)
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))

View file

@ -23,8 +23,8 @@ dsp_pm_late = PatternMatcher([
(UPat.var("x")+UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")*UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")//UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat(Ops.DEFINE_REG, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
(UPat(Ops.BUFFER, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]) if d.addrspace is AddrSpace.REG else None),
])
# NOTE: this just increases readability of the generated code

View file

@ -8,8 +8,8 @@ from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_claus
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.LOAD, Ops.CALL, Ops.FUNCTION}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None

View file

@ -427,7 +427,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
ended_stores.append(store_target.replace(dtype=sdtype).store(store.src[1]).end(*end_rngs))
return buf.after(*ended_stores)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
# NOTE: the local BUFFER needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
if x.src[0].op is Ops.SLICE:
@ -561,11 +561,11 @@ rangeify_codegen = PatternMatcher([
# TODO: this can be moved into codegen?
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0] if len(x.src) else None),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
])

View file

@ -19,10 +19,7 @@ class Ops(FastEnum):
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
# define LOCAL/REG allocate things
DEFINE_LOCAL = auto(); DEFINE_REG = auto()
# BUFFER is the new LOCAL/REG
# BUFFER allocates global/local/register storage depending on its addrspace
BUFFER = auto()
# ** 2 -- non op uops **
@ -125,7 +122,7 @@ class GroupOp:
# TODO: is BITCAST always Elementwise if it's shape changing?
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Defines = {Ops.PARAM, Ops.BUFFER}
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}

View file

@ -279,7 +279,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.GETADDR: return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.BINARY: return (len(self.arg),)
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
case Ops.BUFFER:
if isinstance(self.arg, ParamArg):
if len(self.src): return self.src[0].as_shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
return (self.dtype.count,) if self.dtype.count > 1 else ()
return (self.arg,)
case Ops.SLICE:
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
if self.src[0].op is Ops.INDEX: return ()
@ -288,13 +293,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.STAGE:
# STAGE adds the existing shape to the front, opposite of INDEX
return tuple([int(r.vmax+1) for r in self.src[1:]])+self.src[0].shape
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
if len(self.src) >= 1:
# NOTE: this is the same as PARAM
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
if isinstance(self.dtype, PtrDType):
return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
return (self.dtype.count,) if self.dtype.count > 1 else ()
case Ops.PARAM:
if isinstance(self.dtype, ImageDType): return self.dtype.shape
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
@ -794,8 +792,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def addrspace(self) -> AddrSpace|None:
if self.op is Ops.PARAM: return self.arg.addrspace
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:

View file

@ -86,12 +86,8 @@ spec_shared = PatternMatcher([
# TODO: remove UNROLL here, it's for SPEC=2
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
# TOOD: these should be buffer with different addrspace everywhere.
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
Ops.BITCAST, Ops.INS})),),
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.AFTER, Ops.MULTI, Ops.BITCAST, Ops.INS})),),
allow_any_len=True), lambda: True),
# CUSTOM (inline and non inline)
@ -178,7 +174,7 @@ spec_tensor = PatternMatcher([
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
lambda root,x: root.dtype == x.dtype),
# TODO: this should not be here. STAGE is transformed to DEFINE_LOCAL later
# TODO: this should not be here. STAGE is transformed to BUFFER later
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
@ -198,7 +194,7 @@ spec_tensor = PatternMatcher([
# these ops can exist in programs but not the tensor spec. example: LOAD
spec_program = PatternMatcher([
# no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
(UPat(Ops.GEP), lambda: False),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),

View file

@ -46,7 +46,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE
from tinygrad.dtype import dtypes, AddrSpace
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#D8F9E4", Ops.STACK: "#D8F9E4",
Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",