Compare commits

...

17 commits

Author SHA1 Message Date
George Hotz
f597ec8d5f
Merge branch 'master' into minigen 2026-05-20 11:22:40 -07:00
George Hotz
4524c78e87 broadcast 2026-05-19 15:02:05 -07:00
George Hotz
651031e3da
Merge branch 'master' into minigen 2026-05-19 14:12:51 -07:00
George Hotz
79ad8f3ef4
Merge branch 'master' into minigen 2026-05-18 22:14:50 -07:00
George Hotz
c54151235e expanded 2026-05-18 21:00:57 -07:00
George Hotz
8a25d18135
Merge branch 'master' into minigen 2026-05-18 20:40:14 -07:00
George Hotz
4e3c3e66b0 fix weakint 2026-05-18 19:06:29 -07:00
George Hotz
474ec6e44d more amd fixes 2026-05-18 18:34:16 -07:00
George Hotz
7274239315 emu fixes 2026-05-18 18:23:10 -07:00
George Hotz
3378626308 fix custom kernel 2026-05-18 16:12:46 -07:00
George Hotz
9e408e239f fix rangeify tests 2026-05-18 15:56:57 -07:00
George Hotz
b7e3b92c54 fix pow 2026-05-18 15:37:26 -07:00
George Hotz
96ad6f05bf
Merge branch 'master' into minigen 2026-05-18 14:14:41 -07:00
George Hotz
f3e7efd39e control flow 2026-05-15 19:42:29 -07:00
George Hotz
a263eca378
Merge branch 'master' into minigen 2026-05-15 18:41:57 -07:00
George Hotz
fede39db53 minigen can render more 2026-05-15 09:07:20 -07:00
George Hotz
0e90a46ae0 start minigen, a minimal correct codegen 2026-05-15 08:43:06 -07:00
9 changed files with 169 additions and 16 deletions

View file

@ -375,7 +375,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
"""Conditional memory store with sub-word support. Returns list of store UOps."""
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
word_addr = addr >> UOp.const(adt, 2)
idx = mem.index(word_addr.cast(dtypes.int).valid(active))
idx = mem.index(word_addr.valid(active))
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
# Sub-word store: read-modify-write with mask
byte_pos = addr.cast(dtypes.uint32) & _c(3)
@ -388,7 +388,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
is_cross = byte_pos.eq(_c(3))
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int).valid(active & is_cross))
next_idx = mem.index((word_addr + UOp.const(adt, 1)).valid(active & is_cross))
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
@ -398,7 +398,7 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
for i in range(data_bits // 8):
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int).valid(active)).store(byte_val.cast(dtypes.uint8)))
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).valid(active)).store(byte_val.cast(dtypes.uint8)))
return stores
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
@ -463,7 +463,7 @@ class _Ctx:
"""Read instruction dword from vmem at PC + dword_idx*4."""
pc = self.rpc()
addr = pc if dword_idx == 0 else pc + UOp.const(dtypes.uint64, dword_idx * 4)
return self.vmem.index((addr >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int), ptr=True).load()
return self.vmem.index(addr >> UOp.const(dtypes.uint64, 2), ptr=True).load()
def inst_field(self, field) -> UOp:
"""Extract field bits from instruction encoding. Tracks field for canonical key computation."""
@ -516,8 +516,8 @@ class _Ctx:
# Dynamic register access (takes UOp index instead of int)
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
"""Read SGPR with dynamic register index."""
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int).valid(valid), ptr=True).load()
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
if valid is not None: return self.sgpr.index(reg.valid(valid), ptr=True).load()
return self.sgpr.index(reg, ptr=True).load()
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
@ -833,7 +833,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
nval = int(part.removeprefix('DWORD').removeprefix('X') or '1') if 'DWORD' in part else int(part[1:]) / 32 * (-1 if part[0] == 'I' else 1)
ndwords = max(1, int(abs(nval)))
dword_base = addr >> UOp.const(dtypes.uint64, 2)
vals = [ctx.vmem.index((dword_base + UOp.const(dtypes.uint64, i)).cast(dtypes.int)) for i in range(ndwords)]
vals = [ctx.vmem.index(dword_base + UOp.const(dtypes.uint64, i)) for i in range(ndwords)]
if abs(nval) < 1:
nbits = int(abs(nval) * 32)
byte_off = (addr & UOp.const(dtypes.uint64, 3)).cast(dtypes.uint32) * UOp.const(dtypes.uint32, 8)
@ -1847,7 +1847,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
def wmem(addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> UOp:
if data_bits < 32:
# Sub-dword LDS write: read-modify-write within the uint32 slot
word_addr = (addr >> addr_shift).cast(dtypes.int)
word_addr = addr >> addr_shift
idx = mem.index(word_addr.valid(active))
byte_pos = addr.cast(dtypes.uint32) & _c(3)
byte_shift = byte_pos * _c(8)
@ -1855,7 +1855,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
mask = size_mask << byte_shift
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
return idx.store(active.where(new_word, idx))
idx = mem.index((addr >> addr_shift).cast(dtypes.int))
idx = mem.index(addr >> addr_shift)
return idx.store(active.where(val, idx.load()))
def make_srcs(lane: UOp) -> dict:
@ -2005,7 +2005,7 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
for i in range(n_dwords):
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
lds_idx = (lds_addr + _c(i * 4)) >> _c(2)
lds_slot = ctx.lds.index(lds_idx.valid(active))
stores.append(lds_slot.store(active.where(val, lds_slot)))
elif is_store:

View file

@ -845,10 +845,10 @@ class Parser:
val = _u32(0)
for i in range(4): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
else:
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
idx = (addr >> _const(addr.dtype, 2))
val = mindex(idx)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2))
val = val.cast(dtypes.uint64) | (mindex(idx2).cast(dtypes.uint64) << _u64(32))
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
elif dt in (dtypes.uint16, dtypes.int16):

