mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
PADTO pads Invalids (#16562)
This commit is contained in:
parent
434a8ffc38
commit
5f1e2d3900
6 changed files with 22 additions and 29 deletions
|
|
@ -224,7 +224,6 @@ class TestKernelOpts(unittest.TestCase):
|
|||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# can pad sum reduce axis if there's no unsafe ops prior to sum
|
||||
for axis in (0, 1):
|
||||
helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
|
|
@ -240,22 +239,16 @@ class TestKernelOpts(unittest.TestCase):
|
|||
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_sum_not_ok(self):
|
||||
def test_padto_sum(self):
|
||||
N = 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
|
||||
# exp is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
b = a < 1
|
||||
# lt is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_max(self):
|
||||
N = 18
|
||||
|
|
@ -271,11 +264,8 @@ class TestKernelOpts(unittest.TestCase):
|
|||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# cannot pad max kernel on reduce
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_where(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
|||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||
|
|
@ -89,8 +89,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
|
||||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads
|
||||
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
||||
# add loads and remove invalids
|
||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||
|
|
|
|||
|
|
@ -341,6 +341,9 @@ def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
|
|||
return sink.substitute(subs) if subs else None
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# invalid -> identity element
|
||||
(UPat(Ops.REDUCE, src=(invalid_gate,), allow_any_len=True, name="red"), lambda red,cond,x,i:
|
||||
red.replace(src=(cond.where(x, identity_element(red.arg[0], x.dtype.scalar())),)+red.src[1:])),
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN, then merge ENDs with same range
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ from __future__ import annotations
|
|||
import math, itertools
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, remove_all_tags
|
||||
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, remove_all_tags
|
||||
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import ALLOW_TF32, count, Context
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
|
|
@ -188,17 +188,16 @@ class Scheduler:
|
|||
check(rng.src[0].op is Ops.CONST, "only pad const axes")
|
||||
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
|
||||
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
||||
check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}")
|
||||
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
|
||||
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
|
||||
replaced_rng = UOp.range(new_sz, *rng.arg)
|
||||
replaces = {rng:replaced_rng}
|
||||
valid = replaced_rng < rng.vmax+1
|
||||
store_targets = {s.src[0] for s in self.ast.backward_slice_with_self if s.op is Ops.STORE}
|
||||
for b in self.bufs:
|
||||
if rng in (i:=b.src[1].get_idx()).backward_slice_with_self:
|
||||
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
|
||||
nb = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
|
||||
replaces[b] = nb if b in store_targets else valid.where(nb, UOp.const(b.dtype, Invalid))
|
||||
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
|
||||
elif opt.op is OptOps.SWAP:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -142,7 +142,4 @@ class GroupOp:
|
|||
# These can change the dtype to bool
|
||||
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.CDIV, Ops.POW, Ops.FLOORDIV}
|
||||
|
||||
All = set(Ops)
|
||||
|
|
|
|||
|
|
@ -86,6 +86,10 @@ propagate_invalid = PatternMatcher([
|
|||
lambda x,i: x.src[1] if len(x.src) > 1 else x.const_like(0)),
|
||||
])
|
||||
|
||||
pm_remove_invalid = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=Invalid, name="i"), lambda i: i.const_like(0) if i.dtype.scalar() is not dtypes.weakint else None),
|
||||
])
|
||||
|
||||
symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue