mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
replace networkx with defaultdict
This commit is contained in:
parent
3b9b7eda48
commit
8e22d5ee67
1 changed files with 7 additions and 6 deletions
|
|
@ -7,8 +7,8 @@ import traceback
|
|||
import numpy as np
|
||||
from tinygrad.llops.ops_gpu import CL, CLProgram
|
||||
from tinygrad.helpers import prod
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
import networkx as nx
|
||||
|
||||
DEBUGCL = int(os.getenv("DEBUGCL", 0))
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
|
|
@ -19,19 +19,20 @@ class Thneed:
|
|||
self.gobj = 0
|
||||
|
||||
# build graph
|
||||
G = nx.DiGraph()
|
||||
nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []})
|
||||
for _, args in self.cl_cache:
|
||||
# output is always the first parameter
|
||||
for a in args[3:]:
|
||||
G.add_edge(a, args[2])
|
||||
nodes[a]['out_edges'].append(args[2])
|
||||
nodes[args[2]]['in_edges'].append(a)
|
||||
|
||||
# get buffers to save
|
||||
self.buffers_to_save = set()
|
||||
self.outputs = []
|
||||
for n in G.nodes:
|
||||
if len(G.in_edges(n)) == 0:
|
||||
for n in nodes.keys():
|
||||
if len(nodes[n]['in_edges']) == 0:
|
||||
self.buffers_to_save.add(n)
|
||||
if len(G.out_edges(n)) == 0:
|
||||
if len(nodes[n]['out_edges']) == 0:
|
||||
self.outputs.append(n)
|
||||
|
||||
for n in self.inputs.values():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue