This commit is contained in:
George Hotz 2026-06-02 13:19:04 -07:00
commit c7b6ee0c7d

View file

@ -244,7 +244,7 @@ class ClangRenderer(CStyleLanguage):
kernel_typedef = "__attribute__((ms_abi)) void"
def render_vector_prefix(self, dt:DType) -> str:
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
alignment = 2**int(math.log2(dt.itemsize*dt.count)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
def _render_defines(self, uops) -> list[str]:
@ -433,7 +433,8 @@ class CUDARenderer(CStyleLanguage):
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
return f"struct __align__({dt.itemsize*dt.count}) {vec} {{ {scal} {elems}; }};" + \
f"__device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
@ -450,7 +451,8 @@ class CUDARenderer(CStyleLanguage):
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
# 4 => CUDA reg size in bytes
n_operands = [size*dtype.itemsize*dtype.count//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
operands = [f"%{i}" for i in range(sum(n_operands))]
# mma operands => {c}, {a}, {b}, {c}