Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
a07c9da26b
Merge branch 'master' into remove_programspec 2025-12-23 15:21:46 -05:00
George Hotz
816a359a3c do programspec removal 2025-12-23 15:21:02 -05:00
25 changed files with 292 additions and 193 deletions

View file

@ -1,12 +1,11 @@
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context, to_mv
from tinygrad.helpers import Context, to_mv, to_function_name
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import Ops, UOp
import json
from collections import OrderedDict
@ -15,8 +14,13 @@ EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn: ProgramSpec = ji.prg.p
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
prg: UOp = ji.prg.p
name = prg.src[0].arg.name
function_name = to_function_name(name)
src = prg.src[3].arg
global_size, local_size = prg.sizes
prg_vars = prg.variables()
functions[function_name] = src # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.bufs):
key = id(arg)
@ -28,8 +32,8 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
cargs += [var for var in prg_vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((function_name, cargs, list(global_size) if global_size else None, list(local_size) if local_size else None))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save

View file

@ -4,7 +4,8 @@ import triton.language as tl
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.helpers import getenv
np.set_printoptions(suppress=True)
@ -85,9 +86,12 @@ if __name__ == "__main__":
# remove debug sections
src = src.split("\t.file")[0]
assert '.extern .shared' not in src
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
# Create linearized uops with SPECIAL for global/local sizes
global_size = [M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1]
local_size = [32*compiled.metadata.num_warps, 1, 1]
uops = [UOp(Ops.SPECIAL, arg=('g', i), src=(UOp.const(dtypes.int, global_size[i]),)) for i in range(3)]
uops += [UOp(Ops.SPECIAL, arg=('l', i), src=(UOp.const(dtypes.int, local_size[i]),)) for i in range(3)]
prg = UOp.new_program("matmul_kernel", src, Device.DEFAULT, si.ast, uops)
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
tflops = []
for i in range(5):

View file

@ -1,10 +1,10 @@
import numpy as np
import unittest
import subprocess, struct, math
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
def get_output(asm:str, n_threads:int=1):
@ -22,7 +22,9 @@ def get_output(asm:str, n_threads:int=1):
*(data0_1+l) = res;
}}"""
t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize()
prg = ProgramSpec("test", src, Device.DEFAULT, UOp.sink(t), global_size=[1, 1, 1], local_size=[n_threads, 1, 1])
# Create linearized uops with SPECIAL for local size
uops = [UOp(Ops.SPECIAL, arg=('l', 0), src=(UOp.const(dtypes.int, n_threads),))]
prg = UOp.new_program("test", src, Device.DEFAULT, UOp.sink(t.uop), uops)
car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([t.uop.buffer], {}, wait=True)

View file

@ -1,5 +1,4 @@
# ruff: noqa: E501 E712 F401
from dataclasses import replace
from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import
@ -89,7 +88,9 @@ renderer = Device.default.renderer
allocator = Device.default.allocator
ps = get_program(ast, renderer)
cr = CompiledRunner(replace(ps, device=Device.DEFAULT))
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
ps = ps.replace(src=(ps.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *ps.src[2:]))
cr = CompiledRunner(ps)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
# print(len(gs))

View file

@ -9,7 +9,7 @@ if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
try:
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.renderer import Renderer
from tinygrad.engine.realize import get_program
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.codegen.opt import Opt
@ -51,15 +51,20 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
def replay_get_program(p:UOp, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# the ast.arg is non None if we are inside of search.py
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None)
input_ast = ast.replace(arg=replace(sink_arg, name=p.name))
# p is a PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
p_name = p.src[0].arg.name
p_applied_opts = p.src[0].arg.applied_opts
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p_applied_opts if BEAM>=1 else None)
input_ast = ast.replace(arg=replace(sink_arg, name=p_name))
p2 = get_program(input_ast, renderer=renderer)
def to_str(ret:ProgramSpec) -> str:
def to_str(ret:UOp) -> str:
# PYTHON renderer pickles UOps, first unpickle and decode here
if p.device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret.src))])
return ret.src
ret_src = ret.src[3].arg
ret_device = ret.device
if ret_device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret_src))])
return ret_src
# properly color the name arg
ast_repr = codecs.decode(str(input_ast), "unicode_escape")
return to_str(p2), to_str(p), (ast_repr, renderer)

View file

@ -63,7 +63,7 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer())
prog = PythonProgram("run", PythonCompiler().compile(prg.src))
prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg)) # source code is in src[3].arg
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]

View file

@ -1,10 +1,9 @@
import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import DType
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import AMX, AMD_LLVM, CPU_LLVM, Context
@ -44,7 +43,10 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
if dtype_in == dtypes.bfloat16: r = r.float()
realized_ast, bufs = helper_realized_ast(r)
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT))
p = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts)
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
p = p.replace(src=(p.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *p.src[2:]))
prg = CompiledRunner(p)
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg.exec(bufs)

View file

@ -28,7 +28,7 @@ class TestFusionOp(unittest.TestCase):
sched = a.schedule()
sched[-1].lower()
self.assertLess(time.perf_counter()-st, 2.0)
assert len(sched[-1].prg.p.src.splitlines()) < 250
assert len(sched[-1].prg.p.src[3].arg.splitlines()) < 250 # source code is in src[3].arg
def test_recursive_add_cmp(self):
st = time.perf_counter()

View file

@ -1,6 +1,5 @@
import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
@ -168,8 +167,8 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
def test_upcast_with_locals_cpu(self):
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops
self.assertEqual(len(prg.src.split("for")), 5)
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)])
self.assertEqual(len(prg.src[3].arg.split("for")), 5) # source code is in src[3].arg
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@ -517,7 +516,11 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
device = real_bufs[0].device
wanna_output = [np.array(x).flatten() for x in wanna_output]
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts), device=device))
def get_prg(opts):
prg = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=device), *prg.src[2:]))
return CompiledRunner(prg)
def check_opt(opts):
prg = get_prg(opts=opts)

View file

@ -14,7 +14,7 @@ class TestOpts(unittest.TestCase):
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP:
prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer)
self.assertIn('float4', prg.src)
self.assertIn('float4', prg.src[3].arg) # source code is in src[3].arg
if __name__ == '__main__':
unittest.main()

View file

@ -1,6 +1,5 @@
import unittest
import numpy as np
from dataclasses import replace
from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType
from tinygrad.engine.realize import CompiledRunner, get_program
@ -18,8 +17,8 @@ def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs]
prg = replace(prg, device=Device.DEFAULT)
if local_size is not None: prg = replace(prg, local_size=local_size)
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *prg.src[2:]))
ei = CompiledRunner(prg)
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
@ -72,7 +71,7 @@ class TestCStyleFailures(unittest.TestCase):
schedule = ret.schedule()
assert len(schedule) == 1
schedule[0].lower()
src = schedule[0].prg.p.src
src = schedule[0].prg.p.src[3].arg # source code is in src[3].arg
self.assertEqual("("*5 not in src, should_strip_paren)
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)

View file

@ -15,7 +15,6 @@ from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import get_uops
from dataclasses import replace
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u)
@ -27,7 +26,9 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
def _uops_to_prg(uops_list):
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
return CompiledRunner(replace(prg, device=Device.DEFAULT))
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *prg.src[2:]))
return CompiledRunner(prg)
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg))

View file

@ -2,7 +2,6 @@ import unittest
from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes
@ -166,24 +165,26 @@ class TestStatsOptimized(unittest.TestCase):
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast
def check_gemm(self, p:ProgramSpec, extra_flops=0):
def check_gemm(self, p:UOp, extra_flops=0):
#p.uops.print()
#print(p.src)
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
self.assertEqual(p.estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs
self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats
#print(p.src[3].arg)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
self.assertEqual(estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs
self.assertEqual(estimates.mem, 3*N*N*4) # 3 NxN mats with floats
def test_gemm(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[])
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4 + 4*N*N)
def test_gemm_tc_unroll(self):
try:
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
except KernelOptError:
raise unittest.SkipTest("no tensor cores")
print(p.src)
print(p.src[3].arg)
self.check_gemm(p)
# this is a good lesson about why UPCASTing is a good idea
@ -191,13 +192,15 @@ class TestStatsOptimized(unittest.TestCase):
def test_gemm_one_upcasted(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4)])
self.check_gemm(p)
self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer,
opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)])
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted_locals(self):
try:
@ -206,7 +209,8 @@ class TestStatsOptimized(unittest.TestCase):
except KernelOptError:
raise unittest.SkipTest("no locals")
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_group(self):
try:
@ -216,13 +220,15 @@ class TestStatsOptimized(unittest.TestCase):
SZ = N*N*4
# NOTE: these are sort of wrong. they aren't honoring the IF statement
self.check_gemm(p, extra_flops=SZ*4)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
def test_reduce(self):
p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[])
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
self.assertEqual(p.estimates.ops, N*N)
self.assertEqual(p.estimates.mem, N*N*4 + 4)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
self.assertEqual(estimates.ops, N*N)
self.assertEqual(estimates.mem, N*N*4 + 4)
def test_reduce_group(self):
try:
@ -230,7 +236,8 @@ class TestStatsOptimized(unittest.TestCase):
except KernelOptError:
raise unittest.SkipTest("no locals")
# NOTE: these are wrong, they don't respect the if statement
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -3,7 +3,6 @@ import textwrap
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites
from tinygrad.renderer import ProgramSpec
from tinygrad.helpers import TracingKey
from tinygrad.engine.realize import ExecItem, CompiledRunner
@ -51,9 +50,9 @@ amdhsa.kernels:
.end_amdgpu_metadata
"""
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
def run_asm(name:str, src:str) -> ProgramSpec:
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.src[0].arg.name, ret=ret))
def run_asm(name:str, src:str) -> UOp:
prg = UOp.new_program(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK), [])
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
ei.run()
return prg

View file

@ -133,11 +133,11 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
])
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,))
sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")

View file

@ -1,13 +1,11 @@
import functools, math, time, multiprocessing, traceback, signal, atexit
from dataclasses import replace
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
from tinygrad.uop.ops import sym_infer, AxisType, pyrender, UOp
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str, unwrap
from tinygrad.helpers import IGNORE_BEAM_CACHE
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner, get_program, Estimates
from tinygrad.codegen.opt.postrange import Scheduler
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
@ -34,19 +32,22 @@ def get_test_global_size(global_size, max_global_size, var_vals):
break
return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
def _time_program(p:UOp, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
global_size, local_size = p.sizes
if allow_test_size and global_size is not None and max_global_size is not None:
test_global_size, factor = get_test_global_size(global_size, max_global_size, var_vals)
# NOTE: we can't modify p.sizes directly, but optimize_local_size doesn't run in BEAM so this is ok
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt):
if clear_l2:
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
p_device = p.device
assert isinstance(p_device, str), f"PROGRAM device must be a string, not {type(p_device)}"
if hasattr(dev:=Device[p_device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor)
@ -58,7 +59,7 @@ def timeout_handler(signum, frame):
if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT")
raise TimeoutException()
def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[UOp, bytes, float]|None]:
if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout
@ -66,12 +67,12 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[
ret = None
try:
p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].ren)
assert p.uops is not None, "uop list wasn't generated?"
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
uops = list(p.src[2].src) # LINEAR is src[2]
if len(uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(uops)=}, {uops_max=}")
raise RuntimeError("too many uops")
st = time.perf_counter()
prog = compiler.compile(p.src)
prog = compiler.compile(p.src[3].arg) # SOURCE is src[3]
et = time.perf_counter() - st
ret = (p, prog, et)
except RuntimeError:
@ -154,7 +155,8 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
p, lib, compile_et = proc
if lib in seen_libs: continue
# filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
least_compute_ops = min(this_compute_ops:=sym_infer(estimates.ops, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}")
continue
@ -167,7 +169,7 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
raise
timed.append((candidates[i], min(tms)))
if BEAM_DEBUG > 1:
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(unwrap(p.uops)):5d} uops",
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(p.src[2].src):5d} uops",
f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed[-1][1], w=12)} run",
f" {len(timed):4d}/{len(candidates):4d} {timed[-1][0].colored_shape()}")
elif DEBUG >= 2:

View file

@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, sint
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
@ -78,13 +78,18 @@ class GraphRunner(Runner):
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
self.launch_dims_base:dict[int, tuple[tuple[sint, ...], tuple[sint, ...]]] = {}
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
def is_sym_dim(dim) -> bool: return dim is not None and not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys())
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
sym_dims: list[tuple[sint, ...]] = []
for ji in jit_cache:
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.sizes
if local_size is not None and is_sym_dim(local_size): sym_dims.append(tuple(local_size))
if global_size is not None and is_sym_dim(global_size): sym_dims.append(tuple(global_size))
self.symbolic_dims = dedup(sym_dims)
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
estimates = Estimates()
@ -92,13 +97,15 @@ class GraphRunner(Runner):
assert ji.prg is not None
estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner):
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
prg_vars = ji.prg.p.variables()
if prg_vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(prg_vars) if v.expr not in ji.fixedvars]
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
global_size, local_size = ji.prg.p.sizes
global_dim_idx, local_dim_idx = find_symbolic_dim(global_size), find_symbolic_dim(local_size)
if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
assert global_size is not None and local_size is not None
self.launch_dims_base[j] = (tuple(global_size), tuple(local_size))
# used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: dict[int, Any] = {}

View file

@ -1,28 +1,29 @@
from typing import cast, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from dataclasses import dataclass, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap
from tinygrad.helpers import unwrap, to_function_name
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.renderer import Renderer, Estimates
from tinygrad.codegen import full_rewrite_to_program
from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.src[0].arg.name, (to_function_name(ret.src[0].arg.name), ret.src[0]), ret=ret),
replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> UOp:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Transform an AST into a PROGRAM UOp. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The ProgramSpec of the program.
The PROGRAM UOp with structure (SINK, DEVICE, LINEAR, SOURCE).
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
@ -35,15 +36,11 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
prg = full_rewrite_to_program(ast, renderer)
# SINK/LINEAR/SOURCE
sink, linear, source = prg.src
# print
if DEBUG >= 6: print_uops(list(linear.src))
if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src),
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
return prg
# **************** Runners ****************
@ -72,28 +69,36 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
return ret[1]
class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
if DEBUG >= 3: print(p.applied_opts)
if DEBUG >= 4: print(p.src)
self.p:ProgramSpec = p
def __init__(self, p:UOp, precompiled:bytes|None=None, prg=None):
assert p.op is Ops.PROGRAM, f"CompiledRunner requires PROGRAM UOp, not {p.op}"
self.p:UOp = p
dev = p.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
name = p.src[0].arg.name
src = p.src[3].arg
function_name = to_function_name(name)
if DEBUG >= 3: print(p.src[0].arg.applied_opts)
if DEBUG >= 4: print(src)
if precompiled is not None: self.lib = precompiled
else:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
with cpu_profile(TracingKey(f"compile {name}", (function_name,)), "TINY"):
self.lib = Device[dev].compiler.compile_cached(src)
if DEBUG >= 7: Device[dev].compiler.disassemble(self.lib)
self._prg = Device[dev].runtime(function_name, self.lib) if prg is None else prg
super().__init__(name, dev, Estimates.from_uops(list(p.src[2].src), ignore_indexing=True))
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {}
has_local = Device[self.p.device].renderer.has_local
dev = self.p.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
has_local = Device[dev].renderer.has_local
global_size, local_size = self.p.launch_dims(var_vals)
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
sym_global_size, sym_local_size = self.p.sizes
if has_local and global_size is not None and local_size is None and sym_global_size is not None and all_int(sym_global_size):
local_size = optimize_local_size(self._prg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=global_size, local_size=local_size)
global_size = [g//l if g%l == 0 else int(g/l) for g,l in zip(global_size, local_size)]
lra = {}
if global_size:
lra['global_size'] = tuple(global_size)
@ -101,7 +106,7 @@ class CompiledRunner(Runner):
if local_size:
lra['local_size'] = tuple(local_size)
assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait)
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.variables()), wait=wait)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
@ -158,10 +163,14 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
# update device in PROGRAM UOp
new_p = bret.p.replace(src=(bret.p.src[0], UOp(Ops.DEVICE, arg=device), *bret.p.src[2:]))
method_cache[ckey] = ret = CompiledRunner(new_p, bret.lib)
else:
prg: ProgramSpec = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
prg = get_program(ast, Device[device].renderer)
# update device in PROGRAM UOp to match the actual device
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=device), *prg.src[2:]))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg)
return ret
# **************** lowering functions ****************

View file

@ -1,12 +1,10 @@
from __future__ import annotations
from typing import Callable, cast
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
from dataclasses import dataclass
from tinygrad.helpers import prod
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
@dataclass(frozen=True)
class Estimates:
@ -57,62 +55,6 @@ class Estimates:
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))
@dataclass
class ProgramSpec:
name:str
src:str
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
# filled in from uops (if we have uops)
global_size:list[int]|None=None
local_size:list[int]|None=None
vars:list[Variable]=field(default_factory=list)
globals:list[int]=field(default_factory=list)
outs:list[int]=field(default_factory=list)
ins:list[int]=field(default_factory=list)
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
def __post_init__(self):
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (self.outs if u.op is Ops.STORE else self.ins).append(buf.arg)
# TODO: can else happen?
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0] == 'l' else self.global_size
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self.ins = sorted(dedup(self.ins))
self._ran_post_init = True
@functools.cached_property
def estimates(self) -> Estimates:
return Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
@property
def applied_opts(self) -> tuple[Opt, ...]|None:
if self.uops is None: return None
assert self.uops[-1].op is Ops.SINK, self.uops[-1].op
return self.uops[-1].arg.applied_opts
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
class Renderer:
device: str = ""
suffix: str = ""

View file

@ -22,12 +22,14 @@ class CUDAGraph(MultiGraphRunner):
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals)
assert global_size is not None and local_size is not None
new_node = cuda.CUgraphNode()
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
prg_vars = ji.prg.p.variables()
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in prg_vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))

View file

@ -39,7 +39,7 @@ class HCQGraph(MultiGraphRunner):
if not isinstance(ji.prg, CompiledRunner): continue
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.prog_vars(), argsbuf)
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@ -159,7 +159,8 @@ class HCQGraph(MultiGraphRunner):
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
global_size, local_size = ji.prg.p.sizes
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(global_size or (1,1,1)), tuple(local_size or (1,1,1)))
elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):

View file

@ -42,9 +42,11 @@ class MetalGraph(GraphRunner):
if b is not None and b not in input_rawbuffers:
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.variables()):
icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
assert global_size is not None and local_size is not None
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
icb_command.setBarrier()

View file

@ -609,11 +609,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@staticmethod
def new_program(name:str, src:str, device:str, ast:UOp, uops:list[UOp]):
"""Create a PROGRAM UOp from raw components."""
sink = ast.replace(arg=KernelInfo(name=name)) if ast.arg is None else ast.replace(arg=ast.arg.replace(name=name))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=tuple(uops)), UOp(Ops.SOURCE, arg=src)))
@property
def device(self) -> str|tuple[str, ...]: return unwrap(self._device)
@recursive_property
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.PROGRAM: return self.src[1].arg # PROGRAM src[1] is DEVICE
if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT:
@ -624,6 +630,104 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
for x in self.src:
if x._device is not None: return x._device
return None
# *** PROGRAM UOp properties ***
@property
def uops(self) -> list[UOp]:
"""Linearized uops list. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"uops only valid for PROGRAM, not {self.op}"
return list(self.src[2].src)
@property
def name(self) -> str:
"""Kernel name. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"name only valid for PROGRAM, not {self.op}"
return self.src[0].arg.name
@property
def applied_opts(self):
"""Applied optimizations. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"applied_opts only valid for PROGRAM, not {self.op}"
return self.src[0].arg.applied_opts
@functools.cached_property
def estimates(self):
"""Estimates for this program. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"estimates only valid for PROGRAM, not {self.op}"
from tinygrad.renderer import Estimates
return Estimates.from_uops(self.uops, ignore_indexing=True)
@property
def globals(self) -> list[int]:
"""DEFINE_GLOBAL arg indices from linearized uops. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"globals only valid for PROGRAM, not {self.op}"
return [u.arg for u in self.src[2].src if u.op is Ops.DEFINE_GLOBAL]
@property
def outs(self) -> list[int]:
"""Buffer indices written to (STORE). Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"outs only valid for PROGRAM, not {self.op}"
ret = []
for u in self.src[2].src:
if u.op is Ops.STORE:
idx = u.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
if idx.op is Ops.INDEX and idx.src[0].op is Ops.DEFINE_GLOBAL: ret.append(idx.src[0].arg)
return sorted(set(ret))
@property
def ins(self) -> list[int]:
"""Buffer indices read from (LOAD). Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"ins only valid for PROGRAM, not {self.op}"
ret = []
for u in self.src[2].src:
if u.op is Ops.LOAD:
idx = u.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
if idx.op is Ops.INDEX and idx.src[0].op is Ops.DEFINE_GLOBAL: ret.append(idx.src[0].arg)
return sorted(set(ret))
@functools.cached_property
def sizes(self) -> tuple[list[sint]|None, list[sint]|None]:
"""Get (global_size, local_size) which may contain symbolic values. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"sizes only valid for PROGRAM, not {self.op}"
from tinygrad.device import Device
dev = self.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
ren = Device[dev].renderer
global_size:list[sint]|None = [1,1,1] if ren.has_local or ren.has_threads else None
local_size:list[sint]|None = [1,1,1] if ren.has_local else None
for u in self.src[2].src:
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
if special_size is not None: special_size[int(u.arg[-1])] = cast(sint, u.src[0].ssimplify())
return global_size, local_size
def launch_dims(self, var_vals:dict[str, int]) -> tuple[list[int]|None, list[int]|None]:
"""Resolve global/local sizes to concrete ints. Only valid for PROGRAM."""
global_size, local_size = self.sizes
global_ret = [sym_infer(sz, var_vals) for sz in global_size] if global_size is not None else None
local_ret = [sym_infer(sz, var_vals) for sz in local_size] if local_size is not None else None
return global_ret, local_ret
@property
def global_size(self) -> list[sint]|None:
"""Global size (may be symbolic). Only valid for PROGRAM."""
return self.sizes[0]
@property
def local_size(self) -> list[sint]|None:
"""Local size (may be symbolic). Only valid for PROGRAM."""
return self.sizes[1]
def prog_vars(self) -> list:
"""Variables list for this program. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"prog_vars only valid for PROGRAM, not {self.op}"
# Get variables from the linearized uops
linear_uops = self.src[2]
return linear_uops.variables()
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self

View file

@ -249,10 +249,10 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR, SOURCE)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),

View file

@ -10,7 +10,7 @@ from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
@ -39,7 +39,7 @@ def get_rewrites(t:RewriteTrace) -> list[dict]:
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec):
if isinstance(k.ret, UOp) and k.ret.op is Ops.PROGRAM:
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), k.ret))
steps.append(create_step("View Program", ("/code", i, len(steps)), k.ret))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), k.ret))
@ -161,9 +161,10 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t
if isinstance(p:=trace.keys[ref].ret, UOp) and p.op is Ops.PROGRAM and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
flops = sym_infer(estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
(f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem",
(f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]
@ -425,10 +426,12 @@ def get_render(i:int, j:int, fmt:str) -> dict:
data = ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
if fmt == "code": return {"src":data.src, "lang":"cpp"}
# PROGRAM UOp: src[3].arg is source code
source_code = data.src[3].arg
if fmt == "code": return {"src":source_code, "lang":"cpp"}
if fmt == "asm":
compiler = Device[data.device].compiler
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(data.src)))
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(source_code)))
ret:dict = {"src":disasm_str}
if data.device.startswith("AMD"):
with soft_err(lambda err: ret.update(err)):