/** * @file * @brief Basic operations on generic types. */ #pragma once #include "base_types.metal" #include namespace mittens { /** * @namespace base_ops * * @brief A namespace for operations on basic data types. */ namespace base_ops { #define TEMPLATE_OPS_SINGLE(func_contents) \ template static METAL_FUNC T op(device const T &x) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &x) { func_contents } \ template static METAL_FUNC T op(thread const T &x) { func_contents } #define TEMPLATE_OPS_OVERRIDE_SINGLE(T, op_name, func_contents) \ template<> METAL_FUNC T op_name::op(device const T &x) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &x) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &x) { func_contents } #define TEMPLATE_OPS_DOUBLE(func_contents) \ template static METAL_FUNC T op(device const T &a, device const T &b) { func_contents } \ template static METAL_FUNC T op(device const T &a, threadgroup const T &b) { func_contents } \ template static METAL_FUNC T op(device const T &a, thread const T &b) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, device const T &b) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, thread const T &b) { func_contents } \ template static METAL_FUNC T op(thread const T &a, device const T &b) { func_contents } \ template static METAL_FUNC T op(thread const T &a, threadgroup const T &b) { func_contents } \ template static METAL_FUNC T op(thread const T &a, thread const T &b) { func_contents } #define TEMPLATE_OPS_OVERRIDE_DOUBLE(T, op_name, func_contents) \ template<> METAL_FUNC T op_name::op(device const T &a, device const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, threadgroup const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, thread const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, device const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, threadgroup const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, thread const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, device const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, threadgroup const T &b) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, thread const T &b) { func_contents } #define TEMPLATE_OPS_TRIPLE(func_contents) \ template static METAL_FUNC T op(device const T &a, device const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, device const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, thread const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(device const T &a, thread const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, device const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, device const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, thread const T &b, device const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template static METAL_FUNC T op(thread const T &a, thread const T &b, thread const T &c) { func_contents } #define TEMPLATE_OPS_OVERRIDE_TRIPLE(T, op_name, func_contents) \ template<> METAL_FUNC T op_name::op(device const T &a, device const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, device const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, thread const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(device const T &a, thread const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, device const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, device const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, thread const T &b, device const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \ template<> METAL_FUNC T op_name::op(thread const T &a, thread const T &b, thread const T &c) { func_contents } /* ---------- CONST OPS ---------- */ /** * @brief Represents the zero constant operation. * * This operation returns the zero value of the specified type. * * @tparam T The data type for which to return the zero value. * @return The zero value of type T. */ struct zero { template static METAL_FUNC constexpr T op(args... _) { return base_types::constants::zero(); } }; /** * @brief Represents the one constant operation. * * This operation returns the one value of the specified type. * * @tparam T The data type for which to return the one value. * @return The one value of type T. */ struct one { template static METAL_FUNC constexpr T op(args... _) { return base_types::constants::one(); } }; /** * @brief Represents the positive infinity constant operation. * * This operation returns the positive infinity value of the specified type. * * @tparam T The data type for which to return the positive infinity value. * @return The positive infinity value of type T. */ struct pos_infty { template static METAL_FUNC constexpr T op(args... _) { return base_types::constants::pos_infty(); } }; /** * @brief Represents the negative infinity constant operation. * * This operation returns the negative infinity value of the specified type. * * @tparam T The data type for which to return the negative infinity value. * @return The negative infinity value of type T. */ struct neg_infty { template static METAL_FUNC constexpr T op(args... _) { return base_types::constants::neg_infty(); } }; /* ---------- UNARY OPS ---------- */ /** * @brief Exponential function operation. * * This operation calculates the exponential of the input value. * * @tparam T The data type of the input and output values. * @param x[in] The input value. * @return The exponential of the input value. */ struct exp { TEMPLATE_OPS_SINGLE(return metal::exp(x);) }; TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp, return bf16(metal::exp((float)x));) TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp, return bf16_2(metal::exp(float2(x)));) /** * @brief Exponential function operation, in base 2 * * This operation calculates the exponential of the input value, in base 2. * * @tparam T The data type of the input and output values. * @param x[in] The input value. * @return The exponential of the input value. */ struct exp2 { template static METAL_FUNC T op(device const T &x) { return metal::exp2(x); } \ template static METAL_FUNC T op(threadgroup const T &x) { return metal::exp2(x); } \ template static METAL_FUNC T op(thread const T &x) { return metal::exp2(x); } }; //template<> METAL_FUNC bf16 exp2::op(device const bf16 &x) { return bf16(metal::exp2(x)); } \ //template<> METAL_FUNC bf16 exp2::op(threadgroup const bf16 &x) { return bf16(metal::exp2(x)); } \ //template<> METAL_FUNC bf16 exp2::op(thread const bf16 &x) { return bf16(metal::exp2(x)); } TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp2, return bf16(metal::exp2(x));) TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp2, return bf16_2(metal::exp2((float2)x));) /** * @brief Natural log function operation. * * This operation calculates the natural logarithm of the input value. * * @tparam T The data type of the input and output values. * @param x[in] The input value. * @return The natural logarithm of the input value. */ struct log { TEMPLATE_OPS_SINGLE(return metal::log(x);) }; TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, log, return bf16(metal::log(x));) TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, log, return bf16_2(metal::log((float2)x));) /** * @brief Absolute value operation. * * This operation calculates the absolute value of the input. * * @tparam T The data type of the input and output values. * @param x[in] The input value. * @return The absolute value of the input. */ struct abs { TEMPLATE_OPS_SINGLE(return metal::abs(x);) }; TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , abs, return bf16(metal::abs((float)x));) TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, abs, return bf16_2(metal::abs((float2)x));) /** * @brief Rectified Linear Unit (ReLU) operation. * * This operation applies the ReLU function to the input, which is the * maximum of zero and the input value. * * @tparam T The data type of the input and output values. * @param x[in] The input value. * @return The result of ReLU function applied to the input. */ struct relu { TEMPLATE_OPS_SINGLE(return max(x, base_types::constants::zero());) }; TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , relu, return bf16(metal::max((float)x, base_types::constants::zero()));) TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, relu, return bf16_2(metal::max((float2)x, base_types::constants::zero()));) /** * @brief Copy operation. * * This operation returns the input value unchanged. * * @tparam T The data type of the input and output values. * @param a[in] The input value. * @return The same value as the input. */ struct copy { // for non-compile-time setters. TEMPLATE_OPS_SINGLE(return x;) }; /* ---------- BINARY OPS ---------- */ /** * @brief Copy2 operation. * * This operation returns the second input value unchanged. * * @tparam T The data type of the input and output values. * @param a[in] The first input value (ignored). * @param b[in] The second input value. * @return The same value as the second input. */ struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/ TEMPLATE_OPS_DOUBLE(return b;) }; /** * @brief Sum operation. * * This operation calculates the sum of two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The sum of the input values. */ struct sum { TEMPLATE_OPS_DOUBLE(return a+b;) }; /** * @brief Subtraction operation. * * This operation calculates the difference between two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The difference between the input values. */ struct sub { TEMPLATE_OPS_DOUBLE(return a-b;) }; /** * @brief Multiplication operation. * * This operation calculates the product of two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The product of the input values. */ struct mul { TEMPLATE_OPS_DOUBLE(return a*b;) }; /** * @brief Division operation. * * This operation calculates the quotient of two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The quotient of the input values. */ struct div { TEMPLATE_OPS_DOUBLE(return a/b;) }; /** * @brief Maximum operation. * * This operation calculates the maximum of two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The maximum of the input values. */ struct max { TEMPLATE_OPS_DOUBLE(return metal::max(a,b);) }; TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , max, return (bf16)metal::max((float)a, (float)b);) TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, max, return (bf16_2)metal::max((float2)a, (float2)b);) /** * @brief Minimum operation. * * This operation calculates the minimum of two input values. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @return The minimum of the input values. */ struct min { TEMPLATE_OPS_DOUBLE(return metal::min(a,b);) }; TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , min, return (bf16)metal::min((float)a, (float)b);) TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, min, return (bf16_2)metal::min((float2)a, (float2)b);) /* ---------- TERNARY OPS ---------- */ /** * @brief Fused multiply-add operation A * B + C. * * This operation performs a fused multiply-add, computing (A * B) + C with only one rounding. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The second input value. * @param c[in] The third input value to be added. * @return The result of the fused multiply-add operation. */ struct fma_AxBtC { TEMPLATE_OPS_TRIPLE(return sum::op(mul::op(a, b), c);) }; /** * @brief Fused multiply-add operation A * C + B. * * This operation performs a fused multiply-add, computing (A * C) + B with only one rounding. * This is particularly useful for attention mechanisms in neural networks. * * @tparam T The data type of the input and output values. * @param a[in] The first input value. * @param b[in] The third input value to be added. * @param c[in] The second input value. * @return The result of the fused multiply-add operation. */ struct fma_AxCtB { // this is the one needed for attention TEMPLATE_OPS_TRIPLE(return sum::op(mul::op(a, c), b);) }; #undef TEMPLATE_OPS_SINGLE #undef TEMPLATE_OPS_OVERRIDE_SINGLE #undef TEMPLATE_OPS_DOUBLE #undef TEMPLATE_OPS_OVERRIDE_DOUBLE #undef TEMPLATE_OPS_TRIPLE #undef TEMPLATE_OPS_OVERRIDE_TRIPLE } // base_ops } // mittens