mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
delete extra
This commit is contained in:
parent
2f48c12441
commit
a640292aed
3 changed files with 2 additions and 19 deletions
|
|
@ -725,12 +725,8 @@ def get_onnx_ops():
|
|||
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
|
||||
else:
|
||||
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
|
||||
|
||||
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
|
||||
|
||||
#if ret.shape[0] == 1 and ret.shape[2:] == (14,14):
|
||||
# ret = ret.pad(((0,0), (0,0), (0,0), (0,2)))
|
||||
# print("padding", ret.shape)
|
||||
|
||||
if getenv("NHWC") and len(ret.shape) == 4:
|
||||
in_chans = ret.shape[1]
|
||||
if ret.shape[1] == 3 or in_chans%32 != 0:
|
||||
|
|
|
|||
|
|
@ -267,10 +267,7 @@ symbolic_flat = symbolic+PatternMatcher([
|
|||
# ** combine terms (opinionated) **
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
# TODO: this is an issue with DSP mul const (or is it still with the reduce lowerer change?)
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if x.dtype.count == 1 else None),
|
||||
# factorize
|
||||
#((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")), lambda x,c1,y,c2: (x*(c2//c1) + y)*c1 if c2.arg%c1.arg == 0 else None),
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
||||
|
|
|
|||
|
|
@ -117,7 +117,6 @@ conv_pm = PatternMatcher([
|
|||
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
])
|
||||
|
||||
#dsp_pm = conv_pm+PatternMatcher([
|
||||
dsp_pm = PatternMatcher([
|
||||
# convert load char32 to load char128
|
||||
(UPat(Ops.LOAD, (dtypes.uchar.vec(96), dtypes.uchar.vec(64), dtypes.uchar.vec(32)), src=(UPat.var("buf").cast(),), name="load"),
|
||||
|
|
@ -305,19 +304,10 @@ pretty_render = PatternMatcher([
|
|||
lambda v: UOp(Ops.VECTORIZE, v.dtype, src=tuple(UOp(Ops.CUSTOMI, x.dtype, src=(UOp.const(dtypes.int, x.arg),), arg="{0}") for x in v.src))),
|
||||
])
|
||||
|
||||
vmemu_support = """
|
||||
__attribute__ ((always_inline)) unsigned_char128 vmemu(unsigned_char128 *addr) {
|
||||
unsigned_char128 out;
|
||||
__asm__ __volatile__( "%0 = vmem(%1);" : "=v" (out) : "r"(addr) : "memory");
|
||||
return out;
|
||||
}
|
||||
"""
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
#kernel_prefix = vmemu_support + "__attribute__((noinline)) "
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue