mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
early dtype decomp (#16718)
* early dtype decomp * simplify * cleanup * that goes there * doing too much * stupid symbolic rules
This commit is contained in:
parent
116045cc8e
commit
5a2b3b7b06
3 changed files with 21 additions and 10 deletions
|
|
@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
|||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
|
|
@ -113,18 +113,20 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# optional pre matcher
|
||||
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
||||
|
||||
# decompositions
|
||||
# floordiv+mod / dtype decomp (early)
|
||||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
||||
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
|
||||
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||
|
||||
# GEP/STACK stuff
|
||||
# instruction selection decompositions
|
||||
pm_decomp = pm_decomp+\
|
||||
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
|
||||
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
|
||||
|
||||
# this is new style (TODO: this should all be removed)
|
||||
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
|
||||
|
||||
# this is new style
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||
|
||||
|
|
|
|||
|
|
@ -454,7 +454,8 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
|
|||
|
||||
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
|
||||
@functools.cache
|
||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
||||
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
|
||||
# these are rewrites that make things simpler
|
||||
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
|
||||
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
|
||||
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
|
||||
|
|
@ -463,6 +464,11 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
|
|||
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
||||
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
||||
return PatternMatcher(pat)
|
||||
|
||||
@functools.cache
|
||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
||||
pat: list[tuple[UPat, Callable]] = []
|
||||
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
||||
lambda x,y: (x | y).logical_not())]
|
||||
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
# TODO: combine this with "# rules for threefry" below
|
||||
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
|
||||
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
|
||||
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
|
||||
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# ** constant folding **
|
||||
|
|
@ -160,6 +162,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
||||
# ** simple where folding **
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue