typedef bf16 amd (#7850)

This commit is contained in:
ignaciosica 2024-11-22 16:29:01 -03:00 committed by GitHub
commit fb10ea563e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -412,7 +412,7 @@ class AMDRenderer(CStyleLanguage):
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper