Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
a5afdcc79b add lds double buffering to amd_asm_matmul 2026-01-17 16:26:07 +09:00

View file

@ -47,8 +47,9 @@ V_GLOBAL_B_ADDR = 178 # global memory B prefetch address
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
V_A_TILE_REGS = [130, 134, 138, 142] # A tile: banks 2,2,2,2 (130%4=2, etc.)
V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,0,0,0,0,0
# Double-buffered: buffer 0 (v130-v161), buffer 1 (v194-v225)
V_A_TILE_REGS = [[130, 134, 138, 142], [194, 198, 202, 206]] # A tile: banks 2,2,2,2 (mod 4 = 2)
V_B_TILE_REGS = [[132, 136, 140, 144, 148, 152, 156, 160], [192, 196, 200, 204, 208, 212, 216, 220]] # B tile: banks 0,0,0,0 (mod 4 = 0)
# =============================================================================
# Named register assignments (SGPRs)
@ -98,10 +99,8 @@ FMAC_PAIR_ORDER = [
(0,4),(0,5),(1,5),(1,4), (2,4),(2,5),(3,5),(3,6), (0,6),(0,7),(1,7),(1,6), (2,6),(2,7),(3,7),(3,0),
]
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
def derive_fmac_pattern(acc_grid, a_tile_regs, b_tile_regs):
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
if a_tile_regs is None: a_tile_regs = V_A_TILE_REGS
if b_tile_regs is None: b_tile_regs = V_B_TILE_REGS
pattern = []
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
@ -119,8 +118,8 @@ def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
a_base+1, b_base, a_base, b_base+1))
return pattern
# Derived: 64 dual FMAC operations
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID)
# Derived: 64 dual FMAC operations for each buffer
FMAC_PATTERN = [derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS[i], V_B_TILE_REGS[i]) for i in range(2)]
def derive_permute_swaps(acc_grid, out_regs):
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
@ -218,7 +217,7 @@ class Kernel:
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 192),
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 222),
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
@ -236,7 +235,7 @@ class Kernel:
f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8',
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
' .vgpr_count: 192', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
' .vgpr_count: 222', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
@ -426,7 +425,7 @@ def build_kernel(arch='gfx1100'):
# MAIN GEMM LOOP
# ===========================================================================
NO_DS, NO_GLOBAL = getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
k.label('LOOP_INC')
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
@ -440,9 +439,10 @@ def build_kernel(arch='gfx1100'):
if not NO_GLOBAL:
# Advance prefetch pointers (VGPR)
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
"""
# Advance prefetch pointers (SGPRs, 64-bit adds)
k.emit(s_clause(simm16=31))
for i in range(8):
@ -451,6 +451,7 @@ def build_kernel(arch='gfx1100'):
for i in range(8):
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_PREFETCH_A+i*2+1], 0))
"""
# do the fetch
for vdst, saddr_lo in INIT_PREFETCH:
@ -463,18 +464,38 @@ def build_kernel(arch='gfx1100'):
k.waitcnt(lgkm=0)
k.emit(s_barrier())
# Load initial tiles for iter=0 into buffer 0
if not NO_DS:
a_tile_regs = V_A_TILE_REGS[0]
b_tile_regs = V_B_TILE_REGS[0]
k.emit(s_clause(simm16=len(a_tile_regs) + len(b_tile_regs) - 1))
for i, vdst in enumerate(a_tile_regs):
a_off = (i & 1) * 8 + (i >> 1) * 64 # iter=0
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
for i, vdst in enumerate(b_tile_regs):
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 # iter=0
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
# 8 inner loop iterations
# Double-buffered inner loop: load next iteration's tiles while computing current
# Buffer 0 used for even iterations, buffer 1 for odd iterations
for iter in range(8):
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
if not NO_DS:
k.emit(s_clause(simm16=len(V_A_TILE_REGS) + len(V_B_TILE_REGS) - 1)) # 12 loads total: 4 A + 8 B
buf = iter & 1 # current compute buffer
next_buf = 1 - buf # next load buffer
# Load tiles for NEXT iteration into next_buf (except on last iteration)
if not NO_DS and iter < 7:
next_iter = iter + 1
a_tile_regs = V_A_TILE_REGS[next_buf]
b_tile_regs = V_B_TILE_REGS[next_buf]
k.emit(s_clause(simm16=len(a_tile_regs) + len(b_tile_regs) - 1)) # 12 loads total: 4 A + 8 B
# A tile: 4 ds_load_b64
for i, vdst in enumerate(V_A_TILE_REGS):
a_off = (i & 1) * 8 + (i >> 1) * 64 + iter * LDS_A_STRIDE
for i, vdst in enumerate(a_tile_regs):
a_off = (i & 1) * 8 + (i >> 1) * 64 + next_iter * LDS_A_STRIDE
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
# B tile: 8 ds_load_b64
for i, vdst in enumerate(V_B_TILE_REGS):
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + iter * LDS_B_STRIDE
for i, vdst in enumerate(b_tile_regs):
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + next_iter * LDS_B_STRIDE
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
# Issue global prefetch (first 6 iterations only)
@ -483,12 +504,18 @@ def build_kernel(arch='gfx1100'):
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
# 64 dual FMACs
k.waitcnt(lgkm=0)
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
# Wait for current buffer's loads to complete
# iter 0-6: 12 loads for next iteration are in flight, wait for the other 12 (lgkm=12)
# iter 7: no loads in flight, wait for all (lgkm=0)
k.waitcnt(lgkm=0 if iter == 7 else 12)
# 64 dual FMACs using current buffer
if not NO_ALU:
fmac_pattern = FMAC_PATTERN[buf]
k.emit(s_clause(simm16=len(fmac_pattern)-1))
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(fmac_pattern):
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
# wait for all global stores to finish
# then sync the warp so it's safe to store local