mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix mem_estimate for dtype.itemsize
This commit is contained in:
parent
fe8c05b96f
commit
fd65edf595
2 changed files with 2 additions and 2 deletions
|
|
@ -341,4 +341,4 @@ class GPUCodegen(ASTKernel):
|
|||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,
|
||||
list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1],
|
||||
(self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None,
|
||||
op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs if x is not None))
|
||||
op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs if x is not None))
|
||||
|
|
|
|||
|
|
@ -212,4 +212,4 @@ class LLVMCodegen(ASTKernel):
|
|||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(4*prod(x._base_shape) for x in self.bufs))
|
||||
return ASTRunner('exec', str(module), op_estimate=self.info.flops, mem_estimate=sum(x.dtype.itemsize*prod(x._base_shape) for x in self.bufs))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue