mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
gep pushing
This commit is contained in:
parent
45010f7eff
commit
94d578aec5
2 changed files with 17 additions and 19 deletions
|
|
@ -170,6 +170,20 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
|||
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
return rem//(c//gcd)+quo
|
||||
|
||||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
|
|
@ -230,18 +244,7 @@ symbolic = symbolic_simple+PatternMatcher([
|
|||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
])
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
# ** combine terms (opinionated) **
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ assert sys.platform != 'win32'
|
|||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.codegen.symbolic import gep_pushing
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
|
|
@ -11,7 +12,7 @@ if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-i
|
|||
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
dsp_pm = gep_pushing+PatternMatcher([
|
||||
(((UPat.var('x').maximum(0) ^ -1).maximum(-256) ^ -1).cast(dtypes.uchar.vec(128)),
|
||||
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
|
||||
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
|
||||
|
|
@ -27,11 +28,6 @@ dsp_pm_late = PatternMatcher([
|
|||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
dsp_string = PatternMatcher([
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
|
|
@ -39,7 +35,6 @@ class DSPRenderer(ClangRenderer):
|
|||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue