add smin/smax (#7253)

* add smin/smax

* don't create var with var

* better test errors

* add failing test

* enable shape simplification

* fix tests

* Update view.py

* simpler and simplify
This commit is contained in:
George Hotz 2024-10-24 15:10:49 +07:00 committed by GitHub
commit 532b7b018c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 36 additions and 12 deletions

View file

@ -21,7 +21,7 @@ def assert_jit_cache_len(fxn, expected_len):
return
# until we have a better way of typing the prg in ExecItem
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache)
assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}"
else:
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
# until we have a better way of typing the prg in ExecItem

View file

@ -208,7 +208,7 @@ class TestSymbolicExpand(unittest.TestCase):
vi = Variable("i", 1, 5).bind(i)
a = Tensor.rand(3, i).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
self.assertTupleEqual(a.shape, (3, vi))
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols(self):

View file

@ -55,6 +55,14 @@ class TestTensorVariable(unittest.TestCase):
ret = t.var().item()
assert ret == 0
def test_symbolic_pad2d(self):
vv = Variable("a", 1, 10).bind(2)
t = Tensor.ones(2, 2).contiguous()
t = t.pad2d([vv, vv, vv, vv]).mean()
ones = 4
zeros = 6+6+4+4+6+6
self.assertAlmostEqual(t.item(), ones/(ones+zeros))
@unittest.skip("symbolic arange isn't supported")
def test_symbolic_arange(self):
vv = Variable("a", 1, 10).bind(2)

View file

@ -164,7 +164,15 @@ def resolve(x, default:bool=True):
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
# NOTE: generating the text for the exception is expensive, so we do this
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
# smax/smin are replacements for max/min that preserve symbolic
def _suop(lst, uop_fxn, python_fxn):
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
return python_fxn(max_num)
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.max, max)
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.min, min)
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
@ -307,6 +315,7 @@ class UOp(MathTrait):
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self):
@ -1056,6 +1065,10 @@ renderer = PatternMatcher([
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC),
lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),

View file

@ -1,9 +1,9 @@
from __future__ import annotations
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
from tinygrad.helpers import prod, all_int, argsort
@functools.lru_cache(maxsize=None)
@ -125,9 +125,10 @@ class View:
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
strides = tuple(0 if e else st for st,e in zip(strides, elim))
# simplify as we go
if isinstance(offset, UOp): offset = cast(Union[UOp, int], offset.ssimplify())
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
# TODO: enabling stride simplification breaks it
"""
shape = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in shape)
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
"""
@ -174,6 +175,7 @@ class View:
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
if not all_int(vm1.shape): return None
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
extents: List[Tuple[sint, UOp]] = []
@ -232,9 +234,9 @@ class View:
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
if self.mask:
# move the old mask
nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
# merge the masks if we have two
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
shape = [y-x for x,y in arg]
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)

View file

@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
@ -380,6 +380,7 @@ class Tensor:
if y.op is UOps.ALU:
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
raise RuntimeError(f"unhandled UOp {y}")
# ***** creation entrypoint *****
@ -1364,9 +1365,9 @@ class Tensor:
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
```
"""
pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
pads = tuple((smax(p0, 0), smax(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value)
shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
shrink = tuple((-smin(p0, 0), smin(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink)
@property