Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
721ad48dc6
Merge branch 'master' into sched_lin 2025-11-14 18:31:09 -08:00
George Hotz
038f8a6c2d use linearizer in schedule 2025-11-10 23:42:33 -08:00
3 changed files with 20 additions and 4 deletions

View file

@ -4,7 +4,7 @@ from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
def linearize(sink:UOp) -> list[UOp]:
def linearize(sink:UOp, tuple_order=TUPLE_ORDER) -> list[UOp]:
# this is a toposort with priority
lst = list(sink.toposort())
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
@ -39,7 +39,7 @@ def linearize(sink:UOp) -> list[UOp]:
priorities[u] = (run_count, priority, extra)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if tuple_order else ())))}
# then force them to be toposorted in as close to the ideal order as possible
heap = [(-nkey[sink], sink)]

View file

@ -1,9 +1,10 @@
from typing import cast
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.uop.ops import UOp, Ops, buffers
from tinygrad.uop.ops import UOp, Ops, buffers, print_uops
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.helpers import Metadata, all_same
from tinygrad.codegen.late.linearizer import linearize
# **** ScheduleItem return type
@ -17,6 +18,19 @@ class ScheduleItem:
# **** schedule linearizer
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
lst = linearize(sched_sink, tuple_order=False)
print_uops(lst)
schedule: list[ScheduleItem] = []
var_vals: dict[str, int] = {}
for k in lst:
if k.op is Ops.KERNEL:
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
# ONE -> ONE
schedule.append(ScheduleItem(k.arg.ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata))
pass
"""
# construct the KERNEL children graph based on assigns
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
@ -79,5 +93,6 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
"""
return schedule, var_vals

View file

@ -860,7 +860,8 @@ def print_uops(uops:list[UOp]):
uops_index = {u:i for i,u in enumerate(uops)}
for i,u in enumerate(uops):
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
formatted_arg = str(u.arg)[0:30].replace("\n", "")
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {formatted_arg}")
# ***** pattern matcher *****