Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
e051a3536f
Merge branch 'master' into check_dims 2025-07-30 12:05:58 -07:00
George Hotz
c5573c989d check tc dims 2025-07-30 12:05:29 -07:00
George Hotz
91dc19c713 check elements_per_thread in tensorcore [pr] 2025-07-30 11:35:26 -07:00

View file

@ -37,6 +37,10 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
assert 2**upcast_axes == self.elements_per_thread[2], \
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
# check dims match opts
assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
# NOTE: the K opts is implictly set by the dim
# check swizzle
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
@ -52,6 +56,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
if o[0] == 'u': un += 1
if o[0] == 'l': ln += 1
# NOTE: all the zero_stride dims can be placed in any order in the swizzle
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"