This commit is contained in:
George Hotz 2025-12-18 16:49:26 +00:00
commit c6681d63bb
3 changed files with 774 additions and 34 deletions

299
error Normal file
View file

@ -0,0 +1,299 @@
opened device PYTHON from pid:55710
scheduled 11 kernels in 67.30 ms | CACHE MISS e9f86918 | 1438 uops in cache
loading libc from /lib/x86_64-linux-gnu/libc.so.6
loading hsa from /opt/rocm/lib/libhsa-runtime64.so
loading comgr from /opt/rocm/lib/libamd_comgr.so
loading comgr_3 from /opt/rocm/lib/libamd_comgr.so
loading llvm from /lib/x86_64-linux-gnu/libLLVM-21.so
loading libusb from /lib/x86_64-linux-gnu/libusb-1.0.so.0
am 0000:03:00.0: AM_GFX initialized
am 0000:03:00.0: AM_SDMA initialized
am 0000:03:00.0: boot done
AMDDevice: opening 0 with target (11, 0, 0) arch gfx1100
opened device AMD from pid:55710
*** AMD 1 copy 4, AMD <- PYTHON  arg 2 mem 0.00 GB tm 4294.04us/ 4.29ms ( 0 GFLOPS 0|0 GB/s)
*** AMD 2 copy 8, AMD <- PYTHON  arg 2 mem 0.00 GB tm 73.79us/ 4.37ms ( 0 GFLOPS 0|0 GB/s)
()
.text
.global E
.type E,@function
.p2align 8
E:
s_load_b64 s[6:7], s[0:1], 0
s_waitcnt lgkmcnt(0)
s_load_b64 s[8:9], s[0:1], 8
s_waitcnt lgkmcnt(0)
v_mov_b32 v3, 0
global_load_b32 v4, v3, s[8:9]
s_waitcnt vmcnt(0) lgkmcnt(0)
v_add_nc_u32 v3, 1280, v4
v_mov_b32 v4, 0
global_store_b32 v4, v3, s[6:7]
s_waitcnt vmcnt(0) lgkmcnt(0)
s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
s_endpgm
s_code_end
.size E, .-E
.rodata
.global E.kd
.type E.kd,STT_OBJECT
.align 0x10
.amdhsa_kernel E
.amdhsa_group_segment_fixed_size 0
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 16
.amdhsa_next_free_vgpr 5
.amdhsa_reserve_vcc 0
.amdhsa_reserve_xnack_mask 0
.amdhsa_next_free_sgpr 10
.amdhsa_float_round_mode_32 0
.amdhsa_float_round_mode_16_64 0
.amdhsa_float_denorm_mode_32 3
.amdhsa_float_denorm_mode_16_64 3
.amdhsa_dx10_clamp 1
.amdhsa_ieee_mode 1
.amdhsa_fp16_overflow 0
.amdhsa_workgroup_processor_mode 1
.amdhsa_memory_ordered 1
.amdhsa_forward_progress 0
.amdhsa_enable_private_segment 0
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
.amdhsa_system_sgpr_workgroup_info 0
.amdhsa_system_vgpr_workitem_id 2
.amdhsa_exception_fp_ieee_invalid_op 0
.amdhsa_exception_fp_denorm_src 0
.amdhsa_exception_fp_ieee_div_zero 0
.amdhsa_exception_fp_ieee_overflow 0
.amdhsa_exception_fp_ieee_underflow 0
.amdhsa_exception_fp_ieee_inexact 0
.amdhsa_exception_int_div_zero 0
.amdhsa_user_sgpr_dispatch_ptr 0
.amdhsa_user_sgpr_queue_ptr 0
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_dispatch_id 0
.amdhsa_user_sgpr_private_segment_size 0
.amdhsa_wavefront_size32 1
.amdhsa_uses_dynamic_stack 0
.end_amdhsa_kernel
.amdgpu_metadata
amdhsa.kernels:
- .args:
- .address_space: global
.name: buf_0
.offset: 0
.size: 8
.type_name: void*
.value_kind: global_buffer
- .address_space: global
.name: buf_1
.offset: 8
.size: 8
.type_name: void*
.value_kind: global_buffer
.group_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 16
.language: OpenCL C
.language_version:
- 1
- 2
.max_flat_workgroup_size: 256
.name: E
.private_segment_fixed_size: 0
.sgpr_count: 10
.sgpr_spill_count: 0
.symbol: E.kd
.uses_dynamic_stack: false
.vgpr_count: 5
.vgpr_spill_count: 0
.wavefront_size: 32
amdhsa.target: amdgcn-amd-amdhsa--gfx1100
amdhsa.version:
- 1
- 2
.end_amdgpu_metadata
*** AMD 3 E arg 2 mem 0.00 GB tm 3.08us/ 4.37ms ( 0 GFLOPS 0|0 GB/s) ['uniform']
more upcast axis : [(0, 1, 0, 4)]
(Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=32))
.text
.global E_784_32_4
.type E_784_32_4,@function
.p2align 8
E_784_32_4:
s_load_b64 s[6:7], s[0:1], 0
s_waitcnt lgkmcnt(0)
s_load_b64 s[8:9], s[0:1], 8
s_waitcnt lgkmcnt(0)
s_load_b64 s[10:11], s[0:1], 16
s_waitcnt lgkmcnt(0)
v_mov_b32 v3, 392
v_mov_b32 v4, 784
v_mov_b32 v5, 466688986
v_mov_b32 v6, 1065353216
v_mov_b32 v7, 0
global_load_b32 v8, v7, s[8:9]
v_mov_b32 v7, 0
global_load_b32 v9, v7, s[10:11]
v_mov_b32 v7, 4
global_load_b32 v10, v7, s[10:11]
v_mov_b32 v7, s2
v_and_b32 v4, 0x3ff, v0
v_lshlrev_b32 v11, 7, v7
v_lshlrev_b32 v12, 2, v4
v_add_nc_u32 v4, v11, v12
v_add_nc_u32 v13, 0xFFFF3C01, v4
v_mov_b32 v13, v13
v_add_nc_u32 v14, 0xFFFF3C02, v4
v_mov_b32 v14, v14
v_add_nc_u32 v15, 0xFFFF3C03, v4
v_mov_b32 v15, v15
v_add_nc_u32 v16, 0xFFFF3C04, v4
v_mov_b32 v16, v16
v_add_nc_u32 v17, 1, v4
v_mov_b32 v17, v17
v_add_nc_u32 v18, 2, v4
v_mov_b32 v18, v18
v_add_nc_u32 v19, 3, v4
v_mov_b32 v19, v19
v_add_nc_u32 v20, 4, v4
v_mov_b32 v20, v20
s_waitcnt vmcnt(0) lgkmcnt(0)
v_add_nc_u32 v21, v8, v13
v_add_nc_u32 v22, v8, v17
v_add_nc_u32 v23, 0xFFFE77FF, v21
v_add_nc_u32 v24, 0xFFFF3BFF, v21
v_add_nc_u32 v21, 0xFFFE77FF, v22
v_add_nc_u32 v25, 0xFFFF3BFF, v22
v_add_nc_u32 v22, v23, v9
v_add_nc_u32 v23, v24, v10
v_add_nc_u32 v24, v21, v9
v_add_nc_u32 v21, v25, v10
v_add_nc_u32 v25, v22, v23
v_add_nc_u32 v22, v24, v21
v_lshlrev_b32 v24, 13, v23
v_lshrrev_b32 v26, 19, v23
v_add_nc_u32 v23, v24, v26
v_xor_b32 v26, v25, v23
v_add_nc_u32 v23, v25, v26
v_lshlrev_b32 v25, 13, v21
v_lshrrev_b32 v24, 19, v21
v_add_nc_u32 v21, v25, v24
v_xor_b32 v24, v22, v21
v_add_nc_u32 v21, v22, v24
v_lshlrev_b32 v22, 15, v26
v_lshrrev_b32 v25, 17, v26
v_add_nc_u32 v26, v22, v25
v_xor_b32 v25, v23, v26
v_add_nc_u32 v26, v23, v25
v_lshlrev_b32 v23, 15, v24
v_lshrrev_b32 v22, 17, v24
v_add_nc_u32 v24, v23, v22
v_xor_b32 v22, v21, v24
v_add_nc_u32 v24, v21, v22
v_lshlrev_b32 v21, 26, v25
v_lshrrev_b32 v23, 6, v25
v_add_nc_u32 v25, v21, v23
v_xor_b32 v23, v26, v25
v_add_nc_u32 v25, v26, v23
v_lshlrev_b32 v26, 26, v22
v_lshrrev_b32 v21, 6, v22
v_add_nc_u32 v22, v26, v21
v_xor_b32 v21, v24, v22
v_add_nc_u32 v22, v24, v21
v_add_nc_u32 v24, v25, v10
v_add_nc_u32 v26, v22, v10
v_lshlrev_b32 v27, 6, v23
v_lshrrev_b32 v28, 26, v23
v_add_nc_u32 v23, v27, v28
v_xor_b32 v28, v9, v10
v_xor_b32 v27, v25, v23
v_xor_b32 v23, v28, v5
v_add_nc_u32 v28, v27, v23
v_add_nc_u32 v27, 1, v28
v_add_nc_u32 v28, v24, v27
v_lshlrev_b32 v24, 6, v21
v_lshrrev_b32 v5, 26, v21
v_add_nc_u32 v21, v24, v5
v_xor_b32 v5, v22, v21
v_add_nc_u32 v21, v5, v23
v_add_nc_u32 v5, 1, v21
v_add_nc_u32 v21, v26, v5
v_lshlrev_b32 v26, 17, v27
v_lshrrev_b32 v22, 15, v27
v_add_nc_u32 v27, v26, v22
v_xor_b32 v22, v28, v27
v_add_nc_u32 v27, v28, v22
v_lshlrev_b32 v28, 17, v5
v_lshrrev_b32 v26, 15, v5
v_add_nc_u32 v5, v28, v26
v_xor_b32 v26, v21, v5
v_add_nc_u32 v5, v21, v26
v_lshlrev_b32 v21, 29, v22
v_lshrrev_b32 v28, 3, v22
v_add_nc_u32 v22, v21, v28
v_xor_b32 v28, v27, v22
v_add_nc_u32 v22, v27, v28
v_lshlrev_b32 v27, 29, v26
v_lshrrev_b32 v21, 3, v26
v_add_nc_u32 v26, v27, v21
v_xor_b32 v21, v5, v26
v_add_nc_u32 v26, v5, v21
v_lshlrev_b32 v5, 16, v28
v_lshrrev_b32 v27, 16, v28
v_add_nc_u32 v28, v5, v27
v_xor_b32 v27, v22, v28
v_add_nc_u32 v28, v22, v27
v_lshlrev_b32 v22, 16, v21
v_lshrrev_b32 v5, 16, v21
v_add_nc_u32 v21, v22, v5
v_xor_b32 v5, v26, v21
v_add_nc_u32 v21, v26, v5
v_add_nc_u32 v26, v28, v23
v_add_nc_u32 v22, v21, v23
v_lshlrev_b32 v24, 24, v27
v_lshrrev_b32 v25, 8, v27
v_add_nc_u32 v27, v24, v25
v_xor_b32 v25, v28, v27
v_add_nc_u32 v27, v25, v9
v_add_nc_u32 v25, 1, v27
v_add_nc_u32 v27, 1, v25
v_add_nc_u32 v25, v26, v27
v_lshlrev_b32 v26, 24, v5
v_lshrrev_b32 v28, 8, v5
v_add_nc_u32 v5, v26, v28
v_xor_b32 v28, v21, v5
v_add_nc_u32 v5, v28, v9
v_add_nc_u32 v28, 1, v5
v_add_nc_u32 v5, 1, v28
v_add_nc_u32 v28, v22, v5
v_lshlrev_b32 v22, 13, v27
v_lshrrev_b32 v21, 19, v27
v_add_nc_u32 v27, v22, v21
v_xor_b32 v21, v25, v27
v_add_nc_u32 v27, v25, v21
v_lshlrev_b32 v25, 13, v5
v_lshrrev_b32 v22, 19, v5
v_add_nc_u32 v5, v25, v22
v_xor_b32 v22, v28, v5
v_add_nc_u32 v5, v28, v22
v_lshlrev_b32 v28, 15, v21
v_lshrrev_b32 v25, 17, v21
v_add_nc_u32 v21, v28, v25
v_xor_b32 v25, v27, v21
v_add_nc_u32 v21, v27, v25
v_lshlrev_b32 v27, 15, v22
v_lshrrev_b32 v28, 17, v22
v_add_nc_u32 v22, v27, v28
v_xor_b32 v28, v5, v22
v_add_nc_u32 v22, v5, v28
v_lshlrev_b32 v5, 26, v25
v_lshrrev_b32 v27, 6, v25
v_add_nc_u32 v25, v5, v27
v_xor_b32 v27, v21, v25
v_add_nc_u32 v25, v21, v27
v_lshlrev_b32 v21, 26, v28
v_lshrrev_b32 v5, 6, v28
v_add_nc_u32 v

299
grep Normal file
View file

@ -0,0 +1,299 @@
opened device PYTHON from pid:55710
scheduled 11 kernels in 67.30 ms | CACHE MISS e9f86918 | 1438 uops in cache
loading libc from /lib/x86_64-linux-gnu/libc.so.6
loading hsa from /opt/rocm/lib/libhsa-runtime64.so
loading comgr from /opt/rocm/lib/libamd_comgr.so
loading comgr_3 from /opt/rocm/lib/libamd_comgr.so
loading llvm from /lib/x86_64-linux-gnu/libLLVM-21.so
loading libusb from /lib/x86_64-linux-gnu/libusb-1.0.so.0
am 0000:03:00.0: AM_GFX initialized
am 0000:03:00.0: AM_SDMA initialized
am 0000:03:00.0: boot done
AMDDevice: opening 0 with target (11, 0, 0) arch gfx1100
opened device AMD from pid:55710
*** AMD 1 copy 4, AMD <- PYTHON  arg 2 mem 0.00 GB tm 4294.04us/ 4.29ms ( 0 GFLOPS 0|0 GB/s)
*** AMD 2 copy 8, AMD <- PYTHON  arg 2 mem 0.00 GB tm 73.79us/ 4.37ms ( 0 GFLOPS 0|0 GB/s)
()
.text
.global E
.type E,@function
.p2align 8
E:
s_load_b64 s[6:7], s[0:1], 0
s_waitcnt lgkmcnt(0)
s_load_b64 s[8:9], s[0:1], 8
s_waitcnt lgkmcnt(0)
v_mov_b32 v3, 0
global_load_b32 v4, v3, s[8:9]
s_waitcnt vmcnt(0) lgkmcnt(0)
v_add_nc_u32 v3, 1280, v4
v_mov_b32 v4, 0
global_store_b32 v4, v3, s[6:7]
s_waitcnt vmcnt(0) lgkmcnt(0)
s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
s_endpgm
s_code_end
.size E, .-E
.rodata
.global E.kd
.type E.kd,STT_OBJECT
.align 0x10
.amdhsa_kernel E
.amdhsa_group_segment_fixed_size 0
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 16
.amdhsa_next_free_vgpr 5
.amdhsa_reserve_vcc 0
.amdhsa_reserve_xnack_mask 0
.amdhsa_next_free_sgpr 10
.amdhsa_float_round_mode_32 0
.amdhsa_float_round_mode_16_64 0
.amdhsa_float_denorm_mode_32 3
.amdhsa_float_denorm_mode_16_64 3
.amdhsa_dx10_clamp 1
.amdhsa_ieee_mode 1
.amdhsa_fp16_overflow 0
.amdhsa_workgroup_processor_mode 1
.amdhsa_memory_ordered 1
.amdhsa_forward_progress 0
.amdhsa_enable_private_segment 0
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
.amdhsa_system_sgpr_workgroup_info 0
.amdhsa_system_vgpr_workitem_id 2
.amdhsa_exception_fp_ieee_invalid_op 0
.amdhsa_exception_fp_denorm_src 0
.amdhsa_exception_fp_ieee_div_zero 0
.amdhsa_exception_fp_ieee_overflow 0
.amdhsa_exception_fp_ieee_underflow 0
.amdhsa_exception_fp_ieee_inexact 0
.amdhsa_exception_int_div_zero 0
.amdhsa_user_sgpr_dispatch_ptr 0
.amdhsa_user_sgpr_queue_ptr 0
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_dispatch_id 0
.amdhsa_user_sgpr_private_segment_size 0
.amdhsa_wavefront_size32 1
.amdhsa_uses_dynamic_stack 0
.end_amdhsa_kernel
.amdgpu_metadata
amdhsa.kernels:
- .args:
- .address_space: global
.name: buf_0
.offset: 0
.size: 8
.type_name: void*
.value_kind: global_buffer
- .address_space: global
.name: buf_1
.offset: 8
.size: 8
.type_name: void*
.value_kind: global_buffer
.group_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 16
.language: OpenCL C
.language_version:
- 1
- 2
.max_flat_workgroup_size: 256
.name: E
.private_segment_fixed_size: 0
.sgpr_count: 10
.sgpr_spill_count: 0
.symbol: E.kd
.uses_dynamic_stack: false
.vgpr_count: 5
.vgpr_spill_count: 0
.wavefront_size: 32
amdhsa.target: amdgcn-amd-amdhsa--gfx1100
amdhsa.version:
- 1
- 2
.end_amdgpu_metadata
*** AMD 3 E arg 2 mem 0.00 GB tm 3.08us/ 4.37ms ( 0 GFLOPS 0|0 GB/s) ['uniform']
more upcast axis : [(0, 1, 0, 4)]
(Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=32))
.text
.global E_784_32_4
.type E_784_32_4,@function
.p2align 8
E_784_32_4:
s_load_b64 s[6:7], s[0:1], 0
s_waitcnt lgkmcnt(0)
s_load_b64 s[8:9], s[0:1], 8
s_waitcnt lgkmcnt(0)
s_load_b64 s[10:11], s[0:1], 16
s_waitcnt lgkmcnt(0)
v_mov_b32 v3, 392
v_mov_b32 v4, 784
v_mov_b32 v5, 466688986
v_mov_b32 v6, 1065353216
v_mov_b32 v7, 0
global_load_b32 v8, v7, s[8:9]
v_mov_b32 v7, 0
global_load_b32 v9, v7, s[10:11]
v_mov_b32 v7, 4
global_load_b32 v10, v7, s[10:11]
v_mov_b32 v7, s2
v_and_b32 v4, 0x3ff, v0
v_lshlrev_b32 v11, 7, v7
v_lshlrev_b32 v12, 2, v4
v_add_nc_u32 v4, v11, v12
v_add_nc_u32 v13, 0xFFFF3C01, v4
v_mov_b32 v13, v13
v_add_nc_u32 v14, 0xFFFF3C02, v4
v_mov_b32 v14, v14
v_add_nc_u32 v15, 0xFFFF3C03, v4
v_mov_b32 v15, v15
v_add_nc_u32 v16, 0xFFFF3C04, v4
v_mov_b32 v16, v16
v_add_nc_u32 v17, 1, v4
v_mov_b32 v17, v17
v_add_nc_u32 v18, 2, v4
v_mov_b32 v18, v18
v_add_nc_u32 v19, 3, v4
v_mov_b32 v19, v19
v_add_nc_u32 v20, 4, v4
v_mov_b32 v20, v20
s_waitcnt vmcnt(0) lgkmcnt(0)
v_add_nc_u32 v21, v8, v13
v_add_nc_u32 v22, v8, v17
v_add_nc_u32 v23, 0xFFFE77FF, v21
v_add_nc_u32 v24, 0xFFFF3BFF, v21
v_add_nc_u32 v21, 0xFFFE77FF, v22
v_add_nc_u32 v25, 0xFFFF3BFF, v22
v_add_nc_u32 v22, v23, v9
v_add_nc_u32 v23, v24, v10
v_add_nc_u32 v24, v21, v9
v_add_nc_u32 v21, v25, v10
v_add_nc_u32 v25, v22, v23
v_add_nc_u32 v22, v24, v21
v_lshlrev_b32 v24, 13, v23
v_lshrrev_b32 v26, 19, v23
v_add_nc_u32 v23, v24, v26
v_xor_b32 v26, v25, v23
v_add_nc_u32 v23, v25, v26
v_lshlrev_b32 v25, 13, v21
v_lshrrev_b32 v24, 19, v21
v_add_nc_u32 v21, v25, v24
v_xor_b32 v24, v22, v21
v_add_nc_u32 v21, v22, v24
v_lshlrev_b32 v22, 15, v26
v_lshrrev_b32 v25, 17, v26
v_add_nc_u32 v26, v22, v25
v_xor_b32 v25, v23, v26
v_add_nc_u32 v26, v23, v25
v_lshlrev_b32 v23, 15, v24
v_lshrrev_b32 v22, 17, v24
v_add_nc_u32 v24, v23, v22
v_xor_b32 v22, v21, v24
v_add_nc_u32 v24, v21, v22
v_lshlrev_b32 v21, 26, v25
v_lshrrev_b32 v23, 6, v25
v_add_nc_u32 v25, v21, v23
v_xor_b32 v23, v26, v25
v_add_nc_u32 v25, v26, v23
v_lshlrev_b32 v26, 26, v22
v_lshrrev_b32 v21, 6, v22
v_add_nc_u32 v22, v26, v21
v_xor_b32 v21, v24, v22
v_add_nc_u32 v22, v24, v21
v_add_nc_u32 v24, v25, v10
v_add_nc_u32 v26, v22, v10
v_lshlrev_b32 v27, 6, v23
v_lshrrev_b32 v28, 26, v23
v_add_nc_u32 v23, v27, v28
v_xor_b32 v28, v9, v10
v_xor_b32 v27, v25, v23
v_xor_b32 v23, v28, v5
v_add_nc_u32 v28, v27, v23
v_add_nc_u32 v27, 1, v28
v_add_nc_u32 v28, v24, v27
v_lshlrev_b32 v24, 6, v21
v_lshrrev_b32 v5, 26, v21
v_add_nc_u32 v21, v24, v5
v_xor_b32 v5, v22, v21
v_add_nc_u32 v21, v5, v23
v_add_nc_u32 v5, 1, v21
v_add_nc_u32 v21, v26, v5
v_lshlrev_b32 v26, 17, v27
v_lshrrev_b32 v22, 15, v27
v_add_nc_u32 v27, v26, v22
v_xor_b32 v22, v28, v27
v_add_nc_u32 v27, v28, v22
v_lshlrev_b32 v28, 17, v5
v_lshrrev_b32 v26, 15, v5
v_add_nc_u32 v5, v28, v26
v_xor_b32 v26, v21, v5
v_add_nc_u32 v5, v21, v26
v_lshlrev_b32 v21, 29, v22
v_lshrrev_b32 v28, 3, v22
v_add_nc_u32 v22, v21, v28
v_xor_b32 v28, v27, v22
v_add_nc_u32 v22, v27, v28
v_lshlrev_b32 v27, 29, v26
v_lshrrev_b32 v21, 3, v26
v_add_nc_u32 v26, v27, v21
v_xor_b32 v21, v5, v26
v_add_nc_u32 v26, v5, v21
v_lshlrev_b32 v5, 16, v28
v_lshrrev_b32 v27, 16, v28
v_add_nc_u32 v28, v5, v27
v_xor_b32 v27, v22, v28
v_add_nc_u32 v28, v22, v27
v_lshlrev_b32 v22, 16, v21
v_lshrrev_b32 v5, 16, v21
v_add_nc_u32 v21, v22, v5
v_xor_b32 v5, v26, v21
v_add_nc_u32 v21, v26, v5
v_add_nc_u32 v26, v28, v23
v_add_nc_u32 v22, v21, v23
v_lshlrev_b32 v24, 24, v27
v_lshrrev_b32 v25, 8, v27
v_add_nc_u32 v27, v24, v25
v_xor_b32 v25, v28, v27
v_add_nc_u32 v27, v25, v9
v_add_nc_u32 v25, 1, v27
v_add_nc_u32 v27, 1, v25
v_add_nc_u32 v25, v26, v27
v_lshlrev_b32 v26, 24, v5
v_lshrrev_b32 v28, 8, v5
v_add_nc_u32 v5, v26, v28
v_xor_b32 v28, v21, v5
v_add_nc_u32 v5, v28, v9
v_add_nc_u32 v28, 1, v5
v_add_nc_u32 v5, 1, v28
v_add_nc_u32 v28, v22, v5
v_lshlrev_b32 v22, 13, v27
v_lshrrev_b32 v21, 19, v27
v_add_nc_u32 v27, v22, v21
v_xor_b32 v21, v25, v27
v_add_nc_u32 v27, v25, v21
v_lshlrev_b32 v25, 13, v5
v_lshrrev_b32 v22, 19, v5
v_add_nc_u32 v5, v25, v22
v_xor_b32 v22, v28, v5
v_add_nc_u32 v5, v28, v22
v_lshlrev_b32 v28, 15, v21
v_lshrrev_b32 v25, 17, v21
v_add_nc_u32 v21, v28, v25
v_xor_b32 v25, v27, v21
v_add_nc_u32 v21, v27, v25
v_lshlrev_b32 v27, 15, v22
v_lshrrev_b32 v28, 17, v22
v_add_nc_u32 v22, v27, v28
v_xor_b32 v28, v5, v22
v_add_nc_u32 v22, v5, v28
v_lshlrev_b32 v5, 26, v25
v_lshrrev_b32 v27, 6, v25
v_add_nc_u32 v25, v5, v27
v_xor_b32 v27, v21, v25
v_add_nc_u32 v25, v21, v27
v_lshlrev_b32 v21, 26, v28
v_lshrrev_b32 v5, 6, v28
v_add_nc_u32 v

View file

@ -107,6 +107,82 @@ def global_load(dest:str, addr:str, base:str, dt:DType) -> str:
if dt.itemsize == 16: return f"global_load_b128 {dest}, {addr}, {base}"
raise RuntimeError(f"Unsupported load dtype size: {dt.itemsize}")
def vgpr_mov(dest:str, src:str, dt:DType) -> list[str]:
"""Generate v_mov_b32 instructions for moving registers, handling vector types."""
def parse_reg(r:str) -> tuple[int, int]:
if '[' in r:
base = int(r[2:r.index(':')])
end = int(r[r.index(':')+1:-1])
return base, end - base + 1
else:
return int(r[1:]), 1
dest_base, dest_count = parse_reg(dest)
src_base, src_count = parse_reg(src)
vgpr_count = (dt.itemsize + 3) // 4
count = max(dest_count, src_count, vgpr_count)
return [f"v_mov_b32 v{dest_base+i}, v{src_base+i}" for i in range(count)]
def gated_load(ctx, x, idx, alt, gate, buf, index_op) -> list[str]:
"""Generate gated load using v_cndmask for address selection.
Instead of exec masking (which can still fault on invalid addresses in masked lanes),
we use v_cndmask to select a safe address (0) when gate is false, then unconditionally
load, and finally select between loaded value and alt based on gate."""
ctx.scratch_sgpr_used = True
result = []
# Get gate comparison result in vcc_lo, save to scratch_sgpr (vcc_lo may be clobbered)
if ctx.r[gate].startswith('v'):
result.append(f"v_cmp_ne_u32 vcc_lo, {ctx.r[gate]}, 0")
else:
result.append(f"s_and_b32 vcc_lo, exec_lo, {ctx.r[gate]}")
result.append(f"s_mov_b32 s{ctx.gated_sgpr}, vcc_lo") # save mask
# Select address: use computed address if gate is true, else use 0 (safe address)
addr_reg = ctx.r[index_op]
result.append(f"v_cndmask_b32 {addr_reg}, 0, {addr_reg}, vcc_lo")
# Unconditionally load (address is always valid now - 0 when gate is false)
result.append(global_load(ctx.r[x], addr_reg, ctx.r[buf], x.dtype))
# Wait for load to complete, then select result
result.append("s_waitcnt vmcnt(0)")
# v_cndmask_b32 only works on 32-bit registers - handle vector types component-wise
dest_reg = ctx.r[x]
alt_reg = ctx.r[alt]
if '[' in dest_reg:
# Vector register: v[n:m] -> extract base and count, do component-wise cndmask
base = int(dest_reg[2:dest_reg.index(':')])
end = int(dest_reg[dest_reg.index(':')+1:-1])
alt_base = int(alt_reg[2:alt_reg.index(':')]) if '[' in alt_reg else int(alt_reg[1:])
for i in range(end - base + 1):
result.append(f"v_cndmask_b32 v{base+i}, v{alt_base+i}, v{base+i}, s{ctx.gated_sgpr}")
else:
result.append(f"v_cndmask_b32 {dest_reg}, {alt_reg}, {dest_reg}, s{ctx.gated_sgpr}")
return result
def ds_read(dest:str, addr:str, dt:DType) -> str:
"""Generate LDS read instruction based on dtype size."""
if dt.itemsize == 1: return f"ds_read_u8 {dest}, {addr}"
if dt.itemsize == 2: return f"ds_read_u16 {dest}, {addr}"
if dt.itemsize == 4: return f"ds_read_b32 {dest}, {addr}"
if dt.itemsize == 8: return f"ds_read_b64 {dest}, {addr}"
if dt.itemsize == 16: return f"ds_read_b128 {dest}, {addr}"
raise RuntimeError(f"Unsupported LDS read dtype size: {dt.itemsize}")
def ds_write(addr:str, data:str, dt:DType) -> str:
"""Generate LDS write instruction based on dtype size."""
if dt.itemsize == 1: return f"ds_write_b8 {addr}, {data}"
if dt.itemsize == 2: return f"ds_write_b16 {addr}, {data}"
if dt.itemsize == 4: return f"ds_write_b32 {addr}, {data}"
if dt.itemsize == 8: return f"ds_write_b64 {addr}, {data}"
if dt.itemsize == 16: return f"ds_write_b128 {addr}, {data}"
raise RuntimeError(f"Unsupported LDS write dtype size: {dt.itemsize}")
def render_define_var(ctx, x):
"""Render DEFINE_VAR - load from kernarg buffer. Uses SGPR if available, else VGPR via scratch."""
if ctx.r[x].startswith('s'):
return f"s_load_b32 {ctx.r[x]}, s[0:1], {ctx.kernarg_offset[x]}"
# VGPR fallback - use gated_sgpr for the load, then move to VGPR
ctx.scratch_sgpr_used = True
return [f"s_load_b32 s{ctx.gated_sgpr}, s[0:1], {ctx.kernarg_offset[x]}",
f"v_mov_b32 {ctx.r[x]}, s{ctx.gated_sgpr}"]
def render_const_64(ctx, x):
"""Render 64-bit constant as two v_mov_b32 instructions"""
reg = ctx.r[x]
@ -120,6 +196,37 @@ def render_const_64(ctx, x):
hi = (bits >> 32) & 0xFFFFFFFF
return [f"v_mov_b32 v{reg_num}, 0x{lo:08X}", f"v_mov_b32 v{reg_num+1}, 0x{hi:08X}"]
def render_comparison(ctx, x, src0):
"""Render comparison op. If dest is SGPR, use directly. If VGPR (fallback), use vcc_lo + v_cndmask_b32."""
dest = ctx.r[x]
srcs = [ctx.r[v] for v in x.src]
dtype, typename = src0.dtype, ctx.types[src0.dtype]
cmp_instr = ctx.code_for_op[x.op]("vcc_lo" if dest.startswith('v') else dest, *srcs, dtype, typename)
if dest.startswith('v'):
# VGPR fallback: compare to vcc_lo, then convert to 0/1 in VGPR
return [cmp_instr, f"v_cndmask_b32 {dest}, 0, 1, vcc_lo"]
return cmp_instr
def render_if(ctx, x):
"""Render IF with stack-based exec save for proper nesting."""
ctx.scratch_sgpr_used = True
# Push a new SGPR for this IF level
save_sgpr = ctx.if_sgpr_base + len(ctx.if_save_stack)
ctx.if_save_stack.append(save_sgpr)
ctx.max_if_depth = max(ctx.max_if_depth, len(ctx.if_save_stack))
return [
f"s_and_b32 vcc_lo, exec_lo, {ctx.r[x.src[0]]}",
f"s_and_saveexec_b32 s{save_sgpr}, vcc_lo",
f"s_cbranch_execz IF_END_{ctx.uops.index(x)}"]
def render_endif(ctx, x):
"""Render ENDIF by popping the exec save stack."""
# Pop the SGPR for this IF level
save_sgpr = ctx.if_save_stack.pop()
return [
f"IF_END_{ctx.uops.index(x.src[0])}:",
f"s_mov_b32 exec_lo, s{save_sgpr}"]
def render_64bit_mul(ctx, x):
"""Render 64-bit integer multiplication using scratch registers.
For pattern (a * magic_const) used in division-by-multiplication.
@ -263,17 +370,18 @@ string_rewrite = PatternMatcher([
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {1 if x.arg else 0}"),
# 64-bit float constants need two mov instructions
(UPat.cvar("x", dtypes.float64), render_const_64),
# 64-bit integer constants: just use low 32 bits (sufficient for most patterns)
(UPat.cvar("x", dtypes.long), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
# 64-bit integer constants: use render_const_64 for pairs, single mov for scalar
(UPat.cvar("x", dtypes.long), lambda ctx, x: render_const_64(ctx, x) if '[' in ctx.r[x] else f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: render_const_64(ctx, x) if '[' in ctx.r[x] else f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
(UPat.cvar("x"), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, x.dtype)}"),
# special registers
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: ctx.render_special(x)),
# define global
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"s_load_b64 {ctx.r[x]}, s[0:1], {x.arg*8}"),
# comparison ops
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"s_load_b64 {ctx.r[x]}, s[0:1], {ctx.kernarg_offset[x]}"),
# define var - load variable from kernarg buffer
(UPat(Ops.DEFINE_VAR, name="x"), render_define_var),
# comparison ops - uses SGPR if available, falls back to VGPR with vcc_lo + v_cndmask_b32
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)), render_comparison),
# WHERE: if condition is SGPR, use directly; if VGPR, compare to 0 first to get VCC
(UPat(Ops.WHERE, name="x", src=(UPat.var("cond"), UPat.var("true_val"), UPat.var("false_val"))),
lambda ctx, x, cond, true_val, false_val: f"v_cndmask_b32 {ctx.r[x]}, {ctx.r[false_val]}, {ctx.r[true_val]}, {ctx.r[cond]}"
@ -336,25 +444,19 @@ string_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
lambda ctx, idx, var, buf, index_op: global_store(ctx.r[index_op], ctx.r[var], ctx.r[buf], var.dtype)),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx"), UPat.var("gate")), name="index_op"), UPat.var("alt")), allow_any_len=True),
lambda ctx, x, idx, alt, gate, buf, index_op: [
f"v_mov_b32 {ctx.r[x]}, {ctx.r[alt]}",
# If gate is in VGPR, compare to get SGPR mask; if in SGPR, use directly
f"v_cmp_ne_u32 vcc_lo, {ctx.r[gate]}, 0" if ctx.r[gate].startswith('v') else f"s_and_b32 vcc_lo, exec_lo, {ctx.r[gate]}",
f"s_and_saveexec_b32 s{ctx.scratch_sgpr}, vcc_lo",
global_load(ctx.r[x], ctx.r[index_op], ctx.r[buf], buf.dtype.base),
f"s_mov_b32 exec_lo, s{ctx.scratch_sgpr}"]),
lambda ctx, x, idx, alt, gate, buf, index_op: gated_load(ctx, x, idx, alt, gate, buf, index_op)),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx")), name="index_op"),), allow_any_len=True),
lambda ctx, x, idx, buf, index_op: global_load(ctx.r[x], ctx.r[index_op], ctx.r[buf], x.dtype)),
# store / load for local memory (LDS) - DEFINE_LOCAL directly
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_LOCAL), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
lambda ctx, idx, var, index_op: f"ds_write_b32 {ctx.r[index_op]}, {ctx.r[var]}"),
lambda ctx, idx, var, index_op: ds_write(ctx.r[index_op], ctx.r[var], var.dtype)),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_LOCAL), UPat.var("idx")), name="index_op"),), allow_any_len=True),
lambda ctx, x, idx, index_op: f"ds_read_b32 {ctx.r[x]}, {ctx.r[index_op]}"),
lambda ctx, x, idx, index_op: ds_read(ctx.r[x], ctx.r[index_op], x.dtype)),
# store / load for local memory (LDS) - DEFINE_LOCAL wrapped in AFTER
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL),), allow_any_len=True), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
lambda ctx, idx, var, index_op: f"ds_write_b32 {ctx.r[index_op]}, {ctx.r[var]}"),
lambda ctx, idx, var, index_op: ds_write(ctx.r[index_op], ctx.r[var], var.dtype)),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL),), allow_any_len=True), UPat.var("idx")), name="index_op"),), allow_any_len=True),
lambda ctx, x, idx, index_op: f"ds_read_b32 {ctx.r[x]}, {ctx.r[index_op]}"),
lambda ctx, x, idx, index_op: ds_read(ctx.r[x], ctx.r[index_op], x.dtype)),
# simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
@ -368,13 +470,8 @@ string_rewrite = PatternMatcher([
f"s_cbranch_vccnz LOOP_{ctx.r[r][1:]}"]),
(UPat(Ops.DEFINE_LOCAL, name="x"),
lambda ctx, x: []), # local memory is handled differently in RDNA
(UPat(Ops.IF, name="x"), lambda ctx, x: [
f"s_and_b32 vcc_lo, exec_lo, {ctx.r[x.src[0]]}",
f"s_and_saveexec_b32 s{ctx.scratch_sgpr}, vcc_lo",
f"s_cbranch_execz IF_END_{ctx.uops.index(x)}"]),
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: [
f"IF_END_{ctx.uops.index(x.src[0])}:",
f"s_mov_b32 exec_lo, s{ctx.scratch_sgpr}"]),
(UPat(Ops.IF, name="x"), render_if),
(UPat(Ops.ENDIF, name="x"), render_endif),
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: "s_barrier"),
])
@ -612,11 +709,20 @@ class RDNARenderer(Renderer):
def render(self, uops:list[UOp]) -> str:
kernel:list[str] = []
bufs = []
kernarg_offset: dict[UOp, int] = {} # Track kernarg offset for each DEFINE_GLOBAL/DEFINE_VAR
current_kernarg_offset = 0
r: dict[UOp, str|list[str]] = {}
self.r = r
self.kernarg_offset = kernarg_offset
self.uops = uops
self.scratch_sgpr = 100 # scratch register for exec manipulation
# Separate SGPRs for different exec save contexts to avoid collisions
self.gated_sgpr = 100 # for gated_load (atomic save/restore, doesn't nest)
self.if_sgpr_base = 101 # for IF/ENDIF (can nest, uses stack)
self.if_save_stack: list[int] = [] # stack of IF save SGPRs
self.max_if_depth = 0 # track max nesting depth for SGPR count
self.scratch_sgpr_used = False # track if any scratch SGPRs are used
MAX_SGPR = 100 # RDNA3 limit ~106, reserve some for scratch
self.lds_size = 0 # track local memory (LDS) usage
# Scratch VGPR will be allocated after we know how many VGPRs the kernel uses
self.scratch_vgpr = -1 # will be set after register allocation
@ -656,6 +762,9 @@ class RDNARenderer(Renderer):
aliases[u] = u.src[0]
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
aliases[u] = u.src[0]
# GEP on vector types aliases to source - element extraction shares the source's registers
if u.op is Ops.GEP and isinstance(u.src[0].dtype, DType) and u.src[0].dtype.count > 1:
aliases[u] = u.src[0]
# LOAD/STORE/INDEX on REG addrspace alias to the DEFINE_REG
if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and len(u.src) > 0:
if isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
@ -681,6 +790,15 @@ class RDNARenderer(Renderer):
# Extend lifetime to end of loop
last_use[uop] = max(last_use[uop], end_pos)
# Third pass: extend SPECIAL (thread ID) lifetimes to end of kernel
# Thread IDs are used to compute addresses throughout the kernel, including at the very end.
# Their derived values may be used after the SPECIAL itself is last referenced, so we must
# keep SPECIAL registers alive for the entire kernel to prevent register reuse corruption.
max_uop_pos = len(uops) - 1
for u in uops:
if u.op is Ops.SPECIAL:
last_use[u] = max_uop_pos
# === REGISTER ALLOCATOR ===
# Track free registers (available for reuse)
free_vgprs: list[int] = []
@ -759,6 +877,7 @@ class RDNARenderer(Renderer):
reg_index_const_uses: dict[UOp, int] = defaultdict(int)
add_mul_const_uses: dict[UOp, int] = defaultdict(int) # Constants used in ADD/MUL (can use literals)
store_const_uses: set[UOp] = set() # Constants used in STORE (must have VGPR)
vectorize_const_uses: set[UOp] = set() # Constants used in VECTORIZE sources (must have VGPR)
for u in uops:
for src in u.src:
if src.op is Ops.CONST:
@ -779,6 +898,11 @@ class RDNARenderer(Renderer):
if u.op is Ops.STORE and len(u.src) >= 2:
if u.src[1].op is Ops.CONST:
store_const_uses.add(u.src[1])
# Track constants used in VECTORIZE sources - these MUST have VGPRs for v_mov_b32
if u.op is Ops.VECTORIZE:
for src in u.src:
if src.op is Ops.CONST:
vectorize_const_uses.add(src)
# Skip allocation for constants that are ONLY used in literal-allowed contexts
skip_alloc_consts: set[UOp] = set()
for const_uop, reg_uses in reg_index_const_uses.items():
@ -786,7 +910,7 @@ class RDNARenderer(Renderer):
skip_alloc_consts.add(const_uop)
# Also skip constants only used in ADD/MUL (use literals instead)
for const_uop, add_mul_uses in add_mul_const_uses.items():
if add_mul_uses == const_use_count[const_uop] and const_uop not in store_const_uses:
if add_mul_uses == const_use_count[const_uop] and const_uop not in store_const_uses and const_uop not in vectorize_const_uses:
skip_alloc_consts.add(const_uop)
def free_dead_regs(pos: int):
@ -869,14 +993,16 @@ class RDNARenderer(Renderer):
# Only float64 needs pairs - int64/uint64 use special split hi/lo patterns
return dtype == dtypes.float64
def alloc_sgpr(owner: UOp) -> str:
def alloc_sgpr(owner: UOp) -> str|None:
nonlocal next_sgpr, max_sgpr
if free_sgprs:
reg = free_sgprs.pop()
else:
elif next_sgpr < MAX_SGPR:
reg = next_sgpr
next_sgpr += 1
max_sgpr = max(max_sgpr, next_sgpr)
else:
return None # SGPR exhausted, caller should fall back to VGPR
sgpr_owner[reg] = owner
schedule_sgpr_death(reg, owner)
return f"s{reg}"
@ -1211,6 +1337,8 @@ class RDNARenderer(Renderer):
# Register allocation with liveness-based reuse
if u.op is Ops.DEFINE_GLOBAL:
r[u] = alloc_sgpr_pair(u)
kernarg_offset[u] = current_kernarg_offset
current_kernarg_offset += 8 # Pointers are 8 bytes
bufs.append((f"data{u.arg}", u.dtype))
elif u.op is Ops.DEFINE_LOCAL:
# Local memory - DEFINE_LOCAL.dtype contains the LDS size in the ptr size
@ -1219,7 +1347,10 @@ class RDNARenderer(Renderer):
r[u] = u.arg # Store the offset (arg is the offset in bytes)
continue
elif u.op is Ops.DEFINE_VAR:
r[u] = alloc_sgpr(u)
sgpr = alloc_sgpr(u)
kernarg_offset[u] = current_kernarg_offset
current_kernarg_offset += 4 # Variables are 4 bytes (int32)
r[u] = sgpr if sgpr else alloc_vgpr(u) # Fall back to VGPR if SGPRs exhausted
bufs.append((u.arg[0], u.dtype))
elif u.op is Ops.SPECIAL:
r[u] = alloc_vgpr(u)
@ -1280,7 +1411,9 @@ class RDNARenderer(Renderer):
# Extend last_use to include all uses of this UOp
last_use[original_owner] = max(last_use.get(original_owner, -1), last_use.get(u, i))
continue # Skip rendering - already loaded
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
# For 64-bit types used in STORE, need a pair for global_store_b64
needs_pair = needs_vgpr_pair(u.dtype) or (u in store_const_uses and u.dtype.itemsize == 8)
r[u] = alloc_vgpr_pair(u) if needs_pair else alloc_vgpr(u)
const_cache[const_key] = r[u]
elif u.op is Ops.RANGE:
r[u] = alloc_vgpr(u)
@ -1317,7 +1450,8 @@ class RDNARenderer(Renderer):
# Only direct comparison ops go to SGPR; other bool ops stay in VGPR
# because their inputs might be in VGPR (from memory loads)
if u.dtype == dtypes.bool and u.op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}:
r[u] = alloc_sgpr(u) # comparison results go in SGPR
sgpr = alloc_sgpr(u)
r[u] = sgpr if sgpr is not None else alloc_vgpr(u) # fall back to VGPR if SGPRs exhausted
elif isinstance(u.dtype, DType) and u.dtype.count > 1:
# Vector types need contiguous register ranges - size based on total bytes
vgpr_count = (u.dtype.itemsize + 3) // 4 # Round up to 32-bit chunks
@ -1412,9 +1546,10 @@ class RDNARenderer(Renderer):
if deferred_store_addr_vgpr is None:
deferred_store_addr_vgpr = alloc_vgpr(index_op) # Only allocate once
# Compute the byte offset - detect pattern SHL(ADD(base, const), shift) for inline recompute
# Only recompute if idx is marked as RECOMPUTE_AT_STORE - otherwise it has a valid register
if idx.op is Ops.CONST and idx.arg == 0:
kernel.append(f"v_mov_b32 {deferred_store_addr_vgpr}, 0")
elif idx.op is Ops.SHL and idx.src[0].op is Ops.ADD and idx.src[1].op is Ops.CONST:
elif r.get(idx) == "RECOMPUTE_AT_STORE" and idx.op is Ops.SHL and idx.src[0].op is Ops.ADD and idx.src[1].op is Ops.CONST:
# Pattern: SHL(ADD(base, offset), shift) - recompute inline to avoid holding SHL result
add_op = idx.src[0]
shift_val = idx.src[1].arg
@ -1496,7 +1631,7 @@ class RDNARenderer(Renderer):
free_vgprs.append(reg_num)
# Add wait after loads
if u.op is Ops.DEFINE_GLOBAL:
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}:
kernel.append("s_waitcnt lgkmcnt(0)")
# Final waitcnt and end program - always wait to ensure stores complete
@ -1514,4 +1649,11 @@ class RDNARenderer(Renderer):
for m in re.finditer(r'v\[(\d+):(\d+)\]', line):
actual_max_vgpr = max(actual_max_vgpr, int(m.group(2)) + 1)
# If scratch SGPRs were used, update max_sgpr to include them
if self.scratch_sgpr_used:
# gated_sgpr (s100) + IF stack SGPRs (s101, s102, ...)
max_sgpr = max(max_sgpr, self.gated_sgpr + 1)
if self.max_if_depth > 0:
max_sgpr = max(max_sgpr, self.if_sgpr_base + self.max_if_depth)
return self.render_kernel(kernel, name, bufs, actual_max_vgpr, max_sgpr, uops)