renderer Estimates uses maxel (#16485)

This commit is contained in:
George Hotz 2026-06-03 10:55:00 -07:00 committed by GitHub
commit cee472a0ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -41,8 +41,8 @@ class Estimates:
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
accessed = mem.get((buf, u.op), 0) + u.src[0].max_numel() * u.src[0].dtype.base.scalar().itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.scalar().itemsize)
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= cast(sint, u.src[0].ssimplify())
@ -52,9 +52,9 @@ class Estimates:
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.dtype.itemsize * mults
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.src[1].dtype.itemsize * mults
lds += u.max_numel() * u.src[1].dtype.scalar().itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.max_numel()
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))