Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
a07c9da26b
Merge branch 'master' into remove_programspec 2025-12-23 15:21:46 -05:00
George Hotz
816a359a3c do programspec removal 2025-12-23 15:21:02 -05:00
25 changed files with 292 additions and 193 deletions

View file

@ -1,12 +1,11 @@
from typing import Tuple, Dict, List, Optional from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor from tinygrad.tensor import Device, Tensor
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context, to_mv from tinygrad.helpers import Context, to_mv, to_function_name
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops, UOp
import json import json
from collections import OrderedDict from collections import OrderedDict
@ -15,8 +14,13 @@ EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache: for ji in run.jit_cache:
fxn: ProgramSpec = ji.prg.p prg: UOp = ji.prg.p
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same name = prg.src[0].arg.name
function_name = to_function_name(name)
src = prg.src[3].arg
global_size, local_size = prg.sizes
prg_vars = prg.variables()
functions[function_name] = src # NOTE: this assumes all with the same name are the same
cargs = [] cargs = []
for i,arg in enumerate(ji.bufs): for i,arg in enumerate(ji.bufs):
key = id(arg) key = id(arg)
@ -28,8 +32,8 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
bufnum += 1 bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0]) cargs.append(bufs[key][0])
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR? cargs += [var for var in prg_vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size)) statements.append((function_name, cargs, list(global_size) if global_size else None, list(local_size) if local_size else None))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save

View file

@ -4,7 +4,8 @@ import triton.language as tl
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
np.set_printoptions(suppress=True) np.set_printoptions(suppress=True)
@ -85,9 +86,12 @@ if __name__ == "__main__":
# remove debug sections # remove debug sections
src = src.split("\t.file")[0] src = src.split("\t.file")[0]
assert '.extern .shared' not in src assert '.extern .shared' not in src
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT, # Create linearized uops with SPECIAL for global/local sizes
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1], global_size = [M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1]
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes()) local_size = [32*compiled.metadata.num_warps, 1, 1]
uops = [UOp(Ops.SPECIAL, arg=('g', i), src=(UOp.const(dtypes.int, global_size[i]),)) for i in range(3)]
uops += [UOp(Ops.SPECIAL, arg=('l', i), src=(UOp.const(dtypes.int, local_size[i]),)) for i in range(3)]
prg = UOp.new_program("matmul_kernel", src, Device.DEFAULT, si.ast, uops)
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg)) ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
tflops = [] tflops = []
for i in range(5): for i in range(5):

View file

@ -1,10 +1,10 @@
import numpy as np import numpy as np
import unittest import unittest
import subprocess, struct, math import subprocess, struct, math
from tinygrad import Tensor, dtypes, Device, UOp from tinygrad import Tensor, dtypes, Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
def get_output(asm:str, n_threads:int=1): def get_output(asm:str, n_threads:int=1):
@ -22,7 +22,9 @@ def get_output(asm:str, n_threads:int=1):
*(data0_1+l) = res; *(data0_1+l) = res;
}}""" }}"""
t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize() t = Tensor.zeros(n_threads, dtype=dtypes.uint32).contiguous().realize()
prg = ProgramSpec("test", src, Device.DEFAULT, UOp.sink(t), global_size=[1, 1, 1], local_size=[n_threads, 1, 1]) # Create linearized uops with SPECIAL for local size
uops = [UOp(Ops.SPECIAL, arg=('l', 0), src=(UOp.const(dtypes.int, n_threads),))]
prg = UOp.new_program("test", src, Device.DEFAULT, UOp.sink(t.uop), uops)
car = CompiledRunner(prg) car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib) if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([t.uop.buffer], {}, wait=True) car([t.uop.buffer], {}, wait=True)

View file

@ -1,5 +1,4 @@
# ruff: noqa: E501 E712 F401 # ruff: noqa: E501 E712 F401
from dataclasses import replace
from tinygrad import dtypes, Device from tinygrad import dtypes, Device
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import
@ -89,7 +88,9 @@ renderer = Device.default.renderer
allocator = Device.default.allocator allocator = Device.default.allocator
ps = get_program(ast, renderer) ps = get_program(ast, renderer)
cr = CompiledRunner(replace(ps, device=Device.DEFAULT)) # update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
ps = ps.replace(src=(ps.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *ps.src[2:]))
cr = CompiledRunner(ps)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg) gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
# print(len(gs)) # print(len(gs))

View file

@ -9,7 +9,7 @@ if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
try: try:
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.renderer import Renderer
from tinygrad.engine.realize import get_program from tinygrad.engine.realize import get_program
from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
@ -51,15 +51,20 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
return "\n".join([f"{len(asts)} kernels", *asts]) return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,) return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]: def replay_get_program(p:UOp, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# the ast.arg is non None if we are inside of search.py # the ast.arg is non None if we are inside of search.py
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None) # p is a PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
input_ast = ast.replace(arg=replace(sink_arg, name=p.name)) p_name = p.src[0].arg.name
p_applied_opts = p.src[0].arg.applied_opts
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p_applied_opts if BEAM>=1 else None)
input_ast = ast.replace(arg=replace(sink_arg, name=p_name))
p2 = get_program(input_ast, renderer=renderer) p2 = get_program(input_ast, renderer=renderer)
def to_str(ret:ProgramSpec) -> str: def to_str(ret:UOp) -> str:
# PYTHON renderer pickles UOps, first unpickle and decode here # PYTHON renderer pickles UOps, first unpickle and decode here
if p.device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret.src))]) ret_src = ret.src[3].arg
return ret.src ret_device = ret.device
if ret_device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret_src))])
return ret_src
# properly color the name arg # properly color the name arg
ast_repr = codecs.decode(str(input_ast), "unicode_escape") ast_repr = codecs.decode(str(input_ast), "unicode_escape")
return to_str(p2), to_str(p), (ast_repr, renderer) return to_str(p2), to_str(p), (ast_repr, renderer)

View file

@ -63,7 +63,7 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data))) allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer()) prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer())
prog = PythonProgram("run", PythonCompiler().compile(prg.src)) prog = PythonProgram("run", PythonCompiler().compile(prg.src[3].arg)) # source code is in src[3].arg
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs) prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
return out_buf.cast(uop.dtype.fmt or "").tolist()[0] return out_buf.cast(uop.dtype.fmt or "").tolist()[0]

View file

@ -1,10 +1,9 @@
import numpy as np import numpy as np
import unittest import unittest
from dataclasses import replace
from tinygrad import Device, Tensor, dtypes from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.helpers import AMX, AMD_LLVM, CPU_LLVM, Context from tinygrad.helpers import AMX, AMD_LLVM, CPU_LLVM, Context
@ -44,7 +43,10 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
if dtype_in == dtypes.bfloat16: r = r.float() if dtype_in == dtypes.bfloat16: r = r.float()
realized_ast, bufs = helper_realized_ast(r) realized_ast, bufs = helper_realized_ast(r)
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))] opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT)) p = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts)
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
p = p.replace(src=(p.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *p.src[2:]))
prg = CompiledRunner(p)
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered" if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg.exec(bufs) prg.exec(bufs)

View file

@ -28,7 +28,7 @@ class TestFusionOp(unittest.TestCase):
sched = a.schedule() sched = a.schedule()
sched[-1].lower() sched[-1].lower()
self.assertLess(time.perf_counter()-st, 2.0) self.assertLess(time.perf_counter()-st, 2.0)
assert len(sched[-1].prg.p.src.splitlines()) < 250 assert len(sched[-1].prg.p.src[3].arg.splitlines()) < 250 # source code is in src[3].arg
def test_recursive_add_cmp(self): def test_recursive_add_cmp(self):
st = time.perf_counter() st = time.perf_counter()

View file

@ -1,6 +1,5 @@
import numpy as np import numpy as np
import unittest import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims from tinygrad.codegen.gpudims import get_grouped_dims
@ -168,8 +167,8 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU") @unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
def test_upcast_with_locals_cpu(self): def test_upcast_with_locals_cpu(self):
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous() out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)])
self.assertEqual(len(prg.src.split("for")), 5) self.assertEqual(len(prg.src[3].arg.split("for")), 5) # source code is in src[3].arg
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@ -517,7 +516,11 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
device = real_bufs[0].device device = real_bufs[0].device
wanna_output = [np.array(x).flatten() for x in wanna_output] wanna_output = [np.array(x).flatten() for x in wanna_output]
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts), device=device)) def get_prg(opts):
prg = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)
# update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=device), *prg.src[2:]))
return CompiledRunner(prg)
def check_opt(opts): def check_opt(opts):
prg = get_prg(opts=opts) prg = get_prg(opts=opts)

View file

