add constant folding for WHERE in uops (#3584)

* add constant folding for WHERE in uops

* prereqs for generic constant folding

* fix test

* disable slow overflow logic

* make that test faster
This commit is contained in:
George Hotz 2024-03-02 10:37:14 -08:00 committed by GitHub
commit aa9b013d79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 56 additions and 29 deletions

View file

@ -2,36 +2,14 @@
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, math, struct
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOp, UOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.kernel import LinearizerOptions
def exec_alu(arg, dtype, p):
# TODO: make this complete and correctly honor the dtypes
# TODO: use this for constant folding
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
if arg == UnaryOps.EXP2:
try: return math.exp(p[0]*math.log(2))
except OverflowError: return math.inf
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
if arg == UnaryOps.SIN: return math.sin(p[0])
if arg == UnaryOps.NEG: return -p[0]
if arg == BinaryOps.MUL: return p[0]*p[1]
if arg == BinaryOps.ADD: return p[0]+p[1]
if arg == BinaryOps.SUB: return p[0]-p[1]
if arg == BinaryOps.XOR: return p[0]^p[1]
if arg == BinaryOps.MAX: return max(p[0], p[1])
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
if arg == BinaryOps.CMPLT: return p[0] < p[1]
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
if arg == BinaryOps.MOD: return p[0]%p[1]
raise NotImplementedError(f"no support for {arg}")
def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]