Linearizer -> Lowerer (#4957)

* st to uops function

* lowerer

* uops reduce

* uops reduce

* acc_number correct

* reduce unroll

* complete unroll

* do upcasts

* handle multioutput

* define_accs

* fix valid

* get grouped dims

* revert lin

* minor

* fixup_ast

* group for reduce

* group works now

* all forwards pass

* all ops tests pass

* fix clang

* mypy

* lil cleanups, no image yet

* ugh, variables everywhere

* bugfix

* counters and name fix

* use symbolic, not uops

* cleanups

* Fix tests

* linearizer tests

* expands

* float4 expand load

* tests pass

* woooo, float4 test

* test ops works again

* one more lin test

* more lin tests

* bypass

* fix tests

* something like this

* const in defineacc

* uops get_reduce_acc

* move around

* allow consts in the LOAD/STORE

* each axis should only appear once, 21 failures

* 16 failures

* fix some image

* optional float4

* onnx tests

* gate the stores

* add reorder

* fix terrible skip function

* tc work

* opt add/mul merge

* fix float4 tests

* tiny tweak, 9 failing

* 7 test failures

* start tc, but i don't think this will work

* progress on tensorcores

* note

* fix ops tests

* closer on tc

* weeee...one tensor core works

* still works, more generic

* large WMMA works

* tc test passes

* use WMMA as accumulator

* basic tc tests passing

* small gemm padded works

* 4 failures

* 3 tests failing

* super barrier

* now two tests failing

* one test failing

* cleanpus, add reduce to UopGraph

* remove the linearizer

* remove unused

* lil cleanups

* Lowerer everywhere

* remove test that doesn't exist now

* image indexing

* llvm fix

* fix metal

* fix image

* fix images

* might fix ptx

* fix image type mismatch

* more tests pass

* CAST -> VECTORIZE

* forgot that one

* fix TestOps.test_flip_eye_crash

* locals shouldn't be image dtype

* change less files

* test fix

* fix recursive expands

* touches

* MULACC support in python

* delete unneeded

* alu before contract

* bug fixes

* tests

* no var multireduce

* simpler tc

* metal works in new style

* working on AMD and METAL

* fix amd

* shot in the dark, fix amd

* something for CUDA

* CUDA WORKS from the docs

* comment

* correct merge

* cleanups + ptx fix + get_reduce_acc

* local alias isn't used anymore

* add store sanity check

* fix for AMD

* cleanups and single expand pass

* more correct with acc_cache

* tests should pass

* block on WMMA

* tests pass

* merge contract and reduce

* contractor fixes issue

* multicontract

* pre expand wmma (same as a reduce)

* expand wmma and only take one

* all expands

* comments and whitespace
This commit is contained in:
George Hotz 2024-07-10 15:07:42 -07:00 committed by GitHub
commit 6972a2569f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 587 additions and 630 deletions

View file

@ -135,9 +135,9 @@ class PythonProgram:
elif uop is UOps.WMMA:
# here are the models for the WMMA instruction on the different hardware
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"