@ -14,7 +14,7 @@ class TestOpts(unittest.TestCase):
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts) self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP: if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP:
prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer) prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer)
self.assertIn('float4', prg.src) self.assertIn('float4', prg.src[3].arg) # source code is in src[3].arg
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -1,6 +1,5 @@
import unittest import unittest
import numpy as np import numpy as np
from dataclasses import replace
from tinygrad.device import Buffer, Device, is_dtype_supported from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType from tinygrad.dtype import dtypes, ConstType
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program
@ -18,8 +17,8 @@ def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs] inbufs = [x.uop.base.buffer for x in inputs]
prg = replace(prg, device=Device.DEFAULT) # update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
if local_size is not None: prg = replace(prg, local_size=local_size) prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *prg.src[2:]))
ei = CompiledRunner(prg) ei = CompiledRunner(prg)
ei.exec(outbufs+inbufs) ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
@ -72,7 +71,7 @@ class TestCStyleFailures(unittest.TestCase):
schedule = ret.schedule() schedule = ret.schedule()
assert len(schedule) == 1 assert len(schedule) == 1
schedule[0].lower() schedule[0].lower()
src = schedule[0].prg.p.src src = schedule[0].prg.p.src[3].arg # source code is in src[3].arg
self.assertEqual("("*5 not in src, should_strip_paren) self.assertEqual("("*5 not in src, should_strip_paren)
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD) def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)

View file

@ -15,7 +15,6 @@ from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import get_uops from test.helpers import get_uops
from dataclasses import replace
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]: def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u) sink = UOp.group(*u)
@ -27,7 +26,9 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
def _uops_to_prg(uops_list): def _uops_to_prg(uops_list):
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer) prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
return CompiledRunner(replace(prg, device=Device.DEFAULT)) # update device in PROGRAM UOp: (SINK, DEVICE, LINEAR, SOURCE)
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=Device.DEFAULT), *prg.src[2:]))
return CompiledRunner(prg)
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg)) uops.append(UOp(uop, dtype, tuple(src), arg))

View file

@ -2,7 +2,6 @@ import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters, EMULATE from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import get_program from tinygrad.engine.realize import get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates from tinygrad.renderer import Estimates
from tinygrad.uop.ops import Ops, UOp from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
@ -166,24 +165,26 @@ class TestStatsOptimized(unittest.TestCase):
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast
def check_gemm(self, p:ProgramSpec, extra_flops=0): def check_gemm(self, p:UOp, extra_flops=0):
#p.uops.print() #p.uops.print()
#print(p.src) #print(p.src[3].arg)
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(p.estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats self.assertEqual(estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs
self.assertEqual(estimates.mem, 3*N*N*4) # 3 NxN mats with floats
def test_gemm(self): def test_gemm(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[]) p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[])
self.check_gemm(p) self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4 + 4*N*N)
def test_gemm_tc_unroll(self): def test_gemm_tc_unroll(self):
try: try:
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)]) p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
except KernelOptError: except KernelOptError:
raise unittest.SkipTest("no tensor cores") raise unittest.SkipTest("no tensor cores")
print(p.src) print(p.src[3].arg)
self.check_gemm(p) self.check_gemm(p)
# this is a good lesson about why UPCASTing is a good idea # this is a good lesson about why UPCASTing is a good idea
@ -191,13 +192,15 @@ class TestStatsOptimized(unittest.TestCase):
def test_gemm_one_upcasted(self): def test_gemm_one_upcasted(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4)]) p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4)])
self.check_gemm(p) self.check_gemm(p)
self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted(self): def test_gemm_upcasted(self):
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer,
opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)]) opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)])
self.check_gemm(p) self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_upcasted_locals(self): def test_gemm_upcasted_locals(self):
try: try:
@ -206,7 +209,8 @@ class TestStatsOptimized(unittest.TestCase):
except KernelOptError: except KernelOptError:
raise unittest.SkipTest("no locals") raise unittest.SkipTest("no locals")
self.check_gemm(p) self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4//4 + 4*N*N)
def test_gemm_group(self): def test_gemm_group(self):
try: try:
@ -216,13 +220,15 @@ class TestStatsOptimized(unittest.TestCase):
SZ = N*N*4 SZ = N*N*4
# NOTE: these are sort of wrong. they aren't honoring the IF statement # NOTE: these are sort of wrong. they aren't honoring the IF statement
self.check_gemm(p, extra_flops=SZ*4) self.check_gemm(p, extra_flops=SZ*4)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
def test_reduce(self): def test_reduce(self):
p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[]) p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[])
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
self.assertEqual(p.estimates.ops, N*N) print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
self.assertEqual(p.estimates.mem, N*N*4 + 4) self.assertEqual(estimates.ops, N*N)
self.assertEqual(estimates.mem, N*N*4 + 4)
def test_reduce_group(self): def test_reduce_group(self):
try: try:
@ -230,7 +236,8 @@ class TestStatsOptimized(unittest.TestCase):
except KernelOptError: except KernelOptError:
raise unittest.SkipTest("no locals") raise unittest.SkipTest("no locals")
# NOTE: these are wrong, they don't respect the if statement # NOTE: these are wrong, they don't respect the if statement
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
print(p.src[0].arg.name, estimates.ops, estimates.mem, estimates.lds)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View file

@ -3,7 +3,6 @@ import textwrap
from tinygrad import Device, Tensor from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites from tinygrad.uop.ops import UOp, Ops, track_rewrites
from tinygrad.renderer import ProgramSpec
from tinygrad.helpers import TracingKey from tinygrad.helpers import TracingKey
from tinygrad.engine.realize import ExecItem, CompiledRunner from tinygrad.engine.realize import ExecItem, CompiledRunner
@ -51,9 +50,9 @@ amdhsa.kernels:
.end_amdgpu_metadata .end_amdgpu_metadata
""" """
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret)) @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.src[0].arg.name, ret=ret))
def run_asm(name:str, src:str) -> ProgramSpec: def run_asm(name:str, src:str) -> UOp:
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK)) prg = UOp.new_program(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK), [])
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg)) ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
ei.run() ei.run()
return prg return prg

View file

@ -133,11 +133,11 @@ def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
pm_to_program = PatternMatcher([ pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize), (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
]) ])
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp: def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,)) sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render") return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")

View file

@ -1,13 +1,11 @@
import functools, math, time, multiprocessing, traceback, signal, atexit import functools, math, time, multiprocessing, traceback, signal, atexit
from dataclasses import replace from tinygrad.uop.ops import sym_infer, AxisType, pyrender, UOp
from tinygrad.uop.ops import sym_infer, AxisType, pyrender
from tinygrad.device import Device, Buffer, Compiler from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str, unwrap from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str, unwrap
from tinygrad.helpers import IGNORE_BEAM_CACHE from tinygrad.helpers import IGNORE_BEAM_CACHE
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.engine.realize import CompiledRunner, get_program, Estimates
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.opt.postrange import Scheduler from tinygrad.codegen.opt.postrange import Scheduler
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)] actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
@ -34,19 +32,22 @@ def get_test_global_size(global_size, max_global_size, var_vals):
break break
return test_global_size, input_size / prod(test_global_size) return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None, def _time_program(p:UOp, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]: allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1 factor = 1
if allow_test_size and p.global_size is not None and max_global_size is not None: global_size, local_size = p.sizes
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) if allow_test_size and global_size is not None and max_global_size is not None:
p = replace(p, global_size=global_size) test_global_size, factor = get_test_global_size(global_size, max_global_size, var_vals)
# NOTE: we can't modify p.sizes directly, but optimize_local_size doesn't run in BEAM so this is ok
try: car = CompiledRunner(p, precompiled=lib) try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt except AssertionError: return [math.inf] * cnt
tms = [] tms = []
input_bufs = [rawbufs[i] for i in car.p.globals] input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt): for _ in range(cnt):
if clear_l2: if clear_l2:
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches() p_device = p.device
assert isinstance(p_device, str), f"PROGRAM device must be a string, not {type(p_device)}"
if hasattr(dev:=Device[p_device], 'invalidate_caches'): dev.invalidate_caches()
else: else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor) tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor)
@ -58,7 +59,7 @@ def timeout_handler(signum, frame):
if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT") if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT")
raise TimeoutException() raise TimeoutException()
def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]: def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[UOp, bytes, float]|None]:
if hasattr(signal, "alarm"): if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler) signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout # set timeout
@ -66,12 +67,12 @@ def _try_compile(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[
ret = None ret = None
try: try:
p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].ren) p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].ren)
assert p.uops is not None, "uop list wasn't generated?" uops = list(p.src[2].src) # LINEAR is src[2]
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0: if len(uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}") if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(uops)=}, {uops_max=}")
raise RuntimeError("too many uops") raise RuntimeError("too many uops")
st = time.perf_counter() st = time.perf_counter()
prog = compiler.compile(p.src) prog = compiler.compile(p.src[3].arg) # SOURCE is src[3]
et = time.perf_counter() - st et = time.perf_counter() - st
ret = (p, prog, et) ret = (p, prog, et)
except RuntimeError: except RuntimeError:
@ -154,7 +155,8 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
p, lib, compile_et = proc p, lib, compile_et = proc
if lib in seen_libs: continue if lib in seen_libs: continue
# filter out kernels that use 1000x more compute than the smallest # filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
least_compute_ops = min(this_compute_ops:=sym_infer(estimates.ops, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops: if least_compute_ops*1000 < this_compute_ops:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}") if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too much compute. {this_compute_ops} when least is {least_compute_ops}")
continue continue
@ -167,7 +169,7 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
raise raise
timed.append((candidates[i], min(tms))) timed.append((candidates[i], min(tms)))
if BEAM_DEBUG > 1: if BEAM_DEBUG > 1:
print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(unwrap(p.uops)):5d} uops", print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(p.src[2].src):5d} uops",
f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed[-1][1], w=12)} run", f"{time_to_str(compile_et, w=12)} compile/{time_to_str(timed[-1][1], w=12)} run",
f" {len(timed):4d}/{len(candidates):4d} {timed[-1][0].colored_shape()}") f" {len(timed):4d}/{len(candidates):4d} {timed[-1][0].colored_shape()}")
elif DEBUG >= 2: elif DEBUG >= 2:

