This commit is contained in:
George Hotz 2025-12-26 23:56:58 +00:00
commit 81e9ea2bec
2 changed files with 118 additions and 26 deletions

View file

@ -12,8 +12,8 @@ from extra.assembly.rdna3.autogen import (
v, s, VGPR, SGPR, VCC_LO, EXEC_LO, NULL,
# VOP1
v_mov_b32_e32, v_cvt_f32_i32_e32, v_cvt_i32_f32_e32, v_cvt_f32_u32_e32, v_cvt_u32_f32_e32,
v_cvt_f16_f32_e32, v_cvt_f32_f16_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
v_exp_f32_e32, v_log_f32_e32, v_trunc_f32_e32, v_sin_f32_e32,
v_cvt_f16_f32_e32, v_cvt_f32_f16_e32, v_rcp_f32_e32, v_rcp_f64_e32, v_sqrt_f32_e32,
v_exp_f32_e32, v_log_f32_e32, v_trunc_f32_e32, v_sin_f32_e32, v_fract_f32_e32,
v_cvt_f64_f32_e32, v_cvt_f32_f64_e32, v_cvt_f64_i32_e32, v_cvt_f64_u32_e32,
v_cvt_i32_f64_e32, v_cvt_u32_f64_e32, v_trunc_f64_e32, v_floor_f64_e32,
# VOP2
@ -25,6 +25,7 @@ from extra.assembly.rdna3.autogen import (
v_mul_lo_u32, v_mul_hi_u32, v_bfe_u32, v_bfe_i32,
v_add_co_u32, v_add_co_ci_u32_e32, v_cndmask_b32_e64, v_add_f64, v_mul_f64, v_sub_co_u32, v_sub_co_ci_u32_e32,
v_cmp_lt_f32_e32, v_cmp_eq_f32_e32, v_cmp_neq_f32_e32, v_cmp_gt_f32_e32,
v_cmp_lt_f64_e32, v_cmp_eq_f64_e32, v_cmp_neq_f64_e32, v_cmp_gt_f64_e32,
v_cmp_lt_i32_e32, v_cmp_eq_i32_e32, v_cmp_ne_i32_e32, v_cmp_gt_i32_e32,
v_cmp_lt_u32_e32, v_cmp_eq_u32_e32, v_cmp_ne_u32_e32, v_cmp_gt_u32_e32,
# SOPP/SOP
@ -36,6 +37,8 @@ from extra.assembly.rdna3.autogen import (
global_store_b32, global_store_b64, global_store_b128, global_store_b16, global_store_b8,
# DS (local memory)
ds_load_b32, ds_load_b64, ds_load_b128, ds_store_b32, ds_store_b64, ds_store_b128,
# WMMA (wave matrix multiply-accumulate)
v_wmma_f32_16x16x16_f16, v_wmma_f32_16x16x16_bf16, v_wmma_f16_16x16x16_f16, v_wmma_bf16_16x16x16_bf16,
)
# Helper for VOP2: src0 can be constant/literal, vsrc1 must be VGPR - swap for commutative ops
@ -43,6 +46,19 @@ def _sw(ctx, a, b):
ar, br = ctx.get_reg(a), ctx.get_reg(b)
return (br, ar) if isinstance(br, (int, float)) and not isinstance(ar, (int, float)) else (ar, br)
# Helper for 64-bit bitwise operations: apply op to both low and high 32-bit parts
def _bitwise64(ctx, a, b, op):
ar, br = ctx.get_reg(a), ctx.get_reg(b)
# Handle immediate constants: extract low and high 32-bit parts
if isinstance(br, int):
b_lo, b_hi = br & 0xFFFFFFFF, (br >> 32) & 0xFFFFFFFF
return [op(ctx.dst, ar, b_lo), op(v[ctx.dst.idx+1], v[ar.idx+1] if isinstance(ar, VGPR) else 0, b_hi)]
if isinstance(ar, int):
a_lo, a_hi = ar & 0xFFFFFFFF, (ar >> 32) & 0xFFFFFFFF
return [op(ctx.dst, a_lo, br), op(v[ctx.dst.idx+1], a_hi, v[br.idx+1])]
# Both are VGPRs
return [op(ctx.dst, ar, br), op(v[ctx.dst.idx+1], v[ar.idx+1], v[br.idx+1])]
# Module-level PatternMatcher for simple ALU and CAST operations
render_ops = PatternMatcher([
# CAST: float32 <-> int32/uint32
@ -229,10 +245,20 @@ render_ops = PatternMatcher([
(UPat(Ops.MUL, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_mul_f32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.MUL, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8),
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_mul_lo_u32(ctx.dst, *_sw(ctx,a,b))]),
# Bitwise: int only
(UPat(Ops.AND, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_and_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.OR, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.XOR, dtype=dtypes.ints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_xor_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
# Bitwise: 64-bit (need to operate on both low and high 32-bit parts)
(UPat(Ops.AND, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_and_b32_e32)),
(UPat(Ops.OR, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_or_b32_e32)),
(UPat(Ops.XOR, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
lambda ctx,x,a,b: _bitwise64(ctx, a, b, v_xor_b32_e32)),
# Bitwise: 32-bit and smaller ints, and bool
(UPat(Ops.AND, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_and_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.OR, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.XOR, dtype=(dtypes.int32, dtypes.uint32, dtypes.int16, dtypes.uint16, dtypes.int8, dtypes.uint8, dtypes.bool),
src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_xor_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
# SHL: int64, default to i32
(UPat(Ops.SHL, dtype=(dtypes.int64, dtypes.uint64), src=(UPat.var("a"), UPat.var("b")), name="x"),
lambda ctx,x,a,b: [v_lshlrev_b64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
@ -242,23 +268,24 @@ render_ops = PatternMatcher([
lambda ctx,x,a,b: [v_ashrrev_i64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
(UPat(Ops.SHR, dtype=dtypes.uint64, src=(UPat.var("a"), UPat.var("b")), name="x"),
lambda ctx,x,a,b: [v_lshrrev_b64(ctx.dst, ctx.get_reg(b), ctx.get_reg(a))]),
# MAX: floats, signed ints, unsigned ints
# MAX: floats, signed ints, unsigned ints, bool (bool uses OR since max(True, False) = True)
(UPat(Ops.MAX, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_f32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.MAX, dtype=dtypes.sints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_i32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.MAX, dtype=dtypes.uints, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_max_u32_e32(ctx.dst, *_sw(ctx,a,b))]),
(UPat(Ops.MAX, dtype=dtypes.bool, src=(UPat.var("a"), UPat.var("b")), name="x"), lambda ctx,x,a,b: [v_or_b32_e32(ctx.dst, *_sw(ctx,a,b))]),
# MULACC (FMA): float64, floats
(UPat(Ops.MULACC, dtype=dtypes.float64, src=(UPat.var("a"), UPat.var("b"), UPat.var("d")), name="x"),
lambda ctx,x,a,b,d: [v_fma_f64(ctx.dst, ctx.get_reg(a), ctx.get_reg(b), ctx.get_reg(d))]),
(UPat(Ops.MULACC, dtype=dtypes.floats, src=(UPat.var("a"), UPat.var("b"), UPat.var("d")), name="x"),
lambda ctx,x,a,b,d: [v_fma_f32(ctx.dst, ctx.get_reg(a), ctx.get_reg(b), ctx.get_reg(d))]),
# Transcendental: float only
# Transcendental: float64 first (for precision), then float32
(UPat(Ops.RECIPROCAL, dtype=dtypes.float64, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_rcp_f64_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.RECIPROCAL, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_rcp_f32_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.SQRT, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sqrt_f32_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.EXP2, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_exp_f32_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.LOG2, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_log_f32_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.TRUNC, dtype=dtypes.float64, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_trunc_f64_e32(ctx.dst, ctx.get_reg(a))]),
(UPat(Ops.TRUNC, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_trunc_f32_e32(ctx.dst, ctx.get_reg(a))]),
# SIN: input should already be normalized by 1/(2π) via rdna_uops.py
(UPat(Ops.SIN, dtype=dtypes.float32, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sin_f32_e32(ctx.dst, ctx.get_reg(a))]),
# NEG: floats vs ints
(UPat(Ops.NEG, dtype=dtypes.floats, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_mul_f32_e32(ctx.dst, -1.0, ctx.get_reg(a))]),
(UPat(Ops.NEG, dtype=dtypes.ints, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: [v_sub_nc_u32_e32(ctx.dst, 0, ctx.get_reg(a))]),
@ -286,14 +313,15 @@ class RDNARenderer(Renderer):
extra_matcher = rdna_matcher + create_non_native_float_pats((dtypes.bfloat16,)) + rdna_bf16_cast
tensor_cores = tc.amd_rdna3
# Declare hardware-supported ops (enables decomposition patterns like fast_idiv for constant division)
# NOTE: SIN removed - use software implementations for precision with large values
code_for_op = {
# Transcendental ops
Ops.EXP2: lambda: None, Ops.LOG2: lambda: None, Ops.SIN: lambda: None, Ops.SQRT: lambda: None, Ops.RECIPROCAL: lambda: None,
# Transcendental ops (SIN uses software for precision with large values, RECIPROCAL is needed by software SIN)
Ops.EXP2: lambda: None, Ops.LOG2: lambda: None, Ops.SQRT: lambda: None, Ops.TRUNC: lambda: None, Ops.RECIPROCAL: lambda: None,
# Bitwise ops
Ops.AND: lambda: None, Ops.OR: lambda: None, Ops.XOR: lambda: None, Ops.SHL: lambda: None, Ops.SHR: lambda: None,
# Arithmetic ops
# Arithmetic ops (IDIV/MOD handled directly for proper 64-bit support)
Ops.ADD: lambda: None, Ops.SUB: lambda: None, Ops.MUL: lambda: None, Ops.NEG: lambda: None,
Ops.IDIV: lambda: None, Ops.MOD: lambda: None, Ops.TRUNC: lambda: None,
Ops.IDIV: lambda: None, Ops.MOD: lambda: None,
# Comparison ops
Ops.CMPLT: lambda: None, Ops.CMPEQ: lambda: None, Ops.CMPNE: lambda: None, Ops.WHERE: lambda: None,
# Max (used in various patterns)
@ -398,12 +426,15 @@ class RDNARenderer(Renderer):
(Ops.CMPLT, dtypes.float32): v_cmp_lt_f32_e32, (Ops.CMPLT, dtypes.int32): v_cmp_lt_i32_e32, (Ops.CMPLT, dtypes.uint32): v_cmp_lt_u32_e32,
(Ops.CMPEQ, dtypes.float32): v_cmp_eq_f32_e32, (Ops.CMPEQ, dtypes.int32): v_cmp_eq_i32_e32, (Ops.CMPEQ, dtypes.uint32): v_cmp_eq_u32_e32,
(Ops.CMPNE, dtypes.float32): v_cmp_neq_f32_e32, (Ops.CMPNE, dtypes.int32): v_cmp_ne_i32_e32, (Ops.CMPNE, dtypes.uint32): v_cmp_ne_u32_e32,
(Ops.CMPLT, dtypes.float64): v_cmp_lt_f64_e32, (Ops.CMPEQ, dtypes.float64): v_cmp_eq_f64_e32, (Ops.CMPNE, dtypes.float64): v_cmp_neq_f64_e32,
}
# GT versions for swapping CMPLT: a < b ⇔ b > a
cmp_gt_map = {
(Ops.CMPLT, dtypes.float32): v_cmp_gt_f32_e32, (Ops.CMPLT, dtypes.int32): v_cmp_gt_i32_e32, (Ops.CMPLT, dtypes.uint32): v_cmp_gt_u32_e32,
(Ops.CMPLT, dtypes.float64): v_cmp_gt_f64_e32,
}
base_dtype = dtypes.float32 if dtypes.is_float(dtype) else dtypes.int32 if dtype in (dtypes.int8, dtypes.int16, dtypes.int32) else dtypes.uint32
base_dtype = dtypes.float64 if dtype == dtypes.float64 else dtypes.float32 if dtypes.is_float(dtype) else \
dtypes.int32 if dtype in (dtypes.int8, dtypes.int16, dtypes.int32) else dtypes.uint32
def is_const(x): return isinstance(x, (int, float))
# For CMPLT with constant in vsrc1 position, swap and use GT
if op is Ops.CMPLT and is_const(b) and not is_const(a):
@ -464,9 +495,44 @@ class RDNARenderer(Renderer):
elif op is Ops.IDIV:
# Integer division using floating-point approximation
# quotient = trunc(float(a) * rcp(float(b)))
if dtype in (dtypes.int64, dtypes.uint64):
# 64-bit division via float64 (53 bits precision, sufficient for most cases)
# For values > 53 bits, may have off-by-one errors but acceptable for Payne-Hanek
s = ra.get_scratch_vgpr(8)
a_reg, b_reg = (a if isinstance(a, VGPR) else v[a.idx] if hasattr(a, 'idx') else a,
b if isinstance(b, VGPR) else v[b.idx] if hasattr(b, 'idx') else b)
# Convert a to float64: low + high*2^32
code.append(v_cvt_f64_u32_e32(v[s:s+2], a_reg)) # low part
a_hi = v[a_reg.idx+1] if isinstance(a_reg, VGPR) else 0
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], a_hi)) # high part
code.append(v_mov_b32_e32(v[s+4], 0))
code.append(v_mov_b32_e32(v[s+5], 0x41F00000)) # 2^32 in float64
code.append(v_mul_f64(v[s+2:s+4], v[s+4:s+6], v[s+2:s+4])) # high * 2^32
code.append(v_add_f64(v[s:s+2], v[s:s+2], v[s+2:s+4])) # a as float64
# Convert b to float64: low + high*2^32
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], b_reg)) # low part
b_hi = v[b_reg.idx+1] if isinstance(b_reg, VGPR) else 0
code.append(v_cvt_f64_u32_e32(v[s+4:s+6], b_hi)) # high part
code.append(v_mov_b32_e32(v[s+6], 0))
code.append(v_mov_b32_e32(v[s+7], 0x41F00000)) # 2^32 in float64
code.append(v_mul_f64(v[s+4:s+6], v[s+6:s+8], v[s+4:s+6])) # high * 2^32
code.append(v_add_f64(v[s+2:s+4], v[s+2:s+4], v[s+4:s+6])) # b as float64
# Compute a/b via reciprocal
code.append(v_rcp_f64_e32(v[s+4:s+6], v[s+2:s+4])) # 1/b
code.append(v_mul_f64(v[s:s+2], v[s:s+2], v[s+4:s+6])) # a/b
code.append(v_trunc_f64_e32(v[s:s+2], v[s:s+2])) # floor(a/b)
# Convert back to uint64: for most cases, result fits in low 32 bits, high is 0
code.append(v_cvt_u32_f64_e32(v[dst.idx], v[s:s+2])) # low part
# For high part: (result - low) / 2^32
code.append(v_cvt_f64_u32_e32(v[s+2:s+4], v[dst.idx])) # low as float64
code.append(v_mul_f64(v[s+2:s+4], -1.0, v[s+2:s+4])) # -low
code.append(v_add_f64(v[s:s+2], v[s:s+2], v[s+2:s+4])) # result - low
code.append(v_mov_b32_e32(v[s+4], 0))
code.append(v_mov_b32_e32(v[s+5], 0x3DF00000)) # 2^-32 in float64
code.append(v_mul_f64(v[s:s+2], v[s:s+2], v[s+4:s+6])) # (result - low) * 2^-32
code.append(v_cvt_u32_f64_e32(v[dst.idx+1], v[s:s+2])) # high part
# For signed: handle signs, do unsigned div, restore sign
is_signed = dtype in (dtypes.int32, dtypes.int16, dtypes.int8)
if is_signed:
elif dtype in (dtypes.int32, dtypes.int16, dtypes.int8):
# Compute absolute values and track signs
tmp_abs_a = ra.alloc_vgpr(u) # |a|
tmp_abs_b = ra.alloc_vgpr(u) # |b|
@ -640,7 +706,7 @@ class RDNARenderer(Renderer):
elif u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHL, Ops.SHR,
Ops.MAX, Ops.MULACC, Ops.RECIPROCAL, Ops.SQRT, Ops.EXP2, Ops.LOG2,
Ops.TRUNC, Ops.NEG, Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE, Ops.WHERE, Ops.SIN,
Ops.TRUNC, Ops.NEG, Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE, Ops.WHERE,
Ops.IDIV, Ops.MOD):
maybe_wait(u.src) # Wait for any pending loads used by this operation
dst = ra.alloc_vgpr_pair(u) if RDNARegAlloc.needs_vgpr_pair(u.dtype) else ra.alloc_vgpr(u)
@ -827,10 +893,19 @@ class RDNARenderer(Renderer):
if isinstance(addr, (int, float)):
# addr is the element index, not byte offset
reg_offset = int(addr)
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset], val))
if itemsize == 8:
# 64-bit: move both low and high 32-bit parts
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset * 2], v[val.idx] if isinstance(val, VGPR) else val))
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset * 2 + 1], v[val.idx + 1] if isinstance(val, VGPR) else 0))
else:
code.append(v_mov_b32_e32(v[buf_reg.idx + reg_offset], val))
else:
# Variable offset - use first register as fallback (TODO: proper indirect)
code.append(v_mov_b32_e32(v[buf_reg.idx], val))
if itemsize == 8:
code.append(v_mov_b32_e32(v[buf_reg.idx], v[val.idx] if isinstance(val, VGPR) else val))
code.append(v_mov_b32_e32(v[buf_reg.idx + 1], v[val.idx + 1] if isinstance(val, VGPR) else 0))
else:
code.append(v_mov_b32_e32(v[buf_reg.idx], val))
else:
# Global memory store: use buffer SGPR pair as saddr
buf_result = r.get(buf_uop) if buf_uop in r else get_reg(buf_uop)
@ -889,6 +964,28 @@ class RDNARenderer(Renderer):
elif u.op is Ops.BARRIER:
code.append(s_barrier())
elif u.op is Ops.WMMA:
maybe_wait(u.src)
# WMMA: wave matrix multiply-accumulate
# arg: (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
# src: (A_vec, B_vec, C_acc)
dtype_in, dtype_out = u.arg[2], u.dtype.scalar()
a_reg, b_reg, c_reg = get_reg(u.src[0]), get_reg(u.src[1]), get_reg(u.src[2])
# Output is 8 floats (or 8 halves) = 8 VGPRs
dst = ra.alloc_vgpr_range(u, 8)
# Select the right WMMA instruction based on input/output types
if dtype_in == dtypes.half and dtype_out == dtypes.float:
code.append(v_wmma_f32_16x16x16_f16(dst, a_reg, b_reg, c_reg))
elif dtype_in == dtypes.bfloat16 and dtype_out == dtypes.float:
code.append(v_wmma_f32_16x16x16_bf16(dst, a_reg, b_reg, c_reg))
elif dtype_in == dtypes.half and dtype_out == dtypes.half:
code.append(v_wmma_f16_16x16x16_f16(dst, a_reg, b_reg, c_reg))
elif dtype_in == dtypes.bfloat16 and dtype_out == dtypes.bfloat16:
code.append(v_wmma_bf16_16x16x16_bf16(dst, a_reg, b_reg, c_reg))
else:
raise NotImplementedError(f"WMMA not implemented for {dtype_in} -> {dtype_out}")
r[u] = dst
elif u.op is Ops.AFTER:
# AFTER ensures previous operations complete, then returns the buffer
# src[0] is the buffer, src[1] is the operation that must complete