View file

@ -2,7 +2,7 @@ from typing import cast
from dataclasses import replace
import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, MINIGEN
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo
from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
@ -22,6 +22,7 @@ from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.minigen import minigen_to_sink
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
@ -29,6 +30,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor)
# mini codegen
if MINIGEN: return minigen_to_sink(ast, ren, optimize)
# preprocess
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)

142
tinygrad/codegen/minigen.py Normal file
View file

@ -0,0 +1,142 @@
from dataclasses import dataclass
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, GroupOp, UPat, Ops, AxisType
from tinygrad.uop.spec import type_verify, spec_program
from tinygrad.dtype import dtypes, Invalid, AddrSpace, least_upper_dtype
from tinygrad.helpers import SPEC
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import pm_add_loads, reduce_to_acc, merge_reduce_ends, pm_render
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow
from tinygrad.schedule.rangeify import pm_index_on_index, pm_mops
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.uop.symbolic import sym
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.codegen.late.expander import pm_group_for_reduce
def lower_weakint(x:UOp):
# no sources, it's an int
if len(x.src) == 0: return x.replace(dtype=dtypes.int)
return x.replace(dtype=least_upper_dtype(dtypes.int, *[u.dtype for u in x.src]))
pm_lower_weakint = PatternMatcher([
(UPat(GroupOp.All, dtypes.weakint, name="x"), lower_weakint),
])
@dataclass
class ReduceContext:
acc_num: int = 0
local_num: int = 0
def stage_to_local(ctx:ReduceContext, x:UOp):
# TODO: addrspace shouldn't be on dtype
ret = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(size=x.numel(), addrspace=AddrSpace.LOCAL), (), arg=ctx.local_num).reshape(*[u.vmax+1 for u in x.src[1:]])
ctx.local_num += 1
return ret.after(ret.index(*x.src[1:]).store(x.src[0]).end(*x.src[1:]))
pm_minimal_reduce = PatternMatcher([
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
(UPat(Ops.STAGE, name="x"), stage_to_local),
# TODO: this should not be here! this is caused by pm_load_collapse
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),
])
pm_minimal_move_gates_from_index = PatternMatcher([
# here we create the alt value for load to be 0s and remove the where Invalid
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
])
@dataclass
class ExpanderState:
total_dim: int
current_dim: int = 0
def expand_range(ctx:ExpanderState, r:UOp):
if r.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}: return None
ret = UOp.const(r.dtype, tuple(range(r.vmax+1)))
ret = ret.reshape([-1 if x == ctx.current_dim else 1 for x in range(ctx.total_dim)])
ctx.current_dim += 1
return ret
pm_mini_expander = PatternMatcher([
(UPat(Ops.RANGE, name="r"), expand_range),
])
def minigen_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = ast
if optimize:
# collapse loads reduce (indexing by a tensor), must run while REDUCE is still in the graph
sink = graph_rewrite(sink, pm_load_collapse, name="load collapse")
# split ranges
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges")
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# expanded
rcount = sum([1 for u in sink.toposort() if u.op is Ops.RANGE and u.arg[-1] in {AxisType.UPCAST, AxisType.UNROLL}])
sink = graph_rewrite(sink, pm_mini_expander+pm_group_for_reduce, ctx=ExpanderState(rcount), name="expander 2.0")
# REDUCE is not allowed in programs, we need to do that REDUCE somewhere
# this creates a register where the reduce happens
# STAGE is also not allowed in programs, this is similar to REDUCE
sink = graph_rewrite(sink, pm_minimal_reduce, ctx=ReduceContext(), name="remove reduce/stage")
# if there's any movement ops left, we need to remove them. removing stage might add movement ops
# also handle INDEX on INDEX
sink = graph_rewrite(sink, pm_index_on_index+pm_mops, name="remove movement ops")
# do single symbolic (this rewrites POW)
sink = graph_rewrite(sink, sym, name="symbolic")
# add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
# we need to add loads
# this is really a store to DEFINE_REG, but load is simpler
# LOAD(DATA) is anonymous store -> AFTER(anon_buf, STORE(anon_buf, DATA))
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
# we need to lower weakint for the program
# this will be simpler when we have implicit dtype
sink = graph_rewrite(sink, pm_lower_weakint, name="remove weakint")
# move gates from INDEX to LOAD/STORE (Invalid isn't renderable)
sink = graph_rewrite(sink, pm_minimal_move_gates_from_index, name="move gates")
# split ENDs, renderable ENDs can only end one RANGE
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
# **** enter decanonicalize *****
# decompose dtypes we don't support into renderable versions
# NOTE: this adds Ops.FLOORDIV, so it needs to come before decompose ops
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
# decompose ops like SIN/THREEFRY into renderable versions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = get_late_rewrite_patterns(supported_ops, disable_fast_idiv=True) + \
get_transcendental_patterns(supported_ops, force_transcendental=False)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.target, name="decompose ops to renderable")
# extra matcher from the renderer
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
sink = graph_rewrite(sink, pm_render+extra_matcher, ctx=ren.target, name="final rewrite")
# this was the linearizer, add control flow edges where they are needed on RANGEs
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
if SPEC: type_verify(sink, spec_program)
return sink