View file

@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, partition, unwrap
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, sint
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
@ -78,13 +78,18 @@ class GraphRunner(Runner):
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers) self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} self.launch_dims_base:dict[int, tuple[tuple[sint, ...], tuple[sint, ...]]] = {}
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) def is_sym_dim(dim) -> bool: return dim is not None and not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys()) self.vars = sorted(var_vals.keys())
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] + sym_dims: list[tuple[sint, ...]] = []
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)]) for ji in jit_cache:
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.sizes
if local_size is not None and is_sym_dim(local_size): sym_dims.append(tuple(local_size))
if global_size is not None and is_sym_dim(global_size): sym_dims.append(tuple(global_size))
self.symbolic_dims = dedup(sym_dims)
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
estimates = Estimates() estimates = Estimates()
@ -92,13 +97,15 @@ class GraphRunner(Runner):
assert ji.prg is not None assert ji.prg is not None
estimates += ji.prg.estimates estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner): if isinstance(ji.prg, CompiledRunner):
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars] prg_vars = ji.prg.p.variables()
if prg_vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(prg_vars) if v.expr not in ji.fixedvars]
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) global_size, local_size = ji.prg.p.sizes
global_dim_idx, local_dim_idx = find_symbolic_dim(global_size), find_symbolic_dim(local_size)
if global_dim_idx is not None or local_dim_idx is not None: if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None assert global_size is not None and local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size)) self.launch_dims_base[j] = (tuple(global_size), tuple(local_size))
# used in MultiGraphRunner. the ints are id() of _bufs # used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: dict[int, Any] = {} self.w_dependency_map: dict[int, Any] = {}

View file