View file

@ -240,10 +240,6 @@ rdna_matcher = PatternMatcher([
lambda x, a, b: UOp(Ops.CMPNE, dtypes.bool, (a.cast(dtypes.float32), b.cast(dtypes.float32)))),
# devectorize ALU operations - RDNA doesn't have vector float ALU
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
# SIN: normalize input by 1/(2π) for v_sin_f32 (expects [0,1) -> [0,2π))
(UPat(Ops.SIN, dtype=dtypes.float32, src=(UPat.var("x"),), name="u"),
lambda u, x: None if u.tag == "normalized" else # skip already normalized
UOp(Ops.SIN, dtypes.float32, (x * UOp.const(dtypes.float32, 0.15915494309189535),)).rtag("normalized")),
# Fix fast_idiv output when shift >= 32 (needs 64-bit multiply)
# Pattern: (x * const) >> shift for unsigned
(UPat(Ops.SHR, src=(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c"))), UPat.cvar("shift"))), _fix_fast_idiv_unsigned),
@ -292,13 +288,12 @@ rdna_matcher = PatternMatcher([
(UPat(Ops.SUB, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_sub),
(UPat(Ops.MUL, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_mul),
(UPat(Ops.MAX, dtype=_small_floats, src=(UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_max),
# Unary ops: RECIPROCAL, SQRT, EXP2, LOG2, TRUNC, SIN, NEG
# Unary ops: RECIPROCAL, SQRT, EXP2, LOG2, TRUNC, NEG (SIN uses software impl for precision)
(UPat(Ops.RECIPROCAL, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_reciprocal),
(UPat(Ops.SQRT, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_sqrt),
(UPat(Ops.EXP2, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_exp2),
(UPat(Ops.LOG2, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_log2),
(UPat(Ops.TRUNC, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_trunc),
(UPat(Ops.SIN, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_sin),
(UPat(Ops.NEG, dtype=_small_floats, src=(UPat.var("a"),), name="x"), _lower_f16_neg),
# WHERE for float16
(UPat(Ops.WHERE, dtype=_small_floats, src=(UPat.var("cond"), UPat.var("a"), UPat.var("b")), name="x"), _lower_f16_where),