mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update hipkittens (#16544)
This commit is contained in:
parent
4e2e2e9956
commit
5ef30005fa
17 changed files with 1010 additions and 126 deletions
|
|
@ -133,7 +133,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
constexpr int neg_inf_v = 29;
|
||||
// Move -inf to VGPR neg_inf_v
|
||||
kittens::macros::clobber_gpr<neg_inf_v>();
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000);
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000);
|
||||
|
||||
art<float, DOT_SLICE_QO, WARP_SIZE_KV, col_l, rt_16x16_s, P_ranges> P_ij; // 16 registers
|
||||
art<float, DOT_SLICE_QO, WARP_SIZE_KV, col_l, rt_16x16_s, dP_ranges> dP_ij; // 16 registers
|
||||
|
|
@ -330,7 +330,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 0
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -588,7 +588,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 1
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -845,7 +845,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 2
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -1101,7 +1101,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 3
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -1371,7 +1371,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 0
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -1632,7 +1632,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 1
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -1889,7 +1889,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 2
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -2145,7 +2145,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 3
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -2410,7 +2410,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 0
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -2671,7 +2671,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 1
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -2927,7 +2927,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 2
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
@ -3183,7 +3183,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
|||
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
|
||||
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
|
||||
// Dot slice 3
|
||||
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
|
||||
// If the query position is less than the key position, set P_ij to -inf
|
||||
if (q_pos < k_pos) {
|
||||
mov<neg_inf_v>(P_ij);
|
||||
|
|
|
|||
|
|
@ -216,6 +216,59 @@ template<> __device__ inline bf16_2 relu::op<bf16_2>(const bf16_2 &x) { return _
|
|||
template<> __device__ inline half relu::op<half> (const half &x ) { return __hmax(x, base_types::constants<half>::zero()); }
|
||||
template<> __device__ inline half_2 relu::op<half_2>(const half_2 &x) { return half_2{__hmax(x.x, base_types::constants<half>::zero()),
|
||||
__hmax(x.y, base_types::constants<half>::zero())}; }
|
||||
|
||||
|
||||
constexpr float SQRT_2_OVER_PI = 0.7978845608028654f;
|
||||
constexpr float GELU_COEFF = 0.044715f;
|
||||
constexpr float GELU_INNER_COEFF = GELU_COEFF * SQRT_2_OVER_PI;
|
||||
constexpr float DGELU_COEFF = 3.0f * GELU_COEFF * SQRT_2_OVER_PI;
|
||||
|
||||
static __device__ inline float fast_tanh(float x) {
|
||||
x = fmaxf(fminf(x, 20.f), -20.f);
|
||||
float e2x = __builtin_amdgcn_exp2f(x * 2.8853900817779268f);
|
||||
return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gaussian Error Linear Unit (GELU) activation.
|
||||
*
|
||||
* Computes the GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The GELU activation applied to the input.
|
||||
*/
|
||||
struct gelu {
|
||||
template<typename T> static __device__ inline T op(const T &x);
|
||||
};
|
||||
template<> __device__ inline float gelu::op<float>(const float &x) {
|
||||
return x * (0.5f + 0.5f * fast_tanh(x * (SQRT_2_OVER_PI + GELU_INNER_COEFF * x * x)));
|
||||
}
|
||||
template<> __device__ inline float2 gelu::op<float2>(const float2 &x) {
|
||||
return float2{gelu::op<float>(x.x), gelu::op<float>(x.y)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Derivative of the GELU activation.
|
||||
*
|
||||
* Computes the derivative of the GELU approximation with respect to the input.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The derivative of GELU evaluated at the input.
|
||||
*/
|
||||
struct dgelu {
|
||||
template<typename T> static __device__ inline T op(const T &x);
|
||||
};
|
||||
template<> __device__ inline float dgelu::op<float>(const float &x) {
|
||||
float tanh_out = fast_tanh(SQRT_2_OVER_PI * x * (1.f + GELU_COEFF * x * x));
|
||||
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (SQRT_2_OVER_PI + DGELU_COEFF * x * x)) +
|
||||
0.5f * (1.f + tanh_out);
|
||||
}
|
||||
template<> __device__ inline float2 dgelu::op<float2>(const float2 &x) {
|
||||
return float2{dgelu::op<float>(x.x), dgelu::op<float>(x.y)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy operation.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -10,14 +10,16 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <hip_bf16.h>
|
||||
#include <hip_fp16.h>
|
||||
#include <hip_fp8.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
#include <hip/hip_fp4.h>
|
||||
#include <hip/amd_detail/amd_hip_ocp_types.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <string>
|
||||
#include <bit>
|
||||
|
||||
typedef uint32_t __amd_fp8x4_storage_t;
|
||||
|
||||
namespace kittens {
|
||||
|
||||
|
|
@ -37,7 +39,6 @@ using bf16_2 = __hip_bfloat162;
|
|||
* @brief Packed word of two half-precision floating-point values.
|
||||
*/
|
||||
using half_2 = __half2;
|
||||
#ifdef KITTENS_CDNA4
|
||||
/**
|
||||
* @brief float8 floating-point type.
|
||||
*/
|
||||
|
|
@ -50,20 +51,30 @@ using fp8e4m3_2 = __hip_fp8x2_e4m3;
|
|||
* @brief Packed word of four float8 floating-point values.
|
||||
*/
|
||||
using fp8e4m3_4 = __hip_fp8x4_e4m3;
|
||||
#else
|
||||
/**
|
||||
* @brief float8 floating-point type.
|
||||
* @brief 8-bit exponent-only block-scaling scale type.
|
||||
*/
|
||||
using fp8e4m3 = __hip_fp8_e4m3_fnuz;
|
||||
using fp8e8m0 = __amd_scale_t;
|
||||
/**
|
||||
* @brief Packed word of two float8 floating-point values.
|
||||
* @brief Packed word of two 8-bit exponent-only block-scaling scale values.
|
||||
*/
|
||||
using fp8e4m3_2 = __hip_fp8x2_e4m3_fnuz;
|
||||
using fp8e8m0_2 = __amd_fp8x2_storage_t;
|
||||
/**
|
||||
* @brief Packed word of four float8 floating-point values.
|
||||
* @brief Packed word of four 8-bit exponent-only block-scaling scale values.
|
||||
*/
|
||||
using fp8e4m3_4 = __hip_fp8x4_e4m3_fnuz;
|
||||
#endif
|
||||
using fp8e8m0_4 = __amd_fp8x4_storage_t;
|
||||
/**
|
||||
* @brief FP4 E2M1 floating-point type.
|
||||
*/
|
||||
using fp4e2m1 = __hip_fp4_e2m1;
|
||||
/**
|
||||
* @brief Packed word of two FP4 E2M1 floating-point values.
|
||||
*/
|
||||
using fp4e2m1_2 = __hip_fp4x2_e2m1;
|
||||
/**
|
||||
* @brief Packed word of four FP4 E2M1 floating-point values.
|
||||
*/
|
||||
using fp4e2m1_4 = __hip_fp4x4_e2m1;
|
||||
|
||||
namespace ducks {
|
||||
/**
|
||||
|
|
@ -74,9 +85,11 @@ namespace ducks {
|
|||
namespace base_types {
|
||||
|
||||
template<typename T>
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>;
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>
|
||||
|| std::is_same_v<T, fp4e2m1_4>;
|
||||
template<typename T>
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>;
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>
|
||||
|| std::is_same_v<T, fp4e2m1>;
|
||||
|
||||
} // namespace base_types
|
||||
} // namespace ducks
|
||||
|
|
@ -157,6 +170,26 @@ template<> struct constants<fp8e4m3_4> {
|
|||
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
|
||||
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0> {
|
||||
static __device__ inline constexpr fp8e8m0 zero() { return std::bit_cast<fp8e8m0>(uint8_t(0x00)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0 one() { return std::bit_cast<fp8e8m0>(uint8_t(0x7F)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0_2> {
|
||||
static __device__ inline constexpr fp8e8m0_2 zero() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x0000)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0_2 one() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x7F7F)); }
|
||||
};
|
||||
template<> struct constants<fp8e8m0_4> {
|
||||
static __device__ inline constexpr fp8e8m0_4 zero() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x00000000)); } // not actually 0
|
||||
static __device__ inline constexpr fp8e8m0_4 one() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x7F7F7F7F)); }
|
||||
};
|
||||
template<> struct constants<fp4e2m1> {
|
||||
static __device__ inline constexpr fp4e2m1 zero() { return std::bit_cast<fp4e2m1>(uint8_t(0x00)); }
|
||||
static __device__ inline constexpr fp4e2m1 one() { return std::bit_cast<fp4e2m1>(uint8_t(0x02)); }
|
||||
};
|
||||
template<> struct constants<fp4e2m1_4> {
|
||||
static __device__ inline constexpr fp4e2m1_4 zero() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr fp4e2m1_4 one() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x2222)); }
|
||||
};
|
||||
template<> struct constants<int> {
|
||||
static __device__ inline constexpr int zero() { return 0; }
|
||||
static __device__ inline constexpr int ones() { return 1; }
|
||||
|
|
@ -250,6 +283,26 @@ template<> struct packing<fp8e4m3_4> {
|
|||
using unpacked_type = fp8e4m3;
|
||||
using packed_type = fp8e4m3_4;
|
||||
};
|
||||
template<> struct packing<fp8e8m0> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
template<> struct packing<fp8e8m0_4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
template<> struct packing<fp4e2m1> {
|
||||
static __host__ __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp4e2m1;
|
||||
using packed_type = fp4e2m1_4;
|
||||
};
|
||||
template<> struct packing<fp4e2m1_4> {
|
||||
static __host__ __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp4e2m1;
|
||||
using packed_type = fp4e2m1_4;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Provides templated functionality to convert between different types.
|
||||
|
|
@ -377,5 +430,25 @@ template<> struct convertor<float, fp8e4m3> {
|
|||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp4e2m1, float> {
|
||||
static __host__ __device__ inline fp4e2m1 convert(const float & u) {
|
||||
return fp4e2m1(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, fp4e2m1> {
|
||||
static __host__ __device__ inline float convert(const fp4e2m1 & u) {
|
||||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp4e2m1_4, float4> {
|
||||
static __host__ __device__ inline fp4e2m1_4 convert(const float4& u) {
|
||||
return fp4e2m1_4(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float4, fp4e2m1_4> {
|
||||
static __host__ __device__ inline float4 convert(const fp4e2m1_4& u) {
|
||||
return float4(u);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,152 +158,405 @@ __device__ __forceinline__ void clobber_gpr() {
|
|||
#undef CLOBBER_AREG_CASE
|
||||
#undef CLOBBER_VREG_CASE
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b128(const uint32_t smem_ptr, const int offset) {
|
||||
__device__ __forceinline__ constexpr uint32_t max_ds_inst_offset()
|
||||
{
|
||||
// DS ops contain 2 8-bits instruction offset.
|
||||
// For non-pk2 instructions like ds_read_b32, the 2 fields are regarded as 1.
|
||||
// For pk2 instructions like ds_read2_b32, max offset is limited by 8 bits.
|
||||
return (1u << 16) - 1;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ constexpr uint32_t max_ds_pk2_inst_offset()
|
||||
{
|
||||
// DS ops contain 2 8-bits instruction offset.
|
||||
// For non-pk2 instructions like ds_read_b32, the 2 fields are regarded as a whole.
|
||||
// For pk2 instructions like ds_read2_b32, max offset is limited by 8 bits.
|
||||
return (1u << 8) - 1;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ constexpr uint32_t max_mubuf_inst_offset()
|
||||
{
|
||||
// MUBUF ops contain 1 12-bits instruction offset.
|
||||
return (1u << 12) - 1;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b32(const uint32_t smem_ptr, const int i_offset) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b32 a[%0], %1 offset:%2"
|
||||
:
|
||||
: "n"(GPR_START - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("ds_read_b32 v[%0], %1 offset:%2"
|
||||
:
|
||||
: "n"(GPR_START), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ds_read_b32(T& dst, const uint32_t smem_ptr, const int i_offset) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("ds_read_b32 %0, %1 offset:%2"
|
||||
: "=v"(dst)
|
||||
: "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template <typename T = u32x2>
|
||||
__device__ __forceinline__ T ds_read_b64(const uint32_t smem_ptr, const int i_offset) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
|
||||
T result;
|
||||
asm volatile("ds_read_b64 %0, %1 offset:%2"
|
||||
: "=v"(result)
|
||||
: "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b64(const uint32_t smem_ptr, const int i_offset) {
|
||||
constexpr int GPR_END = GPR_START + 1;
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b64 a[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("ds_read_b64 v[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b64_tr_b4(const uint32_t smem_ptr, const int i_offset) {
|
||||
constexpr int GPR_END = GPR_START + 1;
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b64_tr_b4 a[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_read_b64_tr_b4 v[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b64_tr_b8(const uint32_t smem_ptr, const int i_offset) {
|
||||
constexpr int GPR_END = GPR_START + 1;
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b64_tr_b8 a[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_read_b64_tr_b8 v[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b64_tr_b16(const uint32_t smem_ptr, const int i_offset) {
|
||||
constexpr int GPR_END = GPR_START + 1;
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b64_tr_b16 a[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_read_b64_tr_b16 v[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T = u32x4>
|
||||
__device__ __forceinline__ T ds_read_b128(const uint32_t smem_ptr, const int i_offset) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
|
||||
T result;
|
||||
asm volatile("ds_read_b128 %0, %1 offset:%2"
|
||||
: "=v"(result)
|
||||
: "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b128(const uint32_t smem_ptr, const int i_offset) {
|
||||
constexpr int GPR_END = GPR_START + 3;
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b128 a[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(offset)
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("ds_read_b128 v[%0:%1], %2 offset:%3"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(offset)
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_read_b64_tr_b16(const uint32_t smem_ptr, const int offset) {
|
||||
constexpr int GPR_END = GPR_START + 1;
|
||||
|
||||
__device__ __forceinline__ void ds_write_b32(const uint32_t smem_ptr, const int i_offset) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_read_b64_tr_b16 a[%0:%1], %2 offset:%3"
|
||||
asm volatile("ds_write_b32 %0, a[%1], offset:%2"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(offset)
|
||||
: "v"(smem_ptr), "n"(GPR_START - 256), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_read_b64_tr_b16 v[%0:%1], %2 offset:%3"
|
||||
asm volatile("ds_write_b32 %0, v[%1], offset:%2"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(offset)
|
||||
: "v"(smem_ptr), "n"(GPR_START), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ds_write_b32(const T& val, const uint32_t smem_ptr, const int i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("ds_write_b32 %0, %1 offset:%2"
|
||||
:
|
||||
: "v"(smem_ptr), "v"(val), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_write_b64(const uint32_t smem_ptr, const int offset) {
|
||||
__device__ __forceinline__ void ds_write_b64(const uint32_t smem_ptr, const int i_offset) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_write_b64 %0, a[%1:%2], offset:%3"
|
||||
:
|
||||
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "i"(offset)
|
||||
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_write_b64 %0, v[%1:%2], offset:%3"
|
||||
:
|
||||
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 1), "i"(offset)
|
||||
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 1), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ds_write_b64(const T& val, const uint32_t smem_ptr, const int i_offset = 0) {
|
||||
static_assert(sizeof(T) == 2 * sizeof(uint32_t));
|
||||
asm volatile("ds_write_b64 %0, %1 offset:%2"
|
||||
:
|
||||
: "v"(smem_ptr), "v"(val), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void ds_write_b128(const uint32_t smem_ptr, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("ds_write_b128 %0, a[%1:%2], offset:%3"
|
||||
:
|
||||
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("ds_write_b128 %0, v[%1:%2], offset:%3"
|
||||
:
|
||||
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 3), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ void ds_write_b128(const T& value, const uint32_t smem_ptr, const int i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(u32x4));
|
||||
asm volatile("ds_write_b128 %0, %1 offset:%2"
|
||||
:
|
||||
: "v"(smem_ptr), "v"(value), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dword a[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dword v[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = uint32_t>
|
||||
__device__ __forceinline__ T buffer_load_dword(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T result;
|
||||
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx2(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx2 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx2 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 1), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = u32x2>
|
||||
__device__ __forceinline__ T buffer_load_dwordx2(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
|
||||
T result;
|
||||
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx4 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx4 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 3), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T = u32x4>
|
||||
__device__ __forceinline__ T buffer_load_dwordx4(
|
||||
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
|
||||
T result;
|
||||
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
: "=v"(result)
|
||||
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
return result;
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_store_dword(buffer_resource& br, const uint32_t byte_offset) {
|
||||
|
||||
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR >= 256) {
|
||||
asm volatile("buffer_store_dword a[%0], %1, %2, 0 offen"
|
||||
asm volatile("buffer_store_dword a[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dword v[%0], %1, %2, 0 offen"
|
||||
asm volatile("buffer_store_dword v[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx2(buffer_resource& br, const uint32_t byte_offset) {
|
||||
template<typename T = u32x2>
|
||||
__device__ __forceinline__ void buffer_store_dword(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
|
||||
asm volatile("buffer_store_dword %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx2(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_store_dwordx2 a[%0:%1], %2, %3, 0 offen"
|
||||
asm volatile("buffer_store_dwordx2 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dwordx2 v[%0:%1], %2, %3, 0 offen"
|
||||
asm volatile("buffer_store_dwordx2 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 1), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR_START), "n"(GPR_START + 1), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(buffer_resource& br, const uint32_t byte_offset) {
|
||||
template<typename T = u32x2>
|
||||
__device__ __forceinline__ void buffer_store_dwordx2(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
|
||||
asm volatile("buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
// AGPRS
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_store_dwordx4 a[%0:%1], %2, %3, 0 offen"
|
||||
asm volatile("buffer_store_dwordx4 a[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
// VGPRS
|
||||
} else {
|
||||
asm volatile("buffer_store_dwordx4 v[%0:%1], %2, %3, 0 offen"
|
||||
asm volatile("buffer_store_dwordx4 v[%0:%1], %2, %3, %4 offen offset:%5"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 3), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR_START), "n"(GPR_START + 3), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx4(buffer_resource& br, const uint32_t byte_offset) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx4 a[%0:%1], %2, %3, 0 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx4 v[%0:%1], %2, %3, 0 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 3), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START>
|
||||
__device__ __forceinline__ void buffer_load_dwordx2(buffer_resource& br, const uint32_t byte_offset) {
|
||||
if constexpr (GPR_START >= 256) {
|
||||
asm volatile("buffer_load_dwordx2 a[%0:%1], %2, %3, 0 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_load_dwordx2 v[%0:%1], %2, %3, 0 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR_START), "n"(GPR_START + 1), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
|
||||
: "memory");
|
||||
}
|
||||
template<typename T = u32x4>
|
||||
__device__ __forceinline__ void buffer_store_dwordx4(
|
||||
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
|
||||
asm volatile("buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(buffer_resource& br, const uint32_t byte_offset) {
|
||||
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
|
||||
if constexpr (GPR >= 256) {
|
||||
asm volatile("buffer_atomic_pk_add_bf16 a[%0], %1, %2, 0 offen"
|
||||
asm volatile("buffer_atomic_pk_add_bf16 a[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
} else {
|
||||
asm volatile("buffer_atomic_pk_add_bf16 v[%0], %1, %2, 0 offen"
|
||||
asm volatile("buffer_atomic_pk_add_bf16 v[%0], %1, %2, %3 offen offset:%4"
|
||||
:
|
||||
: "n"(GPR), "v"(byte_offset), "s"(*(i32x4*)&br)
|
||||
: "n"(GPR), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
}
|
||||
|
|
@ -468,6 +721,75 @@ __device__ __forceinline__ void mfma_f32_32x32x16_bf16() {
|
|||
}
|
||||
}
|
||||
|
||||
template<int GPR_START_A, int GPR_START_B, int GPR_START_C, int GPR_START_D>
|
||||
__device__ __forceinline__ void mfma_f32_16x16x32_fp8_fp8() {
|
||||
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], a[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
} else {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], v[%6:%7]"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR_START_A, int GPR_START_B, int GPR_START_D>
|
||||
__device__ __forceinline__ void mfma_f32_16x16x32_bf16_zero_accum() {
|
||||
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
|
||||
|
|
@ -542,6 +864,43 @@ __device__ __forceinline__ void mfma_f32_32x32x16_bf16_zero_accum() {
|
|||
}
|
||||
}
|
||||
|
||||
template<int GPR_START_A, int GPR_START_B, int GPR_START_D>
|
||||
__device__ __forceinline__ void mfma_f32_16x16x32_fp8_fp8_zero_accum() {
|
||||
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1));
|
||||
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
|
||||
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256) {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1));
|
||||
} else {
|
||||
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], 0"
|
||||
:
|
||||
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1));
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR0_START, int GPR1_START, int GPR>
|
||||
__device__ __forceinline__ void v_subrev_f32_dpp() {
|
||||
|
||||
|
|
@ -592,11 +951,29 @@ __device__ __forceinline__ void v_accvgpr_read_b32() {
|
|||
: "n"(GPR0), "n"(GPR1 - 256));
|
||||
}
|
||||
|
||||
template<int GPR>
|
||||
__device__ __forceinline__ void v_mov_b32(const uint32_t value) {
|
||||
template<int GPR, typename T>
|
||||
__device__ __forceinline__ void v_mov_b32_up2p(const T value) {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
asm volatile("v_mov_b32 v[%0], %1"
|
||||
:
|
||||
: "n"(GPR), "i"(value));
|
||||
: "n"(GPR), "v"(value));
|
||||
}
|
||||
|
||||
template <int GPR, typename T = uint32_t>
|
||||
__device__ __forceinline__ T v_mov_b32_p2up() {
|
||||
static_assert(sizeof(T) == sizeof(uint32_t));
|
||||
T r;
|
||||
if constexpr (GPR < 256) {
|
||||
asm volatile("v_mov_b32 %0, v[%1]"
|
||||
: "=v"(r)
|
||||
: "n"(GPR));
|
||||
}
|
||||
else {
|
||||
asm volatile("v_accvgpr_read_b32 %0, a[%1]"
|
||||
: "=v"(r)
|
||||
: "n"(GPR - 256));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template<int GPR0, int GPR1>
|
||||
|
|
@ -612,8 +989,9 @@ __device__ __forceinline__ void v_cndmask_b32_e64(uint64_t mask) {
|
|||
:
|
||||
: "n"(GPR0), "n"(GPR1), "n"(GPR2), "s"(mask));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Multiplication operation on explicit registers.
|
||||
* @brief Multiplication operation on explicit registers and immediate operand.
|
||||
*/
|
||||
struct mul {
|
||||
template<int GPR0, int GPR1>
|
||||
|
|
@ -628,6 +1006,12 @@ struct mul {
|
|||
}
|
||||
}
|
||||
|
||||
template<int GPR0, int GPR1>
|
||||
static __device__ inline void op_pk2(const float ¶m) {
|
||||
op<GPR0, GPR1>(param);
|
||||
op<GPR0 + 1, GPR1 + 1>(param);
|
||||
}
|
||||
|
||||
template<int GPR0, int GPR1, int GPR2>
|
||||
static __device__ inline void op() {
|
||||
if constexpr (GPR0 < 256 && GPR1 < 256 && GPR2 < 256) {
|
||||
|
|
@ -638,8 +1022,44 @@ struct mul {
|
|||
static_assert(false, "Invalid operand for instruction: v_mul_f32_e32");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR0, int GPR1, int GPR2>
|
||||
static __device__ inline void op_pk2() {
|
||||
if constexpr (GPR0 < (256 - 1) && GPR1 < (256 - 1) && GPR2 < (256 - 1)) {
|
||||
asm volatile("v_pk_mul_f32 v[%0:%1], v[%4:%5], v[%2:%3]"
|
||||
:
|
||||
: "n"(GPR0), "n"(GPR0 + 1), "n"(GPR1), "n"(GPR1 + 1), "n"(GPR2), "n"(GPR2 + 1));
|
||||
} else {
|
||||
static_assert(false, "Invalid operand for instruction: v_pk_mul_f32");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct mul_vgpr {
|
||||
template<int GPR0, int GPR1>
|
||||
static __device__ inline void op(const float ¶m) {
|
||||
if constexpr (GPR0 < 256 && GPR1 < 256) {
|
||||
asm volatile("v_mul_f32_e32 v[%0], %2, v[%1]"
|
||||
:
|
||||
: "n"(GPR0), "n"(GPR1), "v"(param));
|
||||
} else {
|
||||
static_assert(false, "Invalid operand for instruction: v_mul_f32_e32");
|
||||
}
|
||||
}
|
||||
|
||||
template<int GPR0, int GPR1>
|
||||
static __device__ inline void op_pk2(const float ¶m) {
|
||||
if constexpr (GPR0 < (256 - 1) && GPR1 < (256 - 1)) {
|
||||
const float2 param2 = {param, param};
|
||||
asm volatile("v_pk_mul_f32 v[%0:%1], %4, v[%2:%3]"
|
||||
:
|
||||
: "n"(GPR0), "n"(GPR0 + 1), "n"(GPR1), "n"(GPR1 + 1), "v"(param2));
|
||||
} else {
|
||||
static_assert(false, "Invalid operand for instruction: v_pk_mul_f32");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct exp2 {
|
||||
template<int GPR0, int GPR1>
|
||||
static __device__ inline void op() {
|
||||
|
|
@ -669,4 +1089,4 @@ struct zero {
|
|||
};
|
||||
|
||||
} // namespace macros
|
||||
} // namespace kittens
|
||||
} // namespace kittens
|
||||
|
|
|
|||
|
|
@ -50,7 +50,11 @@ __device__ __forceinline__ int warpid() { return threadIdx.x >> 6; }
|
|||
*/
|
||||
__device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
|
||||
|
||||
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
|
||||
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
|
||||
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
|
||||
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
struct buffer_resource {
|
||||
uint64_t ptr;
|
||||
uint32_t range;
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
* @param idx[in] The index of the tile to load data from.
|
||||
*/
|
||||
|
||||
template<int axis, ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
|
||||
template<int axis, int elem_offset=0, ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
|
||||
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx, const COORD &warp_idx) {
|
||||
using T2 = RT::dtype;
|
||||
constexpr int packing = base_types::packing<typename RT::dtype>::num();
|
||||
|
|
@ -42,22 +42,48 @@
|
|||
buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000);
|
||||
|
||||
int warp_offset = src.idx(warp_idx.template unit_coord<axis, 3>());
|
||||
int thr_offset = (row_offset * row_stride + col_offset + warp_offset) * sizeof(U);
|
||||
|
||||
// Compile-time loop to load data into the tile
|
||||
auto perform_load_at = [&]<int N, int M, int K>() {
|
||||
using tile_range = ducks::art::get_nth_range_t<typename RT::register_ranges, N * RT::width + M>;
|
||||
const int register_offset = K * RT::registers_per_stride;
|
||||
|
||||
constexpr int col = RT::base_tile_cols*M + K * RT::base_tile_elements_per_stride_group;
|
||||
constexpr int row = RT::base_tile_rows*N;
|
||||
const int k_row_offset = row * row_stride * sizeof(U);
|
||||
|
||||
const int col = RT::base_tile_cols*M + col_offset + K * RT::base_tile_elements_per_stride_group;
|
||||
const int row = RT::base_tile_rows*N + row_offset;
|
||||
const int offset = (row*row_stride + col + warp_offset) * sizeof(U);
|
||||
|
||||
if constexpr (std::is_same_v<U2, bf16_2>) {
|
||||
if constexpr (RT::base_tile_stride == 8) {
|
||||
macros::buffer_load_dwordx4<tile_range::lo + register_offset>(br, offset);
|
||||
} else if constexpr (RT::base_tile_stride == 4) {
|
||||
macros::buffer_load_dwordx2<tile_range::lo + register_offset>(br, offset);
|
||||
}
|
||||
constexpr int stride_in_bytes = RT::base_tile_stride * sizeof(U);
|
||||
constexpr int offset_in_bytes = (elem_offset + col) * sizeof(U);
|
||||
constexpr int start_gpr = tile_range::lo + register_offset;
|
||||
|
||||
if constexpr (offset_in_bytes <= macros::max_mubuf_inst_offset()) {
|
||||
if constexpr (stride_in_bytes == (sizeof(int32_t) * 4)) {
|
||||
macros::buffer_load_dwordx4<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
|
||||
}
|
||||
else if constexpr (stride_in_bytes == (sizeof(int32_t) * 2)) {
|
||||
macros::buffer_load_dwordx2<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
|
||||
}
|
||||
else if constexpr (stride_in_bytes == sizeof(int32_t)) {
|
||||
macros::buffer_load_dword<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
|
||||
}
|
||||
else {
|
||||
static_assert(false, "Encounter unsupported format in ops/warp/memory/tile/assembly/global_to_register.cuh\n");
|
||||
}
|
||||
}
|
||||
else {
|
||||
if constexpr (stride_in_bytes == (sizeof(int32_t) * 4)) {
|
||||
macros::buffer_load_dwordx4<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
|
||||
}
|
||||
else if constexpr (stride_in_bytes == (sizeof(int32_t) * 2)) {
|
||||
macros::buffer_load_dwordx2<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
|
||||
}
|
||||
else if constexpr (stride_in_bytes == sizeof(int32_t)) {
|
||||
macros::buffer_load_dword<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
|
||||
}
|
||||
else {
|
||||
static_assert(false, "Encounter unsupported format in ops/warp/memory/tile/assembly/global_to_register.cuh\n");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -74,12 +100,11 @@
|
|||
}(std::make_index_sequence<RT::width>{});
|
||||
}.template operator()<Ns>(), ...);
|
||||
}(std::make_index_sequence<RT::height>{});
|
||||
|
||||
}
|
||||
|
||||
template<ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
|
||||
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx, const COORD &warp_idx) {
|
||||
load<2, RT, GL>(dst, src, idx, warp_idx);
|
||||
load<2, 0, RT, GL>(dst, src, idx, warp_idx);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -226,7 +226,8 @@ __device__ inline void load(ST& dst, const GL& src, const COORD& idx, const uint
|
|||
|
||||
if (warpid < leftover_warps) {
|
||||
|
||||
uintptr_t lds_addr = lds_base + (memcpy_per_tile * num_warps * bytes_per_warp);
|
||||
const T* lds_elem_ptr = lds_base + (memcpy_per_tile * num_warps * elements_per_warp);
|
||||
uintptr_t lds_addr = reinterpret_cast<uintptr_t>(lds_elem_ptr);
|
||||
as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_addr);
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
|
|
@ -414,4 +415,4 @@ template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST
|
|||
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store<2, false, ST, GL, COORD, WARP_THREADS>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,7 +134,17 @@ __device__ static inline void bin_map(T0 &dst, const T1 &src, const typename bas
|
|||
|
||||
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
|
||||
([&]<std::size_t R>() {
|
||||
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo>(param);
|
||||
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
|
||||
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
|
||||
if constexpr ((R % 2) == 0) {
|
||||
if constexpr ((R + 1) < registers_T0::size) {
|
||||
op::template op_pk2<GPR0, GPR1>(param);
|
||||
}
|
||||
else {
|
||||
op::template op<GPR0, GPR1>(param);
|
||||
}
|
||||
}
|
||||
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
|
||||
}.template operator()<Rs>(), ...);
|
||||
}(std::make_index_sequence<registers_T0::size>{});
|
||||
}
|
||||
|
|
@ -156,7 +166,17 @@ __device__ static inline void bin_map(T0 &dst, const T1 &src, const typename bas
|
|||
|
||||
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
|
||||
([&]<std::size_t R>() {
|
||||
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo>(param);
|
||||
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
|
||||
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
|
||||
if constexpr ((R % 2) == 0) {
|
||||
if constexpr ((R + 1) < registers_T0::size) {
|
||||
op::template op_pk2<GPR0, GPR1>(param);
|
||||
}
|
||||
else {
|
||||
op::template op<GPR0, GPR1>(param);
|
||||
}
|
||||
}
|
||||
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
|
||||
}.template operator()<Rs>(), ...);
|
||||
}(std::make_index_sequence<registers_T0::size>{});
|
||||
};
|
||||
|
|
@ -205,7 +225,18 @@ __device__ static inline void bin_map(T0 &dst, const T1 &lhs, const T2 &rhs) {
|
|||
|
||||
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
|
||||
([&]<std::size_t R>() {
|
||||
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo, ducks::art::get_nth_range_t<registers_T2, R>::lo>();
|
||||
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
|
||||
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
|
||||
constexpr int GPR2 = ducks::art::get_nth_range_t<registers_T2, R>::lo;
|
||||
if constexpr ((R % 2) == 0) {
|
||||
if constexpr ((R + 1) < registers_T0::size) {
|
||||
op::template op_pk2<GPR0, GPR1, GPR2>();
|
||||
}
|
||||
else {
|
||||
op::template op<GPR0, GPR1, GPR2>();
|
||||
}
|
||||
}
|
||||
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
|
||||
}.template operator()<Rs>(), ...);
|
||||
}(std::make_index_sequence<registers_T0::size>{});
|
||||
}
|
||||
|
|
@ -234,7 +265,18 @@ __device__ static inline void bin_map(T0 &dst, const T1 &lhs, const T2 &rhs) {
|
|||
|
||||
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
|
||||
([&]<std::size_t R>() {
|
||||
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo, ducks::art::get_nth_range_t<registers_T2, R>::lo>();
|
||||
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
|
||||
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
|
||||
constexpr int GPR2 = ducks::art::get_nth_range_t<registers_T2, R>::lo;
|
||||
if constexpr ((R % 2) == 0) {
|
||||
if constexpr ((R + 1) < registers_T0::size) {
|
||||
op::template op_pk2<GPR0, GPR1, GPR2>();
|
||||
}
|
||||
else {
|
||||
op::template op<GPR0, GPR1, GPR2>();
|
||||
}
|
||||
}
|
||||
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
|
||||
}.template operator()<Rs>(), ...);
|
||||
}(std::make_index_sequence<registers_T0::size>{});
|
||||
};
|
||||
|
|
@ -364,6 +406,16 @@ __device__ static inline void mul(T0 &dst, const T1 &lhs, const U &rhs) {
|
|||
bin_map<macros::mul, T0, T1>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::art::all T0, ducks::art::all T1, typename U>
|
||||
__device__ static inline void mul_vgpr(T0 &dst, const T1 &lhs, const U &rhs) {
|
||||
bin_map<macros::mul_vgpr, T0, T1>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<int N, int M, ducks::art::all T0, ducks::art::all T1, typename U>
|
||||
__device__ static inline void mul_vgpr(T0 &dst, const T1 &lhs, const U &rhs) {
|
||||
bin_map<N, M, macros::mul_vgpr, T0, T1>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Subtracts row values from each row of a tile.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -20,14 +20,44 @@ namespace kittens {
|
|||
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeC, typename RegisterRangeD>
|
||||
template<typename AccumulatorShape, typename InputType, typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeC, typename RegisterRangeD>
|
||||
__device__ static inline void mma_ABt_base() {
|
||||
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
|
||||
if constexpr (std::is_same_v<AccumulatorShape, ducks::rt_shape::rt_16x16>)
|
||||
{
|
||||
if constexpr (std::is_same_v<InputType, fp8e4m3>)
|
||||
{
|
||||
macros::mfma_f32_16x16x32_fp8_fp8<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
else
|
||||
{
|
||||
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
}
|
||||
template<typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeD>
|
||||
|
||||
template<typename AccumulatorShape, typename InputType, typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeD>
|
||||
__device__ static inline void mma_ABt_base_zero_accum() {
|
||||
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
|
||||
if constexpr (std::is_same_v<AccumulatorShape, ducks::rt_shape::rt_16x16>)
|
||||
{
|
||||
if constexpr (std::is_same_v<InputType, fp8e4m3>)
|
||||
{
|
||||
macros::mfma_f32_16x16x32_fp8_fp8_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
else
|
||||
{
|
||||
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
|
||||
*
|
||||
|
|
@ -87,7 +117,9 @@ __device__ static inline void mma_ABt(D &d,
|
|||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half>)
|
||||
std::is_same_v<typename B::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3>)
|
||||
);
|
||||
|
||||
// Helper function template for compile-time MMA operations
|
||||
|
|
@ -95,7 +127,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width + K>;
|
||||
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
}
|
||||
|
||||
template<ducks::art::all D, ducks::art::all A, ducks::art::all B, ducks::art::all C>
|
||||
|
|
@ -117,7 +149,9 @@ __device__ static inline void mma_ABt(D &d,
|
|||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half>)
|
||||
std::is_same_v<typename B::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3>)
|
||||
);
|
||||
|
||||
// Helper function template for compile-time MMA operations
|
||||
|
|
@ -127,7 +161,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width>;
|
||||
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
|
||||
// Subsequent MMA operations for k=1 to A::width-1
|
||||
[&]<std::size_t... Ks>(std::index_sequence<Ks...>) {
|
||||
|
|
@ -138,7 +172,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, k + M * B::width>;
|
||||
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
|
||||
}
|
||||
}(), ...);
|
||||
}(std::make_index_sequence<A::width>{});
|
||||
|
|
@ -172,7 +206,9 @@ __device__ static inline void mma_ABt(D &d,
|
|||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half>)
|
||||
std::is_same_v<typename B::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3>)
|
||||
);
|
||||
|
||||
// Helper function template for compile-time MMA operations
|
||||
|
|
@ -180,7 +216,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, N * A::width + K>;
|
||||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width + K>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base_zero_accum<range_type_A, range_type_B, range_type_D>();
|
||||
mma_ABt_base_zero_accum<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D>();
|
||||
}
|
||||
|
||||
template<ducks::art::all D, ducks::art::all A, ducks::art::all B>
|
||||
|
|
@ -199,7 +235,9 @@ __device__ static inline void mma_ABt(D &d,
|
|||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half>)
|
||||
std::is_same_v<typename B::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3>)
|
||||
);
|
||||
|
||||
// Helper function template for compile-time MMA operations
|
||||
|
|
@ -208,7 +246,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, N * A::width>;
|
||||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base_zero_accum<range_type_A, range_type_B, range_type_D>();
|
||||
mma_ABt_base_zero_accum<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D>();
|
||||
|
||||
// Subsequent MMA operations for k=1 to A::width-1
|
||||
[&]<std::size_t... Ks>(std::index_sequence<Ks...>) {
|
||||
|
|
@ -218,7 +256,7 @@ __device__ static inline void mma_ABt(D &d,
|
|||
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, k + N * A::width>;
|
||||
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, k + M * B::width>;
|
||||
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
|
||||
mma_ABt_base<range_type_A, range_type_B, range_type_D, range_type_D>();
|
||||
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D, range_type_D>();
|
||||
}
|
||||
}(), ...);
|
||||
}(std::make_index_sequence<A::width>{});
|
||||
|
|
|
|||
|
|
@ -472,6 +472,28 @@ template<ducks::rt::all T>
|
|||
__device__ static inline void relu(T &dst, const T &src) {
|
||||
unary_map<base_ops::relu, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the GELU function (tanh approximation) to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the GELU function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void gelu(T &dst, const T &src) {
|
||||
unary_map<base_ops::gelu, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the GELU derivative to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the GELU derivative on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void dgelu(T &dst, const T &src) {
|
||||
unary_map<base_ops::dgelu, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Copies the elements from one tile to another.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -124,6 +124,28 @@ __device__ static inline void mfma1616128( float2 (&D)[2],
|
|||
)};
|
||||
}
|
||||
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0>
|
||||
__device__ static inline void mfma1616128_scaled( float2 (&D)[2],
|
||||
const fp8e4m3_4 (&A)[8],
|
||||
const fp8e4m3_4 (&B)[8],
|
||||
const float2 (&C)[2],
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
typedef __attribute__((__vector_size__(8 * sizeof(int)))) int intx8_t;
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(float)))) float floatx4_t;
|
||||
|
||||
*(floatx4_t*)D = {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
*(intx8_t*)A,
|
||||
*(intx8_t*)B,
|
||||
*(floatx4_t*)C,
|
||||
cbsz, // cbsz: 0=fp8(e4m3) A, 1=bf8(e5m2) A
|
||||
blgp, // blgp: 0=fp8(e4m3) B, 1=bf8(e5m2) B
|
||||
opsel_a, // opsel_a
|
||||
*scale_a, // scale_a
|
||||
opsel_b, // opsel_b
|
||||
*scale_b // scale_b
|
||||
)};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout.
|
||||
|
|
@ -221,6 +243,46 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
* This function performs the base dot product operation
|
||||
* for block-scaled matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float, col_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<Operand_T, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
|
||||
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
|
||||
*/
|
||||
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
|
||||
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
|
||||
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
|
||||
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
|
||||
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
|
||||
|
||||
constexpr int A_rows = A_shape::rows;
|
||||
constexpr int A_cols = A_shape::cols;
|
||||
constexpr int B_rows = B_shape::rows;
|
||||
constexpr int B_cols = B_shape::cols;
|
||||
|
||||
constexpr int A_stride = A_shape::stride;
|
||||
constexpr int B_stride = B_shape::stride;
|
||||
static_assert(A_stride == B_stride, "A and B must have the same stride");
|
||||
|
||||
if constexpr (std::is_same_v<D_shape, typename ducks::rt_shape::rt_16x16> &&
|
||||
A_rows == 16 && A_cols == 128 &&
|
||||
B_rows == 16 && B_cols == 128 &&
|
||||
std::is_same_v<C_shape, typename ducks::rt_shape::rt_16x16>) {
|
||||
mfma1616128_scaled<opsel_a, opsel_b, cbsz, blgp>(d.data, a.data, b.data, c.data, scale_a, scale_b);
|
||||
} else {
|
||||
static_assert(false, "Unsupported shape combination");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
|
||||
*
|
||||
|
|
@ -420,6 +482,62 @@ __device__ static inline void mma_ABt(D &d,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Block scaled dot product operation for row layout.
|
||||
*
|
||||
* This function performs the dot product operation
|
||||
* for block-scaled matrices in row layout.
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_fl<N, M, col_layout> accumulator.
|
||||
* @param[in] a The first input rt_bf<N, K, row_layout> matrix.
|
||||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, col_layout> accumulator matrix.
|
||||
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
|
||||
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
|
||||
*/
|
||||
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
|
||||
__device__ static inline void mma_ABt_scaled(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
const C &c,
|
||||
const fp8e8m0_4 *scale_a,
|
||||
const fp8e8m0_4 *scale_b) {
|
||||
|
||||
static_assert(D::rows == A::rows && D::cols == B::rows);
|
||||
static_assert(A::cols == B::cols);
|
||||
static_assert(D::rows == C::rows && D::cols == C::cols);
|
||||
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
|
||||
);
|
||||
|
||||
[&]<std::size_t... Ns>(std::index_sequence<Ns...>) {
|
||||
([&]<std::size_t N>() {
|
||||
[&]<std::size_t... Ms>(std::index_sequence<Ms...>) {
|
||||
([&]<std::size_t M>() {
|
||||
mma_ABt_base_scaled<N, M, cbsz, blgp>(
|
||||
d.tiles[N][M],
|
||||
a.tiles[N][0],
|
||||
b.tiles[M][0],
|
||||
c.tiles[N][M],
|
||||
scale_a,
|
||||
scale_b
|
||||
);
|
||||
}.template operator()<Ms>(), ...);
|
||||
}(std::make_index_sequence<D::width>{});
|
||||
}.template operator()<Ns>(), ...);
|
||||
}(std::make_index_sequence<D::height>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation with transposed A.
|
||||
*
|
||||
|
|
|
|||
74
extra/thunder/amd/include/ops/warp/register/tile/scales.cuh
Normal file
74
extra/thunder/amd/include/ops/warp/register/tile/scales.cuh
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* @file
|
||||
* @brief MXFP8 block scale loading and packing utilities.
|
||||
*
|
||||
* Provides functions for staging E8M0 block scales in LDS and packing them
|
||||
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
|
||||
*
|
||||
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
|
||||
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
|
||||
*
|
||||
* @param smem_scales LDS buffer, must be >= 2048 bytes.
|
||||
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
|
||||
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
|
||||
* @param block_m Starting row offset for A within the current block.
|
||||
* @param block_n Starting row offset for B within the current block.
|
||||
* @param k_iter Current K iteration index.
|
||||
* @param M_dim M dimension of the matrix.
|
||||
* @param N_dim N dimension of the matrix.
|
||||
*/
|
||||
__device__ __forceinline__ void load_scales_to_lds(
|
||||
uint8_t *smem_scales,
|
||||
const uint32_t *__restrict__ scale_A_iter,
|
||||
const uint32_t *__restrict__ scale_B_iter,
|
||||
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
|
||||
int tid = threadIdx.x;
|
||||
if (tid < 256) {
|
||||
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
|
||||
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
|
||||
*(uint32_t *)&smem_scales[tid * 4] = sa;
|
||||
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
|
||||
*
|
||||
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
|
||||
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
|
||||
* from each, producing the packed scale register for scaled MFMA.
|
||||
*
|
||||
* @param smem_scales LDS pointer to scale region.
|
||||
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
|
||||
* @param row_offset Starting row within the scale region (warp's tile offset).
|
||||
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
|
||||
*/
|
||||
__device__ __forceinline__ fp8e8m0_4 pack_scales(
|
||||
const uint8_t *smem_scales, int lds_base, int row_offset) {
|
||||
int lid = laneid();
|
||||
int r16 = lid % 16;
|
||||
int k_sub = lid / 16;
|
||||
|
||||
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
|
||||
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
|
||||
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
|
||||
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
|
||||
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
|
||||
|
||||
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
|
||||
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
|
||||
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
|
||||
|
||||
return (fp8e8m0_4)(lo | (hi << 16));
|
||||
}
|
||||
} // namespace kittens
|
||||
|
|
@ -9,5 +9,6 @@
|
|||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "scales.cuh"
|
||||
|
||||
#include "assembly/tile.cuh"
|
||||
|
|
|
|||
|
|
@ -68,11 +68,11 @@ struct gl {
|
|||
}
|
||||
__host__ __device__ inline gl(const gl &other) :
|
||||
raw_ptr(other.raw_ptr), batch_internal(other.batch_internal), depth_internal(other.depth_internal), rows_internal(other.rows_internal), cols_internal(other.cols_internal), tma_descs(other.tma_descs) {}
|
||||
__device__ inline T& operator[](const coord<ducks::default_type> &idx) const { // yes I am abusing the const qualifier here a bit.
|
||||
return raw_ptr[((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c];
|
||||
__device__ inline T& operator[](const coord<ducks::default_type> &idx) const {
|
||||
return raw_ptr[((int64_t(idx.b)*depth() + idx.d)*rows() + idx.r)*cols() + idx.c];
|
||||
}
|
||||
__device__ inline int idx(const coord<ducks::default_type> &idx) const {
|
||||
return ((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c;
|
||||
__device__ inline int64_t idx(const coord<ducks::default_type> &idx) const {
|
||||
return ((int64_t(idx.b)*depth() + idx.d)*rows() + idx.r)*cols() + idx.c;
|
||||
}
|
||||
template<int axis> __device__ inline size_t shape() const {
|
||||
static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3.");
|
||||
|
|
|
|||
|
|
@ -56,7 +56,9 @@
|
|||
using register_range = _register_range; ///< Register range for this tile.
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2>,
|
||||
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2> ||
|
||||
std::is_same_v<dtype, __hip_fp8x4_e4m3> || std::is_same_v<dtype, __hip_fp8x4_e4m3_fnuz> ||
|
||||
std::is_same_v<dtype, __hip_fp8x4_e5m2> || std::is_same_v<dtype, __hip_fp8x4_e5m2_fnuz>,
|
||||
"art_base was provided an unsupported type."
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ using rt_16x32_4 = rt_shape<16, 32, 4>;
|
|||
using rt_16x128 = rt_shape<16, 128, 16>;
|
||||
|
||||
template<typename T>
|
||||
concept all = std::is_same_v<T, rt_16x16> ||
|
||||
std::is_same_v<T, rt_32x32> ||
|
||||
std::is_same_v<T, rt_32x32_8> ||
|
||||
std::is_same_v<T, rt_16x32> ||
|
||||
std::is_same_v<T, rt_32x16> ||
|
||||
std::is_same_v<T, rt_32x16_4> ||
|
||||
concept all = std::is_same_v<T, rt_16x16> ||
|
||||
std::is_same_v<T, rt_32x32> ||
|
||||
std::is_same_v<T, rt_32x32_8> ||
|
||||
std::is_same_v<T, rt_16x32> ||
|
||||
std::is_same_v<T, rt_32x16> ||
|
||||
std::is_same_v<T, rt_32x16_4> ||
|
||||
std::is_same_v<T, rt_16x32_4> ||
|
||||
std::is_same_v<T, rt_16x128>;
|
||||
|
||||
|
|
@ -59,4 +59,4 @@ concept all = std::is_same_v<T, rt_16x16> ||
|
|||
template<> struct transpose<rt_16x32_4> { using type = rt_32x16_4; };
|
||||
} // namespace rt_shape
|
||||
} // namespace ducks
|
||||
} // namespace kittens
|
||||
} // namespace kittens
|
||||
|
|
|
|||
|
|
@ -90,5 +90,6 @@ concept all = requires {
|
|||
template<size_t _length> using sv_bf = sv<bf16, _length>;
|
||||
template<size_t _length> using sv_hf = sv<half, _length>;
|
||||
template<size_t _length> using sv_fl = sv<float, _length>;
|
||||
template<size_t _length> using sv_fp8e4m3 = sv<fp8e4m3, _length>;
|
||||
|
||||
} // namespace kittens
|
||||
Loading…
Add table
Add a link
Reference in a new issue