minor to_indexed_uops cleanup [pr] (#8063)

This commit is contained in:
chenyu 2024-12-05 17:15:03 -05:00 committed by GitHub
commit 05dba6e4ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from typing import Tuple, List, Optional, Dict, Set, cast, Sequence
from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
@ -95,10 +95,11 @@ class View:
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
def __lt__(self, o:View): return self.t < o.t
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
"""(idx, valid)"""
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
if m is not None:
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])