update placeholder to not create DEFINE_LOCAL/DEFINE_REG (#16671)

* update placeholder to not create DEFINE_LOCAL/DEFINE_REG

* simpler

* define_local
This commit is contained in:
George Hotz 2026-06-18 21:21:06 -07:00 committed by GitHub
commit d7b10c69bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 23 additions and 19 deletions

View file

@ -1402,7 +1402,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
acc_dt = dtypes.int32 if is_int_out else dtypes.float32
# Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains.
# The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion.
tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,))
tmp = UOp.placeholder((n_a_elems + n_b_elems,), dtypes.uint32, slot=0, addrspace=AddrSpace.LOCAL)
def cvt_elem(raw: UOp, sub_idx: int) -> UOp:
if is_i8:

View file

@ -4,12 +4,11 @@ import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import ParamArg
from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
@ -42,9 +41,6 @@ pm_remove_vec_dtypes = PatternMatcher([
# remove all vec dtypes
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):

View file

@ -245,11 +245,15 @@ def no_vectorized_alu(alu:UOp):
return UOp(Ops.STACK, alu.dtype, alus)
def no_vectorized_buf(buf:UOp):
if not isinstance(buf.dtype, PtrDType): return None
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
# TODO: this fails on regs
#assert buf.max_numel() == buf.ptrdtype.size
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.addrspace)).cast(buf.dtype)
sz = buf.ptrdtype.size*buf.ptrdtype.count
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(sz, buf.addrspace), src=(UOp.const(dtypes.int, sz),)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
cnt = cast.dtype.count
if bcast is not None and bcast.op is Ops.GEP:
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
@ -264,11 +268,11 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
no_vectorized_index),
])

View file

@ -158,10 +158,9 @@ class CStyleLanguage(Renderer):
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
def render_buffer(self, x:UOp):
shp = x.src[0].as_shape
lanes = 1
prefix = f"{self.smem_align}{self.smem_prefix}" if x.addrspace == AddrSpace.LOCAL else ""
suffix = f"[{shp[0]}]" if len(shp) else ""
suffix = f"[{x.max_numel()}]"
return f"{prefix}{self._render_dtype(x.dtype, sz=lanes)} {self[x]}{suffix};"
def _render_dtype(self, dtype:DType, sz:int=1, addrspace=AddrSpace.ALU, mutable=True, override_ptr=False):

View file

@ -508,7 +508,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
return PatternMatcher(pat)
pm_long_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
@ -531,7 +531,7 @@ pm_long_decomp = PatternMatcher([
# float decomposition patterns - ctx is (fr, to) tuple
pm_float_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda ctx,x:
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
# bitcasted load should just replace load

View file

@ -1060,10 +1060,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
if len(shape) > 1: ret = ret.reshape(shape)
if addrspace is AddrSpace.GLOBAL:
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
else:
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"

View file

@ -79,6 +79,8 @@ spec_shared = PatternMatcher([
# PARAM
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg)),
(UPat(Ops.BUFFER, src=(UPat(),), name="x"), lambda x:
isinstance(x.arg, ParamArg) and x.addrspace in (AddrSpace.REG, AddrSpace.LOCAL)),
# GROUP of stores (or groups, or NOOPs)
# TODO: remove UNROLL here, it's for SPEC=2