delete extra

This commit is contained in:
George Hotz 2025-03-31 14:35:18 +08:00
commit a640292aed
3 changed files with 2 additions and 19 deletions

View file

@ -725,12 +725,8 @@ def get_onnx_ops():
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
else:
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
#if ret.shape[0] == 1 and ret.shape[2:] == (14,14):
# ret = ret.pad(((0,0), (0,0), (0,0), (0,2)))
# print("padding", ret.shape)
if getenv("NHWC") and len(ret.shape) == 4:
in_chans = ret.shape[1]
if ret.shape[1] == 3 or in_chans%32 != 0:

View file

@ -267,10 +267,7 @@ symbolic_flat = symbolic+PatternMatcher([
# ** combine terms (opinionated) **
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
# TODO: this is an issue with DSP mul const (or is it still with the reduce lowerer change?)
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c if x.dtype.count == 1 else None),
# factorize
#((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")), lambda x,c1,y,c2: (x*(c2//c1) + y)*c1 if c2.arg%c1.arg == 0 else None),
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])
# ******** we take a small aside to "simplify_valid" to rewrite valids ********

View file

@ -117,7 +117,6 @@ conv_pm = PatternMatcher([
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
])
#dsp_pm = conv_pm+PatternMatcher([
dsp_pm = PatternMatcher([
# convert load char32 to load char128
(UPat(Ops.LOAD, (dtypes.uchar.vec(96), dtypes.uchar.vec(64), dtypes.uchar.vec(32)), src=(UPat.var("buf").cast(),), name="load"),
@ -305,19 +304,10 @@ pretty_render = PatternMatcher([
lambda v: UOp(Ops.VECTORIZE, v.dtype, src=tuple(UOp(Ops.CUSTOMI, x.dtype, src=(UOp.const(dtypes.int, x.arg),), arg="{0}") for x in v.src))),
])
vmemu_support = """
__attribute__ ((always_inline)) unsigned_char128 vmemu(unsigned_char128 *addr) {
unsigned_char128 out;
__asm__ __volatile__( "%0 = vmem(%1);" : "=v" (out) : "r"(addr) : "memory");
return out;
}
"""
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
buffer_suffix = " restrict __attribute__((align_value(128)))"
#kernel_prefix = vmemu_support + "__attribute__((noinline)) "
kernel_prefix = "__attribute__((noinline)) "
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render