mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simpler idxs_to_idx (#3071)
This commit is contained in:
parent
2495ca95c7
commit
023f5df0e9
1 changed files with 3 additions and 5 deletions
|
|
@ -49,11 +49,9 @@ def simplify(views:Tuple[View, ...]) -> Tuple[View, ...]:
|
|||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
|
||||
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
||||
acc, ret = 1, []
|
||||
for tidx,d in zip(reversed(idxs), reversed(shape)):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Node.sum(ret)
|
||||
# idxs[-1] * 1 + idxs[-2] * shape[-1] + idxs[-3] * shape[-1] * shape[-2] + ...
|
||||
accs = itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1)
|
||||
return Node.sum([idx * acc for idx, acc in zip(reversed(idxs), accs)])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue