update hipkittens (#16544)

This commit is contained in:
wozeparrot 2026-06-08 21:53:25 -04:00 committed by GitHub
commit 5ef30005fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1010 additions and 126 deletions

View file

@ -133,7 +133,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
constexpr int neg_inf_v = 29;
// Move -inf to VGPR neg_inf_v
kittens::macros::clobber_gpr<neg_inf_v>();
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000);
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000);
art<float, DOT_SLICE_QO, WARP_SIZE_KV, col_l, rt_16x16_s, P_ranges> P_ij; // 16 registers
art<float, DOT_SLICE_QO, WARP_SIZE_KV, col_l, rt_16x16_s, dP_ranges> dP_ij; // 16 registers
@ -330,7 +330,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 0
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -588,7 +588,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 1
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -845,7 +845,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 2
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -1101,7 +1101,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 3
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -1371,7 +1371,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 0
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -1632,7 +1632,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 1
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -1889,7 +1889,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 2
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -2145,7 +2145,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 3
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -2410,7 +2410,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 0
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -2671,7 +2671,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 1
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -2927,7 +2927,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 2
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);
@ -3183,7 +3183,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
load<0, 3>(dO_i_col, subtile_inplace<DOT_SLICE_QO, D>(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr);
mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij);
// Dot slice 3
kittens::macros::v_mov_b32<neg_inf_v>(0xff800000); if constexpr (causal) {
kittens::macros::v_mov_b32_up2p<neg_inf_v>(0xff800000); if constexpr (causal) {
// If the query position is less than the key position, set P_ij to -inf
if (q_pos < k_pos) {
mov<neg_inf_v>(P_ij);

View file

@ -216,6 +216,59 @@ template<> __device__ inline bf16_2 relu::op<bf16_2>(const bf16_2 &x) { return _
template<> __device__ inline half relu::op<half> (const half &x ) { return __hmax(x, base_types::constants<half>::zero()); }
template<> __device__ inline half_2 relu::op<half_2>(const half_2 &x) { return half_2{__hmax(x.x, base_types::constants<half>::zero()),
__hmax(x.y, base_types::constants<half>::zero())}; }
constexpr float SQRT_2_OVER_PI = 0.7978845608028654f;
constexpr float GELU_COEFF = 0.044715f;
constexpr float GELU_INNER_COEFF = GELU_COEFF * SQRT_2_OVER_PI;
constexpr float DGELU_COEFF = 3.0f * GELU_COEFF * SQRT_2_OVER_PI;
static __device__ inline float fast_tanh(float x) {
x = fmaxf(fminf(x, 20.f), -20.f);
float e2x = __builtin_amdgcn_exp2f(x * 2.8853900817779268f);
return (e2x - 1.0f) * __frcp_rn(e2x + 1.0f);
}
/**
* @brief Gaussian Error Linear Unit (GELU) activation.
*
* Computes the GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))).
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The GELU activation applied to the input.
*/
struct gelu {
template<typename T> static __device__ inline T op(const T &x);
};
template<> __device__ inline float gelu::op<float>(const float &x) {
return x * (0.5f + 0.5f * fast_tanh(x * (SQRT_2_OVER_PI + GELU_INNER_COEFF * x * x)));
}
template<> __device__ inline float2 gelu::op<float2>(const float2 &x) {
return float2{gelu::op<float>(x.x), gelu::op<float>(x.y)};
}
/**
* @brief Derivative of the GELU activation.
*
* Computes the derivative of the GELU approximation with respect to the input.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The derivative of GELU evaluated at the input.
*/
struct dgelu {
template<typename T> static __device__ inline T op(const T &x);
};
template<> __device__ inline float dgelu::op<float>(const float &x) {
float tanh_out = fast_tanh(SQRT_2_OVER_PI * x * (1.f + GELU_COEFF * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (SQRT_2_OVER_PI + DGELU_COEFF * x * x)) +
0.5f * (1.f + tanh_out);
}
template<> __device__ inline float2 dgelu::op<float2>(const float2 &x) {
return float2{dgelu::op<float>(x.x), dgelu::op<float>(x.y)};
}
/**
* @brief Copy operation.
*

View file

@ -10,14 +10,16 @@
#pragma once
#include <hip_bf16.h>
#include <hip_fp16.h>
#include <hip_fp8.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_fp4.h>
#include <hip/amd_detail/amd_hip_ocp_types.h>
#include <hip/hip_runtime.h>
#include <string>
#include <bit>
typedef uint32_t __amd_fp8x4_storage_t;
namespace kittens {
@ -37,7 +39,6 @@ using bf16_2 = __hip_bfloat162;
* @brief Packed word of two half-precision floating-point values.
*/
using half_2 = __half2;
#ifdef KITTENS_CDNA4
/**
* @brief float8 floating-point type.
*/
@ -50,20 +51,30 @@ using fp8e4m3_2 = __hip_fp8x2_e4m3;
* @brief Packed word of four float8 floating-point values.
*/
using fp8e4m3_4 = __hip_fp8x4_e4m3;
#else
/**
* @brief float8 floating-point type.
* @brief 8-bit exponent-only block-scaling scale type.
*/
using fp8e4m3 = __hip_fp8_e4m3_fnuz;
using fp8e8m0 = __amd_scale_t;
/**
* @brief Packed word of two float8 floating-point values.
* @brief Packed word of two 8-bit exponent-only block-scaling scale values.
*/
using fp8e4m3_2 = __hip_fp8x2_e4m3_fnuz;
using fp8e8m0_2 = __amd_fp8x2_storage_t;
/**
* @brief Packed word of four float8 floating-point values.
* @brief Packed word of four 8-bit exponent-only block-scaling scale values.
*/
using fp8e4m3_4 = __hip_fp8x4_e4m3_fnuz;
#endif
using fp8e8m0_4 = __amd_fp8x4_storage_t;
/**
* @brief FP4 E2M1 floating-point type.
*/
using fp4e2m1 = __hip_fp4_e2m1;
/**
* @brief Packed word of two FP4 E2M1 floating-point values.
*/
using fp4e2m1_2 = __hip_fp4x2_e2m1;
/**
* @brief Packed word of four FP4 E2M1 floating-point values.
*/
using fp4e2m1_4 = __hip_fp4x4_e2m1;
namespace ducks {
/**
@ -74,9 +85,11 @@ namespace ducks {
namespace base_types {
template<typename T>
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>;
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4>
|| std::is_same_v<T, fp4e2m1_4>;
template<typename T>
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>;
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3>
|| std::is_same_v<T, fp4e2m1>;
} // namespace base_types
} // namespace ducks
@ -157,6 +170,26 @@ template<> struct constants<fp8e4m3_4> {
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
};
template<> struct constants<fp8e8m0> {
static __device__ inline constexpr fp8e8m0 zero() { return std::bit_cast<fp8e8m0>(uint8_t(0x00)); } // not actually 0
static __device__ inline constexpr fp8e8m0 one() { return std::bit_cast<fp8e8m0>(uint8_t(0x7F)); }
};
template<> struct constants<fp8e8m0_2> {
static __device__ inline constexpr fp8e8m0_2 zero() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x0000)); } // not actually 0
static __device__ inline constexpr fp8e8m0_2 one() { return std::bit_cast<fp8e8m0_2>(uint16_t(0x7F7F)); }
};
template<> struct constants<fp8e8m0_4> {
static __device__ inline constexpr fp8e8m0_4 zero() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x00000000)); } // not actually 0
static __device__ inline constexpr fp8e8m0_4 one() { return std::bit_cast<fp8e8m0_4>(uint32_t(0x7F7F7F7F)); }
};
template<> struct constants<fp4e2m1> {
static __device__ inline constexpr fp4e2m1 zero() { return std::bit_cast<fp4e2m1>(uint8_t(0x00)); }
static __device__ inline constexpr fp4e2m1 one() { return std::bit_cast<fp4e2m1>(uint8_t(0x02)); }
};
template<> struct constants<fp4e2m1_4> {
static __device__ inline constexpr fp4e2m1_4 zero() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x0000)); }
static __device__ inline constexpr fp4e2m1_4 one() { return std::bit_cast<fp4e2m1_4>(uint16_t(0x2222)); }
};
template<> struct constants<int> {
static __device__ inline constexpr int zero() { return 0; }
static __device__ inline constexpr int ones() { return 1; }
@ -250,6 +283,26 @@ template<> struct packing<fp8e4m3_4> {
using unpacked_type = fp8e4m3;
using packed_type = fp8e4m3_4;
};
template<> struct packing<fp8e8m0> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp8e8m0;
using packed_type = fp8e8m0_4;
};
template<> struct packing<fp8e8m0_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp8e8m0;
using packed_type = fp8e8m0_4;
};
template<> struct packing<fp4e2m1> {
static __host__ __device__ inline constexpr int num() { return 1; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
template<> struct packing<fp4e2m1_4> {
static __host__ __device__ inline constexpr int num() { return 4; }
using unpacked_type = fp4e2m1;
using packed_type = fp4e2m1_4;
};
/**
* @brief Provides templated functionality to convert between different types.
@ -377,5 +430,25 @@ template<> struct convertor<float, fp8e4m3> {
return float(u);
}
};
template<> struct convertor<fp4e2m1, float> {
static __host__ __device__ inline fp4e2m1 convert(const float & u) {
return fp4e2m1(u);
}
};
template<> struct convertor<float, fp4e2m1> {
static __host__ __device__ inline float convert(const fp4e2m1 & u) {
return float(u);
}
};
template<> struct convertor<fp4e2m1_4, float4> {
static __host__ __device__ inline fp4e2m1_4 convert(const float4& u) {
return fp4e2m1_4(u);
}
};
template<> struct convertor<float4, fp4e2m1_4> {
static __host__ __device__ inline float4 convert(const fp4e2m1_4& u) {
return float4(u);
}
};
}
}

View file

@ -158,152 +158,405 @@ __device__ __forceinline__ void clobber_gpr() {
#undef CLOBBER_AREG_CASE
#undef CLOBBER_VREG_CASE
template<int GPR_START>
__device__ __forceinline__ void ds_read_b128(const uint32_t smem_ptr, const int offset) {
__device__ __forceinline__ constexpr uint32_t max_ds_inst_offset()
{
// DS ops contain 2 8-bits instruction offset.
// For non-pk2 instructions like ds_read_b32, the 2 fields are regarded as 1.
// For pk2 instructions like ds_read2_b32, max offset is limited by 8 bits.
return (1u << 16) - 1;
}
__device__ __forceinline__ constexpr uint32_t max_ds_pk2_inst_offset()
{
// DS ops contain 2 8-bits instruction offset.
// For non-pk2 instructions like ds_read_b32, the 2 fields are regarded as a whole.
// For pk2 instructions like ds_read2_b32, max offset is limited by 8 bits.
return (1u << 8) - 1;
}
__device__ __forceinline__ constexpr uint32_t max_mubuf_inst_offset()
{
// MUBUF ops contain 1 12-bits instruction offset.
return (1u << 12) - 1;
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b32(const uint32_t smem_ptr, const int i_offset) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b32 a[%0], %1 offset:%2"
:
: "n"(GPR_START - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("ds_read_b32 v[%0], %1 offset:%2"
:
: "n"(GPR_START), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template <typename T>
__device__ __forceinline__ void ds_read_b32(T& dst, const uint32_t smem_ptr, const int i_offset) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(dst)
: "v"(smem_ptr), "i"(i_offset)
: "memory");
}
template <typename T = u32x2>
__device__ __forceinline__ T ds_read_b64(const uint32_t smem_ptr, const int i_offset) {
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
T result;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(result)
: "v"(smem_ptr), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b64(const uint32_t smem_ptr, const int i_offset) {
constexpr int GPR_END = GPR_START + 1;
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b64 a[%0:%1], %2 offset:%3"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("ds_read_b64 v[%0:%1], %2 offset:%3"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b64_tr_b4(const uint32_t smem_ptr, const int i_offset) {
constexpr int GPR_END = GPR_START + 1;
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b64_tr_b4 a[%0:%1], %2 offset:%3"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_read_b64_tr_b4 v[%0:%1], %2 offset:%3"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b64_tr_b8(const uint32_t smem_ptr, const int i_offset) {
constexpr int GPR_END = GPR_START + 1;
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b64_tr_b8 a[%0:%1], %2 offset:%3"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_read_b64_tr_b8 v[%0:%1], %2 offset:%3"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b64_tr_b16(const uint32_t smem_ptr, const int i_offset) {
constexpr int GPR_END = GPR_START + 1;
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b64_tr_b16 a[%0:%1], %2 offset:%3"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_read_b64_tr_b16 v[%0:%1], %2 offset:%3"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template <typename T = u32x4>
__device__ __forceinline__ T ds_read_b128(const uint32_t smem_ptr, const int i_offset) {
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
T result;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(result)
: "v"(smem_ptr), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b128(const uint32_t smem_ptr, const int i_offset) {
constexpr int GPR_END = GPR_START + 3;
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b128 a[%0:%1], %2 offset:%3"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(offset)
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("ds_read_b128 v[%0:%1], %2 offset:%3"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(offset)
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void ds_read_b64_tr_b16(const uint32_t smem_ptr, const int offset) {
constexpr int GPR_END = GPR_START + 1;
__device__ __forceinline__ void ds_write_b32(const uint32_t smem_ptr, const int i_offset) {
if constexpr (GPR_START >= 256) {
asm volatile("ds_read_b64_tr_b16 a[%0:%1], %2 offset:%3"
asm volatile("ds_write_b32 %0, a[%1], offset:%2"
:
: "n"(GPR_START - 256), "n"(GPR_END - 256), "v"(smem_ptr), "i"(offset)
: "v"(smem_ptr), "n"(GPR_START - 256), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_read_b64_tr_b16 v[%0:%1], %2 offset:%3"
asm volatile("ds_write_b32 %0, v[%1], offset:%2"
:
: "n"(GPR_START), "n"(GPR_END), "v"(smem_ptr), "i"(offset)
: "v"(smem_ptr), "n"(GPR_START), "i"(i_offset)
: "memory");
}
}
template <typename T>
__device__ __forceinline__ void ds_write_b32(const T& val, const uint32_t smem_ptr, const int i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("ds_write_b32 %0, %1 offset:%2"
:
: "v"(smem_ptr), "v"(val), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void ds_write_b64(const uint32_t smem_ptr, const int offset) {
__device__ __forceinline__ void ds_write_b64(const uint32_t smem_ptr, const int i_offset) {
if constexpr (GPR_START >= 256) {
asm volatile("ds_write_b64 %0, a[%1:%2], offset:%3"
:
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "i"(offset)
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_write_b64 %0, v[%1:%2], offset:%3"
:
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 1), "i"(offset)
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 1), "i"(i_offset)
: "memory");
}
}
template <typename T>
__device__ __forceinline__ void ds_write_b64(const T& val, const uint32_t smem_ptr, const int i_offset = 0) {
static_assert(sizeof(T) == 2 * sizeof(uint32_t));
asm volatile("ds_write_b64 %0, %1 offset:%2"
:
: "v"(smem_ptr), "v"(val), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void ds_write_b128(const uint32_t smem_ptr, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("ds_write_b128 %0, a[%1:%2], offset:%3"
:
: "v"(smem_ptr), "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "i"(i_offset)
: "memory");
} else {
asm volatile("ds_write_b128 %0, v[%1:%2], offset:%3"
:
: "v"(smem_ptr), "n"(GPR_START), "n"(GPR_START + 3), "i"(i_offset)
: "memory");
}
}
template<typename T>
__device__ __forceinline__ void ds_write_b128(const T& value, const uint32_t smem_ptr, const int i_offset = 0) {
static_assert(sizeof(T) == sizeof(u32x4));
asm volatile("ds_write_b128 %0, %1 offset:%2"
:
: "v"(smem_ptr), "v"(value), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dword a[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR_START - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dword v[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR_START), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = uint32_t>
__device__ __forceinline__ T buffer_load_dword(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t));
T result;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx2(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx2 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dwordx2 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 1), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = u32x2>
__device__ __forceinline__ T buffer_load_dwordx2(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
T result;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx4 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_load_dwordx4 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 3), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<typename T = u32x4>
__device__ __forceinline__ T buffer_load_dwordx4(
const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
T result;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "=v"(result)
: "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
return result;
}
template<int GPR>
__device__ __forceinline__ void buffer_store_dword(buffer_resource& br, const uint32_t byte_offset) {
__device__ __forceinline__ void buffer_store_dword(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR >= 256) {
asm volatile("buffer_store_dword a[%0], %1, %2, 0 offen"
asm volatile("buffer_store_dword a[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dword v[%0], %1, %2, 0 offen"
asm volatile("buffer_store_dword v[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx2(buffer_resource& br, const uint32_t byte_offset) {
template<typename T = u32x2>
__device__ __forceinline__ void buffer_store_dword(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
asm volatile("buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx2(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("buffer_store_dwordx2 a[%0:%1], %2, %3, 0 offen"
asm volatile("buffer_store_dwordx2 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dwordx2 v[%0:%1], %2, %3, 0 offen"
asm volatile("buffer_store_dwordx2 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 1), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR_START), "n"(GPR_START + 1), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx4(buffer_resource& br, const uint32_t byte_offset) {
template<typename T = u32x2>
__device__ __forceinline__ void buffer_store_dwordx2(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t) * 2);
asm volatile("buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR_START>
__device__ __forceinline__ void buffer_store_dwordx4(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
// AGPRS
if constexpr (GPR_START >= 256) {
asm volatile("buffer_store_dwordx4 a[%0:%1], %2, %3, 0 offen"
asm volatile("buffer_store_dwordx4 a[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
// VGPRS
} else {
asm volatile("buffer_store_dwordx4 v[%0:%1], %2, %3, 0 offen"
asm volatile("buffer_store_dwordx4 v[%0:%1], %2, %3, %4 offen offset:%5"
:
: "n"(GPR_START), "n"(GPR_START + 3), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR_START), "n"(GPR_START + 3), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx4(buffer_resource& br, const uint32_t byte_offset) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx4 a[%0:%1], %2, %3, 0 offen offset:%4"
:
: "n"(GPR_START - 256), "n"(GPR_START + 3 - 256), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
: "memory");
} else {
asm volatile("buffer_load_dwordx4 v[%0:%1], %2, %3, 0 offen offset:%4"
:
: "n"(GPR_START), "n"(GPR_START + 3), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
: "memory");
}
}
template<int GPR_START>
__device__ __forceinline__ void buffer_load_dwordx2(buffer_resource& br, const uint32_t byte_offset) {
if constexpr (GPR_START >= 256) {
asm volatile("buffer_load_dwordx2 a[%0:%1], %2, %3, 0 offen offset:%4"
:
: "n"(GPR_START - 256), "n"(GPR_START + 1 - 256), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
: "memory");
} else {
asm volatile("buffer_load_dwordx2 v[%0:%1], %2, %3, 0 offen offset:%4"
:
: "n"(GPR_START), "n"(GPR_START + 1), "v"(byte_offset), "s"(*(i32x4*)&br), "i"(0)
: "memory");
}
template<typename T = u32x4>
__device__ __forceinline__ void buffer_store_dwordx4(
const T& value, const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const uint32_t i_offset = 0) {
static_assert(sizeof(T) == sizeof(uint32_t) * 4);
asm volatile("buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
: "v"(value), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
template<int GPR>
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(buffer_resource& br, const uint32_t byte_offset) {
__device__ __forceinline__ void buffer_atomic_pk_add_bf16(const buffer_resource& br, const uint32_t v_offset, const uint32_t s_offset = 0, const int i_offset = 0) {
if constexpr (GPR >= 256) {
asm volatile("buffer_atomic_pk_add_bf16 a[%0], %1, %2, 0 offen"
asm volatile("buffer_atomic_pk_add_bf16 a[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR - 256), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR - 256), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
} else {
asm volatile("buffer_atomic_pk_add_bf16 v[%0], %1, %2, 0 offen"
asm volatile("buffer_atomic_pk_add_bf16 v[%0], %1, %2, %3 offen offset:%4"
:
: "n"(GPR), "v"(byte_offset), "s"(*(i32x4*)&br)
: "n"(GPR), "v"(v_offset), "s"(*(const i32x4*)&br), "s"(s_offset), "i"(i_offset)
: "memory");
}
}
@ -468,6 +721,75 @@ __device__ __forceinline__ void mfma_f32_32x32x16_bf16() {
}
}
template<int GPR_START_A, int GPR_START_B, int GPR_START_C, int GPR_START_D>
__device__ __forceinline__ void mfma_f32_16x16x32_fp8_fp8() {
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], a[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C - 256), "n"(GPR_START_C + 3 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256 && GPR_START_C < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
} else {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], v[%6:%7]"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1), "n"(GPR_START_C), "n"(GPR_START_C + 3));
}
}
template<int GPR_START_A, int GPR_START_B, int GPR_START_D>
__device__ __forceinline__ void mfma_f32_16x16x32_bf16_zero_accum() {
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
@ -542,6 +864,43 @@ __device__ __forceinline__ void mfma_f32_32x32x16_bf16_zero_accum() {
}
}
template<int GPR_START_A, int GPR_START_B, int GPR_START_D>
__device__ __forceinline__ void mfma_f32_16x16x32_fp8_fp8_zero_accum() {
if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], a[%4:%5], 0"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], a[%4:%5], 0"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], a[%4:%5], 0"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A >= 256 && GPR_START_B < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], a[%2:%3], v[%4:%5], 0"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1));
} else if constexpr (GPR_START_D < 256 && GPR_START_A >= 256 && GPR_START_B < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], a[%2:%3], v[%4:%5], 0"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A - 256), "n"(GPR_START_A + 1 - 256), "n"(GPR_START_B), "n"(GPR_START_B + 1));
} else if constexpr (GPR_START_D < 256 && GPR_START_A < 256 && GPR_START_B >= 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], a[%4:%5], 0"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B - 256), "n"(GPR_START_B + 1 - 256));
} else if constexpr (GPR_START_D >= 256 && GPR_START_A < 256 && GPR_START_B < 256) {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 a[%0:%1], v[%2:%3], v[%4:%5], 0"
:
: "n"(GPR_START_D - 256), "n"(GPR_START_D + 3 - 256), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1));
} else {
asm volatile("v_mfma_f32_16x16x32_fp8_fp8 v[%0:%1], v[%2:%3], v[%4:%5], 0"
:
: "n"(GPR_START_D), "n"(GPR_START_D + 3), "n"(GPR_START_A), "n"(GPR_START_A + 1), "n"(GPR_START_B), "n"(GPR_START_B + 1));
}
}
template<int GPR0_START, int GPR1_START, int GPR>
__device__ __forceinline__ void v_subrev_f32_dpp() {
@ -592,11 +951,29 @@ __device__ __forceinline__ void v_accvgpr_read_b32() {
: "n"(GPR0), "n"(GPR1 - 256));
}
template<int GPR>
__device__ __forceinline__ void v_mov_b32(const uint32_t value) {
template<int GPR, typename T>
__device__ __forceinline__ void v_mov_b32_up2p(const T value) {
static_assert(sizeof(T) == sizeof(uint32_t));
asm volatile("v_mov_b32 v[%0], %1"
:
: "n"(GPR), "i"(value));
: "n"(GPR), "v"(value));
}
template <int GPR, typename T = uint32_t>
__device__ __forceinline__ T v_mov_b32_p2up() {
static_assert(sizeof(T) == sizeof(uint32_t));
T r;
if constexpr (GPR < 256) {
asm volatile("v_mov_b32 %0, v[%1]"
: "=v"(r)
: "n"(GPR));
}
else {
asm volatile("v_accvgpr_read_b32 %0, a[%1]"
: "=v"(r)
: "n"(GPR - 256));
}
return r;
}
template<int GPR0, int GPR1>
@ -612,8 +989,9 @@ __device__ __forceinline__ void v_cndmask_b32_e64(uint64_t mask) {
:
: "n"(GPR0), "n"(GPR1), "n"(GPR2), "s"(mask));
}
/**
* @brief Multiplication operation on explicit registers.
* @brief Multiplication operation on explicit registers and immediate operand.
*/
struct mul {
template<int GPR0, int GPR1>
@ -628,6 +1006,12 @@ struct mul {
}
}
template<int GPR0, int GPR1>
static __device__ inline void op_pk2(const float &param) {
op<GPR0, GPR1>(param);
op<GPR0 + 1, GPR1 + 1>(param);
}
template<int GPR0, int GPR1, int GPR2>
static __device__ inline void op() {
if constexpr (GPR0 < 256 && GPR1 < 256 && GPR2 < 256) {
@ -638,8 +1022,44 @@ struct mul {
static_assert(false, "Invalid operand for instruction: v_mul_f32_e32");
}
}
template<int GPR0, int GPR1, int GPR2>
static __device__ inline void op_pk2() {
if constexpr (GPR0 < (256 - 1) && GPR1 < (256 - 1) && GPR2 < (256 - 1)) {
asm volatile("v_pk_mul_f32 v[%0:%1], v[%4:%5], v[%2:%3]"
:
: "n"(GPR0), "n"(GPR0 + 1), "n"(GPR1), "n"(GPR1 + 1), "n"(GPR2), "n"(GPR2 + 1));
} else {
static_assert(false, "Invalid operand for instruction: v_pk_mul_f32");
}
}
};
struct mul_vgpr {
template<int GPR0, int GPR1>
static __device__ inline void op(const float &param) {
if constexpr (GPR0 < 256 && GPR1 < 256) {
asm volatile("v_mul_f32_e32 v[%0], %2, v[%1]"
:
: "n"(GPR0), "n"(GPR1), "v"(param));
} else {
static_assert(false, "Invalid operand for instruction: v_mul_f32_e32");
}
}
template<int GPR0, int GPR1>
static __device__ inline void op_pk2(const float &param) {
if constexpr (GPR0 < (256 - 1) && GPR1 < (256 - 1)) {
const float2 param2 = {param, param};
asm volatile("v_pk_mul_f32 v[%0:%1], %4, v[%2:%3]"
:
: "n"(GPR0), "n"(GPR0 + 1), "n"(GPR1), "n"(GPR1 + 1), "v"(param2));
} else {
static_assert(false, "Invalid operand for instruction: v_pk_mul_f32");
}
}
};
struct exp2 {
template<int GPR0, int GPR1>
static __device__ inline void op() {
@ -669,4 +1089,4 @@ struct zero {
};
} // namespace macros
} // namespace kittens
} // namespace kittens

View file

@ -50,7 +50,11 @@ __device__ __forceinline__ int warpid() { return threadIdx.x >> 6; }
*/
__device__ __forceinline__ int laneid() { return threadIdx.x & 0x3f; }
using i32x2 = int32_t __attribute__((ext_vector_type(2)));
using u32x2 = uint32_t __attribute__((ext_vector_type(2)));
using i32x4 = int32_t __attribute__((ext_vector_type(4)));
using u32x4 = uint32_t __attribute__((ext_vector_type(4)));
struct buffer_resource {
uint64_t ptr;
uint32_t range;

View file

@ -21,7 +21,7 @@
* @param idx[in] The index of the tile to load data from.
*/
template<int axis, ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
template<int axis, int elem_offset=0, ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx, const COORD &warp_idx) {
using T2 = RT::dtype;
constexpr int packing = base_types::packing<typename RT::dtype>::num();
@ -42,22 +42,48 @@
buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000);
int warp_offset = src.idx(warp_idx.template unit_coord<axis, 3>());
int thr_offset = (row_offset * row_stride + col_offset + warp_offset) * sizeof(U);
// Compile-time loop to load data into the tile
auto perform_load_at = [&]<int N, int M, int K>() {
using tile_range = ducks::art::get_nth_range_t<typename RT::register_ranges, N * RT::width + M>;
const int register_offset = K * RT::registers_per_stride;
constexpr int col = RT::base_tile_cols*M + K * RT::base_tile_elements_per_stride_group;
constexpr int row = RT::base_tile_rows*N;
const int k_row_offset = row * row_stride * sizeof(U);
const int col = RT::base_tile_cols*M + col_offset + K * RT::base_tile_elements_per_stride_group;
const int row = RT::base_tile_rows*N + row_offset;
const int offset = (row*row_stride + col + warp_offset) * sizeof(U);
if constexpr (std::is_same_v<U2, bf16_2>) {
if constexpr (RT::base_tile_stride == 8) {
macros::buffer_load_dwordx4<tile_range::lo + register_offset>(br, offset);
} else if constexpr (RT::base_tile_stride == 4) {
macros::buffer_load_dwordx2<tile_range::lo + register_offset>(br, offset);
}
constexpr int stride_in_bytes = RT::base_tile_stride * sizeof(U);
constexpr int offset_in_bytes = (elem_offset + col) * sizeof(U);
constexpr int start_gpr = tile_range::lo + register_offset;
if constexpr (offset_in_bytes <= macros::max_mubuf_inst_offset()) {
if constexpr (stride_in_bytes == (sizeof(int32_t) * 4)) {
macros::buffer_load_dwordx4<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
}
else if constexpr (stride_in_bytes == (sizeof(int32_t) * 2)) {
macros::buffer_load_dwordx2<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
}
else if constexpr (stride_in_bytes == sizeof(int32_t)) {
macros::buffer_load_dword<start_gpr>(br, thr_offset + k_row_offset, 0, offset_in_bytes);
}
else {
static_assert(false, "Encounter unsupported format in ops/warp/memory/tile/assembly/global_to_register.cuh\n");
}
}
else {
if constexpr (stride_in_bytes == (sizeof(int32_t) * 4)) {
macros::buffer_load_dwordx4<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
}
else if constexpr (stride_in_bytes == (sizeof(int32_t) * 2)) {
macros::buffer_load_dwordx2<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
}
else if constexpr (stride_in_bytes == sizeof(int32_t)) {
macros::buffer_load_dword<start_gpr>(br, thr_offset + offset_in_bytes + k_row_offset, 0, 0);
}
else {
static_assert(false, "Encounter unsupported format in ops/warp/memory/tile/assembly/global_to_register.cuh\n");
}
}
};
@ -74,12 +100,11 @@
}(std::make_index_sequence<RT::width>{});
}.template operator()<Ns>(), ...);
}(std::make_index_sequence<RT::height>{});
}
template<ducks::art::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<RT>>
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx, const COORD &warp_idx) {
load<2, RT, GL>(dst, src, idx, warp_idx);
load<2, 0, RT, GL>(dst, src, idx, warp_idx);
}
/**

View file

@ -226,7 +226,8 @@ __device__ inline void load(ST& dst, const GL& src, const COORD& idx, const uint
if (warpid < leftover_warps) {
uintptr_t lds_addr = lds_base + (memcpy_per_tile * num_warps * bytes_per_warp);
const T* lds_elem_ptr = lds_base + (memcpy_per_tile * num_warps * elements_per_warp);
uintptr_t lds_addr = reinterpret_cast<uintptr_t>(lds_elem_ptr);
as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_addr);
llvm_amdgcn_raw_buffer_load_lds(
@ -414,4 +415,4 @@ template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
store<2, false, ST, GL, COORD, WARP_THREADS>(dst, src, idx);
}
}
}

View file

@ -134,7 +134,17 @@ __device__ static inline void bin_map(T0 &dst, const T1 &src, const typename bas
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
([&]<std::size_t R>() {
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo>(param);
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
if constexpr ((R % 2) == 0) {
if constexpr ((R + 1) < registers_T0::size) {
op::template op_pk2<GPR0, GPR1>(param);
}
else {
op::template op<GPR0, GPR1>(param);
}
}
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
}.template operator()<Rs>(), ...);
}(std::make_index_sequence<registers_T0::size>{});
}
@ -156,7 +166,17 @@ __device__ static inline void bin_map(T0 &dst, const T1 &src, const typename bas
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
([&]<std::size_t R>() {
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo>(param);
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
if constexpr ((R % 2) == 0) {
if constexpr ((R + 1) < registers_T0::size) {
op::template op_pk2<GPR0, GPR1>(param);
}
else {
op::template op<GPR0, GPR1>(param);
}
}
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
}.template operator()<Rs>(), ...);
}(std::make_index_sequence<registers_T0::size>{});
};
@ -205,7 +225,18 @@ __device__ static inline void bin_map(T0 &dst, const T1 &lhs, const T2 &rhs) {
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
([&]<std::size_t R>() {
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo, ducks::art::get_nth_range_t<registers_T2, R>::lo>();
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
constexpr int GPR2 = ducks::art::get_nth_range_t<registers_T2, R>::lo;
if constexpr ((R % 2) == 0) {
if constexpr ((R + 1) < registers_T0::size) {
op::template op_pk2<GPR0, GPR1, GPR2>();
}
else {
op::template op<GPR0, GPR1, GPR2>();
}
}
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
}.template operator()<Rs>(), ...);
}(std::make_index_sequence<registers_T0::size>{});
}
@ -234,7 +265,18 @@ __device__ static inline void bin_map(T0 &dst, const T1 &lhs, const T2 &rhs) {
[&]<std::size_t... Rs>(std::index_sequence<Rs...>) {
([&]<std::size_t R>() {
op::template op<ducks::art::get_nth_range_t<registers_T0, R>::lo, ducks::art::get_nth_range_t<registers_T1, R>::lo, ducks::art::get_nth_range_t<registers_T2, R>::lo>();
constexpr int GPR0 = ducks::art::get_nth_range_t<registers_T0, R>::lo;
constexpr int GPR1 = ducks::art::get_nth_range_t<registers_T1, R>::lo;
constexpr int GPR2 = ducks::art::get_nth_range_t<registers_T2, R>::lo;
if constexpr ((R % 2) == 0) {
if constexpr ((R + 1) < registers_T0::size) {
op::template op_pk2<GPR0, GPR1, GPR2>();
}
else {
op::template op<GPR0, GPR1, GPR2>();
}
}
// Odd indices are skipped because they're processed by op_pk2 in the previous even iteration
}.template operator()<Rs>(), ...);
}(std::make_index_sequence<registers_T0::size>{});
};
@ -364,6 +406,16 @@ __device__ static inline void mul(T0 &dst, const T1 &lhs, const U &rhs) {
bin_map<macros::mul, T0, T1>(dst, lhs, rhs);
}
template<ducks::art::all T0, ducks::art::all T1, typename U>
__device__ static inline void mul_vgpr(T0 &dst, const T1 &lhs, const U &rhs) {
bin_map<macros::mul_vgpr, T0, T1>(dst, lhs, rhs);
}
template<int N, int M, ducks::art::all T0, ducks::art::all T1, typename U>
__device__ static inline void mul_vgpr(T0 &dst, const T1 &lhs, const U &rhs) {
bin_map<N, M, macros::mul_vgpr, T0, T1>(dst, lhs, rhs);
}
/**
* @brief Subtracts row values from each row of a tile.
*

View file

@ -20,14 +20,44 @@ namespace kittens {
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
*/
template<typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeC, typename RegisterRangeD>
template<typename AccumulatorShape, typename InputType, typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeC, typename RegisterRangeD>
__device__ static inline void mma_ABt_base() {
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
if constexpr (std::is_same_v<AccumulatorShape, ducks::rt_shape::rt_16x16>)
{
if constexpr (std::is_same_v<InputType, fp8e4m3>)
{
macros::mfma_f32_16x16x32_fp8_fp8<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
}
else
{
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
}
}
else
{
macros::mfma_f32_16x16x32_bf16<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeC::lo, RegisterRangeD::lo>();
}
}
template<typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeD>
template<typename AccumulatorShape, typename InputType, typename RegisterRangeA, typename RegisterRangeB, typename RegisterRangeD>
__device__ static inline void mma_ABt_base_zero_accum() {
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
if constexpr (std::is_same_v<AccumulatorShape, ducks::rt_shape::rt_16x16>)
{
if constexpr (std::is_same_v<InputType, fp8e4m3>)
{
macros::mfma_f32_16x16x32_fp8_fp8_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
}
else
{
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
}
}
else
{
macros::mfma_f32_16x16x32_bf16_zero_accum<RegisterRangeA::lo, RegisterRangeB::lo, RegisterRangeD::lo>();
}
}
/**
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
*
@ -87,7 +117,9 @@ __device__ static inline void mma_ABt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half>)
std::is_same_v<typename B::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3>)
);
// Helper function template for compile-time MMA operations
@ -95,7 +127,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width + K>;
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
}
template<ducks::art::all D, ducks::art::all A, ducks::art::all B, ducks::art::all C>
@ -117,7 +149,9 @@ __device__ static inline void mma_ABt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half>)
std::is_same_v<typename B::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3>)
);
// Helper function template for compile-time MMA operations
@ -127,7 +161,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width>;
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
// Subsequent MMA operations for k=1 to A::width-1
[&]<std::size_t... Ks>(std::index_sequence<Ks...>) {
@ -138,7 +172,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, k + M * B::width>;
using range_type_C = ducks::art::get_nth_range_t<typename C::register_ranges, N * C::width + M>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base<range_type_A, range_type_B, range_type_C, range_type_D>();
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_C, range_type_D>();
}
}(), ...);
}(std::make_index_sequence<A::width>{});
@ -172,7 +206,9 @@ __device__ static inline void mma_ABt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half>)
std::is_same_v<typename B::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3>)
);
// Helper function template for compile-time MMA operations
@ -180,7 +216,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, N * A::width + K>;
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width + K>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base_zero_accum<range_type_A, range_type_B, range_type_D>();
mma_ABt_base_zero_accum<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D>();
}
template<ducks::art::all D, ducks::art::all A, ducks::art::all B>
@ -199,7 +235,9 @@ __device__ static inline void mma_ABt(D &d,
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half>)
std::is_same_v<typename B::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3>)
);
// Helper function template for compile-time MMA operations
@ -208,7 +246,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, N * A::width>;
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, M * B::width>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base_zero_accum<range_type_A, range_type_B, range_type_D>();
mma_ABt_base_zero_accum<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D>();
// Subsequent MMA operations for k=1 to A::width-1
[&]<std::size_t... Ks>(std::index_sequence<Ks...>) {
@ -218,7 +256,7 @@ __device__ static inline void mma_ABt(D &d,
using range_type_A = ducks::art::get_nth_range_t<typename A::register_ranges, k + N * A::width>;
using range_type_B = ducks::art::get_nth_range_t<typename B::register_ranges, k + M * B::width>;
using range_type_D = ducks::art::get_nth_range_t<typename D::register_ranges, N * D::width + M>;
mma_ABt_base<range_type_A, range_type_B, range_type_D, range_type_D>();
mma_ABt_base<typename D::shape, typename A::T, range_type_A, range_type_B, range_type_D, range_type_D>();
}
}(), ...);
}(std::make_index_sequence<A::width>{});

View file

@ -472,6 +472,28 @@ template<ducks::rt::all T>
__device__ static inline void relu(T &dst, const T &src) {
unary_map<base_ops::relu, T>(dst, src);
}
/**
* @brief Applies the GELU function (tanh approximation) to each element of a tile.
*
* @tparam T Tile type.
* @param dst[out] Destination tile where the result is stored.
* @param src[in] Source tile to apply the GELU function on.
*/
template<ducks::rt::all T>
__device__ static inline void gelu(T &dst, const T &src) {
unary_map<base_ops::gelu, T>(dst, src);
}
/**
* @brief Applies the GELU derivative to each element of a tile.
*
* @tparam T Tile type.
* @param dst[out] Destination tile where the result is stored.
* @param src[in] Source tile to apply the GELU derivative on.
*/
template<ducks::rt::all T>
__device__ static inline void dgelu(T &dst, const T &src) {
unary_map<base_ops::dgelu, T>(dst, src);
}
/**
* @brief Copies the elements from one tile to another.
*

View file

@ -124,6 +124,28 @@ __device__ static inline void mfma1616128( float2 (&D)[2],
)};
}
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0>
__device__ static inline void mfma1616128_scaled( float2 (&D)[2],
const fp8e4m3_4 (&A)[8],
const fp8e4m3_4 (&B)[8],
const float2 (&C)[2],
const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) {
typedef __attribute__((__vector_size__(8 * sizeof(int)))) int intx8_t;
typedef __attribute__((__vector_size__(4 * sizeof(float)))) float floatx4_t;
*(floatx4_t*)D = {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
*(intx8_t*)A,
*(intx8_t*)B,
*(floatx4_t*)C,
cbsz, // cbsz: 0=fp8(e4m3) A, 1=bf8(e5m2) A
blgp, // blgp: 0=fp8(e4m3) B, 1=bf8(e5m2) B
opsel_a, // opsel_a
*scale_a, // scale_a
opsel_b, // opsel_b
*scale_b // scale_b
)};
}
/**
* @brief Base matrix multiply-accumulate operation for row layout.
@ -221,6 +243,46 @@ __device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::col,
}
}
/**
* @brief Base dot product operation for row layout.
*
* This function performs the base dot product operation
* for block-scaled matrices in row layout.
*
* @param[out] d The output rt_base<float, col_layout> accumulator.
* @param[in] a The first input rt_base<Operand_T, row_layout> matrix.
* @param[in] b The second input rt_base<Operand_T, row_layout> matrix.
* @param[in] c The input rt_base<float, col_layout> accumulator matrix.
*/
template<int opsel_a, int opsel_b, int cbsz = 0, int blgp = 0, ducks::rt_shape::all D_shape, ducks::rt_shape::all A_shape, ducks::rt_shape::all B_shape, ducks::rt_shape::all C_shape, typename MM_Operand_T>
__device__ static inline void mma_ABt_base_scaled(rt_base<float, ducks::rt_layout::col, D_shape> &d,
const rt_base<MM_Operand_T, ducks::rt_layout::row, A_shape> &a,
const rt_base<MM_Operand_T, ducks::rt_layout::row, B_shape> &b,
const rt_base<float, ducks::rt_layout::col, C_shape> &c,
const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) {
static_assert(std::is_same_v<D_shape, C_shape>, "D and C must have the same shape");
constexpr int A_rows = A_shape::rows;
constexpr int A_cols = A_shape::cols;
constexpr int B_rows = B_shape::rows;
constexpr int B_cols = B_shape::cols;
constexpr int A_stride = A_shape::stride;
constexpr int B_stride = B_shape::stride;
static_assert(A_stride == B_stride, "A and B must have the same stride");
if constexpr (std::is_same_v<D_shape, typename ducks::rt_shape::rt_16x16> &&
A_rows == 16 && A_cols == 128 &&
B_rows == 16 && B_cols == 128 &&
std::is_same_v<C_shape, typename ducks::rt_shape::rt_16x16>) {
mfma1616128_scaled<opsel_a, opsel_b, cbsz, blgp>(d.data, a.data, b.data, c.data, scale_a, scale_b);
} else {
static_assert(false, "Unsupported shape combination");
}
}
/**
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
*
@ -420,6 +482,62 @@ __device__ static inline void mma_ABt(D &d,
}
}
}
/**
* @brief Block scaled dot product operation for row layout.
*
* This function performs the dot product operation
* for block-scaled matrices in row layout.
*
* @tparam N The number of row tiles.
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
* @tparam M The number of column tiles for the B matrix.
* @param[out] d The output rt_fl<N, M, col_layout> accumulator.
* @param[in] a The first input rt_bf<N, K, row_layout> matrix.
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
* @param[in] c The input rt_fl<N, M, col_layout> accumulator matrix.
* @param[in] scale_a Pointer to the packed E8M0 scale for the A matrix.
* @param[in] scale_b Pointer to the packed E8M0 scale for the B matrix.
*/
template<int cbsz = 0, int blgp = 0, ducks::rt::col_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::col_layout C>
__device__ static inline void mma_ABt_scaled(D &d,
const A &a,
const B &b,
const C &c,
const fp8e8m0_4 *scale_a,
const fp8e8m0_4 *scale_b) {
static_assert(D::rows == A::rows && D::cols == B::rows);
static_assert(A::cols == B::cols);
static_assert(D::rows == C::rows && D::cols == C::cols);
static_assert(
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
);
[&]<std::size_t... Ns>(std::index_sequence<Ns...>) {
([&]<std::size_t N>() {
[&]<std::size_t... Ms>(std::index_sequence<Ms...>) {
([&]<std::size_t M>() {
mma_ABt_base_scaled<N, M, cbsz, blgp>(
d.tiles[N][M],
a.tiles[N][0],
b.tiles[M][0],
c.tiles[N][M],
scale_a,
scale_b
);
}.template operator()<Ms>(), ...);
}(std::make_index_sequence<D::width>{});
}.template operator()<Ns>(), ...);
}(std::make_index_sequence<D::height>{});
}
/**
* @brief Matrix multiply-accumulate operation with transposed A.
*

View file

@ -0,0 +1,74 @@
/**
* @file
* @brief MXFP8 block scale loading and packing utilities.
*
* Provides functions for staging E8M0 block scales in LDS and packing them
* into fp8e8m0_4 registers for use with scaled MFMA instructions.
*/
#pragma once
#include "../../../../common/common.cuh"
namespace kittens {
/**
* @brief Load iteration-major packed E8M0 scales from global memory into LDS.
*
* First 256 threads each load one uint32 (4 packed E8M0 bytes) for A and B.
* A scales are placed at smem[0..1023], B scales at smem[1024..2047].
*
* @param smem_scales LDS buffer, must be >= 2048 bytes.
* @param scale_A_iter Iteration-major A scales: [k_iter * M + row] as uint32.
* @param scale_B_iter Iteration-major B scales: [k_iter * N + row] as uint32.
* @param block_m Starting row offset for A within the current block.
* @param block_n Starting row offset for B within the current block.
* @param k_iter Current K iteration index.
* @param M_dim M dimension of the matrix.
* @param N_dim N dimension of the matrix.
*/
__device__ __forceinline__ void load_scales_to_lds(
uint8_t *smem_scales,
const uint32_t *__restrict__ scale_A_iter,
const uint32_t *__restrict__ scale_B_iter,
int block_m, int block_n, int k_iter, int M_dim, int N_dim) {
int tid = threadIdx.x;
if (tid < 256) {
uint32_t sa = scale_A_iter[k_iter * M_dim + block_m + tid];
uint32_t sb = scale_B_iter[k_iter * N_dim + block_n + tid];
*(uint32_t *)&smem_scales[tid * 4] = sa;
*(uint32_t *)&smem_scales[1024 + tid * 4] = sb;
}
}
/**
* @brief Pack 4 E8M0 scale bytes from LDS into one fp8e8m0_4 register.
*
* Each lane (r16 = laneid%16, k_sub = laneid/16) loads 4 dwords from
* consecutive 16-row groups, then uses v_perm_b32 to extract byte k_sub
* from each, producing the packed scale register for scaled MFMA.
*
* @param smem_scales LDS pointer to scale region.
* @param lds_base Byte offset within smem_scales (0 for A, 1024 for B).
* @param row_offset Starting row within the scale region (warp's tile offset).
* @return fp8e8m0_4 with 4 scale bytes packed for MFMA opsel.
*/
__device__ __forceinline__ fp8e8m0_4 pack_scales(
const uint8_t *smem_scales, int lds_base, int row_offset) {
int lid = laneid();
int r16 = lid % 16;
int k_sub = lid / 16;
const uint32_t *s32 = (const uint32_t *)(smem_scales + lds_base);
uint32_t w0 = s32[row_offset + 0 * 16 + r16];
uint32_t w1 = s32[row_offset + 1 * 16 + r16];
uint32_t w2 = s32[row_offset + 2 * 16 + r16];
uint32_t w3 = s32[row_offset + 3 * 16 + r16];
uint32_t sel = 0x0C0C0000u | (k_sub << 8) | (4u + k_sub);
uint32_t lo = __builtin_amdgcn_perm(w0, w1, sel);
uint32_t hi = __builtin_amdgcn_perm(w2, w3, sel);
return (fp8e8m0_4)(lo | (hi << 16));
}
} // namespace kittens