@ -1,28 +1,29 @@
from typing import cast, Callable from typing import cast, Callable
import time, pprint, random, itertools, math import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field from dataclasses import dataclass, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap from tinygrad.helpers import unwrap, to_function_name
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.renderer import Renderer, Estimates
from tinygrad.codegen import full_rewrite_to_program from tinygrad.codegen import full_rewrite_to_program
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
# **************** Program Creation **************** # **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.src[0].arg.name, (to_function_name(ret.src[0].arg.name), ret.src[0]), ret=ret),
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec: replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> UOp:
""" """
Transform an AST into a ProgramSpec. May trigger BEAM search. Transform an AST into a PROGRAM UOp. May trigger BEAM search.
Args: Args:
ast: The Ops.SINK rooted AST ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code renderer: The renderer used to generate the code
Returns: Returns:
The ProgramSpec of the program. The PROGRAM UOp with structure (SINK, DEVICE, LINEAR, SOURCE).
""" """
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
@ -35,15 +36,11 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
if ast.arg is None: ast = ast.replace(arg=KernelInfo()) if ast.arg is None: ast = ast.replace(arg=KernelInfo())
prg = full_rewrite_to_program(ast, renderer) prg = full_rewrite_to_program(ast, renderer)
# SINK/LINEAR/SOURCE
sink, linear, source = prg.src
# print # print
if DEBUG >= 6: print_uops(list(linear.src)) if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src), return prg
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners **************** # **************** Runners ****************
@ -72,28 +69,36 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
return ret[1] return ret[1]
class CompiledRunner(Runner): class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None): def __init__(self, p:UOp, precompiled:bytes|None=None, prg=None):
if DEBUG >= 3: print(p.applied_opts) assert p.op is Ops.PROGRAM, f"CompiledRunner requires PROGRAM UOp, not {p.op}"
if DEBUG >= 4: print(p.src) self.p:UOp = p
self.p:ProgramSpec = p dev = p.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
name = p.src[0].arg.name
src = p.src[3].arg
function_name = to_function_name(name)
if DEBUG >= 3: print(p.src[0].arg.applied_opts)
if DEBUG >= 4: print(src)
if precompiled is not None: self.lib = precompiled if precompiled is not None: self.lib = precompiled
else: else:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): with cpu_profile(TracingKey(f"compile {name}", (function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src) self.lib = Device[dev].compiler.compile_cached(src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) if DEBUG >= 7: Device[dev].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg self._prg = Device[dev].runtime(function_name, self.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates) super().__init__(name, dev, Estimates.from_uops(list(p.src[2].src), ignore_indexing=True))
def __reduce__(self): return self.__class__, (self.p, self.lib) def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {} if var_vals is None: var_vals = {}
has_local = Device[self.p.device].renderer.has_local dev = self.p.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
has_local = Device[dev].renderer.has_local
global_size, local_size = self.p.launch_dims(var_vals) global_size, local_size = self.p.launch_dims(var_vals)
if has_local and global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type] sym_global_size, sym_local_size = self.p.sizes
if has_local and global_size is not None and local_size is None and sym_global_size is not None and all_int(sym_global_size):
local_size = optimize_local_size(self._prg, global_size, rawbufs) local_size = optimize_local_size(self._prg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] global_size = [g//l if g%l == 0 else int(g/l) for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=global_size, local_size=local_size)
lra = {} lra = {}
if global_size: if global_size:
lra['global_size'] = tuple(global_size) lra['global_size'] = tuple(global_size)
@ -101,7 +106,7 @@ class CompiledRunner(Runner):
if local_size: if local_size:
lra['local_size'] = tuple(local_size) lra['local_size'] = tuple(local_size)
assert len(local_size) == 3, "local size must have len 3" assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.vars), wait=wait) return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k.expr] for k in self.p.variables()), wait=wait)
class ViewOp(Runner): class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
@ -158,10 +163,14 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey): if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) # update device in PROGRAM UOp
new_p = bret.p.replace(src=(bret.p.src[0], UOp(Ops.DEVICE, arg=device), *bret.p.src[2:]))
method_cache[ckey] = ret = CompiledRunner(new_p, bret.lib)
else: else:
prg: ProgramSpec = get_program(ast, Device[device].renderer) prg = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) # update device in PROGRAM UOp to match the actual device
prg = prg.replace(src=(prg.src[0], UOp(Ops.DEVICE, arg=device), *prg.src[2:]))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg)
return ret return ret
# **************** lowering functions **************** # **************** lowering functions ****************

View file

@ -1,12 +1,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, cast from typing import Callable, cast
import functools from dataclasses import dataclass
from dataclasses import dataclass, field from tinygrad.helpers import prod
from tinygrad.helpers import to_function_name, dedup, prod from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, GroupOp, PatternMatcher
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
@dataclass(frozen=True) @dataclass(frozen=True)
class Estimates: class Estimates:
@ -57,62 +55,6 @@ class Estimates:
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values())) return Estimates(flops, lds, sum(mem.values()))
@dataclass
class ProgramSpec:
name:str
src:str
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
# filled in from uops (if we have uops)
global_size:list[int]|None=None
local_size:list[int]|None=None
vars:list[Variable]=field(default_factory=list)
globals:list[int]=field(default_factory=list)
outs:list[int]=field(default_factory=list)
ins:list[int]=field(default_factory=list)
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
def __post_init__(self):
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (self.outs if u.op is Ops.STORE else self.ins).append(buf.arg)
# TODO: can else happen?
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0] == 'l' else self.global_size
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self.ins = sorted(dedup(self.ins))
self._ran_post_init = True
@functools.cached_property
def estimates(self) -> Estimates:
return Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True)
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
@property
def applied_opts(self) -> tuple[Opt, ...]|None:
if self.uops is None: return None
assert self.uops[-1].op is Ops.SINK, self.uops[-1].op
return self.uops[-1].arg.applied_opts
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
class Renderer: class Renderer:
device: str = "" device: str = ""
suffix: str = "" suffix: str = ""

View file

@ -22,12 +22,14 @@ class CUDAGraph(MultiGraphRunner):
for j,ji in enumerate(jit_cache): for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, CompiledRunner): if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals) global_size, local_size = ji.prg.p.launch_dims(var_vals)
assert global_size is not None and local_size is not None
new_node = cuda.CUgraphNode() new_node = cuda.CUgraphNode()
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node) deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars]) prg_vars = ji.prg.p.variables()
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in prg_vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs) kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))

View file

@ -39,7 +39,7 @@ class HCQGraph(MultiGraphRunner):
if not isinstance(ji.prg, CompiledRunner): continue if not isinstance(ji.prg, CompiledRunner): continue
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16)) argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf) self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.prog_vars(), argsbuf)
# Schedule Dependencies. # Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@ -159,7 +159,8 @@ class HCQGraph(MultiGraphRunner):
# Encode main commands based on ji type. # Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner): if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1))) global_size, local_size = ji.prg.p.sizes
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(global_size or (1,1,1)), tuple(local_size or (1,1,1)))
elif isinstance(ji.prg, (BufferXfer, BufferCopy)): elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
for bufid, src in enumerate(cast(list[Buffer], ji.bufs)): for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):

View file

@ -42,9 +42,11 @@ class MetalGraph(GraphRunner):
if b is not None and b not in input_rawbuffers: if b is not None and b not in input_rawbuffers:
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i) icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf) all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i) for i,v in enumerate(prg.p.variables()):
icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals) global_size, local_size = prg.p.launch_dims(var_vals)
assert global_size is not None and local_size is not None
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size)) icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
icb_command.setBarrier() icb_command.setBarrier()

View file

@ -609,11 +609,17 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@staticmethod @staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None): def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size) return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@staticmethod
def new_program(name:str, src:str, device:str, ast:UOp, uops:list[UOp]):
"""Create a PROGRAM UOp from raw components."""
sink = ast.replace(arg=KernelInfo(name=name)) if ast.arg is None else ast.replace(arg=ast.arg.replace(name=name))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=tuple(uops)), UOp(Ops.SOURCE, arg=src)))
@property @property
def device(self) -> str|tuple[str, ...]: return unwrap(self._device) def device(self) -> str|tuple[str, ...]: return unwrap(self._device)
@recursive_property @recursive_property
def _device(self) -> str|tuple[str, ...]|None: def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.PROGRAM: return self.src[1].arg # PROGRAM src[1] is DEVICE
if self.op is Ops.BUFFERIZE: return self.arg.device if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT: if self.op is Ops.MSELECT:
@ -624,6 +630,104 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
for x in self.src: for x in self.src:
if x._device is not None: return x._device if x._device is not None: return x._device
return None return None
# *** PROGRAM UOp properties ***
@property
def uops(self) -> list[UOp]:
"""Linearized uops list. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"uops only valid for PROGRAM, not {self.op}"
return list(self.src[2].src)
@property
def name(self) -> str:
"""Kernel name. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"name only valid for PROGRAM, not {self.op}"
return self.src[0].arg.name
@property
def applied_opts(self):
"""Applied optimizations. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"applied_opts only valid for PROGRAM, not {self.op}"
return self.src[0].arg.applied_opts
@functools.cached_property
def estimates(self):
"""Estimates for this program. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"estimates only valid for PROGRAM, not {self.op}"
from tinygrad.renderer import Estimates
return Estimates.from_uops(self.uops, ignore_indexing=True)
@property
def globals(self) -> list[int]:
"""DEFINE_GLOBAL arg indices from linearized uops. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"globals only valid for PROGRAM, not {self.op}"
return [u.arg for u in self.src[2].src if u.op is Ops.DEFINE_GLOBAL]
@property
def outs(self) -> list[int]:
"""Buffer indices written to (STORE). Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"outs only valid for PROGRAM, not {self.op}"
ret = []
for u in self.src[2].src:
if u.op is Ops.STORE:
idx = u.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
if idx.op is Ops.INDEX and idx.src[0].op is Ops.DEFINE_GLOBAL: ret.append(idx.src[0].arg)
return sorted(set(ret))
@property
def ins(self) -> list[int]:
"""Buffer indices read from (LOAD). Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"ins only valid for PROGRAM, not {self.op}"
ret = []
for u in self.src[2].src:
if u.op is Ops.LOAD:
idx = u.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
if idx.op is Ops.INDEX and idx.src[0].op is Ops.DEFINE_GLOBAL: ret.append(idx.src[0].arg)
return sorted(set(ret))
@functools.cached_property
def sizes(self) -> tuple[list[sint]|None, list[sint]|None]:
"""Get (global_size, local_size) which may contain symbolic values. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"sizes only valid for PROGRAM, not {self.op}"
from tinygrad.device import Device
dev = self.device
assert isinstance(dev, str), f"PROGRAM device must be a string, not {type(dev)}"
ren = Device[dev].renderer
global_size:list[sint]|None = [1,1,1] if ren.has_local or ren.has_threads else None
local_size:list[sint]|None = [1,1,1] if ren.has_local else None
for u in self.src[2].src:
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
if special_size is not None: special_size[int(u.arg[-1])] = cast(sint, u.src[0].ssimplify())
return global_size, local_size
def launch_dims(self, var_vals:dict[str, int]) -> tuple[list[int]|None, list[int]|None]:
"""Resolve global/local sizes to concrete ints. Only valid for PROGRAM."""
global_size, local_size = self.sizes
global_ret = [sym_infer(sz, var_vals) for sz in global_size] if global_size is not None else None
local_ret = [sym_infer(sz, var_vals) for sz in local_size] if local_size is not None else None
return global_ret, local_ret
@property
def global_size(self) -> list[sint]|None:
"""Global size (may be symbolic). Only valid for PROGRAM."""
return self.sizes[0]
@property
def local_size(self) -> list[sint]|None:
"""Local size (may be symbolic). Only valid for PROGRAM."""
return self.sizes[1]
def prog_vars(self) -> list:
"""Variables list for this program. Only valid for PROGRAM."""
assert self.op is Ops.PROGRAM, f"prog_vars only valid for PROGRAM, not {self.op}"
# Get variables from the linearized uops
linear_uops = self.src[2]
return linear_uops.variables()
@property @property
def buf_uop(self) -> UOp: def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self if self.op is Ops.BUFFER: return self

View file

@ -249,10 +249,10 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device # in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline # codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR, SOURCE)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE # codegen: standalone LINEAR/SOURCE
(UPat(Ops.LINEAR, dtypes.void), lambda: True), (UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True), (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),

View file

@ -10,7 +10,7 @@ from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import Estimates
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
@ -39,7 +39,7 @@ def get_rewrites(t:RewriteTrace) -> list[dict]:
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)): for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc), steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)] trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec): if isinstance(k.ret, UOp) and k.ret.op is Ops.PROGRAM:
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), k.ret)) steps.append(create_step("View UOp List", ("/uops", i, len(steps)), k.ret))
steps.append(create_step("View Program", ("/code", i, len(steps)), k.ret)) steps.append(create_step("View Program", ("/code", i, len(steps)), k.ret))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), k.ret)) steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), k.ret))
@ -161,9 +161,10 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
name, fmt, key = e.name, [], None name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None: if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"] name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None: if isinstance(p:=trace.keys[ref].ret, UOp) and p.op is Ops.PROGRAM and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6) estimates = Estimates.from_uops(list(p.src[2].src), ignore_indexing=True)
membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t flops = sym_infer(estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS", fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
(f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem", (f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem",
(f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"] (f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]
@ -425,10 +426,12 @@ def get_render(i:int, j:int, fmt:str) -> dict:
data = ctxs[i]["steps"][j]["data"] data = ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"} if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"} if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
if fmt == "code": return {"src":data.src, "lang":"cpp"} # PROGRAM UOp: src[3].arg is source code
source_code = data.src[3].arg
if fmt == "code": return {"src":source_code, "lang":"cpp"}
if fmt == "asm": if fmt == "asm":
compiler = Device[data.device].compiler compiler = Device[data.device].compiler
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(data.src))) disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(source_code)))
ret:dict = {"src":disasm_str} ret:dict = {"src":disasm_str}
if data.device.startswith("AMD"): if data.device.startswith("AMD"):
with soft_err(lambda err: ret.update(err)): with soft_err(lambda err: ret.update(err)):