fp8 gemm cleanup (#16607)

This commit is contained in:
qazal 2026-06-13 12:17:32 +08:00 committed by GitHub
commit 2e77bd01db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -346,20 +346,15 @@ __global__ __launch_bounds__(512, 2) void hk_fp8_gemm(bf16 *C_ptr, fp8e4m3 *A_pt
}
// apply x_scale * w_scale before bf16 store to prevent overflow
#if SCALE_MODE != 0
#if SCALE_MODE == 1
float scale = *x_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 2
float scale = *w_scale_ptr;
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);
mul(cD, cD, scale);
#elif SCALE_MODE == 3
float scale = *x_scale_ptr * *w_scale_ptr;
#endif
mul(cA, cA, scale);
mul(cB, cB, scale);
mul(cC, cC, scale);