double reduce

This commit is contained in:
George Hotz 2025-03-21 17:33:48 +08:00
commit dc1469a188
2 changed files with 4 additions and 1 deletions

View file

@ -240,6 +240,9 @@ pm_quant = symbolic+PatternMatcher([
#(UPat(Ops.REDUCE_AXIS, name="r")+UPat.var("x"), lambda r,x: r.replace(src=(r.src[0], (r.src[1]+x) if len(r.src) == 2 else x))),
# distribute on casted MUL
#((UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2"), lambda v1,v2,c: (v1*v2)+(c*v2)),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2",), name="r"),
lambda v1,v2,c,r: r.replace(src=(v1*v2,)) + r.replace(src=(c*v2,))),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View file

@ -77,7 +77,7 @@ class Estimates:
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpybus"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpy"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
@dataclass