fix graphs

This commit is contained in:
George Hotz 2023-02-09 09:40:46 -06:00
commit 473bbd3e35
6 changed files with 30 additions and 11 deletions

View file

@ -0,0 +1,15 @@
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.graph import nm
if __name__ == "__main__":
GlobalCounters.cache = []
a = Tensor.ones(4,4)
b = Tensor.ones(4,4)
a += b
print(a.numpy())
runner, args = GlobalCounters.cache[0]
b0, b1, b2 = args
print(nm(b0), b0)
print(nm(b1), b1)
print(nm(b2), b2)

View file

@ -49,7 +49,7 @@ class ASTKernel:
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
# TODO: should be optional if it's hitting a function cache
self.processed = False

View file

@ -34,6 +34,13 @@ if GRAPH:
atexit.register(save_graph_exit)
global_num_max = 0
def nm(x):
global global_num_max
if not hasattr(x, 'global_num'):
setattr(x, 'global_num', global_num_max)
global_num_max += 1
return f"<{x.global_num}>"
def log_op(ret : DeviceBuffer, ast : LazyOp):
if not DEBUG and not GRAPH: return
op : List[Op] = [x.op for x in get_lazyops(ast)]
@ -43,15 +50,8 @@ def log_op(ret : DeviceBuffer, ast : LazyOp):
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if DEBUG >= 3:
print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
print(f"{op} : {', '.join([f'{x.shape}-{nm(x)}' for x in inp])} -> {ret.shape}-{nm(ret)}")
if GRAPH:
def nm(x):
global global_num_max
if not hasattr(x, 'global_num'):
setattr(x, 'global_num', global_num_max)
global_num_max += 1
return f"<<< {x.global_num} >>>"
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore

View file

@ -61,7 +61,8 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
return LazyOp(MovementOps.RESHAPE, (map_buffers(real_srcs, self.op), ), self.shape)
ast = map_buffers(real_srcs, self.op)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
# **** lazy operations ****
@ -139,7 +140,8 @@ class LazyBuffer:
# run the ast if we still have to, and log the op
if self.realized is None:
self.realized = self.dbuffer.exec_ast(map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast))
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
self.realized = self.dbuffer.exec_ast(ast)
log_op(self.realized, ast)
assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}"

View file

@ -345,6 +345,7 @@ class GPUBuffer(ExplicitExecAST):
self._backing = None
return self._buf._cl
# TODO: we don't always need a hostbuf
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
@staticmethod

View file

@ -34,6 +34,7 @@ class CL:
class CLBuffer:
def __init__(self, size):
if DEBUG >= 4: print(f"allocate GPU Buffer {size}")
if len(CL.BUFFER_CACHE[size]) > 0:
self._cl = CL.BUFFER_CACHE[size].pop()
else: