This commit is contained in:
George Hotz 2025-12-16 22:54:35 +00:00
commit 70747d760f

View file

@ -6,9 +6,22 @@ from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
from tinygrad.renderer import Renderer
from tinygrad.helpers import prod, get_single_element
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
from tinygrad.codegen.opt import tc
def render_val(x, dtype):
if dtypes.is_float(dtype):
# Check if this is an inlineable float constant
fval = float(x)
if fval == 0.0: return "0"
if fval == 0.5: return "0.5"
if fval == 1.0: return "1.0"
if fval == 2.0: return "2.0"
if fval == 4.0: return "4.0"
if fval == -0.5: return "-0.5"
if fval == -1.0: return "-1.0"
if fval == -2.0: return "-2.0"
if fval == -4.0: return "-4.0"
# Non-inlineable float - use hex representation
if dtype == dtypes.double: return "0x%016X" % struct.unpack("Q", struct.pack("d", x))[0]
if dtype == dtypes.half: return "0x%04X" % struct.unpack("H", struct.pack("e", x))[0]
return "0x%08X" % struct.unpack("I", struct.pack("f", x))[0]
@ -19,7 +32,20 @@ def render_val(x, dtype):
return f"0x{val:08X}"
return str(val)
def can_inline_const(val, dtype) -> bool:
"""Check if a constant can be inlined in RDNA3 instructions."""
if dtypes.is_float(dtype):
# Float inline constants: 0.0, 0.5, 1.0, 2.0, 4.0, -0.5, -1.0, -2.0, -4.0
return float(val) in (0.0, 0.5, 1.0, 2.0, 4.0, -0.5, -1.0, -2.0, -4.0)
# Integer inline constants: -16 to 64
try:
return -16 <= int(val) <= 64
except (TypeError, ValueError):
return False
# RDNA3 uses different instruction names and formats than PTX
# NOTE: These are used via string_rewrite which passes r[v] (register strings), not UOps
# For literal constant embedding, we handle ADD/MUL specially in string_rewrite
asm_for_op: dict[Ops, Callable] = {
Ops.RECIPROCAL: lambda d,a,dt,name: f"v_rcp_{name} {d}, {a}",
Ops.EXP2: lambda d,a,dt,name: f"v_exp_{name} {d}, {a}", Ops.LOG2: lambda d,a,dt,name: f"v_log_{name} {d}, {a}",
@ -65,23 +91,19 @@ def mem_type(x:UOp) -> str:
case Ops.DEFINE_GLOBAL: return 'global'
case _: raise RuntimeError(f"{x.op} needs to be memory")
def mem_size_suffix(dt:DType) -> str:
"""Get memory instruction suffix based on dtype size"""
if dt.itemsize == 1: return "byte"
if dt.itemsize == 2: return "short"
if dt.itemsize == 4: return "dword"
if dt.itemsize == 8: return "dwordx2"
raise RuntimeError(f"Unsupported dtype size: {dt.itemsize}")
def global_store(addr:str, data:str, base:str, dt:DType) -> str:
suffix = mem_size_suffix(dt)
return f"global_store_{suffix} {addr}, {data}, {base}"
if dt.itemsize == 1: return f"global_store_byte {addr}, {data}, {base}"
if dt.itemsize == 2: return f"global_store_b16 {addr}, {data}, {base}"
if dt.itemsize == 4: return f"global_store_b32 {addr}, {data}, {base}"
if dt.itemsize == 8: return f"global_store_b64 {addr}, {data}, {base}"
raise RuntimeError(f"Unsupported store dtype size: {dt.itemsize}")
def global_load(dest:str, addr:str, base:str, dt:DType) -> str:
suffix = mem_size_suffix(dt)
if dt.itemsize == 1:
return f"global_load_ubyte {dest}, {addr}, {base}" # unsigned byte load
return f"global_load_{suffix} {dest}, {addr}, {base}"
if dt.itemsize == 1: return f"global_load_ubyte {dest}, {addr}, {base}"
if dt.itemsize == 2: return f"global_load_u16 {dest}, {addr}, {base}"
if dt.itemsize == 4: return f"global_load_b32 {dest}, {addr}, {base}"
if dt.itemsize == 8: return f"global_load_b64 {dest}, {addr}, {base}"
raise RuntimeError(f"Unsupported load dtype size: {dt.itemsize}")
def render_const_64(ctx, x):
"""Render 64-bit constant as two v_mov_b32 instructions"""
@ -96,11 +118,116 @@ def render_const_64(ctx, x):
hi = (bits >> 32) & 0xFFFFFFFF
return [f"v_mov_b32 v{reg_num}, 0x{lo:08X}", f"v_mov_b32 v{reg_num+1}, 0x{hi:08X}"]
def render_64bit_mul(ctx, x):
"""Render 64-bit integer multiplication using scratch registers.
For pattern (a * magic_const) used in division-by-multiplication.
Result: uses scratch registers, stores hi bits in destination for subsequent SHR."""
rx = ctx.r[x]
a, b = x.src[0], x.src[1]
# Get source registers
ra = ctx.r[a]
# If a is a CAST from 32-bit, use the source register directly
if a.op is Ops.CAST and a.src[0].dtype.itemsize == 4:
ra = ctx.r[a.src[0]]
elif '[' in ra: # 64-bit reg pair - use low 32 bits
ra = f"v{ra[2:ra.index(':')]}"
rb = ctx.r[b]
if b.op is Ops.CONST:
rb = render_val(b.arg, dtypes.uint32)
elif '[' in rb: # 64-bit reg pair - use low 32 bits
rb = f"v{rb[2:rb.index(':')]}"
# Destination is single VGPR - we'll store high 32 bits there for subsequent SHR
# Use scratch for low bits (usually not needed)
scratch = ctx.get_scratch_vgpr()
# Full 64-bit multiply: lo in scratch, hi in destination
# This works because the common pattern is (x*magic)>>N where N>=32, so only hi bits matter
return [f"v_mul_lo_u32 v{scratch}, {ra}, {rb}", f"v_mul_hi_u32 {rx}, {ra}, {rb}"]
def render_64bit_shr(ctx, x, a, b):
"""Render 64-bit right shift. For shifts >= 32, we just need the high 32 bits shifted.
The source from render_64bit_mul is the high 32 bits in a single VGPR."""
rx = ctx.r[x]
ra = ctx.r[a]
shift_amt = b.arg if b.op is Ops.CONST else None
if shift_amt is None:
raise RuntimeError("64-bit SHR requires constant shift amount")
# Handle the result - always 32-bit destination
if '[' in rx:
dst_num = int(rx[2:rx.index(':')])
else:
dst_num = int(rx[1:])
# For source: if pair v[n:n+1], high bits in v(n+1); otherwise single reg has high bits
if '[' in ra:
src_hi = int(ra[2:ra.index(':')]) + 1
else:
src_hi = int(ra[1:]) # Single reg - this IS the high bits from MUL
if shift_amt >= 32:
# Just shift the high 32 bits by (shift_amt - 32)
remaining_shift = shift_amt - 32
if remaining_shift == 0:
return f"v_mov_b32 v{dst_num}, v{src_hi}"
return f"v_lshrrev_b32 v{dst_num}, {remaining_shift}, v{src_hi}"
else:
# For shift < 32, we'd need both halves, but our MUL only stores high bits
# This pattern shouldn't occur for division-by-multiplication
# Fall back to just using high bits (loses precision but avoids crash)
return f"v_lshrrev_b32 v{dst_num}, {shift_amt}, v{src_hi}"
def render_wmma(ctx, x):
"""Render WMMA instruction for RDNA3.
RDNA3 WMMA: v_wmma_f32_16x16x16_f16 dst[0:7], src_a[0:7], src_b[0:7], src_c[0:7]
- dtype_in: half (16 elements per thread = 8 VGPRs)
- dtype_out: float (8 elements per thread = 8 VGPRs) or half (8 elements = 4 VGPRs, packed)
"""
# x.arg[2] = dtype_in, x.arg[3] = dtype_out
dtype_in, dtype_out = x.arg[2], x.arg[3]
# Get the register ranges for the three sources
# src[0] = A matrix (half16 = 8 VGPRs), src[1] = B matrix (half16 = 8 VGPRs), src[2] = accumulator
def get_reg_range(reg_list, count):
"""Convert a list of register names to a v[start:end] range."""
if isinstance(reg_list, list):
# Extract register numbers from v0, v1, etc.
nums = [int(r[1:]) for r in reg_list]
return f"v[{min(nums)}:{max(nums)}]"
# Single register string like v[10:17]
return reg_list
ra = ctx.r[x.src[0]] # A matrix
rb = ctx.r[x.src[1]] # B matrix
rc = ctx.r[x.src[2]] # accumulator
rd = ctx.r[x] # destination (same size as accumulator)
# Convert register lists to ranges
ra_range = get_reg_range(ra, 8)
rb_range = get_reg_range(rb, 8)
rc_range = get_reg_range(rc, 8)
rd_range = get_reg_range(rd, 8)
# Select instruction based on input/output types
if dtype_out == dtypes.float:
if dtype_in == dtypes.half:
instr = "v_wmma_f32_16x16x16_f16"
elif dtype_in == dtypes.bfloat16:
instr = "v_wmma_f32_16x16x16_bf16"
else:
raise RuntimeError(f"Unsupported WMMA dtype_in: {dtype_in}")
elif dtype_out == dtypes.half:
instr = "v_wmma_f16_16x16x16_f16"
else:
raise RuntimeError(f"Unsupported WMMA dtype_out: {dtype_out}")
return f"{instr} {rd_range}, {ra_range}, {rb_range}, {rc_range}"
string_rewrite = PatternMatcher([
# WMMA for tensor cores
(UPat(Ops.WMMA, name="x"), render_wmma),
# const rendering
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {1 if x.arg else 0}"),
# 64-bit float constants need two mov instructions
(UPat.cvar("x", dtypes.float64), render_const_64),
# 64-bit integer constants: just use low 32 bits (sufficient for most patterns)
(UPat.cvar("x", dtypes.long), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
(UPat.cvar("x"), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, x.dtype)}"),
# special registers
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: ctx.render_special(x)),
@ -131,6 +258,12 @@ string_rewrite = PatternMatcher([
lambda ctx, x, a, b: ctx.render_bool_and(x, a, b)),
(UPat(Ops.OR, name="x", dtype=dtypes.bool, src=(UPat.var("a", dtype=dtypes.bool), UPat.var("b", dtype=dtypes.bool))),
lambda ctx, x, a, b: ctx.render_bool_or(x, a, b)),
# 64-bit integer MUL: need full 64-bit product for division-by-multiplication pattern
(UPat(Ops.MUL, name="x", dtype=dtypes.long), render_64bit_mul),
(UPat(Ops.MUL, name="x", dtype=dtypes.ulong), render_64bit_mul),
# 64-bit integer SHR: for shifts >= 32, use high bits only
(UPat(Ops.SHR, name="x", dtype=dtypes.long, src=(UPat.var("a"), UPat.cvar("b"))), render_64bit_shr),
(UPat(Ops.SHR, name="x", dtype=dtypes.ulong, src=(UPat.var("a"), UPat.cvar("b"))), render_64bit_shr),
# alu ops
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
# bitcast/cast
@ -147,6 +280,13 @@ string_rewrite = PatternMatcher([
lambda ctx, x, a: [
f"v_cmp_{'neq' if dtypes.is_float(a.dtype) else 'ne'}_{ctx.types[a.dtype]} vcc_lo, {ctx.r[a]}, 0",
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, vcc_lo"]),
# cast TO 64-bit int: just move (we treat 64-bit ints as 32-bit for register allocation)
(UPat(Ops.CAST, name="x", dtype=dtypes.long, src=(UPat.var("a"),)),
lambda ctx, x, a: f"v_mov_b32 {ctx.r[x]}, {ctx.r[a]}" if not ctx.r[a].startswith('s') else
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, {ctx.r[a]}"),
(UPat(Ops.CAST, name="x", dtype=dtypes.ulong, src=(UPat.var("a"),)),
lambda ctx, x, a: f"v_mov_b32 {ctx.r[x]}, {ctx.r[a]}" if not ctx.r[a].startswith('s') else
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, {ctx.r[a]}"),
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)), lambda ctx, x, a: ctx.render_cast(x, a)),
# store / load for global memory
# store boolean value - if SGPR (comparison result), convert via cndmask; if VGPR, store directly
@ -209,6 +349,8 @@ class RDNARenderer(Renderer):
shared_max = 65536
code_for_op = asm_for_op
extra_matcher = rdna_matcher
# TODO: WMMA needs scheduler changes to interleave loads/computes for register efficiency
# tensor_cores = tc.amd_rdna3 # RDNA3 WMMA tensor cores - disabled pending register optimization
def __init__(self, arch:str="gfx1100"):
self.arch = arch
@ -274,6 +416,14 @@ class RDNARenderer(Renderer):
elif x.dtype == dtypes.float64:
# int to float64: first convert to f32, then to f64
return self.render_mov_64(x, a) # TODO: proper int->f64 conversion
# 64-bit to 32-bit integer: just use low 32 bits
elif x.dtype.itemsize == 4 and a.dtype in (dtypes.long, dtypes.ulong):
ra = self.r[a]
# Extract low register from pair v[n:n+1]
if '[' in ra:
src_num = int(ra[2:ra.index(':')])
return f"v_mov_b32 {self.r[x]}, v{src_num}"
return f"v_mov_b32 {self.r[x]}, {ra}"
# fallback: just move (same size types)
return f"v_mov_b32 {self.r[x]}, {self.r[a]}"
@ -315,6 +465,34 @@ class RDNARenderer(Renderer):
src_num = get_reg_num(ra)
return [f"v_mov_b32 v{dst_num}, v{src_num}", f"v_mov_b32 v{dst_num+1}, v{src_num+1}"]
def render_cast_to_64(self, x: UOp, a: UOp, signed: bool = False) -> list[str]:
"""Render cast from 32-bit to 64-bit integer (sign or zero extend)"""
rx, ra = self.r[x], self.r[a]
# Extract dest register number from v[n:n+1] format
dst_num = int(rx[2:rx.index(':')])
# Source can be single reg (v5) or in SGPR for comparison results
if ra.startswith('s'):
# SGPR comparison result - expand to VGPR first
return [
f"v_cndmask_b32 v{dst_num}, 0, 1, {ra}",
f"v_mov_b32 v{dst_num+1}, 0" # zero extend for bool/comparison results
]
src_reg = ra if ra.startswith('v') else f"v{ra}"
# Low bits: copy from source
# High bits: 0 for unsigned, or sign-extend for signed
if signed:
# Sign extend: copy bit 31 to all bits of high word
return [
f"v_mov_b32 v{dst_num}, {src_reg}",
f"v_ashrrev_i32 v{dst_num+1}, 31, {src_reg}" # arithmetic shift right by 31 gets sign bit
]
else:
# Zero extend
return [
f"v_mov_b32 v{dst_num}, {src_reg}",
f"v_mov_b32 v{dst_num+1}, 0"
]
def render_kernel(self, kernel, function_name, bufs, v_cnt, s_cnt, uops) -> str:
# Build metadata for kernel
args = []
@ -449,34 +627,36 @@ class RDNARenderer(Renderer):
elif u.op is Ops.LOAD:
aliases[u] = u.src[0] # LOAD from REG aliases to INDEX (which aliases to DEFINE_REG)
# Second pass: extend last_use for values used inside loops
# Values used inside a loop must stay alive until the END of the loop
# Second pass: extend last_use for values DEFINED OUTSIDE a loop but USED INSIDE
# Only these need their lifetime extended to the END of the loop (for loop-carried dependencies)
# Values defined INSIDE a loop can be freed normally
uop_positions = {u: i for i, u in enumerate(uops)}
for uop, use_pos in list(last_use.items()):
# Check if this use position is inside any loop
if uop not in uop_positions:
continue
def_pos = uop_positions[uop]
# Check if defined OUTSIDE loop but used INSIDE loop
for range_pos, end_pos in loop_ranges.items():
if range_pos < use_pos <= end_pos:
# This value is used inside a loop, extend its lifetime to loop END
last_use[uop] = max(last_use[uop], end_pos)
# Also check if the uop itself is defined inside a loop
if uop in uop_positions:
def_pos = uop_positions[uop]
for range_pos, end_pos in loop_ranges.items():
if range_pos < def_pos <= end_pos:
# This value is defined inside a loop
# If it's used inside the loop, extend to END
if use_pos <= end_pos:
last_use[uop] = max(last_use[uop], end_pos)
# Value defined before the loop starts
if def_pos <= range_pos:
# And used inside the loop
if range_pos < use_pos <= end_pos:
# Extend lifetime to end of loop
last_use[uop] = max(last_use[uop], end_pos)
# === REGISTER ALLOCATOR ===
# Track free registers (available for reuse)
free_vgprs: list[int] = []
free_vgpr_pairs: list[int] = [] # Track free aligned pairs (base register numbers)
free_sgprs: list[int] = []
# Track SGPR pairs to prevent individual registers from being freed
sgpr_pairs: set[int] = set()
vgpr_pairs: set[int] = set() # Track which VGPRs are part of pairs
# Track which UOp uses which register (for freeing)
vgpr_owner: dict[int, UOp] = {}
sgpr_owner: dict[int, UOp] = {}
# Constant deduplication: (dtype, value) -> register
const_cache: dict[tuple, str] = {}
# v[0:2] is local_xyz, we start allocating from v3
next_vgpr = 3
# s[0:1] is kernarg ptr, s[2:4] is group id xyz, we start from s5
@ -495,9 +675,36 @@ class RDNARenderer(Renderer):
root = get_root_owner(u)
alias_groups[root].append(u)
# === IDENTIFY COMPILE-TIME CONSTANTS ===
# Constants don't need VGPRs if they're only used as indices into DEFINE_REG (compile-time array indices)
# First count all uses of each CONST
const_use_count: dict[UOp, int] = defaultdict(int)
reg_index_const_uses: dict[UOp, int] = defaultdict(int)
store_const_uses: set[UOp] = set() # Constants used in STORE (must have VGPR)
for u in uops:
for src in u.src:
if src.op is Ops.CONST:
const_use_count[src] += 1
# Track uses as REG indices specifically
if u.op is Ops.INDEX and len(u.src) > 1:
buf = u.src[0]
idx = u.src[1]
if isinstance(buf.dtype, PtrDType) and buf.dtype.addrspace == AddrSpace.REG:
if idx.op is Ops.CONST:
reg_index_const_uses[idx] += 1
# Track constants used in STORE - these MUST have VGPRs (can't use immediate in store data)
if u.op is Ops.STORE and len(u.src) >= 2:
if u.src[1].op is Ops.CONST:
store_const_uses.add(u.src[1])
# Skip allocation for constants that are ONLY used for REG indexing
skip_alloc_consts: set[UOp] = set()
for const_uop, reg_uses in reg_index_const_uses.items():
if reg_uses == const_use_count[const_uop]:
skip_alloc_consts.add(const_uop)
def free_dead_regs(pos: int):
"""Free registers whose owners (and all aliases) are no longer live after position pos"""
nonlocal free_vgprs, free_sgprs
nonlocal free_vgprs, free_vgpr_pairs, free_sgprs
dead_vgprs = []
for reg, owner in vgpr_owner.items():
# Check if owner and all its aliases are dead
@ -514,9 +721,22 @@ class RDNARenderer(Renderer):
owner_last_use = max(owner_last_use, last_use.get(alias_uop, -1))
if owner_last_use < pos:
dead_sgprs.append(reg)
# Process dead VGPRs - handle pairs specially
dead_vgprs_set = set(dead_vgprs)
for reg in dead_vgprs:
del vgpr_owner[reg]
free_vgprs.append(reg)
# If this is part of a pair, check if both regs are dead and return as pair
if reg in vgpr_pairs:
base_reg = reg if reg % 2 == 0 else reg - 1
other_reg = base_reg + 1 if reg == base_reg else base_reg
if other_reg in dead_vgprs_set and base_reg not in free_vgpr_pairs:
# Both regs of pair are dead - return as pair
free_vgpr_pairs.append(base_reg)
vgpr_pairs.discard(base_reg)
vgpr_pairs.discard(other_reg)
# Don't add to free_vgprs - pairs are handled separately
else:
free_vgprs.append(reg)
for reg in dead_sgprs:
# Don't free SGPRs that are part of a pair (used for 64-bit values like buffer addresses)
if reg in sgpr_pairs:
@ -538,19 +758,25 @@ class RDNARenderer(Renderer):
def alloc_vgpr_pair(owner: UOp) -> str:
"""Allocate an aligned pair of VGPRs for 64-bit values (float64, int64, uint64)"""
nonlocal next_vgpr, max_vgpr
# Align to even for 64-bit values
if next_vgpr % 2 != 0:
next_vgpr += 1
reg = next_vgpr
next_vgpr += 2
max_vgpr = max(max_vgpr, next_vgpr)
# Try to reuse a free pair first
if free_vgpr_pairs:
reg = free_vgpr_pairs.pop()
else:
# Align to even for 64-bit values
if next_vgpr % 2 != 0:
next_vgpr += 1
reg = next_vgpr
next_vgpr += 2
max_vgpr = max(max_vgpr, next_vgpr)
vgpr_owner[reg] = owner
vgpr_owner[reg+1] = owner
vgpr_pairs.add(reg)
vgpr_pairs.add(reg+1)
return f"v[{reg}:{reg+1}]"
def needs_vgpr_pair(dtype: DType) -> bool:
"""Check if a dtype needs a VGPR pair (64-bit types)"""
# Only float64 needs pairs - int64/uint64 are lowered to 32-bit
# Only float64 needs pairs - int64/uint64 use special patterns that don't need persistent pairs
return dtype == dtypes.float64
def alloc_sgpr(owner: UOp) -> str:
@ -607,10 +833,60 @@ class RDNARenderer(Renderer):
if u.arg is not None: name = u.arg.function_name
continue
if u.op is Ops.VECTORIZE:
r[u] = [cast(str,r[x]) for x in u.src]
# For WMMA inputs (half16), we need contiguous packed VGPRs
# half16 = 16 halfs = 8 VGPRs (2 halfs per 32-bit VGPR)
if u.dtype.scalar() == dtypes.half and u.dtype.count == 16:
# Allocate 8 contiguous VGPRs for packed half16
base = next_vgpr
if base % 2 != 0: # Align to even for better access
base = next_vgpr = next_vgpr + 1
next_vgpr += 8
max_vgpr = max(max_vgpr, next_vgpr)
r[u] = f"v[{base}:{base+7}]"
# Pack the source halfs into the destination VGPRs
# Each VGPR holds 2 halfs: low 16 bits and high 16 bits
for i in range(8):
src_lo = r[u.src[i*2]]
src_hi = r[u.src[i*2+1]]
# Pack two halfs into one VGPR using v_pack_b32_f16
kernel.append(f"v_pack_b32_f16 v{base+i}, {src_lo}, {src_hi}")
# For float8 (WMMA accumulator), check if sources are contiguous
elif u.dtype.scalar() == dtypes.float and u.dtype.count == 8:
# Check if all sources are from contiguous registers
src_regs = []
all_contiguous = True
for src in u.src:
src_str = r[src]
if isinstance(src_str, str) and src_str.startswith('v') and '[' not in src_str:
src_regs.append(int(src_str[1:]))
else:
all_contiguous = False
break
if all_contiguous and len(src_regs) == 8:
# Check if contiguous
if src_regs == list(range(src_regs[0], src_regs[0] + 8)):
# Already contiguous - just create range reference
r[u] = f"v[{src_regs[0]}:{src_regs[0]+7}]"
continue
# Not contiguous - allocate new registers and copy
base = next_vgpr
next_vgpr += 8
max_vgpr = max(max_vgpr, next_vgpr)
r[u] = f"v[{base}:{base+7}]"
for i, src in enumerate(u.src):
kernel.append(f"v_mov_b32 v{base+i}, {r[src]}")
else:
r[u] = [cast(str,r[x]) for x in u.src]
continue
if u.op is Ops.GEP:
r[u] = r[u.src[0]][get_single_element(u.arg)]
src_reg = r[u.src[0]]
idx = get_single_element(u.arg)
if isinstance(src_reg, str) and src_reg.startswith('v['):
# Extract base from v[base:end] range
base = int(src_reg[2:src_reg.index(':')])
r[u] = f"v{base + idx}"
else:
r[u] = src_reg[idx]
continue
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
r[u] = r[u.src[0]]
@ -637,7 +913,29 @@ class RDNARenderer(Renderer):
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
pending_waits.add(u)
elif u.op is Ops.CONST:
# Skip allocation for constants used as compile-time REG indices
if u in skip_alloc_consts:
r[u] = render_val(u.arg, u.dtype) # Store the value as string, no register needed
continue # Skip rendering - it's a compile-time index
# Check if constant can be inlined (no VGPR needed) - but not if used in STORE
if can_inline_const(u.arg, u.dtype) and u not in store_const_uses:
r[u] = render_val(u.arg, u.dtype) # Store value string, not register
continue # Skip rendering - will be inlined
# Deduplicate non-inlineable constants - reuse register if same value already loaded
const_key = (u.dtype, u.arg)
if const_key in const_cache:
r[u] = const_cache[const_key]
# Extend the original owner's lifetime to include this use
# by updating last_use for the original constant
reg_str = const_cache[const_key]
reg_num = int(reg_str[1:]) if reg_str.startswith('v') and '[' not in reg_str else None
if reg_num is not None and reg_num in vgpr_owner:
original_owner = vgpr_owner[reg_num]
# Extend last_use to include all uses of this UOp
last_use[original_owner] = max(last_use.get(original_owner, -1), last_use.get(u, i))
continue # Skip rendering - already loaded
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
const_cache[const_key] = r[u]
elif u.op is Ops.RANGE:
r[u] = alloc_vgpr(u)
elif u.op is Ops.END:
@ -657,6 +955,21 @@ class RDNARenderer(Renderer):
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
elif u.op is Ops.IF:
r[u] = alloc_sgpr(u)
elif u.op is Ops.WMMA:
# WMMA outputs a vector of floats (8 for RDNA3)
# For RDNA3 WMMA, we can do in-place accumulation if dst == C source
# Check if we can reuse the accumulator register range
acc_src = u.src[2] # accumulator input
acc_reg = r.get(acc_src)
if isinstance(acc_reg, str) and acc_reg.startswith('v['):
# Accumulator is already a contiguous range - reuse it for output
r[u] = acc_reg
else:
# Allocate contiguous VGPRs for the output
base = next_vgpr
next_vgpr += 8
max_vgpr = max(max_vgpr, next_vgpr)
r[u] = f"v[{base}:{base+7}]"
elif u.op is Ops.DEFINE_REG:
# For 64-bit types, allocate VGPR pairs
if needs_vgpr_pair(u.ptrdtype.base):