mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
721ad48dc6 |
||
|
|
038f8a6c2d |
3 changed files with 20 additions and 4 deletions
|
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
def linearize(sink:UOp, tuple_order=TUPLE_ORDER) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
lst = list(sink.toposort())
|
||||
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
|
|
@ -39,7 +39,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
priorities[u] = (run_count, priority, extra)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if tuple_order else ())))}
|
||||
|
||||
# then force them to be toposorted in as close to the ideal order as possible
|
||||
heap = [(-nkey[sink], sink)]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, print_uops
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, all_same
|
||||
from tinygrad.codegen.late.linearizer import linearize
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
|
|
@ -17,6 +18,19 @@ class ScheduleItem:
|
|||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
lst = linearize(sched_sink, tuple_order=False)
|
||||
print_uops(lst)
|
||||
|
||||
schedule: list[ScheduleItem] = []
|
||||
var_vals: dict[str, int] = {}
|
||||
for k in lst:
|
||||
if k.op is Ops.KERNEL:
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(k.arg.ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
|
||||
pass
|
||||
|
||||
"""
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree: dict[UOp, int] = {}
|
||||
|
|
@ -79,5 +93,6 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
|||
for x in children[k]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
"""
|
||||
|
||||
return schedule, var_vals
|
||||
|
|
|
|||
|
|
@ -860,7 +860,8 @@ def print_uops(uops:list[UOp]):
|
|||
uops_index = {u:i for i,u in enumerate(uops)}
|
||||
for i,u in enumerate(uops):
|
||||
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
|
||||
formatted_arg = str(u.arg)[0:30].replace("\n", "")
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {formatted_arg}")
|
||||
|
||||
# ***** pattern matcher *****
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue