mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix graphs
This commit is contained in:
parent
16a7edc775
commit
473bbd3e35
6 changed files with 30 additions and 11 deletions
15
test/external_test_assign.py
Normal file
15
test/external_test_assign.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
|
||||
if __name__ == "__main__":
|
||||
GlobalCounters.cache = []
|
||||
a = Tensor.ones(4,4)
|
||||
b = Tensor.ones(4,4)
|
||||
a += b
|
||||
print(a.numpy())
|
||||
runner, args = GlobalCounters.cache[0]
|
||||
b0, b1, b2 = args
|
||||
print(nm(b0), b0)
|
||||
print(nm(b1), b1)
|
||||
print(nm(b2), b2)
|
||||
|
|
@ -49,7 +49,7 @@ class ASTKernel:
|
|||
|
||||
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
|
||||
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
|
||||
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
|
||||
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
|
||||
|
||||
# TODO: should be optional if it's hitting a function cache
|
||||
self.processed = False
|
||||
|
|
|
|||
|
|
@ -34,6 +34,13 @@ if GRAPH:
|
|||
atexit.register(save_graph_exit)
|
||||
|
||||
global_num_max = 0
|
||||
def nm(x):
|
||||
global global_num_max
|
||||
if not hasattr(x, 'global_num'):
|
||||
setattr(x, 'global_num', global_num_max)
|
||||
global_num_max += 1
|
||||
return f"<{x.global_num}>"
|
||||
|
||||
def log_op(ret : DeviceBuffer, ast : LazyOp):
|
||||
if not DEBUG and not GRAPH: return
|
||||
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
|
|
@ -43,15 +50,8 @@ def log_op(ret : DeviceBuffer, ast : LazyOp):
|
|||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if DEBUG >= 3:
|
||||
print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
|
||||
print(f"{op} : {', '.join([f'{x.shape}-{nm(x)}' for x in inp])} -> {ret.shape}-{nm(ret)}")
|
||||
if GRAPH:
|
||||
def nm(x):
|
||||
global global_num_max
|
||||
if not hasattr(x, 'global_num'):
|
||||
setattr(x, 'global_num', global_num_max)
|
||||
global_num_max += 1
|
||||
return f"<<< {x.global_num} >>>"
|
||||
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", ProcessingOps: "#ff8080"}
|
||||
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,8 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
|||
for x in real_srcs.keys():
|
||||
if real_srcs[x] is None:
|
||||
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
|
||||
return LazyOp(MovementOps.RESHAPE, (map_buffers(real_srcs, self.op), ), self.shape)
|
||||
ast = map_buffers(real_srcs, self.op)
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
# **** lazy operations ****
|
||||
|
||||
|
|
@ -139,7 +140,8 @@ class LazyBuffer:
|
|||
|
||||
# run the ast if we still have to, and log the op
|
||||
if self.realized is None:
|
||||
self.realized = self.dbuffer.exec_ast(map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast))
|
||||
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
|
||||
self.realized = self.dbuffer.exec_ast(ast)
|
||||
log_op(self.realized, ast)
|
||||
|
||||
assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}"
|
||||
|
|
|
|||
|
|
@ -345,6 +345,7 @@ class GPUBuffer(ExplicitExecAST):
|
|||
self._backing = None
|
||||
return self._buf._cl
|
||||
|
||||
# TODO: we don't always need a hostbuf
|
||||
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class CL:
|
|||
|
||||
class CLBuffer:
|
||||
def __init__(self, size):
|
||||
if DEBUG >= 4: print(f"allocate GPU Buffer {size}")
|
||||
if len(CL.BUFFER_CACHE[size]) > 0:
|
||||
self._cl = CL.BUFFER_CACHE[size].pop()
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue