mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move reduce to lowerer [pr] (#7462)
* move reduce to lowerer [pr] * simpler
This commit is contained in:
parent
2cfca230b5
commit
a7ba3d2d91
4 changed files with 21 additions and 18 deletions
|
|
@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
|||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class TestArithmeticSimplifications(unittest.TestCase):
|
|||
|
||||
|
||||
class TestFoldingAndReduction(unittest.TestCase):
|
||||
@unittest.skip("reduce is removed now")
|
||||
def test_full_graph_rewrite_constant_reduction_folding(self):
|
||||
const1 = UOp.const(dtypes.int32, 5)
|
||||
const2 = UOp.const(dtypes.int32, 10)
|
||||
|
|
@ -59,6 +60,7 @@ class TestFoldingAndReduction(unittest.TestCase):
|
|||
expected_sum = 5 + 10 + 20
|
||||
self.assertEqual(optimized_sink.arg, expected_sum)
|
||||
|
||||
@unittest.skip("reduce is removed now")
|
||||
def test_full_graph_rewrite_reduction_with_unused_range(self):
|
||||
const1 = UOp.const(dtypes.int32, 15)
|
||||
const2 = UOp.const(dtypes.int32, 25)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
|
|
@ -133,4 +133,19 @@ pm_lowerer = PatternMatcher([
|
|||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
return acc.assign(acc.alu(root.arg, root.src[0]))
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, just_reduce, ctx=[0])
|
||||
return sink
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import functools, itertools, operator
|
|||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
|
|
@ -381,12 +381,6 @@ def do_expand(root:UOp):
|
|||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
return acc.assign(acc.alu(root.arg, root.src[0]))
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
# CONTRACT without EXPAND repeats the element VECTORIZED
|
||||
|
|
@ -451,11 +445,6 @@ def no_vectorized_acc(acc:UOp):
|
|||
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, acc.dtype, alus)
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.INDEX), name="alu"), no_vectorized_alu),
|
||||
|
|
@ -520,9 +509,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, just_reduce, ctx=[0])
|
||||
|
||||
# initial symbolic + migrate indexing (remove this) + transcendental
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue