mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
lds_double
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5afdcc79b |
1 changed files with 52 additions and 25 deletions
|
|
@ -47,8 +47,9 @@ V_GLOBAL_B_ADDR = 178 # global memory B prefetch address
|
|||
|
||||
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
|
||||
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
|
||||
V_A_TILE_REGS = [130, 134, 138, 142] # A tile: banks 2,2,2,2 (130%4=2, etc.)
|
||||
V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,0,0,0,0,0
|
||||
# Double-buffered: buffer 0 (v130-v161), buffer 1 (v194-v225)
|
||||
V_A_TILE_REGS = [[130, 134, 138, 142], [194, 198, 202, 206]] # A tile: banks 2,2,2,2 (mod 4 = 2)
|
||||
V_B_TILE_REGS = [[132, 136, 140, 144, 148, 152, 156, 160], [192, 196, 200, 204, 208, 212, 216, 220]] # B tile: banks 0,0,0,0 (mod 4 = 0)
|
||||
|
||||
# =============================================================================
|
||||
# Named register assignments (SGPRs)
|
||||
|
|
@ -98,10 +99,8 @@ FMAC_PAIR_ORDER = [
|
|||
(0,4),(0,5),(1,5),(1,4), (2,4),(2,5),(3,5),(3,6), (0,6),(0,7),(1,7),(1,6), (2,6),(2,7),(3,7),(3,0),
|
||||
]
|
||||
|
||||
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
||||
def derive_fmac_pattern(acc_grid, a_tile_regs, b_tile_regs):
|
||||
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
|
||||
if a_tile_regs is None: a_tile_regs = V_A_TILE_REGS
|
||||
if b_tile_regs is None: b_tile_regs = V_B_TILE_REGS
|
||||
pattern = []
|
||||
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
|
||||
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
|
||||
|
|
@ -119,8 +118,8 @@ def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
|||
a_base+1, b_base, a_base, b_base+1))
|
||||
return pattern
|
||||
|
||||
# Derived: 64 dual FMAC operations
|
||||
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID)
|
||||
# Derived: 64 dual FMAC operations for each buffer
|
||||
FMAC_PATTERN = [derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS[i], V_B_TILE_REGS[i]) for i in range(2)]
|
||||
|
||||
def derive_permute_swaps(acc_grid, out_regs):
|
||||
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
|
||||
|
|
@ -218,7 +217,7 @@ class Kernel:
|
|||
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
|
||||
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
|
||||
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 192),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 222),
|
||||
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
|
||||
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
|
||||
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
|
||||
|
|
@ -236,7 +235,7 @@ class Kernel:
|
|||
f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8',
|
||||
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
|
||||
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
|
||||
' .vgpr_count: 192', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
' .vgpr_count: 222', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
|
||||
|
||||
|
||||
|
|
@ -426,7 +425,7 @@ def build_kernel(arch='gfx1100'):
|
|||
# MAIN GEMM LOOP
|
||||
# ===========================================================================
|
||||
|
||||
NO_DS, NO_GLOBAL = getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
||||
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
||||
|
||||
k.label('LOOP_INC')
|
||||
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
|
||||
|
|
@ -440,9 +439,10 @@ def build_kernel(arch='gfx1100'):
|
|||
|
||||
if not NO_GLOBAL:
|
||||
# Advance prefetch pointers (VGPR)
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
"""
|
||||
# Advance prefetch pointers (SGPRs, 64-bit adds)
|
||||
k.emit(s_clause(simm16=31))
|
||||
for i in range(8):
|
||||
|
|
@ -451,6 +451,7 @@ def build_kernel(arch='gfx1100'):
|
|||
for i in range(8):
|
||||
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
|
||||
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_PREFETCH_A+i*2+1], 0))
|
||||
"""
|
||||
|
||||
# do the fetch
|
||||
for vdst, saddr_lo in INIT_PREFETCH:
|
||||
|
|
@ -463,18 +464,38 @@ def build_kernel(arch='gfx1100'):
|
|||
k.waitcnt(lgkm=0)
|
||||
k.emit(s_barrier())
|
||||
|
||||
# Load initial tiles for iter=0 into buffer 0
|
||||
if not NO_DS:
|
||||
a_tile_regs = V_A_TILE_REGS[0]
|
||||
b_tile_regs = V_B_TILE_REGS[0]
|
||||
k.emit(s_clause(simm16=len(a_tile_regs) + len(b_tile_regs) - 1))
|
||||
for i, vdst in enumerate(a_tile_regs):
|
||||
a_off = (i & 1) * 8 + (i >> 1) * 64 # iter=0
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
|
||||
for i, vdst in enumerate(b_tile_regs):
|
||||
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 # iter=0
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
|
||||
|
||||
# 8 inner loop iterations
|
||||
# Double-buffered inner loop: load next iteration's tiles while computing current
|
||||
# Buffer 0 used for even iterations, buffer 1 for odd iterations
|
||||
for iter in range(8):
|
||||
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
|
||||
if not NO_DS:
|
||||
k.emit(s_clause(simm16=len(V_A_TILE_REGS) + len(V_B_TILE_REGS) - 1)) # 12 loads total: 4 A + 8 B
|
||||
buf = iter & 1 # current compute buffer
|
||||
next_buf = 1 - buf # next load buffer
|
||||
|
||||
# Load tiles for NEXT iteration into next_buf (except on last iteration)
|
||||
if not NO_DS and iter < 7:
|
||||
next_iter = iter + 1
|
||||
a_tile_regs = V_A_TILE_REGS[next_buf]
|
||||
b_tile_regs = V_B_TILE_REGS[next_buf]
|
||||
k.emit(s_clause(simm16=len(a_tile_regs) + len(b_tile_regs) - 1)) # 12 loads total: 4 A + 8 B
|
||||
# A tile: 4 ds_load_b64
|
||||
for i, vdst in enumerate(V_A_TILE_REGS):
|
||||
a_off = (i & 1) * 8 + (i >> 1) * 64 + iter * LDS_A_STRIDE
|
||||
for i, vdst in enumerate(a_tile_regs):
|
||||
a_off = (i & 1) * 8 + (i >> 1) * 64 + next_iter * LDS_A_STRIDE
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
|
||||
# B tile: 8 ds_load_b64
|
||||
for i, vdst in enumerate(V_B_TILE_REGS):
|
||||
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + iter * LDS_B_STRIDE
|
||||
for i, vdst in enumerate(b_tile_regs):
|
||||
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + next_iter * LDS_B_STRIDE
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
|
||||
|
||||
# Issue global prefetch (first 6 iterations only)
|
||||
|
|
@ -483,12 +504,18 @@ def build_kernel(arch='gfx1100'):
|
|||
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
|
||||
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
|
||||
|
||||
# 64 dual FMACs
|
||||
k.waitcnt(lgkm=0)
|
||||
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
# Wait for current buffer's loads to complete
|
||||
# iter 0-6: 12 loads for next iteration are in flight, wait for the other 12 (lgkm=12)
|
||||
# iter 7: no loads in flight, wait for all (lgkm=0)
|
||||
k.waitcnt(lgkm=0 if iter == 7 else 12)
|
||||
|
||||
# 64 dual FMACs using current buffer
|
||||
if not NO_ALU:
|
||||
fmac_pattern = FMAC_PATTERN[buf]
|
||||
k.emit(s_clause(simm16=len(fmac_pattern)-1))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(fmac_pattern):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
|
||||
# wait for all global stores to finish
|
||||
# then sync the warp so it's safe to store local
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue