mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
tests
This commit is contained in:
parent
3bed227c14
commit
c6681d63bb
3 changed files with 774 additions and 34 deletions
299
error
Normal file
299
error
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
opened device PYTHON from pid:55710
|
||||
scheduled 11 kernels in 67.30 ms | CACHE MISS e9f86918 | 1438 uops in cache
|
||||
loading libc from /lib/x86_64-linux-gnu/libc.so.6
|
||||
loading hsa from /opt/rocm/lib/libhsa-runtime64.so
|
||||
loading comgr from /opt/rocm/lib/libamd_comgr.so
|
||||
loading comgr_3 from /opt/rocm/lib/libamd_comgr.so
|
||||
loading llvm from /lib/x86_64-linux-gnu/libLLVM-21.so
|
||||
loading libusb from /lib/x86_64-linux-gnu/libusb-1.0.so.0
|
||||
am 0000:03:00.0: AM_GFX initialized
|
||||
am 0000:03:00.0: AM_SDMA initialized
|
||||
am 0000:03:00.0: boot done
|
||||
AMDDevice: opening 0 with target (11, 0, 0) arch gfx1100
|
||||
opened device AMD from pid:55710
|
||||
[32m*** AMD 1[0m [33mcopy 4, AMD <- PYTHON [0m arg 2 mem 0.00 GB tm 4294.04us/ 4.29ms ( 0 GFLOPS 0|0 GB/s)
|
||||
[32m*** AMD 2[0m [33mcopy 8, AMD <- PYTHON [0m arg 2 mem 0.00 GB tm 73.79us/ 4.37ms ( 0 GFLOPS 0|0 GB/s)
|
||||
()
|
||||
.text
|
||||
.global E
|
||||
.type E,@function
|
||||
.p2align 8
|
||||
E:
|
||||
s_load_b64 s[6:7], s[0:1], 0
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[8:9], s[0:1], 8
|
||||
s_waitcnt lgkmcnt(0)
|
||||
v_mov_b32 v3, 0
|
||||
global_load_b32 v4, v3, s[8:9]
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
v_add_nc_u32 v3, 1280, v4
|
||||
v_mov_b32 v4, 0
|
||||
global_store_b32 v4, v3, s[6:7]
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
|
||||
s_endpgm
|
||||
s_code_end
|
||||
.size E, .-E
|
||||
|
||||
.rodata
|
||||
.global E.kd
|
||||
.type E.kd,STT_OBJECT
|
||||
.align 0x10
|
||||
.amdhsa_kernel E
|
||||
.amdhsa_group_segment_fixed_size 0
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 16
|
||||
.amdhsa_next_free_vgpr 5
|
||||
.amdhsa_reserve_vcc 0
|
||||
.amdhsa_reserve_xnack_mask 0
|
||||
.amdhsa_next_free_sgpr 10
|
||||
.amdhsa_float_round_mode_32 0
|
||||
.amdhsa_float_round_mode_16_64 0
|
||||
.amdhsa_float_denorm_mode_32 3
|
||||
.amdhsa_float_denorm_mode_16_64 3
|
||||
.amdhsa_dx10_clamp 1
|
||||
.amdhsa_ieee_mode 1
|
||||
.amdhsa_fp16_overflow 0
|
||||
.amdhsa_workgroup_processor_mode 1
|
||||
.amdhsa_memory_ordered 1
|
||||
.amdhsa_forward_progress 0
|
||||
.amdhsa_enable_private_segment 0
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
.amdhsa_system_sgpr_workgroup_info 0
|
||||
.amdhsa_system_vgpr_workitem_id 2
|
||||
.amdhsa_exception_fp_ieee_invalid_op 0
|
||||
.amdhsa_exception_fp_denorm_src 0
|
||||
.amdhsa_exception_fp_ieee_div_zero 0
|
||||
.amdhsa_exception_fp_ieee_overflow 0
|
||||
.amdhsa_exception_fp_ieee_underflow 0
|
||||
.amdhsa_exception_fp_ieee_inexact 0
|
||||
.amdhsa_exception_int_div_zero 0
|
||||
.amdhsa_user_sgpr_dispatch_ptr 0
|
||||
.amdhsa_user_sgpr_queue_ptr 0
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_dispatch_id 0
|
||||
.amdhsa_user_sgpr_private_segment_size 0
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.end_amdhsa_kernel
|
||||
.amdgpu_metadata
|
||||
amdhsa.kernels:
|
||||
- .args:
|
||||
- .address_space: global
|
||||
.name: buf_0
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: void*
|
||||
.value_kind: global_buffer
|
||||
- .address_space: global
|
||||
.name: buf_1
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.type_name: void*
|
||||
.value_kind: global_buffer
|
||||
.group_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 16
|
||||
.language: OpenCL C
|
||||
.language_version:
|
||||
- 1
|
||||
- 2
|
||||
.max_flat_workgroup_size: 256
|
||||
.name: E
|
||||
.private_segment_fixed_size: 0
|
||||
.sgpr_count: 10
|
||||
.sgpr_spill_count: 0
|
||||
.symbol: E.kd
|
||||
.uses_dynamic_stack: false
|
||||
.vgpr_count: 5
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 32
|
||||
amdhsa.target: amdgcn-amd-amdhsa--gfx1100
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 2
|
||||
.end_amdgpu_metadata
|
||||
[32m*** AMD 3[0m E[90m[0m arg 2 mem 0.00 GB tm 3.08us/ 4.37ms ( 0 GFLOPS 0|0 GB/s) ['uniform']
|
||||
more upcast axis : [(0, 1, 0, 4)]
|
||||
(Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=32))
|
||||
.text
|
||||
.global E_784_32_4
|
||||
.type E_784_32_4,@function
|
||||
.p2align 8
|
||||
E_784_32_4:
|
||||
s_load_b64 s[6:7], s[0:1], 0
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[8:9], s[0:1], 8
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[10:11], s[0:1], 16
|
||||
s_waitcnt lgkmcnt(0)
|
||||
v_mov_b32 v3, 392
|
||||
v_mov_b32 v4, 784
|
||||
v_mov_b32 v5, 466688986
|
||||
v_mov_b32 v6, 1065353216
|
||||
v_mov_b32 v7, 0
|
||||
global_load_b32 v8, v7, s[8:9]
|
||||
v_mov_b32 v7, 0
|
||||
global_load_b32 v9, v7, s[10:11]
|
||||
v_mov_b32 v7, 4
|
||||
global_load_b32 v10, v7, s[10:11]
|
||||
v_mov_b32 v7, s2
|
||||
v_and_b32 v4, 0x3ff, v0
|
||||
v_lshlrev_b32 v11, 7, v7
|
||||
v_lshlrev_b32 v12, 2, v4
|
||||
v_add_nc_u32 v4, v11, v12
|
||||
v_add_nc_u32 v13, 0xFFFF3C01, v4
|
||||
v_mov_b32 v13, v13
|
||||
v_add_nc_u32 v14, 0xFFFF3C02, v4
|
||||
v_mov_b32 v14, v14
|
||||
v_add_nc_u32 v15, 0xFFFF3C03, v4
|
||||
v_mov_b32 v15, v15
|
||||
v_add_nc_u32 v16, 0xFFFF3C04, v4
|
||||
v_mov_b32 v16, v16
|
||||
v_add_nc_u32 v17, 1, v4
|
||||
v_mov_b32 v17, v17
|
||||
v_add_nc_u32 v18, 2, v4
|
||||
v_mov_b32 v18, v18
|
||||
v_add_nc_u32 v19, 3, v4
|
||||
v_mov_b32 v19, v19
|
||||
v_add_nc_u32 v20, 4, v4
|
||||
v_mov_b32 v20, v20
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
v_add_nc_u32 v21, v8, v13
|
||||
v_add_nc_u32 v22, v8, v17
|
||||
v_add_nc_u32 v23, 0xFFFE77FF, v21
|
||||
v_add_nc_u32 v24, 0xFFFF3BFF, v21
|
||||
v_add_nc_u32 v21, 0xFFFE77FF, v22
|
||||
v_add_nc_u32 v25, 0xFFFF3BFF, v22
|
||||
v_add_nc_u32 v22, v23, v9
|
||||
v_add_nc_u32 v23, v24, v10
|
||||
v_add_nc_u32 v24, v21, v9
|
||||
v_add_nc_u32 v21, v25, v10
|
||||
v_add_nc_u32 v25, v22, v23
|
||||
v_add_nc_u32 v22, v24, v21
|
||||
v_lshlrev_b32 v24, 13, v23
|
||||
v_lshrrev_b32 v26, 19, v23
|
||||
v_add_nc_u32 v23, v24, v26
|
||||
v_xor_b32 v26, v25, v23
|
||||
v_add_nc_u32 v23, v25, v26
|
||||
v_lshlrev_b32 v25, 13, v21
|
||||
v_lshrrev_b32 v24, 19, v21
|
||||
v_add_nc_u32 v21, v25, v24
|
||||
v_xor_b32 v24, v22, v21
|
||||
v_add_nc_u32 v21, v22, v24
|
||||
v_lshlrev_b32 v22, 15, v26
|
||||
v_lshrrev_b32 v25, 17, v26
|
||||
v_add_nc_u32 v26, v22, v25
|
||||
v_xor_b32 v25, v23, v26
|
||||
v_add_nc_u32 v26, v23, v25
|
||||
v_lshlrev_b32 v23, 15, v24
|
||||
v_lshrrev_b32 v22, 17, v24
|
||||
v_add_nc_u32 v24, v23, v22
|
||||
v_xor_b32 v22, v21, v24
|
||||
v_add_nc_u32 v24, v21, v22
|
||||
v_lshlrev_b32 v21, 26, v25
|
||||
v_lshrrev_b32 v23, 6, v25
|
||||
v_add_nc_u32 v25, v21, v23
|
||||
v_xor_b32 v23, v26, v25
|
||||
v_add_nc_u32 v25, v26, v23
|
||||
v_lshlrev_b32 v26, 26, v22
|
||||
v_lshrrev_b32 v21, 6, v22
|
||||
v_add_nc_u32 v22, v26, v21
|
||||
v_xor_b32 v21, v24, v22
|
||||
v_add_nc_u32 v22, v24, v21
|
||||
v_add_nc_u32 v24, v25, v10
|
||||
v_add_nc_u32 v26, v22, v10
|
||||
v_lshlrev_b32 v27, 6, v23
|
||||
v_lshrrev_b32 v28, 26, v23
|
||||
v_add_nc_u32 v23, v27, v28
|
||||
v_xor_b32 v28, v9, v10
|
||||
v_xor_b32 v27, v25, v23
|
||||
v_xor_b32 v23, v28, v5
|
||||
v_add_nc_u32 v28, v27, v23
|
||||
v_add_nc_u32 v27, 1, v28
|
||||
v_add_nc_u32 v28, v24, v27
|
||||
v_lshlrev_b32 v24, 6, v21
|
||||
v_lshrrev_b32 v5, 26, v21
|
||||
v_add_nc_u32 v21, v24, v5
|
||||
v_xor_b32 v5, v22, v21
|
||||
v_add_nc_u32 v21, v5, v23
|
||||
v_add_nc_u32 v5, 1, v21
|
||||
v_add_nc_u32 v21, v26, v5
|
||||
v_lshlrev_b32 v26, 17, v27
|
||||
v_lshrrev_b32 v22, 15, v27
|
||||
v_add_nc_u32 v27, v26, v22
|
||||
v_xor_b32 v22, v28, v27
|
||||
v_add_nc_u32 v27, v28, v22
|
||||
v_lshlrev_b32 v28, 17, v5
|
||||
v_lshrrev_b32 v26, 15, v5
|
||||
v_add_nc_u32 v5, v28, v26
|
||||
v_xor_b32 v26, v21, v5
|
||||
v_add_nc_u32 v5, v21, v26
|
||||
v_lshlrev_b32 v21, 29, v22
|
||||
v_lshrrev_b32 v28, 3, v22
|
||||
v_add_nc_u32 v22, v21, v28
|
||||
v_xor_b32 v28, v27, v22
|
||||
v_add_nc_u32 v22, v27, v28
|
||||
v_lshlrev_b32 v27, 29, v26
|
||||
v_lshrrev_b32 v21, 3, v26
|
||||
v_add_nc_u32 v26, v27, v21
|
||||
v_xor_b32 v21, v5, v26
|
||||
v_add_nc_u32 v26, v5, v21
|
||||
v_lshlrev_b32 v5, 16, v28
|
||||
v_lshrrev_b32 v27, 16, v28
|
||||
v_add_nc_u32 v28, v5, v27
|
||||
v_xor_b32 v27, v22, v28
|
||||
v_add_nc_u32 v28, v22, v27
|
||||
v_lshlrev_b32 v22, 16, v21
|
||||
v_lshrrev_b32 v5, 16, v21
|
||||
v_add_nc_u32 v21, v22, v5
|
||||
v_xor_b32 v5, v26, v21
|
||||
v_add_nc_u32 v21, v26, v5
|
||||
v_add_nc_u32 v26, v28, v23
|
||||
v_add_nc_u32 v22, v21, v23
|
||||
v_lshlrev_b32 v24, 24, v27
|
||||
v_lshrrev_b32 v25, 8, v27
|
||||
v_add_nc_u32 v27, v24, v25
|
||||
v_xor_b32 v25, v28, v27
|
||||
v_add_nc_u32 v27, v25, v9
|
||||
v_add_nc_u32 v25, 1, v27
|
||||
v_add_nc_u32 v27, 1, v25
|
||||
v_add_nc_u32 v25, v26, v27
|
||||
v_lshlrev_b32 v26, 24, v5
|
||||
v_lshrrev_b32 v28, 8, v5
|
||||
v_add_nc_u32 v5, v26, v28
|
||||
v_xor_b32 v28, v21, v5
|
||||
v_add_nc_u32 v5, v28, v9
|
||||
v_add_nc_u32 v28, 1, v5
|
||||
v_add_nc_u32 v5, 1, v28
|
||||
v_add_nc_u32 v28, v22, v5
|
||||
v_lshlrev_b32 v22, 13, v27
|
||||
v_lshrrev_b32 v21, 19, v27
|
||||
v_add_nc_u32 v27, v22, v21
|
||||
v_xor_b32 v21, v25, v27
|
||||
v_add_nc_u32 v27, v25, v21
|
||||
v_lshlrev_b32 v25, 13, v5
|
||||
v_lshrrev_b32 v22, 19, v5
|
||||
v_add_nc_u32 v5, v25, v22
|
||||
v_xor_b32 v22, v28, v5
|
||||
v_add_nc_u32 v5, v28, v22
|
||||
v_lshlrev_b32 v28, 15, v21
|
||||
v_lshrrev_b32 v25, 17, v21
|
||||
v_add_nc_u32 v21, v28, v25
|
||||
v_xor_b32 v25, v27, v21
|
||||
v_add_nc_u32 v21, v27, v25
|
||||
v_lshlrev_b32 v27, 15, v22
|
||||
v_lshrrev_b32 v28, 17, v22
|
||||
v_add_nc_u32 v22, v27, v28
|
||||
v_xor_b32 v28, v5, v22
|
||||
v_add_nc_u32 v22, v5, v28
|
||||
v_lshlrev_b32 v5, 26, v25
|
||||
v_lshrrev_b32 v27, 6, v25
|
||||
v_add_nc_u32 v25, v5, v27
|
||||
v_xor_b32 v27, v21, v25
|
||||
v_add_nc_u32 v25, v21, v27
|
||||
v_lshlrev_b32 v21, 26, v28
|
||||
v_lshrrev_b32 v5, 6, v28
|
||||
v_add_nc_u32 v
|
||||
299
grep
Normal file
299
grep
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
opened device PYTHON from pid:55710
|
||||
scheduled 11 kernels in 67.30 ms | CACHE MISS e9f86918 | 1438 uops in cache
|
||||
loading libc from /lib/x86_64-linux-gnu/libc.so.6
|
||||
loading hsa from /opt/rocm/lib/libhsa-runtime64.so
|
||||
loading comgr from /opt/rocm/lib/libamd_comgr.so
|
||||
loading comgr_3 from /opt/rocm/lib/libamd_comgr.so
|
||||
loading llvm from /lib/x86_64-linux-gnu/libLLVM-21.so
|
||||
loading libusb from /lib/x86_64-linux-gnu/libusb-1.0.so.0
|
||||
am 0000:03:00.0: AM_GFX initialized
|
||||
am 0000:03:00.0: AM_SDMA initialized
|
||||
am 0000:03:00.0: boot done
|
||||
AMDDevice: opening 0 with target (11, 0, 0) arch gfx1100
|
||||
opened device AMD from pid:55710
|
||||
[32m*** AMD 1[0m [33mcopy 4, AMD <- PYTHON [0m arg 2 mem 0.00 GB tm 4294.04us/ 4.29ms ( 0 GFLOPS 0|0 GB/s)
|
||||
[32m*** AMD 2[0m [33mcopy 8, AMD <- PYTHON [0m arg 2 mem 0.00 GB tm 73.79us/ 4.37ms ( 0 GFLOPS 0|0 GB/s)
|
||||
()
|
||||
.text
|
||||
.global E
|
||||
.type E,@function
|
||||
.p2align 8
|
||||
E:
|
||||
s_load_b64 s[6:7], s[0:1], 0
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[8:9], s[0:1], 8
|
||||
s_waitcnt lgkmcnt(0)
|
||||
v_mov_b32 v3, 0
|
||||
global_load_b32 v4, v3, s[8:9]
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
v_add_nc_u32 v3, 1280, v4
|
||||
v_mov_b32 v4, 0
|
||||
global_store_b32 v4, v3, s[6:7]
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
|
||||
s_endpgm
|
||||
s_code_end
|
||||
.size E, .-E
|
||||
|
||||
.rodata
|
||||
.global E.kd
|
||||
.type E.kd,STT_OBJECT
|
||||
.align 0x10
|
||||
.amdhsa_kernel E
|
||||
.amdhsa_group_segment_fixed_size 0
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 16
|
||||
.amdhsa_next_free_vgpr 5
|
||||
.amdhsa_reserve_vcc 0
|
||||
.amdhsa_reserve_xnack_mask 0
|
||||
.amdhsa_next_free_sgpr 10
|
||||
.amdhsa_float_round_mode_32 0
|
||||
.amdhsa_float_round_mode_16_64 0
|
||||
.amdhsa_float_denorm_mode_32 3
|
||||
.amdhsa_float_denorm_mode_16_64 3
|
||||
.amdhsa_dx10_clamp 1
|
||||
.amdhsa_ieee_mode 1
|
||||
.amdhsa_fp16_overflow 0
|
||||
.amdhsa_workgroup_processor_mode 1
|
||||
.amdhsa_memory_ordered 1
|
||||
.amdhsa_forward_progress 0
|
||||
.amdhsa_enable_private_segment 0
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
.amdhsa_system_sgpr_workgroup_info 0
|
||||
.amdhsa_system_vgpr_workitem_id 2
|
||||
.amdhsa_exception_fp_ieee_invalid_op 0
|
||||
.amdhsa_exception_fp_denorm_src 0
|
||||
.amdhsa_exception_fp_ieee_div_zero 0
|
||||
.amdhsa_exception_fp_ieee_overflow 0
|
||||
.amdhsa_exception_fp_ieee_underflow 0
|
||||
.amdhsa_exception_fp_ieee_inexact 0
|
||||
.amdhsa_exception_int_div_zero 0
|
||||
.amdhsa_user_sgpr_dispatch_ptr 0
|
||||
.amdhsa_user_sgpr_queue_ptr 0
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_dispatch_id 0
|
||||
.amdhsa_user_sgpr_private_segment_size 0
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.end_amdhsa_kernel
|
||||
.amdgpu_metadata
|
||||
amdhsa.kernels:
|
||||
- .args:
|
||||
- .address_space: global
|
||||
.name: buf_0
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: void*
|
||||
.value_kind: global_buffer
|
||||
- .address_space: global
|
||||
.name: buf_1
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.type_name: void*
|
||||
.value_kind: global_buffer
|
||||
.group_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 16
|
||||
.language: OpenCL C
|
||||
.language_version:
|
||||
- 1
|
||||
- 2
|
||||
.max_flat_workgroup_size: 256
|
||||
.name: E
|
||||
.private_segment_fixed_size: 0
|
||||
.sgpr_count: 10
|
||||
.sgpr_spill_count: 0
|
||||
.symbol: E.kd
|
||||
.uses_dynamic_stack: false
|
||||
.vgpr_count: 5
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 32
|
||||
amdhsa.target: amdgcn-amd-amdhsa--gfx1100
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 2
|
||||
.end_amdgpu_metadata
|
||||
[32m*** AMD 3[0m E[90m[0m arg 2 mem 0.00 GB tm 3.08us/ 4.37ms ( 0 GFLOPS 0|0 GB/s) ['uniform']
|
||||
more upcast axis : [(0, 1, 0, 4)]
|
||||
(Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=32))
|
||||
.text
|
||||
.global E_784_32_4
|
||||
.type E_784_32_4,@function
|
||||
.p2align 8
|
||||
E_784_32_4:
|
||||
s_load_b64 s[6:7], s[0:1], 0
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[8:9], s[0:1], 8
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_b64 s[10:11], s[0:1], 16
|
||||
s_waitcnt lgkmcnt(0)
|
||||
v_mov_b32 v3, 392
|
||||
v_mov_b32 v4, 784
|
||||
v_mov_b32 v5, 466688986
|
||||
v_mov_b32 v6, 1065353216
|
||||
v_mov_b32 v7, 0
|
||||
global_load_b32 v8, v7, s[8:9]
|
||||
v_mov_b32 v7, 0
|
||||
global_load_b32 v9, v7, s[10:11]
|
||||
v_mov_b32 v7, 4
|
||||
global_load_b32 v10, v7, s[10:11]
|
||||
v_mov_b32 v7, s2
|
||||
v_and_b32 v4, 0x3ff, v0
|
||||
v_lshlrev_b32 v11, 7, v7
|
||||
v_lshlrev_b32 v12, 2, v4
|
||||
v_add_nc_u32 v4, v11, v12
|
||||
v_add_nc_u32 v13, 0xFFFF3C01, v4
|
||||
v_mov_b32 v13, v13
|
||||
v_add_nc_u32 v14, 0xFFFF3C02, v4
|
||||
v_mov_b32 v14, v14
|
||||
v_add_nc_u32 v15, 0xFFFF3C03, v4
|
||||
v_mov_b32 v15, v15
|
||||
v_add_nc_u32 v16, 0xFFFF3C04, v4
|
||||
v_mov_b32 v16, v16
|
||||
v_add_nc_u32 v17, 1, v4
|
||||
v_mov_b32 v17, v17
|
||||
v_add_nc_u32 v18, 2, v4
|
||||
v_mov_b32 v18, v18
|
||||
v_add_nc_u32 v19, 3, v4
|
||||
v_mov_b32 v19, v19
|
||||
v_add_nc_u32 v20, 4, v4
|
||||
v_mov_b32 v20, v20
|
||||
s_waitcnt vmcnt(0) lgkmcnt(0)
|
||||
v_add_nc_u32 v21, v8, v13
|
||||
v_add_nc_u32 v22, v8, v17
|
||||
v_add_nc_u32 v23, 0xFFFE77FF, v21
|
||||
v_add_nc_u32 v24, 0xFFFF3BFF, v21
|
||||
v_add_nc_u32 v21, 0xFFFE77FF, v22
|
||||
v_add_nc_u32 v25, 0xFFFF3BFF, v22
|
||||
v_add_nc_u32 v22, v23, v9
|
||||
v_add_nc_u32 v23, v24, v10
|
||||
v_add_nc_u32 v24, v21, v9
|
||||
v_add_nc_u32 v21, v25, v10
|
||||
v_add_nc_u32 v25, v22, v23
|
||||
v_add_nc_u32 v22, v24, v21
|
||||
v_lshlrev_b32 v24, 13, v23
|
||||
v_lshrrev_b32 v26, 19, v23
|
||||
v_add_nc_u32 v23, v24, v26
|
||||
v_xor_b32 v26, v25, v23
|
||||
v_add_nc_u32 v23, v25, v26
|
||||
v_lshlrev_b32 v25, 13, v21
|
||||
v_lshrrev_b32 v24, 19, v21
|
||||
v_add_nc_u32 v21, v25, v24
|
||||
v_xor_b32 v24, v22, v21
|
||||
v_add_nc_u32 v21, v22, v24
|
||||
v_lshlrev_b32 v22, 15, v26
|
||||
v_lshrrev_b32 v25, 17, v26
|
||||
v_add_nc_u32 v26, v22, v25
|
||||
v_xor_b32 v25, v23, v26
|
||||
v_add_nc_u32 v26, v23, v25
|
||||
v_lshlrev_b32 v23, 15, v24
|
||||
v_lshrrev_b32 v22, 17, v24
|
||||
v_add_nc_u32 v24, v23, v22
|
||||
v_xor_b32 v22, v21, v24
|
||||
v_add_nc_u32 v24, v21, v22
|
||||
v_lshlrev_b32 v21, 26, v25
|
||||
v_lshrrev_b32 v23, 6, v25
|
||||
v_add_nc_u32 v25, v21, v23
|
||||
v_xor_b32 v23, v26, v25
|
||||
v_add_nc_u32 v25, v26, v23
|
||||
v_lshlrev_b32 v26, 26, v22
|
||||
v_lshrrev_b32 v21, 6, v22
|
||||
v_add_nc_u32 v22, v26, v21
|
||||
v_xor_b32 v21, v24, v22
|
||||
v_add_nc_u32 v22, v24, v21
|
||||
v_add_nc_u32 v24, v25, v10
|
||||
v_add_nc_u32 v26, v22, v10
|
||||
v_lshlrev_b32 v27, 6, v23
|
||||
v_lshrrev_b32 v28, 26, v23
|
||||
v_add_nc_u32 v23, v27, v28
|
||||
v_xor_b32 v28, v9, v10
|
||||
v_xor_b32 v27, v25, v23
|
||||
v_xor_b32 v23, v28, v5
|
||||
v_add_nc_u32 v28, v27, v23
|
||||
v_add_nc_u32 v27, 1, v28
|
||||
v_add_nc_u32 v28, v24, v27
|
||||
v_lshlrev_b32 v24, 6, v21
|
||||
v_lshrrev_b32 v5, 26, v21
|
||||
v_add_nc_u32 v21, v24, v5
|
||||
v_xor_b32 v5, v22, v21
|
||||
v_add_nc_u32 v21, v5, v23
|
||||
v_add_nc_u32 v5, 1, v21
|
||||
v_add_nc_u32 v21, v26, v5
|
||||
v_lshlrev_b32 v26, 17, v27
|
||||
v_lshrrev_b32 v22, 15, v27
|
||||
v_add_nc_u32 v27, v26, v22
|
||||
v_xor_b32 v22, v28, v27
|
||||
v_add_nc_u32 v27, v28, v22
|
||||
v_lshlrev_b32 v28, 17, v5
|
||||
v_lshrrev_b32 v26, 15, v5
|
||||
v_add_nc_u32 v5, v28, v26
|
||||
v_xor_b32 v26, v21, v5
|
||||
v_add_nc_u32 v5, v21, v26
|
||||
v_lshlrev_b32 v21, 29, v22
|
||||
v_lshrrev_b32 v28, 3, v22
|
||||
v_add_nc_u32 v22, v21, v28
|
||||
v_xor_b32 v28, v27, v22
|
||||
v_add_nc_u32 v22, v27, v28
|
||||
v_lshlrev_b32 v27, 29, v26
|
||||
v_lshrrev_b32 v21, 3, v26
|
||||
v_add_nc_u32 v26, v27, v21
|
||||
v_xor_b32 v21, v5, v26
|
||||
v_add_nc_u32 v26, v5, v21
|
||||
v_lshlrev_b32 v5, 16, v28
|
||||
v_lshrrev_b32 v27, 16, v28
|
||||
v_add_nc_u32 v28, v5, v27
|
||||
v_xor_b32 v27, v22, v28
|
||||
v_add_nc_u32 v28, v22, v27
|
||||
v_lshlrev_b32 v22, 16, v21
|
||||
v_lshrrev_b32 v5, 16, v21
|
||||
v_add_nc_u32 v21, v22, v5
|
||||
v_xor_b32 v5, v26, v21
|
||||
v_add_nc_u32 v21, v26, v5
|
||||
v_add_nc_u32 v26, v28, v23
|
||||
v_add_nc_u32 v22, v21, v23
|
||||
v_lshlrev_b32 v24, 24, v27
|
||||
v_lshrrev_b32 v25, 8, v27
|
||||
v_add_nc_u32 v27, v24, v25
|
||||
v_xor_b32 v25, v28, v27
|
||||
v_add_nc_u32 v27, v25, v9
|
||||
v_add_nc_u32 v25, 1, v27
|
||||
v_add_nc_u32 v27, 1, v25
|
||||
v_add_nc_u32 v25, v26, v27
|
||||
v_lshlrev_b32 v26, 24, v5
|
||||
v_lshrrev_b32 v28, 8, v5
|
||||
v_add_nc_u32 v5, v26, v28
|
||||
v_xor_b32 v28, v21, v5
|
||||
v_add_nc_u32 v5, v28, v9
|
||||
v_add_nc_u32 v28, 1, v5
|
||||
v_add_nc_u32 v5, 1, v28
|
||||
v_add_nc_u32 v28, v22, v5
|
||||
v_lshlrev_b32 v22, 13, v27
|
||||
v_lshrrev_b32 v21, 19, v27
|
||||
v_add_nc_u32 v27, v22, v21
|
||||
v_xor_b32 v21, v25, v27
|
||||
v_add_nc_u32 v27, v25, v21
|
||||
v_lshlrev_b32 v25, 13, v5
|
||||
v_lshrrev_b32 v22, 19, v5
|
||||
v_add_nc_u32 v5, v25, v22
|
||||
v_xor_b32 v22, v28, v5
|
||||
v_add_nc_u32 v5, v28, v22
|
||||
v_lshlrev_b32 v28, 15, v21
|
||||
v_lshrrev_b32 v25, 17, v21
|
||||
v_add_nc_u32 v21, v28, v25
|
||||
v_xor_b32 v25, v27, v21
|
||||
v_add_nc_u32 v21, v27, v25
|
||||
v_lshlrev_b32 v27, 15, v22
|
||||
v_lshrrev_b32 v28, 17, v22
|
||||
v_add_nc_u32 v22, v27, v28
|
||||
v_xor_b32 v28, v5, v22
|
||||
v_add_nc_u32 v22, v5, v28
|
||||
v_lshlrev_b32 v5, 26, v25
|
||||
v_lshrrev_b32 v27, 6, v25
|
||||
v_add_nc_u32 v25, v5, v27
|
||||
v_xor_b32 v27, v21, v25
|
||||
v_add_nc_u32 v25, v21, v27
|
||||
v_lshlrev_b32 v21, 26, v28
|
||||
v_lshrrev_b32 v5, 6, v28
|
||||
v_add_nc_u32 v
|
||||
|
|
@ -107,6 +107,82 @@ def global_load(dest:str, addr:str, base:str, dt:DType) -> str:
|
|||
if dt.itemsize == 16: return f"global_load_b128 {dest}, {addr}, {base}"
|
||||
raise RuntimeError(f"Unsupported load dtype size: {dt.itemsize}")
|
||||
|
||||
def vgpr_mov(dest:str, src:str, dt:DType) -> list[str]:
|
||||
"""Generate v_mov_b32 instructions for moving registers, handling vector types."""
|
||||
def parse_reg(r:str) -> tuple[int, int]:
|
||||
if '[' in r:
|
||||
base = int(r[2:r.index(':')])
|
||||
end = int(r[r.index(':')+1:-1])
|
||||
return base, end - base + 1
|
||||
else:
|
||||
return int(r[1:]), 1
|
||||
dest_base, dest_count = parse_reg(dest)
|
||||
src_base, src_count = parse_reg(src)
|
||||
vgpr_count = (dt.itemsize + 3) // 4
|
||||
count = max(dest_count, src_count, vgpr_count)
|
||||
return [f"v_mov_b32 v{dest_base+i}, v{src_base+i}" for i in range(count)]
|
||||
|
||||
def gated_load(ctx, x, idx, alt, gate, buf, index_op) -> list[str]:
|
||||
"""Generate gated load using v_cndmask for address selection.
|
||||
Instead of exec masking (which can still fault on invalid addresses in masked lanes),
|
||||
we use v_cndmask to select a safe address (0) when gate is false, then unconditionally
|
||||
load, and finally select between loaded value and alt based on gate."""
|
||||
ctx.scratch_sgpr_used = True
|
||||
result = []
|
||||
# Get gate comparison result in vcc_lo, save to scratch_sgpr (vcc_lo may be clobbered)
|
||||
if ctx.r[gate].startswith('v'):
|
||||
result.append(f"v_cmp_ne_u32 vcc_lo, {ctx.r[gate]}, 0")
|
||||
else:
|
||||
result.append(f"s_and_b32 vcc_lo, exec_lo, {ctx.r[gate]}")
|
||||
result.append(f"s_mov_b32 s{ctx.gated_sgpr}, vcc_lo") # save mask
|
||||
# Select address: use computed address if gate is true, else use 0 (safe address)
|
||||
addr_reg = ctx.r[index_op]
|
||||
result.append(f"v_cndmask_b32 {addr_reg}, 0, {addr_reg}, vcc_lo")
|
||||
# Unconditionally load (address is always valid now - 0 when gate is false)
|
||||
result.append(global_load(ctx.r[x], addr_reg, ctx.r[buf], x.dtype))
|
||||
# Wait for load to complete, then select result
|
||||
result.append("s_waitcnt vmcnt(0)")
|
||||
# v_cndmask_b32 only works on 32-bit registers - handle vector types component-wise
|
||||
dest_reg = ctx.r[x]
|
||||
alt_reg = ctx.r[alt]
|
||||
if '[' in dest_reg:
|
||||
# Vector register: v[n:m] -> extract base and count, do component-wise cndmask
|
||||
base = int(dest_reg[2:dest_reg.index(':')])
|
||||
end = int(dest_reg[dest_reg.index(':')+1:-1])
|
||||
alt_base = int(alt_reg[2:alt_reg.index(':')]) if '[' in alt_reg else int(alt_reg[1:])
|
||||
for i in range(end - base + 1):
|
||||
result.append(f"v_cndmask_b32 v{base+i}, v{alt_base+i}, v{base+i}, s{ctx.gated_sgpr}")
|
||||
else:
|
||||
result.append(f"v_cndmask_b32 {dest_reg}, {alt_reg}, {dest_reg}, s{ctx.gated_sgpr}")
|
||||
return result
|
||||
|
||||
def ds_read(dest:str, addr:str, dt:DType) -> str:
|
||||
"""Generate LDS read instruction based on dtype size."""
|
||||
if dt.itemsize == 1: return f"ds_read_u8 {dest}, {addr}"
|
||||
if dt.itemsize == 2: return f"ds_read_u16 {dest}, {addr}"
|
||||
if dt.itemsize == 4: return f"ds_read_b32 {dest}, {addr}"
|
||||
if dt.itemsize == 8: return f"ds_read_b64 {dest}, {addr}"
|
||||
if dt.itemsize == 16: return f"ds_read_b128 {dest}, {addr}"
|
||||
raise RuntimeError(f"Unsupported LDS read dtype size: {dt.itemsize}")
|
||||
|
||||
def ds_write(addr:str, data:str, dt:DType) -> str:
|
||||
"""Generate LDS write instruction based on dtype size."""
|
||||
if dt.itemsize == 1: return f"ds_write_b8 {addr}, {data}"
|
||||
if dt.itemsize == 2: return f"ds_write_b16 {addr}, {data}"
|
||||
if dt.itemsize == 4: return f"ds_write_b32 {addr}, {data}"
|
||||
if dt.itemsize == 8: return f"ds_write_b64 {addr}, {data}"
|
||||
if dt.itemsize == 16: return f"ds_write_b128 {addr}, {data}"
|
||||
raise RuntimeError(f"Unsupported LDS write dtype size: {dt.itemsize}")
|
||||
|
||||
def render_define_var(ctx, x):
|
||||
"""Render DEFINE_VAR - load from kernarg buffer. Uses SGPR if available, else VGPR via scratch."""
|
||||
if ctx.r[x].startswith('s'):
|
||||
return f"s_load_b32 {ctx.r[x]}, s[0:1], {ctx.kernarg_offset[x]}"
|
||||
# VGPR fallback - use gated_sgpr for the load, then move to VGPR
|
||||
ctx.scratch_sgpr_used = True
|
||||
return [f"s_load_b32 s{ctx.gated_sgpr}, s[0:1], {ctx.kernarg_offset[x]}",
|
||||
f"v_mov_b32 {ctx.r[x]}, s{ctx.gated_sgpr}"]
|
||||
|
||||
def render_const_64(ctx, x):
|
||||
"""Render 64-bit constant as two v_mov_b32 instructions"""
|
||||
reg = ctx.r[x]
|
||||
|
|
@ -120,6 +196,37 @@ def render_const_64(ctx, x):
|
|||
hi = (bits >> 32) & 0xFFFFFFFF
|
||||
return [f"v_mov_b32 v{reg_num}, 0x{lo:08X}", f"v_mov_b32 v{reg_num+1}, 0x{hi:08X}"]
|
||||
|
||||
def render_comparison(ctx, x, src0):
|
||||
"""Render comparison op. If dest is SGPR, use directly. If VGPR (fallback), use vcc_lo + v_cndmask_b32."""
|
||||
dest = ctx.r[x]
|
||||
srcs = [ctx.r[v] for v in x.src]
|
||||
dtype, typename = src0.dtype, ctx.types[src0.dtype]
|
||||
cmp_instr = ctx.code_for_op[x.op]("vcc_lo" if dest.startswith('v') else dest, *srcs, dtype, typename)
|
||||
if dest.startswith('v'):
|
||||
# VGPR fallback: compare to vcc_lo, then convert to 0/1 in VGPR
|
||||
return [cmp_instr, f"v_cndmask_b32 {dest}, 0, 1, vcc_lo"]
|
||||
return cmp_instr
|
||||
|
||||
def render_if(ctx, x):
|
||||
"""Render IF with stack-based exec save for proper nesting."""
|
||||
ctx.scratch_sgpr_used = True
|
||||
# Push a new SGPR for this IF level
|
||||
save_sgpr = ctx.if_sgpr_base + len(ctx.if_save_stack)
|
||||
ctx.if_save_stack.append(save_sgpr)
|
||||
ctx.max_if_depth = max(ctx.max_if_depth, len(ctx.if_save_stack))
|
||||
return [
|
||||
f"s_and_b32 vcc_lo, exec_lo, {ctx.r[x.src[0]]}",
|
||||
f"s_and_saveexec_b32 s{save_sgpr}, vcc_lo",
|
||||
f"s_cbranch_execz IF_END_{ctx.uops.index(x)}"]
|
||||
|
||||
def render_endif(ctx, x):
|
||||
"""Render ENDIF by popping the exec save stack."""
|
||||
# Pop the SGPR for this IF level
|
||||
save_sgpr = ctx.if_save_stack.pop()
|
||||
return [
|
||||
f"IF_END_{ctx.uops.index(x.src[0])}:",
|
||||
f"s_mov_b32 exec_lo, s{save_sgpr}"]
|
||||
|
||||
def render_64bit_mul(ctx, x):
|
||||
"""Render 64-bit integer multiplication using scratch registers.
|
||||
For pattern (a * magic_const) used in division-by-multiplication.
|
||||
|
|
@ -263,17 +370,18 @@ string_rewrite = PatternMatcher([
|
|||
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {1 if x.arg else 0}"),
|
||||
# 64-bit float constants need two mov instructions
|
||||
(UPat.cvar("x", dtypes.float64), render_const_64),
|
||||
# 64-bit integer constants: just use low 32 bits (sufficient for most patterns)
|
||||
(UPat.cvar("x", dtypes.long), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
|
||||
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
|
||||
# 64-bit integer constants: use render_const_64 for pairs, single mov for scalar
|
||||
(UPat.cvar("x", dtypes.long), lambda ctx, x: render_const_64(ctx, x) if '[' in ctx.r[x] else f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.int32)}"),
|
||||
(UPat.cvar("x", dtypes.ulong), lambda ctx, x: render_const_64(ctx, x) if '[' in ctx.r[x] else f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, dtypes.uint32)}"),
|
||||
(UPat.cvar("x"), lambda ctx, x: f"v_mov_b32 {ctx.r[x]}, {render_val(x.arg, x.dtype)}"),
|
||||
# special registers
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: ctx.render_special(x)),
|
||||
# define global
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"s_load_b64 {ctx.r[x]}, s[0:1], {x.arg*8}"),
|
||||
# comparison ops
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
|
||||
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"s_load_b64 {ctx.r[x]}, s[0:1], {ctx.kernarg_offset[x]}"),
|
||||
# define var - load variable from kernarg buffer
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), render_define_var),
|
||||
# comparison ops - uses SGPR if available, falls back to VGPR with vcc_lo + v_cndmask_b32
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)), render_comparison),
|
||||
# WHERE: if condition is SGPR, use directly; if VGPR, compare to 0 first to get VCC
|
||||
(UPat(Ops.WHERE, name="x", src=(UPat.var("cond"), UPat.var("true_val"), UPat.var("false_val"))),
|
||||
lambda ctx, x, cond, true_val, false_val: f"v_cndmask_b32 {ctx.r[x]}, {ctx.r[false_val]}, {ctx.r[true_val]}, {ctx.r[cond]}"
|
||||
|
|
@ -336,25 +444,19 @@ string_rewrite = PatternMatcher([
|
|||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
|
||||
lambda ctx, idx, var, buf, index_op: global_store(ctx.r[index_op], ctx.r[var], ctx.r[buf], var.dtype)),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx"), UPat.var("gate")), name="index_op"), UPat.var("alt")), allow_any_len=True),
|
||||
lambda ctx, x, idx, alt, gate, buf, index_op: [
|
||||
f"v_mov_b32 {ctx.r[x]}, {ctx.r[alt]}",
|
||||
# If gate is in VGPR, compare to get SGPR mask; if in SGPR, use directly
|
||||
f"v_cmp_ne_u32 vcc_lo, {ctx.r[gate]}, 0" if ctx.r[gate].startswith('v') else f"s_and_b32 vcc_lo, exec_lo, {ctx.r[gate]}",
|
||||
f"s_and_saveexec_b32 s{ctx.scratch_sgpr}, vcc_lo",
|
||||
global_load(ctx.r[x], ctx.r[index_op], ctx.r[buf], buf.dtype.base),
|
||||
f"s_mov_b32 exec_lo, s{ctx.scratch_sgpr}"]),
|
||||
lambda ctx, x, idx, alt, gate, buf, index_op: gated_load(ctx, x, idx, alt, gate, buf, index_op)),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_GLOBAL, name="buf"), UPat.var("idx")), name="index_op"),), allow_any_len=True),
|
||||
lambda ctx, x, idx, buf, index_op: global_load(ctx.r[x], ctx.r[index_op], ctx.r[buf], x.dtype)),
|
||||
# store / load for local memory (LDS) - DEFINE_LOCAL directly
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_LOCAL), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
|
||||
lambda ctx, idx, var, index_op: f"ds_write_b32 {ctx.r[index_op]}, {ctx.r[var]}"),
|
||||
lambda ctx, idx, var, index_op: ds_write(ctx.r[index_op], ctx.r[var], var.dtype)),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_LOCAL), UPat.var("idx")), name="index_op"),), allow_any_len=True),
|
||||
lambda ctx, x, idx, index_op: f"ds_read_b32 {ctx.r[x]}, {ctx.r[index_op]}"),
|
||||
lambda ctx, x, idx, index_op: ds_read(ctx.r[x], ctx.r[index_op], x.dtype)),
|
||||
# store / load for local memory (LDS) - DEFINE_LOCAL wrapped in AFTER
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL),), allow_any_len=True), UPat.var("idx")), name="index_op", allow_any_len=True), UPat.var("var"))),
|
||||
lambda ctx, idx, var, index_op: f"ds_write_b32 {ctx.r[index_op]}, {ctx.r[var]}"),
|
||||
lambda ctx, idx, var, index_op: ds_write(ctx.r[index_op], ctx.r[var], var.dtype)),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat(Ops.AFTER, src=(UPat(Ops.DEFINE_LOCAL),), allow_any_len=True), UPat.var("idx")), name="index_op"),), allow_any_len=True),
|
||||
lambda ctx, x, idx, index_op: f"ds_read_b32 {ctx.r[x]}, {ctx.r[index_op]}"),
|
||||
lambda ctx, x, idx, index_op: ds_read(ctx.r[x], ctx.r[index_op], x.dtype)),
|
||||
# simple
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
|
||||
|
|
@ -368,13 +470,8 @@ string_rewrite = PatternMatcher([
|
|||
f"s_cbranch_vccnz LOOP_{ctx.r[r][1:]}"]),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
||||
lambda ctx, x: []), # local memory is handled differently in RDNA
|
||||
(UPat(Ops.IF, name="x"), lambda ctx, x: [
|
||||
f"s_and_b32 vcc_lo, exec_lo, {ctx.r[x.src[0]]}",
|
||||
f"s_and_saveexec_b32 s{ctx.scratch_sgpr}, vcc_lo",
|
||||
f"s_cbranch_execz IF_END_{ctx.uops.index(x)}"]),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: [
|
||||
f"IF_END_{ctx.uops.index(x.src[0])}:",
|
||||
f"s_mov_b32 exec_lo, s{ctx.scratch_sgpr}"]),
|
||||
(UPat(Ops.IF, name="x"), render_if),
|
||||
(UPat(Ops.ENDIF, name="x"), render_endif),
|
||||
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: "s_barrier"),
|
||||
])
|
||||
|
||||
|
|
@ -612,11 +709,20 @@ class RDNARenderer(Renderer):
|
|||
def render(self, uops:list[UOp]) -> str:
|
||||
kernel:list[str] = []
|
||||
bufs = []
|
||||
kernarg_offset: dict[UOp, int] = {} # Track kernarg offset for each DEFINE_GLOBAL/DEFINE_VAR
|
||||
current_kernarg_offset = 0
|
||||
|
||||
r: dict[UOp, str|list[str]] = {}
|
||||
self.r = r
|
||||
self.kernarg_offset = kernarg_offset
|
||||
self.uops = uops
|
||||
self.scratch_sgpr = 100 # scratch register for exec manipulation
|
||||
# Separate SGPRs for different exec save contexts to avoid collisions
|
||||
self.gated_sgpr = 100 # for gated_load (atomic save/restore, doesn't nest)
|
||||
self.if_sgpr_base = 101 # for IF/ENDIF (can nest, uses stack)
|
||||
self.if_save_stack: list[int] = [] # stack of IF save SGPRs
|
||||
self.max_if_depth = 0 # track max nesting depth for SGPR count
|
||||
self.scratch_sgpr_used = False # track if any scratch SGPRs are used
|
||||
MAX_SGPR = 100 # RDNA3 limit ~106, reserve some for scratch
|
||||
self.lds_size = 0 # track local memory (LDS) usage
|
||||
# Scratch VGPR will be allocated after we know how many VGPRs the kernel uses
|
||||
self.scratch_vgpr = -1 # will be set after register allocation
|
||||
|
|
@ -656,6 +762,9 @@ class RDNARenderer(Renderer):
|
|||
aliases[u] = u.src[0]
|
||||
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
||||
aliases[u] = u.src[0]
|
||||
# GEP on vector types aliases to source - element extraction shares the source's registers
|
||||
if u.op is Ops.GEP and isinstance(u.src[0].dtype, DType) and u.src[0].dtype.count > 1:
|
||||
aliases[u] = u.src[0]
|
||||
# LOAD/STORE/INDEX on REG addrspace alias to the DEFINE_REG
|
||||
if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and len(u.src) > 0:
|
||||
if isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
|
||||
|
|
@ -681,6 +790,15 @@ class RDNARenderer(Renderer):
|
|||
# Extend lifetime to end of loop
|
||||
last_use[uop] = max(last_use[uop], end_pos)
|
||||
|
||||
# Third pass: extend SPECIAL (thread ID) lifetimes to end of kernel
|
||||
# Thread IDs are used to compute addresses throughout the kernel, including at the very end.
|
||||
# Their derived values may be used after the SPECIAL itself is last referenced, so we must
|
||||
# keep SPECIAL registers alive for the entire kernel to prevent register reuse corruption.
|
||||
max_uop_pos = len(uops) - 1
|
||||
for u in uops:
|
||||
if u.op is Ops.SPECIAL:
|
||||
last_use[u] = max_uop_pos
|
||||
|
||||
# === REGISTER ALLOCATOR ===
|
||||
# Track free registers (available for reuse)
|
||||
free_vgprs: list[int] = []
|
||||
|
|
@ -759,6 +877,7 @@ class RDNARenderer(Renderer):
|
|||
reg_index_const_uses: dict[UOp, int] = defaultdict(int)
|
||||
add_mul_const_uses: dict[UOp, int] = defaultdict(int) # Constants used in ADD/MUL (can use literals)
|
||||
store_const_uses: set[UOp] = set() # Constants used in STORE (must have VGPR)
|
||||
vectorize_const_uses: set[UOp] = set() # Constants used in VECTORIZE sources (must have VGPR)
|
||||
for u in uops:
|
||||
for src in u.src:
|
||||
if src.op is Ops.CONST:
|
||||
|
|
@ -779,6 +898,11 @@ class RDNARenderer(Renderer):
|
|||
if u.op is Ops.STORE and len(u.src) >= 2:
|
||||
if u.src[1].op is Ops.CONST:
|
||||
store_const_uses.add(u.src[1])
|
||||
# Track constants used in VECTORIZE sources - these MUST have VGPRs for v_mov_b32
|
||||
if u.op is Ops.VECTORIZE:
|
||||
for src in u.src:
|
||||
if src.op is Ops.CONST:
|
||||
vectorize_const_uses.add(src)
|
||||
# Skip allocation for constants that are ONLY used in literal-allowed contexts
|
||||
skip_alloc_consts: set[UOp] = set()
|
||||
for const_uop, reg_uses in reg_index_const_uses.items():
|
||||
|
|
@ -786,7 +910,7 @@ class RDNARenderer(Renderer):
|
|||
skip_alloc_consts.add(const_uop)
|
||||
# Also skip constants only used in ADD/MUL (use literals instead)
|
||||
for const_uop, add_mul_uses in add_mul_const_uses.items():
|
||||
if add_mul_uses == const_use_count[const_uop] and const_uop not in store_const_uses:
|
||||
if add_mul_uses == const_use_count[const_uop] and const_uop not in store_const_uses and const_uop not in vectorize_const_uses:
|
||||
skip_alloc_consts.add(const_uop)
|
||||
|
||||
def free_dead_regs(pos: int):
|
||||
|
|
@ -869,14 +993,16 @@ class RDNARenderer(Renderer):
|
|||
# Only float64 needs pairs - int64/uint64 use special split hi/lo patterns
|
||||
return dtype == dtypes.float64
|
||||
|
||||
def alloc_sgpr(owner: UOp) -> str:
|
||||
def alloc_sgpr(owner: UOp) -> str|None:
|
||||
nonlocal next_sgpr, max_sgpr
|
||||
if free_sgprs:
|
||||
reg = free_sgprs.pop()
|
||||
else:
|
||||
elif next_sgpr < MAX_SGPR:
|
||||
reg = next_sgpr
|
||||
next_sgpr += 1
|
||||
max_sgpr = max(max_sgpr, next_sgpr)
|
||||
else:
|
||||
return None # SGPR exhausted, caller should fall back to VGPR
|
||||
sgpr_owner[reg] = owner
|
||||
schedule_sgpr_death(reg, owner)
|
||||
return f"s{reg}"
|
||||
|
|
@ -1211,6 +1337,8 @@ class RDNARenderer(Renderer):
|
|||
# Register allocation with liveness-based reuse
|
||||
if u.op is Ops.DEFINE_GLOBAL:
|
||||
r[u] = alloc_sgpr_pair(u)
|
||||
kernarg_offset[u] = current_kernarg_offset
|
||||
current_kernarg_offset += 8 # Pointers are 8 bytes
|
||||
bufs.append((f"data{u.arg}", u.dtype))
|
||||
elif u.op is Ops.DEFINE_LOCAL:
|
||||
# Local memory - DEFINE_LOCAL.dtype contains the LDS size in the ptr size
|
||||
|
|
@ -1219,7 +1347,10 @@ class RDNARenderer(Renderer):
|
|||
r[u] = u.arg # Store the offset (arg is the offset in bytes)
|
||||
continue
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
r[u] = alloc_sgpr(u)
|
||||
sgpr = alloc_sgpr(u)
|
||||
kernarg_offset[u] = current_kernarg_offset
|
||||
current_kernarg_offset += 4 # Variables are 4 bytes (int32)
|
||||
r[u] = sgpr if sgpr else alloc_vgpr(u) # Fall back to VGPR if SGPRs exhausted
|
||||
bufs.append((u.arg[0], u.dtype))
|
||||
elif u.op is Ops.SPECIAL:
|
||||
r[u] = alloc_vgpr(u)
|
||||
|
|
@ -1280,7 +1411,9 @@ class RDNARenderer(Renderer):
|
|||
# Extend last_use to include all uses of this UOp
|
||||
last_use[original_owner] = max(last_use.get(original_owner, -1), last_use.get(u, i))
|
||||
continue # Skip rendering - already loaded
|
||||
r[u] = alloc_vgpr_pair(u) if needs_vgpr_pair(u.dtype) else alloc_vgpr(u)
|
||||
# For 64-bit types used in STORE, need a pair for global_store_b64
|
||||
needs_pair = needs_vgpr_pair(u.dtype) or (u in store_const_uses and u.dtype.itemsize == 8)
|
||||
r[u] = alloc_vgpr_pair(u) if needs_pair else alloc_vgpr(u)
|
||||
const_cache[const_key] = r[u]
|
||||
elif u.op is Ops.RANGE:
|
||||
r[u] = alloc_vgpr(u)
|
||||
|
|
@ -1317,7 +1450,8 @@ class RDNARenderer(Renderer):
|
|||
# Only direct comparison ops go to SGPR; other bool ops stay in VGPR
|
||||
# because their inputs might be in VGPR (from memory loads)
|
||||
if u.dtype == dtypes.bool and u.op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}:
|
||||
r[u] = alloc_sgpr(u) # comparison results go in SGPR
|
||||
sgpr = alloc_sgpr(u)
|
||||
r[u] = sgpr if sgpr is not None else alloc_vgpr(u) # fall back to VGPR if SGPRs exhausted
|
||||
elif isinstance(u.dtype, DType) and u.dtype.count > 1:
|
||||
# Vector types need contiguous register ranges - size based on total bytes
|
||||
vgpr_count = (u.dtype.itemsize + 3) // 4 # Round up to 32-bit chunks
|
||||
|
|
@ -1412,9 +1546,10 @@ class RDNARenderer(Renderer):
|
|||
if deferred_store_addr_vgpr is None:
|
||||
deferred_store_addr_vgpr = alloc_vgpr(index_op) # Only allocate once
|
||||
# Compute the byte offset - detect pattern SHL(ADD(base, const), shift) for inline recompute
|
||||
# Only recompute if idx is marked as RECOMPUTE_AT_STORE - otherwise it has a valid register
|
||||
if idx.op is Ops.CONST and idx.arg == 0:
|
||||
kernel.append(f"v_mov_b32 {deferred_store_addr_vgpr}, 0")
|
||||
elif idx.op is Ops.SHL and idx.src[0].op is Ops.ADD and idx.src[1].op is Ops.CONST:
|
||||
elif r.get(idx) == "RECOMPUTE_AT_STORE" and idx.op is Ops.SHL and idx.src[0].op is Ops.ADD and idx.src[1].op is Ops.CONST:
|
||||
# Pattern: SHL(ADD(base, offset), shift) - recompute inline to avoid holding SHL result
|
||||
add_op = idx.src[0]
|
||||
shift_val = idx.src[1].arg
|
||||
|
|
@ -1496,7 +1631,7 @@ class RDNARenderer(Renderer):
|
|||
free_vgprs.append(reg_num)
|
||||
|
||||
# Add wait after loads
|
||||
if u.op is Ops.DEFINE_GLOBAL:
|
||||
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}:
|
||||
kernel.append("s_waitcnt lgkmcnt(0)")
|
||||
|
||||
# Final waitcnt and end program - always wait to ensure stores complete
|
||||
|
|
@ -1514,4 +1649,11 @@ class RDNARenderer(Renderer):
|
|||
for m in re.finditer(r'v\[(\d+):(\d+)\]', line):
|
||||
actual_max_vgpr = max(actual_max_vgpr, int(m.group(2)) + 1)
|
||||
|
||||
# If scratch SGPRs were used, update max_sgpr to include them
|
||||
if self.scratch_sgpr_used:
|
||||
# gated_sgpr (s100) + IF stack SGPRs (s101, s102, ...)
|
||||
max_sgpr = max(max_sgpr, self.gated_sgpr + 1)
|
||||
if self.max_if_depth > 0:
|
||||
max_sgpr = max(max_sgpr, self.if_sgpr_base + self.max_if_depth)
|
||||
|
||||
return self.render_kernel(kernel, name, bufs, actual_max_vgpr, max_sgpr, uops)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue