mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update placeholder to not create DEFINE_LOCAL/DEFINE_REG (#16671)
* update placeholder to not create DEFINE_LOCAL/DEFINE_REG * simpler * define_local
This commit is contained in:
parent
091ec8d10d
commit
d7b10c69bc
7 changed files with 23 additions and 19 deletions
|
|
@ -1402,7 +1402,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
|
|||
acc_dt = dtypes.int32 if is_int_out else dtypes.float32
|
||||
# Use uint32 temp array to prevent optimizer from eliminating f16→f32 bitcast chains.
|
||||
# The optimizer folds bitcast(uint32→float32) stores to float32 arrays, losing the conversion.
|
||||
tmp = UOp(Ops.DEFINE_LOCAL, dtypes.uint32.ptr(n_a_elems + n_b_elems, addrspace=AddrSpace.LOCAL), arg=(n_a_elems + n_b_elems,))
|
||||
tmp = UOp.placeholder((n_a_elems + n_b_elems,), dtypes.uint32, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
def cvt_elem(raw: UOp, sub_idx: int) -> UOp:
|
||||
if is_i8:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ import itertools
|
|||
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
||||
from tinygrad.uop.ops import ParamArg
|
||||
from tinygrad.uop.render import pyrender
|
||||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -42,9 +41,6 @@ pm_remove_vec_dtypes = PatternMatcher([
|
|||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
|
||||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def do_number_param(ctx:list[int], x:UOp):
|
||||
|
|
|
|||
|
|
@ -245,11 +245,15 @@ def no_vectorized_alu(alu:UOp):
|
|||
return UOp(Ops.STACK, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
if not isinstance(buf.dtype, PtrDType): return None
|
||||
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
|
||||
# TODO: this fails on regs
|
||||
#assert buf.max_numel() == buf.ptrdtype.size
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.addrspace)).cast(buf.dtype)
|
||||
sz = buf.ptrdtype.size*buf.ptrdtype.count
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(sz, buf.addrspace), src=(UOp.const(dtypes.int, sz),)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
if buf.addrspace not in (AddrSpace.LOCAL, AddrSpace.REG): return None
|
||||
cnt = cast.dtype.count
|
||||
if bcast is not None and bcast.op is Ops.GEP:
|
||||
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
|
||||
|
|
@ -264,11 +268,11 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
|||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -158,10 +158,9 @@ class CStyleLanguage(Renderer):
|
|||
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
|
||||
|
||||
def render_buffer(self, x:UOp):
|
||||
shp = x.src[0].as_shape
|
||||
lanes = 1
|
||||
prefix = f"{self.smem_align}{self.smem_prefix}" if x.addrspace == AddrSpace.LOCAL else ""
|
||||
suffix = f"[{shp[0]}]" if len(shp) else ""
|
||||
suffix = f"[{x.max_numel()}]"
|
||||
return f"{prefix}{self._render_dtype(x.dtype, sz=lanes)} {self[x]}{suffix};"
|
||||
|
||||
def _render_dtype(self, dtype:DType, sz:int=1, addrspace=AddrSpace.ALU, mutable=True, override_ptr=False):
|
||||
|
|
|
|||
|
|
@ -508,7 +508,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
|
|||
return PatternMatcher(pat)
|
||||
|
||||
pm_long_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
|
||||
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
|
|
@ -531,7 +531,7 @@ pm_long_decomp = PatternMatcher([
|
|||
|
||||
# float decomposition patterns - ctx is (fr, to) tuple
|
||||
pm_float_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
(UPat((*GroupOp.Defines, Ops.BUFFER, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
|
||||
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
|
||||
# bitcasted load should just replace load
|
||||
|
|
|
|||
|
|
@ -1060,10 +1060,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
|
||||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
arg = ParamArg(slot, addrspace=addrspace) if addrspace is AddrSpace.GLOBAL else slot
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=arg)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
if addrspace is AddrSpace.GLOBAL:
|
||||
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
|
||||
else:
|
||||
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
|
||||
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
|
||||
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
||||
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
|
|
|
|||
|
|
@ -79,6 +79,8 @@ spec_shared = PatternMatcher([
|
|||
|
||||
# PARAM
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: isinstance(x.arg, ParamArg)),
|
||||
(UPat(Ops.BUFFER, src=(UPat(),), name="x"), lambda x:
|
||||
isinstance(x.arg, ParamArg) and x.addrspace in (AddrSpace.REG, AddrSpace.LOCAL)),
|
||||
|
||||
# GROUP of stores (or groups, or NOOPs)
|
||||
# TODO: remove UNROLL here, it's for SPEC=2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue