correct flops

This commit is contained in:
George Hotz 2025-03-20 21:46:13 +08:00
commit 2ed30f5366

View file

@ -77,7 +77,7 @@ class Estimates:
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpybus"): flops += 128*mults*(2 if 'acc' in u.arg else 1)
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpybus"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
@dataclass