clean up tensor cores [run_process_replay] (#5736)

* clean up tensor cores [run_process_replay]

* remove tuple(wmma_sz), self.opts.device

* remove tls, leave DEVICE
This commit is contained in:
George Hotz 2024-07-26 13:21:23 -07:00 committed by GitHub
commit 4df46eac67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 22 additions and 24 deletions

View file

@ -150,13 +150,13 @@ class PythonProgram:
return out
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
if arg[5] == "METAL":
if arg[4] == "METAL":
# A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg[5] == "AMD":
elif arg[4] == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@ -165,7 +165,7 @@ class PythonProgram:
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg[5] == "CUDA":
elif arg[4] == "CUDA":
# A (8 elements on 32 threads)
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
# B (4 elements on 32 threads)