View file

@ -91,7 +91,7 @@ class DType(metaclass=DTypeMetaClass):
return float("inf") if dtypes.is_float(self) else True
def const(self, val: tuple[ConstType, ...]|ConstType):
if isinstance(val, tuple):
assert len(val) == self.count, f"mismatch {val} {self}"
#assert len(val) == self.count, f"mismatch {val} {self}"
return tuple(map(self.const, val))
if isinstance(val, InvalidType): return val
# NOTE: float('nan') != float('nan'), so we canonicalize here

View file

@ -268,6 +268,8 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# enable mini codegen
MINIGEN = ContextVar("MINIGEN", 1)
@dataclass(frozen=True)
class Metadata:

View file

@ -37,10 +37,13 @@ pm_store_ranges = PatternMatcher([
(UPat(Ops.STORE, name="x"), add_ranges_to_store),
])
pm_syntactic_sugar = PatternMatcher([
pm_index_on_index = PatternMatcher([
# INDEX on ptr INDEX concats them
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
])
pm_syntactic_sugar = pm_index_on_index+PatternMatcher([
# early rangeify
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise | {Ops.CONST}, name="x"),), allow_any_len=True, name="idx"),
lambda idx,x: x.replace(src=tuple([s.index(*idx.src[1:]) for s in x.src]))),

View file

@ -247,6 +247,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
# TODO: these should have the shape of the dtype.count
case Ops.VCONST: return (len(self.arg),)
case Ops.CONST | Ops.DEFINE_VAR: return ()
case Ops.GEP | Ops.STACK | Ops.VCAT | Ops.GETADDR: return ()

View file

@ -205,7 +205,8 @@ spec_program = PatternMatcher([
(UPat(Ops.CONST, arg=Invalid), lambda: False),
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
#(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.STACK, name="x"), lambda x: True),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# if has a <gate, index_for_dedup>