View file

@ -9,5 +9,6 @@
#include "maps.cuh"
#include "reductions.cuh"
#include "mma.cuh"
#include "scales.cuh"
#include "assembly/tile.cuh"

View file

@ -68,11 +68,11 @@ struct gl {
}
__host__ __device__ inline gl(const gl &other) :
raw_ptr(other.raw_ptr), batch_internal(other.batch_internal), depth_internal(other.depth_internal), rows_internal(other.rows_internal), cols_internal(other.cols_internal), tma_descs(other.tma_descs) {}
__device__ inline T& operator[](const coord<ducks::default_type> &idx) const { // yes I am abusing the const qualifier here a bit.
return raw_ptr[((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c];
__device__ inline T& operator[](const coord<ducks::default_type> &idx) const {
return raw_ptr[((int64_t(idx.b)*depth() + idx.d)*rows() + idx.r)*cols() + idx.c];
}
__device__ inline int idx(const coord<ducks::default_type> &idx) const {
return ((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c;
__device__ inline int64_t idx(const coord<ducks::default_type> &idx) const {
return ((int64_t(idx.b)*depth() + idx.d)*rows() + idx.r)*cols() + idx.c;
}
template<int axis> __device__ inline size_t shape() const {
static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3.");

View file

@ -56,7 +56,9 @@
using register_range = _register_range; ///< Register range for this tile.
static_assert(
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2>,
std::is_same_v<dtype, bf16_2> || std::is_same_v<dtype, float2> || std::is_same_v<dtype, half_2> ||
std::is_same_v<dtype, __hip_fp8x4_e4m3> || std::is_same_v<dtype, __hip_fp8x4_e4m3_fnuz> ||
std::is_same_v<dtype, __hip_fp8x4_e5m2> || std::is_same_v<dtype, __hip_fp8x4_e5m2_fnuz>,
"art_base was provided an unsupported type."
);

View file

@ -37,12 +37,12 @@ using rt_16x32_4 = rt_shape<16, 32, 4>;
using rt_16x128 = rt_shape<16, 128, 16>;
template<typename T>
concept all = std::is_same_v<T, rt_16x16> ||
std::is_same_v<T, rt_32x32> ||
std::is_same_v<T, rt_32x32_8> ||
std::is_same_v<T, rt_16x32> ||
std::is_same_v<T, rt_32x16> ||
std::is_same_v<T, rt_32x16_4> ||
concept all = std::is_same_v<T, rt_16x16> ||
std::is_same_v<T, rt_32x32> ||
std::is_same_v<T, rt_32x32_8> ||
std::is_same_v<T, rt_16x32> ||
std::is_same_v<T, rt_32x16> ||
std::is_same_v<T, rt_32x16_4> ||
std::is_same_v<T, rt_16x32_4> ||
std::is_same_v<T, rt_16x128>;
@ -59,4 +59,4 @@ concept all = std::is_same_v<T, rt_16x16> ||
template<> struct transpose<rt_16x32_4> { using type = rt_32x16_4; };
} // namespace rt_shape
} // namespace ducks
} // namespace kittens
} // namespace kittens

View file

@ -90,5 +90,6 @@ concept all = requires {
template<size_t _length> using sv_bf = sv<bf16, _length>;
template<size_t _length> using sv_hf = sv<half, _length>;
template<size_t _length> using sv_fl = sv<float, _length>;
template<size_t _length> using sv_fp8e4m3 = sv<fp8e4m3, _length>;
} // namespace kittens