mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more
This commit is contained in:
parent
09c4f61aed
commit
81e9ea2bec
2 changed files with 118 additions and 26 deletions
|
|
@ -12,8 +12,8 @@ from extra.assembly.rdna3.autogen import (
|
|||
v, s, VGPR, SGPR, VCC_LO, EXEC_LO, NULL,
|
||||
# VOP1
|
||||
v_mov_b32_e32, v_cvt_f32_i32_e32, v_cvt_i32_f32_e32, v_cvt_f32_u32_e32, v_cvt_u32_f32_e32,
|
||||
v_cvt_f16_f32_e32, v_cvt_f32_f16_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
|
||||
v_exp_f32_e32, v_log_f32_e32, v_trunc_f32_e32, v_sin_f32_e32,
|
||||
v_cvt_f16_f32_e32, v_cvt_f32_f16_e32, v_rcp_f32_e32, v_rcp_f64_e32, v_sqrt_f32_e32,
|
||||
v_exp_f32_e32, v_log_f32_e32, v_trunc_f32_e32, v_sin_f32_e32, v_fract_f32_e32,
|
||||
v_cvt_f64_f32_e32, v_cvt_f32_f64_e32, v_cvt_f64_i32_e32, v_cvt_f64_u32_e32,
|
||||
v_cvt_i32_f64_e32, v_cvt_u32_f64_e32, v_trunc_f64_e32, v_floor_f64_e32,
|
||||
# VOP2
|
||||
|
|
@ -25,6 +25,7 @@ from extra.assembly.rdna3.autogen import (
|
|||
v_mul_lo_u32, v_mul_hi_u32, v_bfe_u32, v_bfe_i32,
|
||||
v_add_co_u32, v_add_co_ci_u32_e32, v_cndmask_b32_e64, v_add_f64, v_mul_f64, v_sub_co_u32, v_sub_co_ci_u32_e32,
|
||||
v_cmp_lt_f32_e32, v_cmp_eq_f32_e32, v_cmp_neq_f32_e32, v_cmp_gt_f32_e32,
|
||||
v_cmp_lt_f64_e32, v_cmp_eq_f64_e32, v_cmp_neq_f64_e32, v_cmp_gt_f64_e32,
|
||||
v_cmp_lt_i32_e32, v_cmp_eq_i32_e32, v_cmp_ne_i32_e32, v_cmp_gt_i32_e32,
|
||||
v_cmp_lt_u32_e32, v_cmp_eq_u32_e32, v_cmp_ne_u32_e32, v_cmp_gt_u32_e32,
|
||||
# SOPP/SOP
|
||||
|
|
@ -36,6 +37,8 @@ from extra.assembly.rdna3.autogen import (
|
|||
global_store_b32, global_store_b64, global_store_b128, global_store_b16, global_store_b8,
|
||||
# DS (local memory)
|
||||
ds_load_b32, ds_load_b64, ds_load_b128, ds_store_b32, ds_store_b64, ds_store_b128,
|
||||
# WMMA (wave matrix multiply-accumulate)
|
||||
v_wmma_f32_16x16x16_f16, v_wmma_f32_16x16x16_bf16, v_wmma_f16_16x16x16_f16, v_wmma_bf16_16x16x16_bf16,
|
||||
)
|
||||
|
||||
# Helper for VOP2: src0 can be constant/literal, vsrc1 must be VGPR - swap for commutative ops
|
||||
|
|
@ -43,6 +46,19 @@ def _sw(ctx, a, b):
|
|||
ar, br = ctx.get_reg(a), ctx.get_reg(b)
|
||||
return (br, ar) if isinstance(br, (int, float)) and not isinstance(ar, (int, float)) else (ar, br)
|
||||
|
||||
# Helper for 64-bit bitwise operations: apply op to both low and high 32-bit parts
|
||||
def _bitwise64(ctx, a, b, op):
|
||||
ar, br = ctx.get_reg(a), ctx.get_reg(b)
|
||||
# Handle immediate constants: extract low and high 32-bit parts
|
||||
if isinstance(br, int):
|
||||
b_lo, b_hi = br & 0xFFFFFFFF, (br >> 32) & 0xFFFFFFFF
|
||||
return [op(ctx.dst, ar, b_lo), op(v[ctx.dst.idx+1], v[ar.idx+1] if isinstance(ar, VGPR) else 0, b_hi)]
|
||||
if isinstance(ar, int):
|
||||
a_lo, a_hi = ar & 0xFFFFFFFF, (ar >> 32) & 0xFFFFFFFF
|
||||
return [op(ctx.dst, a_lo, br), op(v[ctx.dst.idx+1], a_hi, v[br.idx+1])]
|
||||
# Both are VGPRs
|
||||
return [op(ctx.dst, ar, br), op(v[ctx.dst.idx+1], v[ar.idx+1], v[br.idx+1])]
|
||||
|
||||
# Module-level PatternMatcher for simple ALU and CAST operations
|
||||
render_ops = PatternMatcher([
|
||||
# CAST: float32 <-> int32/uint32
|
||||
|
|
@ -229,10 +245,20 @@ render_ops = PatternMatcher([
|
|||
(UPat(Ops.MUL, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_mul_f32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.MUL, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8),
|
||||
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_mul_lo_u32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
# Bitwise: int only
|
||||
(UPat(Ops.AND, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_and_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.OR, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.XOR, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_xor_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
# Bitwise: 64-bit (need to operate on both low and high 32-bit parts)
|
||||
(UPat(Ops.AND, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
|
||||
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_and_b32_e32)),
|
||||
(UPat(Ops.OR, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
|
||||
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_or_b32_e32)),
|
||||
(UPat(Ops.XOR, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
|
||||
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_xor_b32_e32)),
|
||||
# Bitwise: 32-bit and smaller ints, and bool
|
||||
(UPat(Ops.AND, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
|
||||
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_and_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.OR, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
|
||||
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.XOR, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
|
||||
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_xor_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
# SHL: int64, default to i32
|
||||
(UPat(Ops.SHL, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
|
||||
lambda ctx,x,a,b: [v_lshlrev_b64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
|
||||
|
|
@ -242,23 +268,24 @@ render_ops = PatternMatcher([
|
|||
lambda ctx,x,a,b: [v_ashrrev_i64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
|
||||
(UPat(Ops.SHR, dtype=dtypes.uint64, src=(UPat.var("a"), UPat.var("b")), name="x"),
|
||||
lambda ctx,x,a,b: [v_lshrrev_b64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
|
||||
# MAX: floats, signed ints, unsigned ints
|
||||
# MAX: floats, signed ints, unsigned ints, bool (bool uses OR since max(True, False) = True)
|
||||
(UPat(Ops.MAX, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_f32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.MAX, dtype=dtypes.sints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_i32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.MAX, dtype=dtypes.uints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_u32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
(UPat(Ops.MAX, dtype=dtypes.bool, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
|
||||
# MULACC (FMA): float64, floats
|
||||
(UPat(Ops.MULACC, dtype=dtypes.float64, src=(UPat.var("a"), UPat.var("b"), UPat.var("d")), name="x"),
|
||||
lambda ctx,x,a,b,d: [v_fma_f64(ctx.dst, ctx.get_reg(a), ctx.get_reg(b), ctx.get_reg(d))]),
|
||||
(UPat(Ops.MULACC, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b"), UPat.var("d")), name="x"),
|
||||
lambda ctx,x,a,b,d: [v_fma_f32(ctx.dst, ctx.get_reg(a), ctx.get_reg(b), ctx.get_reg(d))]),
|
||||
# Transcendental: float only
|
||||
# Transcendental: float64 first (for precision), then float32
|
||||
(UPat(Ops.RECIPROCAL, dtype=dtypes.float64, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_rcp_f64_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.RECIPROCAL, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_rcp_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.SQRT, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sqrt_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.EXP2, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_exp_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.LOG2, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_log_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.TRUNC, dtype=dtypes.float64, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_trunc_f64_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
(UPat(Ops.TRUNC, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_trunc_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
# SIN: input should already be normalized by 1/(2π) via rdna_uops.py
|
||||
(UPat(Ops.SIN, dtype=dtypes.float32, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sin_f32_e32(ctx.dst, ctx.get_reg(a))]),
|
||||
# NEG: floats vs ints
|
||||
(UPat(Ops.NEG, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_mul_f32_e32(ctx.dst, -1.0, ctx.get_reg(a))]),
|
||||
(UPat(Ops.NEG, dtype=dtypes.ints, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sub_nc_u32_e32(ctx.dst, 0, ctx.get_reg(a))]),
|
||||
|
|
@ -286,14 +313,15 @@ class RDNARenderer(Renderer):
|
|||
extra_matcher = rdna_matcher + create_non_native_float_pats((dtypes.bfloat16,)) + rdna_bf16_cast
|
||||
tensor_cores = tc.amd_rdna3
|
||||
# Declare hardware-supported ops (enables decomposition patterns like fast_idiv for constant division)
|
||||
# NOTE: SIN removed - use software implementations for precision with large values
|
||||
code_for_op = {
|
||||
# Transcendental ops
|
||||
Ops.EXP2: lambda: None, Ops.LOG2: lambda: None, Ops.SIN: lambda: None, Ops.SQRT: lambda: None, Ops.RECIPROCAL: lambda: None,
|
||||
# Transcendental ops (SIN uses software for precision with large values, RECIPROCAL is needed by software SIN)
|
||||
Ops.EXP2: lambda: None, Ops.LOG2: lambda: None, Ops.SQRT: lambda: None, Ops.TRUNC: lambda: None, Ops.RECIPROCAL: lambda: None,
|
||||
# Bitwise ops
|
||||
Ops.AND: lambda: None, Ops.OR: lambda: None, Ops.XOR: lambda: None, Ops.SHL: lambda: None, Ops.SHR: lambda: None,
|
||||
# Arithmetic ops
|
||||
# Arithmetic ops (IDIV/MOD handled directly for proper 64-bit support)
|
||||
Ops.ADD: lambda: None, Ops.SUB: lambda: None, Ops.MUL: lambda: None, Ops.NEG: lambda: None,
|
||||
Ops.IDIV: lambda: None, Ops.MOD: lambda: None, Ops.TRUNC: lambda: None,
|
||||
Ops.IDIV: lambda: None, Ops.MOD: lambda: None,
|
||||
# Comparison ops
|
||||
Ops.CMPLT: lambda: None, Ops.CMPEQ: lambda: None, Ops.CMPNE: lambda: None, Ops.WHERE: lambda: None,
|
||||
# Max (used in various patterns)
|
||||
|
|
@ -398,12 +426,15 @@ class RDNARenderer(Renderer):
|
|||
(Ops.CMPLT, dtypes.float32): v_cmp_lt_f32_e32, (Ops.CMPLT, dtypes.int32): v_cmp_lt_i32_e32, (Ops.CMPLT, dtypes.uint32): v_cmp_lt_u32_e32,
|
||||
(Ops.CMPEQ, dtypes.float32): v_cmp_eq_f32_e32, (Ops.CMPEQ, dtypes.int32): v_cmp_eq_i32_e32, (Ops.CMPEQ, dtypes.uint32): v_cmp_eq_u32_e32,
|
||||
(Ops.CMPNE, dtypes.float32): v_cmp_neq_f32_e32, (Ops.CMPNE, dtypes.int32): v_cmp_ne_i32_e32, (Ops.CMPNE, dtypes.uint32): v_cmp_ne_u32_e32,
|
||||
(Ops.CMPLT, dtypes.float64): v_cmp_lt_f64_e32, (Ops.CMPEQ, dtypes.float64): v_cmp_eq_f64_e32, (Ops.CMPNE, dtypes.float64): v_cmp_neq_f64_e32,
|
||||
}
|
||||
# GT versions for swapping CMPLT: a < b ⇔ b > a
|
||||
cmp_gt_map = {
|
||||
(Ops.CMPLT, dtypes.float32): v_cmp_gt_f32_e32, (Ops.CMPLT, dtypes.int32): v_cmp_gt_i32_e32, (Ops.CMPLT, dtypes.uint32): v_cmp_gt_u32_e32,
|
||||
(Ops.CMPLT, dtypes.float64): v_cmp_gt_f64_e32,
|
||||
}
|
||||
base_dtype = dtypes.float32 if dtypes.is_float(dtype) else dtypes.int32 if dtype in (dtypes.int8, dtypes.int16, dtypes.int32) else dtypes.uint32
|
||||
base_dtype = dtypes.float64 if dtype == dtypes.float64 else dtypes.float32 if dtypes.is_float(dtype) else \
|
||||
dtypes.int32 if dtype in (dtypes.int8, dtypes.int16, dtypes.int32) else dtypes.uint32
|
||||
def is_const(x): return isinstance(x, (int, float))
|
||||
# For CMPLT with constant in vsrc1 position, swap and use GT
|
||||
if op is Ops.CMPLT and is_const(b) and not is_const(a):
|
||||
|
|
@ -464,9 +495,44 @@ class RDNARenderer(Renderer):
|
|||
elif op is Ops.IDIV:
|
||||
# Integer division using floating-point approximation
|
||||
# quotient = trunc(float(a) * rcp(float(b)))
|
||||
if dtype in (dtypes.int64, dtypes.uint64):
|
||||
# 64-bit division via float64 (53 bits precision, sufficient for most cases)
|
||||
# For values > 53 bits, may have off-by-one errors but acceptable for Payne-Hanek
|
||||
s = ra.get_scratch_vgpr(8)
|
||||
a_reg, b_reg = (a if isinstance(a, VGPR) else v[a.idx] if hasattr(a, 'idx') else a,
|
||||
b if isinstance(b, VGPR) else v[b.idx] if hasattr(b, 'idx') else b)
|
||||
# Convert a to float64: low + high*2^32
|
||||
code.append(v_cvt_f64_u32_e32(v[s:s+2], a_reg)) # low part
|
||||
a_hi = v[a_reg.idx+1] if isinstance(a_reg, VGPR) else 0
|
||||
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], a_hi)) # high part
|
||||
code.append(v_mov_b32_e32(v[s+4], 0))
|
||||
code.append(v_mov_b32_e32(v[s+5], 0x41F00000)) # 2^32 in float64
|
||||
code.append(v_mul_f64(v[s+2:s+4], v[s+4:s+6], v[s+2:s+4])) # high * 2^32
|
||||
code.append(v_add_f64(v[s:s+2], v[s:s+2], v[s+2:s+4])) # a as float64
|
||||
# Convert b to float64: low + high*2^32
|
||||
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], b_reg)) # low part
|
||||
b_hi = v[b_reg.idx+1] if isinstance(b_reg, VGPR) else 0
|
||||
code.append(v_cvt_f64_u32_e32(v[s+4:s+6], b_hi)) # high part
|
||||
code.append(v_mov_b32_e32(v[s+6], 0))
|
||||
code.append(v_mov_b32_e32(v[s+7], 0x41F00000)) # 2^32 in float64
|
||||
code.append(v_mul_f64(v[s+4:s+6], v[s+6:s+8], v[s+4:s+6])) # high * 2^32
|
||||
code.append(v_add_f64(v[s+2:s+4], v[s+2:s+4], v[s+4:s+6])) # b as float64
|
||||
# Compute a/b via reciprocal
|
||||
code.append(v_rcp_f64_e32(v[s+4:s+6], v[s+2:s+4])) # 1/b
|
||||
code.append(v_mul_f64(v[s:s+2], v[s:s+2], v[s+4:s+6])) # a/b
|
||||
code.append(v_trunc_f64_e32(v[s:s+2], v[s:s+2])) # floor(a/b)
|
||||
# Convert back to uint64: for most cases, result fits in low 32 bits, high is 0
|
||||
code.append(v_cvt_u32_f64_e32(v[dst.idx], v[s:s+2])) # low part
|
||||
# For high part: (result - low) / 2^32
|
||||
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], v[dst.idx])) # low as float64
|
||||
code.append(v_mul_f64(v[s+2:s+4], -1.0, v[s+2:s+4])) # -low
|
||||
code.append(v_add_f64(v[s:s+2], v[s:s+2], v[s+2:s+4])) # result - low
|
||||
code.append(v_mov_b32_e32(v[s+4], 0))
|
||||
code.append(v_mov_b32_e32(v[s+5], 0x3DF00000)) # 2^-32 in float64
|
||||
code.append(v_mul_f64(v[s:s+2], v[s:s+2], v[s+4:s+6])) # (result - low) * 2^-32
|
||||
code.append(v_cvt_u32_f64_e32(v[dst.idx+1], v[s:s+2])) # high part
|
||||
# For signed: handle signs, do unsigned div, restore sign
|
||||
is_signed = dtype in (dtypes.int32, dtypes.int16, dtypes.int8)
|
||||
if is_signed:
|
||||
elif dtype in (dtypes.int32, dtypes.int16, dtypes.int8):
|
||||
# Compute absolute values and track signs
|
||||
tmp_abs_a = ra.alloc_vgpr(u) # |a|
|
||||
tmp_abs_b = ra.alloc_vgpr(u) # |b|
|
||||
|
|
@ -640,7 +706,7 @@ class RDNARenderer(Renderer):
|
|||
|
||||
elif u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHL, Ops.SHR,
|
||||
Ops.MAX, Ops.MULACC, Ops.RECIPROCAL, Ops.SQRT, Ops.EXP2, Ops.LOG2,
|
||||
Ops.TRUNC, Ops.NEG, Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE, Ops.WHERE, Ops.SIN,
|
||||
Ops.TRUNC, Ops.NEG, Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE, Ops.WHERE,
|
||||
Ops.IDIV, Ops.MOD):
|
||||
maybe_wait(u.src) # Wait for any pending loads used by this operation
|
||||
dst = ra.alloc_vgpr_pair(u) if RDNARegAlloc.needs_vgpr_pair(u.dtype) else ra.alloc_vgpr(u)
|
||||
|
|
@ -827,10 +893,19 @@ class RDNARenderer(Renderer):
|
|||
if isinstance(addr, (int, float)):
|
||||
# addr is the element index, not byte offset
|
||||
reg_offset = int(addr)
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset], val))
|
||||
if itemsize == 8:
|
||||
# 64-bit: move both low and high 32-bit parts
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset * 2], v[val.idx] if isinstance(val, VGPR) else val))
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset * 2 + 1], v[val.idx + 1] if isinstance(val, VGPR) else 0))
|
||||
else:
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset], val))
|
||||
else:
|
||||
# Variable offset - use first register as fallback (TODO: proper indirect)
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx], val))
|
||||
if itemsize == 8:
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx], v[val.idx] if isinstance(val, VGPR) else val))
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx + 1], v[val.idx + 1] if isinstance(val, VGPR) else 0))
|
||||
else:
|
||||
code.append(v_mov_b32_e32(v[buf_reg.idx], val))
|
||||
else:
|
||||
# Global memory store: use buffer SGPR pair as saddr
|
||||
buf_result = r.get(buf_uop) if buf_uop in r else get_reg(buf_uop)
|
||||
|
|
@ -889,6 +964,28 @@ class RDNARenderer(Renderer):
|
|||
elif u.op is Ops.BARRIER:
|
||||
code.append(s_barrier())
|
||||
|
||||
elif u.op is Ops.WMMA:
|
||||
maybe_wait(u.src)
|
||||
# WMMA: wave matrix multiply-accumulate
|
||||
# arg: (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
||||
# src: (A_vec, B_vec, C_acc)
|
||||
dtype_in, dtype_out = u.arg[2], u.dtype.scalar()
|
||||
a_reg, b_reg, c_reg = get_reg(u.src[0]), get_reg(u.src[1]), get_reg(u.src[2])
|
||||
# Output is 8 floats (or 8 halves) = 8 VGPRs
|
||||
dst = ra.alloc_vgpr_range(u, 8)
|
||||
# Select the right WMMA instruction based on input/output types
|
||||
if dtype_in == dtypes.half and dtype_out == dtypes.float:
|
||||
code.append(v_wmma_f32_16x16x16_f16(dst, a_reg, b_reg, c_reg))
|
||||
elif dtype_in == dtypes.bfloat16 and dtype_out == dtypes.float:
|
||||
code.append(v_wmma_f32_16x16x16_bf16(dst, a_reg, b_reg, c_reg))
|
||||
elif dtype_in == dtypes.half and dtype_out == dtypes.half:
|
||||
code.append(v_wmma_f16_16x16x16_f16(dst, a_reg, b_reg, c_reg))
|
||||
elif dtype_in == dtypes.bfloat16 and dtype_out == dtypes.bfloat16:
|
||||
code.append(v_wmma_bf16_16x16x16_bf16(dst, a_reg, b_reg, c_reg))
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {dtype_in} -> {dtype_out}")
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.AFTER:
|
||||
# AFTER ensures previous operations complete, then returns the buffer
|
||||
# src[0] is the buffer, src[1] is the operation that must complete
|
||||
|
|
|
|||
|
|
@ -240,10 +240,6 @@ rdna_matcher = PatternMatcher([
|
|||
lambda x, a, b: UOp(Ops.CMPNE, dtypes.bool, (a.cast(dtypes.float32), b.cast(dtypes.float32)))),
|
||||
# devectorize ALU operations - RDNA doesn't have vector float ALU
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
# SIN: normalize input by 1/(2π) for v_sin_f32 (expects [0,1) -> [0,2π))
|
||||
(UPat(Ops.SIN, dtype=dtypes.float32, src=(UPat.var("x"),), name="u"),
|
||||
lambda u, x: None if u.tag == "normalized" else # skip already normalized
|
||||
UOp(Ops.SIN, dtypes.float32, (x * UOp.const(dtypes.float32, 0.15915494309189535),)).rtag("normalized")),
|
||||
# Fix fast_idiv output when shift >= 32 (needs 64-bit multiply)
|
||||
# Pattern: (x * const) >> shift for unsigned
|
||||
(UPat(Ops.SHR, src=(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c"))), UPat.cvar("shift"))), _fix_fast_idiv_unsigned),
|
||||
|
|
@ -292,13 +288,12 @@ rdna_matcher = PatternMatcher([
|
|||
(UPat(Ops.SUB, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_sub),
|
||||
(UPat(Ops.MUL, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_mul),
|
||||
(UPat(Ops.MAX, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_max),
|
||||
# Unary ops: RECIPROCAL, SQRT, EXP2, LOG2, TRUNC, SIN, NEG
|
||||
# Unary ops: RECIPROCAL, SQRT, EXP2, LOG2, TRUNC, NEG (SIN uses software impl for precision)
|
||||
(UPat(Ops.RECIPROCAL, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_reciprocal),
|
||||
(UPat(Ops.SQRT, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_sqrt),
|
||||
(UPat(Ops.EXP2, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_exp2),
|
||||
(UPat(Ops.LOG2, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_log2),
|
||||
(UPat(Ops.TRUNC, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_trunc),
|
||||
(UPat(Ops.SIN, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_sin),
|
||||
(UPat(Ops.NEG, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_neg),
|
||||
# WHERE for float16
|
||||
(UPat(Ops.WHERE, dtype=_small_floats, src=(UPat.var("cond"), UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_where),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue