mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
164 lines
8.9 KiB
Python
164 lines
8.9 KiB
Python
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
|
import itertools, random, math, time, multiprocessing, traceback, signal
|
|
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner
|
|
from tinygrad.ops import MemBuffer, LazyOp
|
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
|
from tinygrad.dtype import ImageDType
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from collections import defaultdict
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.shape.symbolic import sym_infer
|
|
|
|
from tinygrad.codegen.kernel import Opt, OptOps
|
|
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
|
|
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
|
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
|
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
|
actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)])
|
|
actions += [
|
|
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
|
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
|
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
|
]
|
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
|
|
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
|
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
|
|
while prod(test_global_size) > max_global_size:
|
|
for j in range(len(global_size)-1,-1,-1):
|
|
if test_global_size[j] > 16:
|
|
test_global_size[j] //= 2
|
|
factor *= 2
|
|
break
|
|
return test_global_size, factor
|
|
|
|
def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
|
|
factor = 1
|
|
if global_size is not None and max_global_size is not None:
|
|
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
|
|
car = CompiledASTRunner(ast, name, "", lib, global_size, local_size).build(rdev.runtime)
|
|
tms = []
|
|
for _ in range(cnt):
|
|
if clear_l2:
|
|
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
|
|
tms.append(car(rawbufs, var_vals, wait=True, do_update_stats=False)*factor)
|
|
if early_stop is not None and early_stop < tms[-1]: break
|
|
return tms
|
|
|
|
def _compile_linearizer(rdev:Compiled, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
|
|
lin.linearize()
|
|
src = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
|
return rdev.compiler(src), lin.global_size, lin.local_size
|
|
|
|
def _try_compile_linearized_w_idx(x):
|
|
try: return (x[0], _compile_linearizer(cast(Compiled, Device[Device.DEFAULT]), x[1], "test"))
|
|
except Exception:
|
|
if DEBUG >= 4: traceback.print_exc()
|
|
return (x[0], None)
|
|
|
|
# workers should ignore ctrl c
|
|
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
|
|
# *** external API ***
|
|
|
|
# get (scrap) buffers for timing the linearizer
|
|
def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
|
|
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
|
for x in lin.membufs: bufsts[x.idx].append(x)
|
|
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
|
for k,lx in bufsts.items():
|
|
rawbufs[k] = Buffer(Device.DEFAULT, prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
|
assert all(r is not None for r in rawbufs)
|
|
return cast(List[Buffer], rawbufs)
|
|
|
|
# get dictionary of all possible actions
|
|
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
|
acted_lins = {0:lin} if include_0 else {}
|
|
for i,a in enumerate(actions):
|
|
if a.axis is not None and a.axis >= lin.shape_len: continue
|
|
if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
|
lin2 = lin.copy()
|
|
try:
|
|
lin2.apply_opt(a)
|
|
up, lcl = 1, 1
|
|
for s,c in zip(lin2.full_shape, lin2.colors()):
|
|
if c in {"magenta", "yellow"}: up *= s
|
|
if c in {"cyan", "green", "white"}: lcl *= s
|
|
if up > 256 or lcl > 256: continue
|
|
acted_lins[i+1] = lin2
|
|
except Exception:
|
|
pass
|
|
return acted_lins
|
|
|
|
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
|
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
|
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
|
ret = lin.copy()
|
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
|
return ret
|
|
|
|
beam: List[Tuple[Linearizer, float]] = []
|
|
seen_libs = set()
|
|
|
|
default_parallel = 1 if Device.DEFAULT in {"CUDA", "HIP"} else 0
|
|
pool = multiprocessing.Pool(multiprocessing.cpu_count(), _init_worker) if getenv("PARALLEL", default_parallel) else None
|
|
|
|
try:
|
|
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
|
exiting, st = False, time.perf_counter()
|
|
dev = Device[Device.DEFAULT]
|
|
assert isinstance(dev, Compiled)
|
|
while not exiting:
|
|
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
|
timed_lins: List[Tuple[Linearizer, float]] = []
|
|
for i,proc in (pool.imap_unordered(_try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(_try_compile_linearized_w_idx, enumerate(acted_lins))): # noqa: E501
|
|
if proc is None: continue
|
|
lib, global_size, local_size = proc
|
|
if lib in seen_libs: continue
|
|
seen_libs.add(lib)
|
|
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
|
timed_lins.append((acted_lins[i], min(tms)))
|
|
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
|
|
|
# done
|
|
opts = sorted(timed_lins, key=lambda x: x[1])
|
|
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
|
|
if not exiting: beam = opts[:amt]
|
|
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
|
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
|
if pool is not None: pool.close() # the pool is closed
|
|
except KeyboardInterrupt as e:
|
|
if pool is not None: pool.terminate()
|
|
raise e
|
|
|
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
|
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
|
return beam[0][0]
|
|
|
|
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
|
|
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
|
MAX_WORKGROUP = 1024
|
|
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
|
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
|
def try_exec(local_size):
|
|
try:
|
|
return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
|
except Exception:
|
|
return float('inf')
|
|
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
|
return ret[1]
|
|
|
|
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
|
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
|
|
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
|
|
|
dev = Device[Device.DEFAULT]
|
|
assert isinstance(dev, Compiled)
|
|
|
|
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
|
lib, global_size, local_size = _compile_linearizer(dev, lin)
|
|
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
|
|
|
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
|
return min(tms)
|