mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
bf16 support in metal (#3929)
it runs if device gpu supports bfloat. updated ci benchmark too
This commit is contained in:
parent
72d617a37d
commit
ef537672bf
2 changed files with 25 additions and 13 deletions
12
.github/workflows/benchmark.yml
vendored
12
.github/workflows/benchmark.yml
vendored
|
|
@ -56,6 +56,10 @@ jobs:
|
|||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.5 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
run: STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
- name: Run 10 CIFAR training steps w HALF
|
||||
run: STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
- name: Run 10 CIFAR training steps w BF16
|
||||
run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
# TODO: this is flaky too
|
||||
# - name: Run 10 CIFAR training steps w winograd
|
||||
# run: WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
|
|
@ -65,9 +69,6 @@ jobs:
|
|||
path: |
|
||||
onnx_inference_speed.csv
|
||||
torch_speed.txt
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_wino.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
|
|
@ -78,6 +79,11 @@ jobs:
|
|||
matmul.txt
|
||||
matmul_half.txt
|
||||
sd.txt
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_half.txt
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
|
||||
testnvidiabenchmark:
|
||||
name: NVIDIA Benchmark
|
||||
|
|
|
|||
|
|
@ -204,6 +204,14 @@ class MetalLanguage(CStyleLanguage):
|
|||
uses_ptr_arithmetic = True
|
||||
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
type_map = {dtypes.bfloat16: "bfloat"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
|
||||
|
||||
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
|
|
@ -242,20 +250,18 @@ class CUDALanguage(CStyleLanguage):
|
|||
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;}""",
|
||||
]
|
||||
return c;}""",]
|
||||
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
|
||||
code_for_op_hip = {
|
||||
# TODO: MAX with int uses fmax_f32?
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
}
|
||||
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
||||
# TODO: MAX with int uses fmax_f32?
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
|
||||
|
||||
def _make_hip_code_for_op():
|
||||
def wrapper(key, func):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue