mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
double reduce
This commit is contained in:
parent
0416b0998d
commit
dc1469a188
2 changed files with 4 additions and 1 deletions
|
|
@ -240,6 +240,9 @@ pm_quant = symbolic+PatternMatcher([
|
|||
#(UPat(Ops.REDUCE_AXIS, name="r")+UPat.var("x"), lambda r,x: r.replace(src=(r.src[0], (r.src[1]+x) if len(r.src) == 2 else x))),
|
||||
# distribute on casted MUL
|
||||
#((UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2"), lambda v1,v2,c: (v1*v2)+(c*v2)),
|
||||
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2",), name="r"),
|
||||
lambda v1,v2,c,r: r.replace(src=(v1*v2,)) + r.replace(src=(c*v2,))),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class Estimates:
|
|||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
|
||||
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpybus"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
|
||||
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpy"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
|
||||
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue