mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
DEFINE_ACC takes UOps.CONST in vin instead of arg (#4975)
* Change DEFINE_ACC to receive UOps.CONST in vin * Use localtype instead of acc dtype * Fix idp * Fix copy list * Fix warp * Fix error * Fix merge * Fix testing * Fix merge * Use deepcopy * Change to copy of inp * Fix lint * Move const to first place * Fix issue upat * Fix upat patterns * Change to list, to test permutations * Add condition * Change pm * Revert change pm * Remove unused rule * Fix * Change of float4 DEFINE_ACC values * Cast on PM to correct dtype * Improve assert message * Move IFs --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
d84beaa6dd
commit
dfa562dbc1
6 changed files with 18 additions and 16 deletions
|
|
@ -41,7 +41,7 @@ class PythonProgram:
|
|||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
||||
if uop is UOps.DEFINE_ACC: idp.clear()
|
||||
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
|
|
@ -90,7 +90,7 @@ class PythonProgram:
|
|||
elif uop is UOps.CONST:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
|
||||
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
||||
elif uop is UOps.RANGE:
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue