mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add smin/smax (#7253)
* add smin/smax * don't create var with var * better test errors * add failing test * enable shape simplification * fix tests * Update view.py * simpler and simplify
This commit is contained in:
parent
de7b9d7c42
commit
532b7b018c
6 changed files with 36 additions and 12 deletions
|
|
@ -21,7 +21,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
|||
return
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
||||
assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache)
|
||||
assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}"
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1, len(fxn.jit_cache)
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ class TestSymbolicExpand(unittest.TestCase):
|
|||
vi = Variable("i", 1, 5).bind(i)
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
||||
self.assertTupleEqual(a.shape, (3, vi))
|
||||
|
||||
class TestSymbolicShrink(unittest.TestCase):
|
||||
def test_shrink_symbols(self):
|
||||
|
|
|
|||
|
|
@ -55,6 +55,14 @@ class TestTensorVariable(unittest.TestCase):
|
|||
ret = t.var().item()
|
||||
assert ret == 0
|
||||
|
||||
def test_symbolic_pad2d(self):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous()
|
||||
t = t.pad2d([vv, vv, vv, vv]).mean()
|
||||
ones = 4
|
||||
zeros = 6+6+4+4+6+6
|
||||
self.assertAlmostEqual(t.item(), ones/(ones+zeros))
|
||||
|
||||
@unittest.skip("symbolic arange isn't supported")
|
||||
def test_symbolic_arange(self):
|
||||
vv = Variable("a", 1, 10).bind(2)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,15 @@ def resolve(x, default:bool=True):
|
|||
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
||||
# NOTE: generating the text for the exception is expensive, so we do this
|
||||
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
|
||||
|
||||
# smax/smin are replacements for max/min that preserve symbolic
|
||||
def _suop(lst, uop_fxn, python_fxn):
|
||||
max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp))
|
||||
if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify()
|
||||
return python_fxn(max_num)
|
||||
def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.max, max)
|
||||
def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.min, min)
|
||||
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
|
|
@ -307,6 +315,7 @@ class UOp(MathTrait):
|
|||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self):
|
||||
|
|
@ -1056,6 +1065,10 @@ renderer = PatternMatcher([
|
|||
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
|
||||
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer
|
||||
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
|
@ -125,9 +125,10 @@ class View:
|
|||
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
||||
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
||||
# simplify as we go
|
||||
if isinstance(offset, UOp): offset = cast(Union[UOp, int], offset.ssimplify())
|
||||
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
|
||||
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
|
||||
# TODO: enabling stride simplification breaks it
|
||||
"""
|
||||
shape = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in shape)
|
||||
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
||||
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
||||
"""
|
||||
|
|
@ -174,6 +175,7 @@ class View:
|
|||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
if not all_int(vm1.shape): return None
|
||||
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
extents: List[Tuple[sint, UOp]] = []
|
||||
|
|
@ -232,9 +234,9 @@ class View:
|
|||
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
||||
if self.mask:
|
||||
# move the old mask
|
||||
nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
||||
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
|||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
|
||||
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
|
@ -380,6 +380,7 @@ class Tensor:
|
|||
if y.op is UOps.ALU:
|
||||
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
raise RuntimeError(f"unhandled UOp {y}")
|
||||
|
||||
# ***** creation entrypoint *****
|
||||
|
|
@ -1364,9 +1365,9 @@ class Tensor:
|
|||
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
||||
```
|
||||
"""
|
||||
pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
|
||||
pads = tuple((smax(p0, 0), smax(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
|
||||
padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value)
|
||||
shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
|
||||
shrink = tuple((-smin(p0, 0), smin(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
|
||||
return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink)
|
||||
|
||||
@property
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue