mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Linearizer -> Lowerer (#4957)
* st to uops function * lowerer * uops reduce * uops reduce * acc_number correct * reduce unroll * complete unroll * do upcasts * handle multioutput * define_accs * fix valid * get grouped dims * revert lin * minor * fixup_ast * group for reduce * group works now * all forwards pass * all ops tests pass * fix clang * mypy * lil cleanups, no image yet * ugh, variables everywhere * bugfix * counters and name fix * use symbolic, not uops * cleanups * Fix tests * linearizer tests * expands * float4 expand load * tests pass * woooo, float4 test * test ops works again * one more lin test * more lin tests * bypass * fix tests * something like this * const in defineacc * uops get_reduce_acc * move around * allow consts in the LOAD/STORE * each axis should only appear once, 21 failures * 16 failures * fix some image * optional float4 * onnx tests * gate the stores * add reorder * fix terrible skip function * tc work * opt add/mul merge * fix float4 tests * tiny tweak, 9 failing * 7 test failures * start tc, but i don't think this will work * progress on tensorcores * note * fix ops tests * closer on tc * weeee...one tensor core works * still works, more generic * large WMMA works * tc test passes * use WMMA as accumulator * basic tc tests passing * small gemm padded works * 4 failures * 3 tests failing * super barrier * now two tests failing * one test failing * cleanpus, add reduce to UopGraph * remove the linearizer * remove unused * lil cleanups * Lowerer everywhere * remove test that doesn't exist now * image indexing * llvm fix * fix metal * fix image * fix images * might fix ptx * fix image type mismatch * more tests pass * CAST -> VECTORIZE * forgot that one * fix TestOps.test_flip_eye_crash * locals shouldn't be image dtype * change less files * test fix * fix recursive expands * touches * MULACC support in python * delete unneeded * alu before contract * bug fixes * tests * no var multireduce * simpler tc * metal works in new style * working on AMD and METAL * fix amd * shot in the dark, fix amd * something for CUDA * CUDA WORKS from the docs * comment * correct merge * cleanups + ptx fix + get_reduce_acc * local alias isn't used anymore * add store sanity check * fix for AMD * cleanups and single expand pass * more correct with acc_cache * tests should pass * block on WMMA * tests pass * merge contract and reduce * contractor fixes issue * multicontract * pre expand wmma (same as a reduce) * expand wmma and only take one * all expands * comments and whitespace
This commit is contained in:
parent
204b6169ca
commit
6972a2569f
9 changed files with 587 additions and 630 deletions
|
|
@ -135,9 +135,9 @@ class PythonProgram:
|
|||
elif uop is UOps.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
|
||||
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
|
||||
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
|
||||
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
|
||||
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
|
||||
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
|
||||
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
||||
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
||||
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue