replace networkx with defaultdict

This commit is contained in:
George Hotz 2022-10-20 19:36:43 -07:00
commit 8e22d5ee67

View file

@ -7,8 +7,8 @@ import traceback
import numpy as np
from tinygrad.llops.ops_gpu import CL, CLProgram
from tinygrad.helpers import prod
from collections import defaultdict
import pyopencl as cl
import networkx as nx
DEBUGCL = int(os.getenv("DEBUGCL", 0))
FLOAT16 = int(os.getenv("FLOAT16", 0))
@ -19,19 +19,20 @@ class Thneed:
self.gobj = 0
# build graph
G = nx.DiGraph()
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
for _, args in self.cl_cache:
# output is always the first parameter
for a in args[3:]:
G.add_edge(a, args[2])
nodes[a]['out_edges'].append(args[2])
nodes[args[2]]['in_edges'].append(a)
# get buffers to save
self.buffers_to_save = set()
self.outputs = []
for n in G.nodes:
if len(G.in_edges(n)) == 0:
for n in nodes.keys():
if len(nodes[n]['in_edges']) == 0:
self.buffers_to_save.add(n)
if len(G.out_edges(n)) == 0:
if len(nodes[n]['out_edges']) == 0:
self.outputs.append(n)
for n in self.inputs.values():