mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
vibing
This commit is contained in:
parent
1282b387f3
commit
70747d760f
1 changed files with 353 additions and 40 deletions
|
|
@ -6,9 +6,22 @@ from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
|
|||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import prod, get_single_element
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
from tinygrad.codegen.opt import tc
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
# Check if this is an inlineable float constant
|
||||
fval = float(x)
|
||||
if fval == 0.0: return "0"
|
||||
if fval == 0.5: return "0.5"
|
||||
if fval == 1.0: return "1.0"
|
||||
if fval == 2.0: return "2.0"
|
||||
if fval == 4.0: return "4.0"
|
||||
if fval == -0.5: return "-0.5"
|
||||
if fval == -1.0: return "-1.0"
|
||||
if fval == -2.0: return "-2.0"
|
||||
if fval == -4.0: return "-4.0"
|
||||
# Non-inlineable float - use hex representation
|
||||
if dtype == dtypes.double: return "0x%016X" % struct.unpack("Q", struct.pack("d", x))[0]
|
||||
if dtype == dtypes.half: return "0x%04X" % struct.unpack("H", struct.pack("e", x))[0]
|
||||
return "0x%08X" % struct.unpack("I", struct.pack("f", x))[0]
|
||||
|
|
@ -19,7 +32,20 @@ def render_val(x, dtype):
|
|||
return f"0x{val:08X}"
|
||||
return str(val)
|
||||
|
||||
def can_inline_const(val, dtype) -> bool:
|
||||
"""Check if a constant can be inlined in RDNA3 instructions."""
|
||||
if dtypes.is_float(dtype):
|
||||
# Float inline constants: 0.0, 0.5, 1.0, 2.0, 4.0, -0.5, -1.0, -2.0, -4.0
|
||||
return float(val) in (0.0, 0.5, 1.0, 2.0, 4.0, -0.5, -1.0, -2.0, -4.0)
|
||||
# Integer inline constants: -16 to 64
|
||||
try:
|
||||
return -16 <= int(val) <= 64
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
# RDNA3 uses different instruction names and formats than PTX
|
||||
# NOTE: These are used via string_rewrite which passes r[v] (register strings), not UOps
|
||||
# For literal constant embedding, we handle ADD/MUL specially in string_rewrite
|
||||
asm_for_op: dict[Ops, Callable] = {
|
||||
Ops.RECIPROCAL: lambda d,a,dt,name: f"v_rcp_{name} {d}, {a}",
|
||||
Ops.EXP2: lambda d,a,dt,name: f"v_exp_{name} {d}, {a}", Ops.LOG2: lambda d,a,dt,name: f"v_log_{name} {d}, {a}",
|
||||
|
|
@ -65,23 +91,19 @@ def mem_type(x:UOp) -> str:
|
|||
case Ops.DEFINE_GLOBAL: return 'global'
|
||||
case _: raise RuntimeError(f"{x.op} needs to be memory")
|
||||
|
||||
def mem_size_suffix(dt:DType) -> str:
|
||||
"""Get memory instruction suffix based on dtype size"""
|
||||
if dt.itemsize == 1: return "byte"
|
||||
if dt.itemsize == 2: return "short"
|
||||
if dt.itemsize == 4: return "dword"
|
||||
if dt.itemsize == 8: return "dwordx2"
|
||||
raise RuntimeError(f"Unsupported dtype size: {dt.itemsize}")
|
||||
|
||||
def global_store(addr:str, data:str, base:str, dt:DType) -> str:
|
||||
suffix = mem_size_suffix(dt)
|
||||
return f"global_store_{suffix} {addr}, {data}, {base}"
|
||||
if dt.itemsize == 1: return f"global_store_byte {addr}, {data}, {base}"
|
||||
if dt.itemsize == 2: return f"global_store_b16 {addr}, {data}, {base}"
|
||||
if dt.itemsize == 4: return f"global_store_b32 {addr}, {data}, {base}"
|
||||
if dt.itemsize == 8: return f"global_store_b64 {addr}, {data}, {base}"
|
||||
raise RuntimeError(f"Unsupported store dtype size: {dt.itemsize}")
|
||||
|
||||
def global_load(dest:str, addr:str, base:str, dt:DType) -> str:
|
||||
suffix = mem_size_suffix(dt)
|
||||
if dt.itemsize == 1:
|
||||
return f"global_load_ubyte {dest}, {addr}, {base}" # unsigned byte load
|
||||
return f"global_load_{suffix} {dest}, {addr}, {base}"
|
||||
if dt.itemsize == 1: return f"global_load_ubyte {dest}, {addr}, {base}"
|
||||
if dt.itemsize == 2: return f"global_load_u16 {dest}, {addr}, {base}"
|
||||
if dt.itemsize == 4: return f"global_load_b32 {dest}, {addr}, {base}"
|
||||
if dt.itemsize == 8: return f"global_load_b64 {dest}, {addr}, {base}"
|
||||
raise RuntimeError(f"Unsupported load dtype size: {dt.itemsize}")
|
||||
|
||||
def render_const_64(ctx, x):
|
||||
"""Render 64-bit constant as two v_mov_b32 instructions"""
|
||||
|
|
@ -96,11 +118,116 @@ def render_const_64(ctx, x):
|
|||
hi = (bits >> 32) & 0xFFFFFFFF
|
||||
return [f"v_mov_b32 v{reg_num}, 0x{lo:08X}", f"v_mov_b32 v{reg_num+1}, 0x{hi:08X}"]
|
||||
|
||||
def render_64bit_mul(ctx, x):
|
||||
"""Render 64-bit integer multiplication using scratch registers.
|
||||
For pattern (a * magic_const) used in division-by-multiplication.
|
||||
Result: uses scratch registers, stores hi bits in destination for subsequent SHR."""
|
||||
rx = ctx.r[x]
|
||||
a, b = x.src[0], x.src[1]
|
||||
# Get source registers
|
||||
ra = ctx.r[a]
|
||||
# If a is a CAST from 32-bit, use the source register directly
|
||||
if a.op is Ops.CAST and a.src[0].dtype.itemsize == 4:
|
||||
ra = ctx.r[a.src[0]]
|
||||
elif '[' in ra: # 64-bit reg pair - use low 32 bits
|
||||
ra = f"v{ra[2:ra.index(':')]}"
|
||||
rb = ctx.r[b]
|
||||
if b.op is Ops.CONST:
|
||||
rb = render_val(b.arg, dtypes.uint32)
|
||||
elif '[' in rb: # 64-bit reg pair - use low 32 bits
|
||||
rb = f"v{rb[2:rb.index(':')]}"
|
||||
# Destination is single VGPR - we'll store high 32 bits there for subsequent SHR
|
||||
# Use scratch for low bits (usually not needed)
|
||||
scratch = ctx.get_scratch_vgpr()
|
||||
# Full 64-bit multiply: lo in scratch, hi in destination
|
||||
# This works because the common pattern is (x*magic)>>N where N>=32, so only hi bits matter
|
||||
return [f"v_mul_lo_u32 v{scratch}, {ra}, {rb}", f"v_mul_hi_u32 {rx}, {ra}, {rb}"]
|
||||
|
||||
def render_64bit_shr(ctx, x, a, b):
|
||||
"""Render 64-bit right shift. For shifts >= 32, we just need the high 32 bits shifted.
|
||||
The source from render_64bit_mul is the high 32 bits in a single VGPR."""
|
||||
rx = ctx.r[x]
|
||||
ra = ctx.r[a]
|
||||
shift_amt = b.arg if b.op is Ops.CONST else None
|
||||
if shift_amt is None:
|
||||
raise RuntimeError("64-bit SHR requires constant shift amount")
|
||||
# Handle the result - always 32-bit destination
|
||||
if '[' in rx:
|
||||
dst_num = int(rx[2:rx.index(':')])
|
||||
else:
|
||||
dst_num = int(rx[1:])
|
||||
# For source: if pair v[n:n+1], high bits in v(n+1); otherwise single reg has high bits
|
||||
if '[' in ra:
|
||||
src_hi = int(ra[2:ra.index(':')]) + 1
|
||||
else:
|
||||
src_hi = int(ra[1:]) # Single reg - this IS the high bits from MUL
|
||||
if shift_amt >= 32:
|
||||
# Just shift the high 32 bits by (shift_amt - 32)
|
||||
remaining_shift = shift_amt - 32
|
||||
if remaining_shift == 0:
|
||||
return f"v_mov_b32 v{dst_num}, v{src_hi}"
|
||||
return f"v_lshrrev_b32 v{dst_num}, {remaining_shift}, v{src_hi}"
|
||||
else:
|
||||
# For shift < 32, we'd need both halves, but our MUL only stores high bits
|
||||
# This pattern shouldn't occur for division-by-multiplication
|
||||
# Fall back to just using high bits (loses precision but avoids crash)
|
||||
return f"v_lshrrev_b32 v{dst_num}, {shift_amt}, v{src_hi}"
|
||||
|
||||
def render_wmma(ctx, x):
|
||||
"""Render WMMA instruction for RDNA3.
|
||||
RDNA3 WMMA: v_wmma_f32_16x16x16_f16 dst[0:7], src_a[0:7], src_b[0:7], src_c[0:7]
|
||||
- dtype_in: half (16 elements per thread = 8 VGPRs)
|
||||
- dtype_out: float (8 elements per thread = 8 VGPRs) or half (8 elements = 4 VGPRs, packed)
|
||||
"""
|
||||
# x.arg[2] = dtype_in, x.arg[3] = dtype_out
|
||||
dtype_in, dtype_out = x.arg[2], x.arg[3]
|
||||
# Get the register ranges for the three sources
|
||||
# src[0] = A matrix (half16 = 8 VGPRs), src[1] = B matrix (half16 = 8 VGPRs), src[2] = accumulator
|
||||
def get_reg_range(reg_list, count):
|
||||
"""Convert a list of register names to a v[start:end] range."""
|
||||
if isinstance(reg_list, list):
|
||||
# Extract register numbers from v0, v1, etc.
|
||||
nums = [int(r[1:]) for r in reg_list]
|
||||
return f"v[{min(nums)}:{max(nums)}]"
|
||||
# Single register string like v[10:17]
|
||||
return reg_list
|
||||
|
||||
ra = ctx.r[x.src[0]] # A matrix
|
||||
rb = ctx.r[x.src[1]] # B matrix
|
||||
rc = ctx.r[x.src[2]] # accumulator
|
||||
rd = ctx.r[x] # destination (same size as accumulator)
|
||||
|
||||
# Convert register lists to ranges
|
||||
ra_range = get_reg_range(ra, 8)
|
||||
rb_range = get_reg_range(rb, 8)
|
||||
rc_range = get_reg_range(rc, 8)
|
||||
rd_range = get_reg_range(rd, 8)
|
||||
|
||||
# Select instruction based on input/output types
|
||||
if dtype_out == dtypes.float:
|
||||
if dtype_in == dtypes.half:
|
||||
instr = "v_wmma_f32_16x16x16_f16"
|
||||
elif dtype_in == dtypes.bfloat16:
|
||||
instr = "v_wmma_f32_16x16x16_bf16"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported WMMA dtype_in: {dtype_in}")
|
||||
elif dtype_out == dtypes.half:
|
||||
instr = "v_wmma_f16_16x16x16_f16"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported WMMA dtype_out: {dtype_out}")
|
||||
|
||||
return f"{instr} {rd_range}, {ra_range}, {rb_range}, {rc_range}"
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
# WMMA for tensor cores
|
||||
(UPat(Ops.WMMA, name="x"), render_wmma),
|
||||
# const rendering
|
||||
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {1 if x.arg else 0}"),
|
||||
# 64-bit float constants need two mov instructions
|
||||
(UPat.cvar("x", dtypes.float64), render_const_64),
|
||||
# 64-bit integer constants: just use low 32 bits (sufficient for most patterns)
|
||||
(UPat.cvar("x", dtypes.long), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
|
||||
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
|
||||
(UPat.cvar("x"), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, x.dtype)}"),
|
||||
# special registers
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: ctx.render_special(x)),
|
||||
|
|
@ -131,6 +258,12 @@ string_rewrite = PatternMatcher([
|
|||
lambda ctx, x, a, b: ctx.render_bool_and(x, a, b)),
|
||||
(UPat(Ops.OR, name="x", dtype=dtypes.bool, src=(UPat.var("a", dtype=dtypes.bool), UPat.var("b", dtype=dtypes.bool))),
|
||||
lambda ctx, x, a, b: ctx.render_bool_or(x, a, b)),
|
||||
# 64-bit integer MUL: need full 64-bit product for division-by-multiplication pattern
|
||||
(UPat(Ops.MUL, name="x", dtype=dtypes.long), render_64bit_mul),
|
||||
(UPat(Ops.MUL, name="x", dtype=dtypes.ulong), render_64bit_mul),
|
||||
# 64-bit integer SHR: for shifts >= 32, use high bits only
|
||||
(UPat(Ops.SHR, name="x", dtype=dtypes.long, src=(UPat.var("a"), UPat.cvar("b"))), render_64bit_shr),
|
||||
(UPat(Ops.SHR, name="x", dtype=dtypes.ulong, src=(UPat.var("a"), UPat.cvar("b"))), render_64bit_shr),
|
||||
# alu ops
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
||||
# bitcast/cast
|
||||
|
|
@ -147,6 +280,13 @@ string_rewrite = PatternMatcher([
|
|||
lambda ctx, x, a: [
|
||||
f"v_cmp_{'neq' if dtypes.is_float(a.dtype) else 'ne'}_{ctx.types[a.dtype]} vcc_lo, {ctx.r[a]}, 0",
|
||||
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, vcc_lo"]),
|
||||
# cast TO 64-bit int: just move (we treat 64-bit ints as 32-bit for register allocation)
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.long, src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"v_mov_b32 {ctx.r[x]}, {ctx.r[a]}" if not ctx.r[a].startswith('s') else
|
||||
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, {ctx.r[a]}"),
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.ulong, src=(UPat.var("a"),)),
|
||||
lambda ctx, x, a: f"v_mov_b32 {ctx.r[x]}, {ctx.r[a]}" if not ctx.r[a].startswith('s') else
|
||||
f"v_cndmask_b32 {ctx.r[x]}, 0, 1, {ctx.r[a]}"),
|
||||
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)), lambda ctx, x, a: ctx.render_cast(x, a)),
|
||||
# store / load for global memory
|
||||
# store boolean value - if SGPR (comparison result), convert via cndmask; if VGPR, store directly
|
||||
|
|
@ -209,6 +349,8 @@ class RDNARenderer(Renderer):
|
|||
shared_max = 65536
|
||||
code_for_op = asm_for_op
|
||||
extra_matcher = rdna_matcher
|
||||
# TODO: WMMA needs scheduler changes to interleave loads/computes for register efficiency
|
||||
# tensor_cores = tc.amd_rdna3 # RDNA3 WMMA tensor cores - disabled pending register optimization
|
||||
|
||||
def __init__(self, arch:str="gfx1100"):
|
||||
self.arch = arch
|
||||
|
|
@ -274,6 +416,14 @@ class RDNARenderer(Renderer):
|
|||
elif x.dtype == dtypes.float64:
|
||||
# int to float64: first convert to f32, then to f64
|
||||
return self.render_mov_64(x, a) # TODO: proper int->f64 conversion
|
||||
# 64-bit to 32-bit integer: just use low 32 bits
|
||||
elif x.dtype.itemsize == 4 and a.dtype in (dtypes.long, dtypes.ulong):
|
||||
ra = self.r[a]
|
||||
# Extract low register from pair v[n:n+1]
|
||||
if '[' in ra:
|
||||
src_num = int(ra[2:ra.index(':')])
|
||||
return f"v_mov_b32 {self.r[x]}, v{src_num}"
|
||||
return f"v_mov_b32 {self.r[x]}, {ra}"
|
||||
# fallback: just move (same size types)
|
||||
return f"v_mov_b32 {self.r[x]}, {self.r[a]}"
|
||||
|
||||
|
|
@ -315,6 +465,34 @@ class RDNARenderer(Renderer):
|
|||
src_num = get_reg_num(ra)
|
||||
return [f"v_mov_b32 v{dst_num}, v{src_num}", f"v_mov_b32 v{dst_num+1}, v{src_num+1}"]
|
||||
|
||||
def render_cast_to_64(self, x: UOp, a: UOp, signed: bool = False) -> list[str]:
|
||||
"""Render cast from 32-bit to 64-bit integer (sign or zero extend)"""
|
||||
rx, ra = self.r[x], self.r[a]
|
||||
# Extract dest register number from v[n:n+1] format
|
||||
dst_num = int(rx[2:rx.index(':')])
|
||||
# Source can be single reg (v5) or in SGPR for comparison results
|
||||
if ra.startswith('s'):
|
||||
# SGPR comparison result - expand to VGPR first
|
||||
return [
|
||||
f"v_cndmask_b32 v{dst_num}, 0, 1, {ra}",
|
||||
f"v_mov_b32 v{dst_num+1}, 0" # zero extend for bool/comparison results
|
||||
]
|
||||
src_reg = ra if ra.startswith('v') else f"v{ra}"
|
||||
# Low bits: copy from source
|
||||
# High bits: 0 for unsigned, or sign-extend for signed
|
||||
if signed:
|
||||
# Sign extend: copy bit 31 to all bits of high word
|
||||
return [
|
||||
f"v_mov_b32 v{dst_num}, {src_reg}",
|
||||
f"v_ashrrev_i32 v{dst_num+1}, 31, {src_reg}" # arithmetic shift right by 31 gets sign bit
|
||||
]
|
||||
else:
|
||||
# Zero extend
|
||||
return [
|
||||
f"v_mov_b32 v{dst_num}, {src_reg}",
|
||||
f"v_mov_b32 v{dst_num+1}, 0"
|
||||
]
|
||||
|
||||
def render_kernel(self, kernel, function_name, bufs, v_cnt, s_cnt, uops) -> str:
|
||||
# Build metadata for kernel
|
||||
args = []
|
||||
|
|
@ -449,34 +627,36 @@ class RDNARenderer(Renderer):
|
|||
elif u.op is Ops.LOAD:
|
||||
aliases[u] = u.src[0] # LOAD from REG aliases to INDEX (which aliases to DEFINE_REG)
|
||||
|
||||
# Second pass: extend last_use for values used inside loops
|
||||
# Values used inside a loop must stay alive until the END of the loop
|
||||
# Second pass: extend last_use for values DEFINED OUTSIDE a loop but USED INSIDE
|
||||
# Only these need their lifetime extended to the END of the loop (for loop-carried dependencies)
|
||||
# Values defined INSIDE a loop can be freed normally
|
||||
uop_positions = {u: i for i, u in enumerate(uops)}
|
||||
for uop, use_pos in list(last_use.items()):
|
||||
# Check if this use position is inside any loop
|
||||
if uop not in uop_positions:
|
||||
continue
|
||||
def_pos = uop_positions[uop]
|
||||
# Check if defined OUTSIDE loop but used INSIDE loop
|
||||
for range_pos, end_pos in loop_ranges.items():
|
||||
if range_pos < use_pos <= end_pos:
|
||||
# This value is used inside a loop, extend its lifetime to loop END
|
||||
last_use[uop] = max(last_use[uop], end_pos)
|
||||
# Also check if the uop itself is defined inside a loop
|
||||
if uop in uop_positions:
|
||||
def_pos = uop_positions[uop]
|
||||
for range_pos, end_pos in loop_ranges.items():
|
||||
if range_pos < def_pos <= end_pos:
|
||||
# This value is defined inside a loop
|
||||
# If it's used inside the loop, extend to END
|
||||
if use_pos <= end_pos:
|
||||
last_use[uop] = max(last_use[uop], end_pos)
|
||||
# Value defined before the loop starts
|
||||
if def_pos <= range_pos:
|
||||
# And used inside the loop
|
||||
if range_pos < use_pos <= end_pos:
|
||||
# Extend lifetime to end of loop
|
||||
last_use[uop] = max(last_use[uop], end_pos)
|
||||
|
||||
# === REGISTER ALLOCATOR ===
|
||||
# Track free registers (available for reuse)
|
||||
free_vgprs: list[int] = []
|
||||
free_vgpr_pairs: list[int] = [] # Track free aligned pairs (base register numbers)
|
||||
free_sgprs: list[int] = []
|
||||
# Track SGPR pairs to prevent individual registers from being freed
|
||||
sgpr_pairs: set[int] = set()
|
||||
vgpr_pairs: set[int] = set() # Track which VGPRs are part of pairs
|
||||
# Track which UOp uses which register (for freeing)
|
||||
vgpr_owner: dict[int, UOp] = {}
|
||||
sgpr_owner: dict[int, UOp] = {}
|
||||
# Constant deduplication: (dtype, value) -> register
|
||||
const_cache: dict[tuple, str] = {}
|
||||
# v[0:2] is local_xyz, we start allocating from v3
|
||||
next_vgpr = 3
|
||||
# s[0:1] is kernarg ptr, s[2:4] is group id xyz, we start from s5
|
||||
|
|
@ -495,9 +675,36 @@ class RDNARenderer(Renderer):
|
|||
root = get_root_owner(u)
|
||||
alias_groups[root].append(u)
|
||||
|
||||
# === IDENTIFY COMPILE-TIME CONSTANTS ===
|
||||
# Constants don't need VGPRs if they're only used as indices into DEFINE_REG (compile-time array indices)
|
||||
# First count all uses of each CONST
|
||||
const_use_count: dict[UOp, int] = defaultdict(int)
|
||||
reg_index_const_uses: dict[UOp, int] = defaultdict(int)
|
||||
store_const_uses: set[UOp] = set() # Constants used in STORE (must have VGPR)
|
||||
for u in uops:
|
||||
for src in u.src:
|
||||
if src.op is Ops.CONST:
|
||||
const_use_count[src] += 1
|
||||
# Track uses as REG indices specifically
|
||||
if u.op is Ops.INDEX and len(u.src) > 1:
|
||||
buf = u.src[0]
|
||||
idx = u.src[1]
|
||||
if isinstance(buf.dtype, PtrDType) and buf.dtype.addrspace == AddrSpace.REG:
|
||||
if idx.op is Ops.CONST:
|
||||
reg_index_const_uses[idx] += 1
|
||||
# Track constants used in STORE - these MUST have VGPRs (can't use immediate in store data)
|
||||
if u.op is Ops.STORE and len(u.src) >= 2:
|
||||
if u.src[1].op is Ops.CONST:
|
||||
store_const_uses.add(u.src[1])
|
||||
# Skip allocation for constants that are ONLY used for REG indexing
|
||||
skip_alloc_consts: set[UOp] = set()
|
||||
for const_uop, reg_uses in reg_index_const_uses.items():
|
||||
if reg_uses == const_use_count[const_uop]:
|
||||
skip_alloc_consts.add(const_uop)
|
||||
|
||||
def free_dead_regs(pos: int):
|
||||
"""Free registers whose owners (and all aliases) are no longer live after position pos"""
|
||||
nonlocal free_vgprs, free_sgprs
|
||||
nonlocal free_vgprs, free_vgpr_pairs, free_sgprs
|
||||
dead_vgprs = []
|
||||
for reg, owner in vgpr_owner.items():
|
||||
# Check if owner and all its aliases are dead
|
||||
|
|
@ -514,9 +721,22 @@ class RDNARenderer(Renderer):
|
|||
owner_last_use = max(owner_last_use, last_use.get(alias_uop, -1))
|
||||
if owner_last_use < pos:
|
||||
dead_sgprs.append(reg)
|
||||
# Process dead VGPRs - handle pairs specially
|
||||
dead_vgprs_set = set(dead_vgprs)
|
||||
for reg in dead_vgprs:
|
||||
del vgpr_owner[reg]
|
||||
free_vgprs.append(reg)
|
||||
# If this is part of a pair, check if both regs are dead and return as pair
|
||||
if reg in vgpr_pairs:
|
||||
base_reg = reg if reg % 2 == 0 else reg - 1
|
||||
other_reg = base_reg + 1 if reg == base_reg else base_reg
|
||||
if other_reg in dead_vgprs_set and base_reg not in free_vgpr_pairs:
|
||||
# Both regs of pair are dead - return as pair
|
||||
free_vgpr_pairs.append(base_reg)
|
||||
vgpr_pairs.discard(base_reg)
|
||||
vgpr_pairs.discard(other_reg)
|
||||
# Don't add to free_vgprs - pairs are handled separately
|
||||
else:
|
||||
free_vgprs.append(reg)
|
||||
for reg in dead_sgprs:
|
||||
# Don't free SGPRs that are part of a pair (used for 64-bit values like buffer addresses)
|
||||
if reg in sgpr_pairs:
|
||||
|
|
@ -538,19 +758,25 @@ class RDNARenderer(Renderer):
|
|||
def alloc_vgpr_pair(owner: UOp) -> str:
|
||||
"""Allocate an aligned pair of VGPRs for 64-bit values (float64, int64, uint64)"""
|
||||
nonlocal next_vgpr, max_vgpr
|
||||
# Align to even for 64-bit values
|
||||
if next_vgpr % 2 != 0:
|
||||
next_vgpr += 1
|
||||
reg = next_vgpr
|
||||
next_vgpr += 2
|
||||
max_vgpr = max(max_vgpr, next_vgpr)
|
||||
# Try to reuse a free pair first
|
||||
if free_vgpr_pairs:
|
||||
reg = free_vgpr_pairs.pop()
|
||||
else:
|
||||
# Align to even for 64-bit values
|
||||
if next_vgpr % 2 != 0:
|
||||
next_vgpr += 1
|
||||
reg = next_vgpr
|
||||
next_vgpr += 2
|
||||
max_vgpr = max(max_vgpr, next_vgpr)
|
||||
vgpr_owner[reg] = owner
|
||||
vgpr_owner[reg+1] = owner
|
||||
vgpr_pairs.add(reg)
|
||||
vgpr_pairs.add(reg+1)
|
||||
return f"v[{reg}:{reg+1}]"
|
||||
|
||||
def needs_vgpr_pair(dtype: DType) -> bool:
|
||||
"""Check if a dtype needs a VGPR pair (64-bit types)"""
|
||||
# Only float64 needs pairs - int64/uint64 are lowered to 32-bit
|
||||
# Only float64 needs pairs - int64/uint64 use special patterns that don't need persistent pairs
|
||||
return dtype == dtypes.float64
|
||||
|
||||
def alloc_sgpr(owner: UOp) -> str:
|
||||
|
|
@ -607,10 +833,60 @@ class RDNARenderer(Renderer):
|
|||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op is Ops.VECTORIZE:
|
||||
r[u] = [cast(str,r[x]) for x in u.src]
|
||||
# For WMMA inputs (half16), we need contiguous packed VGPRs
|
||||
# half16 = 16 halfs = 8 VGPRs (2 halfs per 32-bit VGPR)
|
||||
if u.dtype.scalar() == dtypes.half and u.dtype.count == 16:
|
||||
# Allocate 8 contiguous VGPRs for packed half16
|
||||
base = next_vgpr
|
||||
if base % 2 != 0: # Align to even for better access
|
||||
base = next_vgpr = next_vgpr + 1
|
||||
next_vgpr += 8
|
||||
max_vgpr = max(max_vgpr, next_vgpr)
|
||||
r[u] = f"v[{base}:{base+7}]"
|
||||
# Pack the source halfs into the destination VGPRs
|
||||
# Each VGPR holds 2 halfs: low 16 bits and high 16 bits
|
||||
for i in range(8):
|
||||
src_lo = r[u.src[i*2]]
|
||||
src_hi = r[u.src[i*2+1]]
|
||||
# Pack two halfs into one VGPR using v_pack_b32_f16
|
||||
kernel.append(f"v_pack_b32_f16 v{base+i}, {src_lo}, {src_hi}")
|
||||
# For float8 (WMMA accumulator), check if sources are contiguous
|
||||
elif u.dtype.scalar() == dtypes.float and u.dtype.count == 8:
|
||||
# Check if all sources are from contiguous registers
|
||||
src_regs = []
|
||||
all_contiguous = True
|
||||
for src in u.src:
|
||||
src_str = r[src]
|
||||
if isinstance(src_str, str) and src_str.startswith('v') and '[' not in src_str:
|
||||
src_regs.append(int(src_str[1:]))
|
||||
else:
|
||||
all_contiguous = False
|
||||
break
|
||||
if all_contiguous and len(src_regs) == 8:
|
||||
# Check if contiguous
|
||||
if src_regs == list(range(src_regs[0], src_regs[0] + 8)):
|
||||
# Already contiguous - just create range reference
|
||||
r[u] = f"v[{src_regs[0]}:{src_regs[0]+7}]"
|
||||
continue
|
||||
# Not contiguous - allocate new registers and copy
|
||||
base = next_vgpr
|
||||
next_vgpr += 8
|
||||
max_vgpr = max(max_vgpr, next_vgpr)
|
||||
r[u] = f"v[{base}:{base+7}]"
|
||||
for i, src in enumerate(u.src):
|
||||
kernel.append(f"v_mov_b32 v{base+i}, {r[src]}")
|
||||
else:
|
||||
r[u] = [cast(str,r[x]) for x in u.src]
|
||||
continue
|
||||
if u.op is Ops.GEP:
|
||||
r[u] = r[u.src[0]][get_single_element(u.arg)]
|
||||
src_reg = r[u.src[0]]
|
||||
idx = get_single_element(u.arg)
|
||||
if isinstance(src_reg, str) and src_reg.startswith('v['):
|
||||
# Extract base from v[base:end] range
|
||||
base = int(src_reg[2:src_reg.index(':')])
|
||||
r[u] = f"v{base + idx}"
|
||||
else:
|
||||
r[u] = src_reg[idx]
|
||||
continue
|
||||
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]]
|
||||
|
|
@ -637,7 +913,29 @@ class RDNARenderer(Renderer):
|
|||
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
|
||||
pending_waits.add(u)
|
||||
elif u.op is Ops.CONST:
|
||||
# Skip allocation for constants used as compile-time REG indices
|
||||
if u in skip_alloc_consts:
|
||||
r[u] = render_val(u.arg, u.dtype) # Store the value as string, no register needed
|
||||
continue # Skip rendering - it's a compile-time index
|
||||
# Check if constant can be inlined (no VGPR needed) - but not if used in STORE
|
||||
if can_inline_const(u.arg, u.dtype) and u not in store_const_uses:
|
||||
r[u] = render_val(u.arg, u.dtype) # Store value string, not register
|
||||
continue # Skip rendering - will be inlined
|
||||
# Deduplicate non-inlineable constants - reuse register if same value already loaded
|
||||
const_key = (u.dtype, u.arg)
|
||||
if const_key in const_cache:
|
||||
r[u] = const_cache[const_key]
|
||||
# Extend the original owner's lifetime to include this use
|
||||
# by updating last_use for the original constant
|
||||
reg_str = const_cache[const_key]
|
||||
reg_num = int(reg_str[1:]) if reg_str.startswith('v') and '[' not in reg_str else None
|
||||
if reg_num is not None and reg_num in vgpr_owner:
|
||||
original_owner = vgpr_owner[reg_num]
|
||||
# Extend last_use to include all uses of this UOp
|
||||
last_use[original_owner] = max(last_use.get(original_owner, -1), last_use.get(u, i))
|
||||
continue # Skip rendering - already loaded
|
||||
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
|
||||
const_cache[const_key] = r[u]
|
||||
elif u.op is Ops.RANGE:
|
||||
r[u] = alloc_vgpr(u)
|
||||
elif u.op is Ops.END:
|
||||
|
|
@ -657,6 +955,21 @@ class RDNARenderer(Renderer):
|
|||
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
|
||||
elif u.op is Ops.IF:
|
||||
r[u] = alloc_sgpr(u)
|
||||
elif u.op is Ops.WMMA:
|
||||
# WMMA outputs a vector of floats (8 for RDNA3)
|
||||
# For RDNA3 WMMA, we can do in-place accumulation if dst == C source
|
||||
# Check if we can reuse the accumulator register range
|
||||
acc_src = u.src[2] # accumulator input
|
||||
acc_reg = r.get(acc_src)
|
||||
if isinstance(acc_reg, str) and acc_reg.startswith('v['):
|
||||
# Accumulator is already a contiguous range - reuse it for output
|
||||
r[u] = acc_reg
|
||||
else:
|
||||
# Allocate contiguous VGPRs for the output
|
||||
base = next_vgpr
|
||||
next_vgpr += 8
|
||||
max_vgpr = max(max_vgpr, next_vgpr)
|
||||
r[u] = f"v[{base}:{base+7}]"
|
||||
elif u.op is Ops.DEFINE_REG:
|
||||
# For 64-bit types, allocate VGPR pairs
|
||||
if needs_vgpr_pair(u.ptrdtype.base):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue