mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
check_dims
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e051a3536f |
||
|
|
c5573c989d | ||
|
|
91dc19c713 |
1 changed files with 5 additions and 0 deletions
|
|
@ -37,6 +37,10 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||||
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
||||||
assert 2**upcast_axes == self.elements_per_thread[2], \
|
assert 2**upcast_axes == self.elements_per_thread[2], \
|
||||||
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
|
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
|
||||||
|
# check dims match opts
|
||||||
|
assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
|
||||||
|
assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
|
||||||
|
# NOTE: the K opts is implictly set by the dim
|
||||||
# check swizzle
|
# check swizzle
|
||||||
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
|
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
|
||||||
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
|
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
|
||||||
|
|
@ -52,6 +56,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||||
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
|
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
|
||||||
if o[0] == 'u': un += 1
|
if o[0] == 'u': un += 1
|
||||||
if o[0] == 'l': ln += 1
|
if o[0] == 'l': ln += 1
|
||||||
|
# NOTE: all the zero_stride dims can be placed in any order in the swizzle
|
||||||
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
|
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
|
||||||
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
|
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
|
||||||
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
|
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue