DEFINE_ACC takes UOps.CONST in vin instead of arg (#4975)

* Change DEFINE_ACC to receive UOps.CONST in vin

* Use localtype instead of acc dtype

* Fix idp

* Fix copy list

* Fix warp

* Fix error

* Fix merge

* Fix testing

* Fix merge

* Use deepcopy

* Change to copy of inp

* Fix lint

* Move const to first place

* Fix issue upat

* Fix upat patterns

* Change to list, to test permutations

* Add condition

* Change pm

* Revert change pm

* Remove unused rule

* Fix

* Change of float4 DEFINE_ACC values

* Cast on PM to correct dtype

* Improve assert message

* Move IFs

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Jhenner Tigreros 2024-06-24 11:25:33 -05:00 committed by GitHub
commit dfa562dbc1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 18 additions and 16 deletions

View file

@ -41,7 +41,7 @@ class PythonProgram:
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
if uop is UOps.DEFINE_ACC: idp.clear()
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
@ -90,7 +90,7 @@ class PythonProgram:
elif uop is UOps.CONST:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.DEFINE_ACC:
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
elif uop is UOps.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else: