move reduce to lowerer [pr] (#7462)

* move reduce to lowerer [pr]

* simpler
This commit is contained in:
George Hotz 2024-11-01 15:39:20 +07:00 committed by GitHub
commit a7ba3d2d91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 21 additions and 18 deletions

View file

@ -34,7 +34,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV", "METAL"} and not CI and not getenv("PTX"))
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function

View file

@ -51,6 +51,7 @@ class TestArithmeticSimplifications(unittest.TestCase):
class TestFoldingAndReduction(unittest.TestCase):
@unittest.skip("reduce is removed now")
def test_full_graph_rewrite_constant_reduction_folding(self):
const1 = UOp.const(dtypes.int32, 5)
const2 = UOp.const(dtypes.int32, 10)
@ -59,6 +60,7 @@ class TestFoldingAndReduction(unittest.TestCase):
expected_sum = 5 + 10 + 20
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("reduce is removed now")
def test_full_graph_rewrite_reduction_with_unused_range(self):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)

View file

@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
@ -133,4 +133,19 @@ pm_lowerer = PatternMatcher([
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
def do_reduce(ctx:List[int], root:UOp):
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
ctx[0] += 1
return acc.assign(acc.alu(root.arg, root.src[0]))
just_reduce = PatternMatcher([
# do reduce
(UPat(UOps.REDUCE, name="root"), do_reduce),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, just_reduce, ctx=[0])
return sink

View file

@ -4,7 +4,7 @@ import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@ -381,12 +381,6 @@ def do_expand(root:UOp):
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
def do_reduce(ctx:List[int], root:UOp):
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
ctx[0] += 1
return acc.assign(acc.alu(root.arg, root.src[0]))
def do_contract(con:UOp):
ex = con.src[0]
# CONTRACT without EXPAND repeats the element VECTORIZED
@ -451,11 +445,6 @@ def no_vectorized_acc(acc:UOp):
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(UOps.VECTORIZE, acc.dtype, alus)
just_reduce = PatternMatcher([
# do reduce
(UPat(UOps.REDUCE, name="root"), do_reduce),
])
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.INDEX), name="alu"), no_vectorized_alu),
@ -520,9 +509,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, just_reduce, ctx=[0])
# initial symbolic + migrate indexing (remove this) + transcendental
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))