This commit is contained in:
George Hotz 2025-03-21 16:28:08 +08:00
commit f66b03f0a6
4 changed files with 15 additions and 5 deletions

View file

@ -58,7 +58,7 @@ if __name__ == "__main__":
return None
return {"input": img.numpy()}
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8,
extra_options={"ActivationSymmetric": False})
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))

View file

@ -726,7 +726,7 @@ def get_onnx_ops():
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
if out_dtype == dtypes.uchar:
if out_dtype == dtypes.uchar and False:
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
else:

View file

@ -116,6 +116,14 @@ if __name__ == "__main__":
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 24))
k.apply_opt(Opt(OptOps.UPCAST, 0, 16))
elif knum == 11:
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 144))
#k.apply_opt(Opt(OptOps.UPCAST, 0, 8))
elif knum == 14:
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 192))
k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
elif knum == 37:
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 384))

View file

@ -36,11 +36,13 @@ def multi_mul(a0, a1, b0, b1, c0, c1, d0, d1, acc=None):
m1 = UOp(Ops.CAT, dt2, src=(a1.src[0],b1.src[0],c1.src[0],d1.src[0])).gep(swizzle)
simp_m1 = m1.simplify()
if simp_m1.op is Ops.GEP and simp_m1.arg == simp_m1.arg[0:4]*32:
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4]).bitcast(dtypes.uint)
# Vx32.w+=vrmpy(Vu32.ub,Rt32.b) -> __builtin_HEXAGON_V6_vrmpybus_acc
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_acc_128B({0}, {1}, {2})")
else:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpybus_128B({0}, {1})")
if acc is not None:
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpybusv_acc_128B({0}, {1}, {2})")
else: