mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add constant folding for WHERE in uops (#3584)
* add constant folding for WHERE in uops * prereqs for generic constant folding * fix test * disable slow overflow logic * make that test faster
This commit is contained in:
parent
3b7e3fa2e4
commit
aa9b013d79
5 changed files with 56 additions and 29 deletions
|
|
@ -2,36 +2,14 @@
|
|||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, math, struct
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOp, UOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
def exec_alu(arg, dtype, p):
|
||||
# TODO: make this complete and correctly honor the dtypes
|
||||
# TODO: use this for constant folding
|
||||
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
|
||||
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
|
||||
if arg == UnaryOps.EXP2:
|
||||
try: return math.exp(p[0]*math.log(2))
|
||||
except OverflowError: return math.inf
|
||||
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
|
||||
if arg == UnaryOps.SIN: return math.sin(p[0])
|
||||
if arg == UnaryOps.NEG: return -p[0]
|
||||
if arg == BinaryOps.MUL: return p[0]*p[1]
|
||||
if arg == BinaryOps.ADD: return p[0]+p[1]
|
||||
if arg == BinaryOps.SUB: return p[0]-p[1]
|
||||
if arg == BinaryOps.XOR: return p[0]^p[1]
|
||||
if arg == BinaryOps.MAX: return max(p[0], p[1])
|
||||
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
|
||||
if arg == BinaryOps.CMPLT: return p[0] < p[1]
|
||||
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
|
||||
if arg == BinaryOps.MOD: return p[0]%p[1]
|
||||
raise NotImplementedError(f"no support for {arg}")
|
||||
|
||||
def _load(m, i):
|
||||
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue