viz profiler (#8287)

* only hcq

* fix get_metadata

* linter

* oops

* tiny

* linter

* time

* print pm

* hmm

* nits
This commit is contained in:
nimlgen 2024-12-17 20:00:53 +03:00 committed by GitHub
commit af87e4b53c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 500 additions and 301 deletions

View file

@ -300,8 +300,6 @@ class TestHCQ(unittest.TestCase):
# Test profile api
def test_speed_exec_time(self):
TestHCQ.d0._prof_setup()
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
@ -311,7 +309,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
et = float(sig_en.timestamp - sig_st.timestamp)
print(f"exec kernel time: {et:.2f} us")
assert 0.1 <= et <= (7000 if CI else 100)
@ -319,8 +317,6 @@ class TestHCQ(unittest.TestCase):
def test_speed_copy_bandwidth(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
TestHCQ.d0._prof_setup()
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
SZ = 200_000_000
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
@ -335,7 +331,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
et = float(sig_en.timestamp - sig_st.timestamp)
et_ms = et / 1e3
gb_s = ((SZ / 1e9) / et_ms) * 1e3
@ -348,8 +344,6 @@ class TestHCQ(unittest.TestCase):
try: _ = Device[f"{Device.DEFAULT}:1"]
except Exception: self.skipTest("no multidevice, test skipped")
TestHCQ.d0._prof_setup()
SZ = 200_000_000
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
@ -364,7 +358,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
et = float(sig_en.timestamp - sig_st.timestamp)
et_ms = et / 1e3
gb_s = ((SZ / 1e9) / et_ms) * 1e3

View file

@ -1,73 +1,30 @@
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
import unittest, struct, contextlib, statistics
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferSpec
from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileRangeEvent, ProfileDeviceEvent, ProfileGraphEvent
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner
MOCKGPU = getenv("MOCKGPU")
@contextlib.contextmanager
def helper_collect_profile(*devs, random_setup_delay=False):
ProfileLogger.mjson, ProfileLogger.actors = [], {}
def helper_collect_profile(*devs):
Compiled.profile_events = []
if random_setup_delay:
devs = list(devs)
for dev in devs: dev.synchronize()
random.shuffle(devs)
for dev in devs:
dev._prof_setup()
time.sleep(random.randint(1, 1000) / 1000)
else:
for dev in devs: dev._prof_setup()
profile_dict = {}
_, tmp = tempfile.mkstemp()
with Context(PROFILE=1, PROFILEPATH=tmp):
try: yield profile_dict
profile_list = []
with Context(PROFILE=1):
try: yield profile_list
finally:
for dev in devs:
dev.synchronize()
dev._prof_finalize()
atexit.unregister(dev._prof_finalize)
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()
for x in Compiled.profile_events: profile_list.append(x)
for k,v in json.loads(pathlib.Path(tmp).read_text()).items(): profile_dict[k] = v
pathlib.Path(tmp).unlink()
def helper_profile_filter_node(profile, **kwargs):
assert len(profile) > 0, "Empty profile"
assert 'traceEvents' in profile, "traceEvents should present"
return [x for x in profile['traceEvents'] if all(x.get(k, None) == v for k,v in kwargs.items())]
def helper_profile_parse_pids(profile):
pids, tids = {}, {}
procs = helper_profile_filter_node(profile, name='process_name')
for proc in procs: pids[proc['pid']] = proc['args']['name']
threads = helper_profile_filter_node(profile, name='thread_name')
for th in threads: tids[th['tid']] = th['args']['name']
return pids, tids
def helper_profile_parse_deps(profile):
deps = []
for s in helper_profile_filter_node(profile, ph="s"):
f = helper_profile_filter_node(profile, ph="f", id=s['id'])[0]
starts, ends = [], []
for x in helper_profile_filter_node(profile, ph="X"):
if s['pid'] == x['pid'] and s['tid'] == x['tid'] and x['ts'] <= s['ts'] <= x['ts'] + x['dur']: starts.append(x)
if f['pid'] == x['pid'] and f['tid'] == x['tid'] and x['ts'] <= f['ts'] <= x['ts'] + x['dur']: ends.append(x)
assert len(starts) == 1 and len(ends) == 1, "more than one start and end possible, valid?"
deps.append((s, f, starts[0], ends[0]))
return deps
def helper_validate_node(node, duration_s=10, ts_age_s=30, profile=None, pid_name=None, tid_name=None):
pids, tids = helper_profile_parse_pids(profile)
assert abs(node['ts'] - time.perf_counter_ns() / 1e3) < ts_age_s * 1e6, "timestimp is not in 30s range"
assert 0 < node['dur'] < duration_s * 1e6, "duration is not in 10s range"
assert pid_name is None or pids[node['pid']] == pid_name
assert tid_name is None or tids[node['tid']] == tid_name
def helper_profile_filter_device(profile, device:str):
assert any(getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent) for x in profile), f"device {device} is not registred"
dev_events = [x for x in profile if getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent)]
assert len(dev_events) == 1, "only one device registration event is expected"
return [x for x in profile if getattr(x, "device", None) == device], dev_events[0]
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run")
class TestProfiler(unittest.TestCase):
@ -90,8 +47,11 @@ class TestProfiler(unittest.TestCase):
with helper_collect_profile(TestProfiler.d0) as profile:
TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={})
kernel_node = helper_profile_filter_node(profile, name=runner_name)[0]
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
assert len(kernel_runs) == 1, "one kernel run is expected"
assert kernel_runs[0].name == runner_name, "kernel name is not correct"
assert not kernel_runs[0].is_copy, "kernel should not be copy"
def test_profile_copyin(self):
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
@ -99,8 +59,10 @@ class TestProfiler(unittest.TestCase):
with helper_collect_profile(TestProfiler.d0) as profile:
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
assert len(kernel_runs) == 1, "one kernel run is expected"
assert kernel_runs[0].is_copy, "kernel should not be copy"
def test_profile_multiops(self):
runner_name = TestProfiler.runner._prg.name
@ -111,19 +73,19 @@ class TestProfiler(unittest.TestCase):
TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={})
buf1.as_buffer()
copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
evs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
kernel_node = helper_profile_filter_node(profile, name=runner_name)[0]
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
assert len(evs) == 3, "two kernel runs are expected"
assert evs[0].is_copy, "kernel should be copy"
assert evs[1].name == runner_name, "kernel name is not correct"
assert not evs[1].is_copy, "kernel should not be copy"
assert evs[2].is_copy, "kernel should be copy"
copyout_node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> CPU")[0]
helper_validate_node(copyout_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
for i in range(1, 3):
assert evs[i].st > evs[i-1].en, "timestamp not aranged"
assert copyin_node['ts'] + copyin_node['dur'] < kernel_node['ts'], "timestamp not aranged"
assert kernel_node['ts'] + kernel_node['dur'] < copyout_node['ts'], "timestamp not aranged"
def test_profile_multidev_copyin(self):
def test_profile_multidev(self):
d1 = Device[f"{Device.DEFAULT}:1"]
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
@ -132,25 +94,16 @@ class TestProfiler(unittest.TestCase):
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
buf2.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
copyin_node_1 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
helper_validate_node(copyin_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
profile0, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
profile1, _ = helper_profile_filter_device(profile, d1.device)
copyin_node_2 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}:1")[0]
helper_validate_node(copyin_node_2, profile=profile, pid_name=f"{Device.DEFAULT}:1", tid_name="DMA")
def test_profile_multidev_transfer(self):
d1 = Device[f"{Device.DEFAULT}:1"]
a = Tensor.randn(1 << 20, device=Device.DEFAULT).realize()
with helper_collect_profile(TestProfiler.d0, d1) as profile:
y = a.to(f"{Device.DEFAULT}:1")
y.realize()
transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0]
helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
for p in [profile0, profile1]:
evs = [x for x in p if isinstance(x, ProfileRangeEvent)]
assert len(evs) == 1, "one kernel runs are expected"
assert evs[0].is_copy, "kernel should be copy"
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
def test_profile_deps(self):
def test_profile_graph(self):
d1 = Device[f"{Device.DEFAULT}:1"]
def f(a):
@ -163,59 +116,40 @@ class TestProfiler(unittest.TestCase):
for _ in range(3): jf(a)
del jf
deps = helper_profile_parse_deps(profile)
assert len(deps) == 1, "one dep is expected, one launch"
graph_evs = [x for x in profile if isinstance(x, ProfileGraphEvent)]
_, _, l, r = deps[0]
assert l['name'].find("->") == -1, "should be kernel"
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
_, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
_, _ = helper_profile_filter_device(profile, d1.device)
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
def test_profile_copy_args(self):
d1 = Device[f"{Device.DEFAULT}:1"]
def f(a):
x = (a + 1).realize()
return x, x.to(d1.device).realize()
a = Tensor.randn(10, 10, device=TestProfiler.d0.device).realize()
with helper_collect_profile(TestProfiler.d0, d1) as profile:
jf = TinyJit(f)
for _ in range(3):
TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs
d1.raw_prof_records, d1.sig_prof_records = [], []
jf(a)
del jf
node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1]
assert node['args']['Size'] == "400.00 B"
assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01
assert len(graph_evs) == 1, "one graph event is expected"
assert len(graph_evs[0].ents) == 2, "two entities are expected"
@unittest.skipIf(CI, "skip CI")
def test_profile_sync(self):
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))
expected_diff = 100000 # sleep in us
def test_dev_jitter_matrix(self):
dev_cnt = 6
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(dev_cnt)]
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(6)]
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() for i in range(6)]
def _sync_d2d(d1:HCQCompiled, d2:HCQCompiled):
d1.hw_compute_queue_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
d2.hw_compute_queue_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
d1.timeline_value += 2
d2.timeline_value += 2
d1.timeline_signal.wait(d1.timeline_value - 1)
d2.timeline_signal.wait(d2.timeline_value - 1)
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
# enqueue ops on different queues to check the timer sync
cpu_time = []
with helper_collect_profile(*devs, random_setup_delay=True) as profile:
for i in range(6):
x = time.perf_counter_ns()
time.sleep(expected_diff / 1e6)
bufs[i].copyin(mv)
cpu_time.append(((time.perf_counter_ns() - x) / 1000) - expected_diff)
nodes = [helper_profile_filter_node(profile, name=f"CPU -> {Device.canonicalize(f'{Device.DEFAULT}:{i}')}")[-1] for i in range(6)]
avg_diff = []
for i in range(1, 6):
diff = nodes[i]['ts'] - nodes[i-1]['ts'] - cpu_time[i]
avg_diff.append(diff - expected_diff)
assert expected_diff * 0.998 < diff < expected_diff * 1.002, "more that 0.2% diff"
print(f"total avg delay is {sum(avg_diff) / len(avg_diff)} us")
# then test it by timing the GPU to GPU times
jitter_matrix = [[float('nan')] * len(devs) for _ in range(len(devs))]
pairs = [(p1, p2) for p1 in enumerate(devs) for p2 in enumerate(devs) if p1 != p2]
for (i1, d1), (i2, d2) in pairs:
cpu_diff = d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff
jitter_matrix[i1][i2] = statistics.median(_sync_d2d(d1, d2) - _sync_d2d(d2, d1) for _ in range(20)) / 2 - cpu_diff
assert abs(jitter_matrix[i1][i2]) < 0.5, "jitter should be less than 0.5ms"
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
if __name__ == "__main__":
unittest.main()
unittest.main()

View file

@ -1,9 +1,10 @@
from typing import Dict, List, Optional
import unittest
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
@track_rewrites(named=True)
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
@ -112,5 +113,77 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][1]['tid'], 0)
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][2]['tid'], 1)
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
self.assertEqual(j['traceEvents'][3]['ts'], 0)
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][3]['tid'], 0)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['tid'], 1)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
self.assertEqual(j['traceEvents'][6]['ts'], 0)
self.assertEqual(j['traceEvents'][6]['dur'], 2)
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
self.assertEqual(j['traceEvents'][7]['ts'], 954)
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()

View file

@ -1,9 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from typing import Optional, Dict, Tuple, Any, Iterator, List, Set
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp, buffers
@ -13,6 +13,7 @@ from tinygrad.ops import UOp, buffers
class _Device:
def __init__(self) -> None:
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self._opened_devices:Set[str] = set()
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
@ -26,6 +27,7 @@ class _Device:
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
if (cname.lower() == x.lower() + "device")][0](ix)
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
self._opened_devices.add(ix)
return ret
@property
def default(self) -> Compiled: return self[self.DEFAULT]
@ -42,6 +44,23 @@ class _Device:
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
Device = _Device()
# **************** Profile ****************
class ProfileEvent: pass
@dataclass(frozen=True)
class ProfileDeviceEvent(ProfileEvent):
device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
@dataclass(frozen=True)
class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
@dataclass(frozen=True)
class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
@dataclass(frozen=True)
class ProfileGraphEvent(ProfileEvent): ents:List[ProfileGraphEntry]; deps:List[List[int]]; sigs:List[decimal.Decimal] # noqa: E702
# **************** Buffer + Allocators ****************
@ -202,6 +221,8 @@ class Compiler:
def disassemble(self, lib:bytes): pass
class Compiled:
profile_events:List[ProfileEvent] = []
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
self.renderer = renderer or Renderer()
@ -212,6 +233,11 @@ class Compiled:
This method ensures that all previously queued operations on the device have been completed before proceeding.
"""
# override this in your device implementation
def _at_profile_finalize(self):
"""
Called at the end of profiling to allow the device to finalize any profiling.
"""
# override this in your device implementation
# TODO: move this to each Device
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
@ -232,3 +258,15 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
if PROFILE:
@atexit.register
def finlize_profile():
devs = [Device[d] for d in Device._opened_devices]
for dev in devs: dev.synchronize()
for dev in devs: dev._at_profile_finalize()
with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f)
from tinygrad.ops import launch_viz
launch_viz("PROFILE", temp("profile.pkl"))

View file

@ -97,8 +97,7 @@ class ContextVar:
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
WINO, CAPTURING, TRACEMETA, PROFILE = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("PROFILE", 0)
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)

View file

@ -891,9 +891,7 @@ if TRACK_MATCH_STATS:
with open(fn:=temp("rewrites.pkl"), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump((tracked_keys, tracked_ctxs), f)
if getenv("VIZ"):
os.environ["VIZ"] = "0"
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
launch_viz("VIZ", temp("rewrites.pkl"))
if getenv("PRINT_MATCH_STATS", 1):
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
@ -902,6 +900,14 @@ if TRACK_MATCH_STATS:
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
def launch_viz(env_str:str, data:str):
os.environ[env_str] = "0"
os.environ[f"{env_str}_DATA"] = data
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
# *** simple graph rewrite engine ***
class RewriteContext:

View file

@ -1,8 +1,8 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
from tinygrad.helpers import round_up, PROFILE
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
@ -51,8 +51,11 @@ class HCQGraph(MultiGraphRunner):
self.kickoff_value: int = 0
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
# TODO: This logic might allocate a few extra signals...
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
self.prog_graph_deps: List[List[int]] = []
self.prof_graph_entries: List[ProfileGraphEntry] = []
last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None)
queue_access: Dict[HWQueue, Dict[HWQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
@ -102,18 +105,20 @@ class HCQGraph(MultiGraphRunner):
# Collect profile information if profiling is enabled.
if PROFILE:
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
# Description based on the command.
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
if is_exec_prg: prof_args = None
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
last_j[enqueue_queue] = j
# Check which signals are used in the profile graph.
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
# Build hardware queues.
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
@ -132,7 +137,7 @@ class HCQGraph(MultiGraphRunner):
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
# Encode waits and start profile timestamp (if needed).
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
@ -145,7 +150,7 @@ class HCQGraph(MultiGraphRunner):
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
# Encode finish profile timestamp (if needed).
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
@ -189,14 +194,8 @@ class HCQGraph(MultiGraphRunner):
return None
def collect_timestamps(self):
timestamps = [s.timestamp for s in self.prof_signals]
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
for x in deps:
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
# NOTE: Append to any device is fine...
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])

View file

@ -1,9 +1,9 @@
from __future__ import annotations
from typing import List, Optional, Dict, Tuple, cast, Type, Union, TypeVar, Generic, Any
import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv, round_up
from typing import List, Optional, Dict, Tuple, cast, Type, TypeVar, Generic, Any
import contextlib, decimal, statistics, time, ctypes, array
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
from tinygrad.ops import sym_infer, sint, Variable
# **************** for HCQ Compatible Devices ****************
@ -294,51 +294,11 @@ class HCQProgram(Generic[DeviceType]):
if wait: self.dev.synchronize()
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[Union[str, Tuple[str, str]], int] = {}
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
return self.actors[actor_name], self.actors.get(subactor_key, -1)
def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
class HCQCompiled(Compiled, Generic[SignalType]):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
"""
devices: List[HCQCompiled] = []
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
@ -350,7 +310,6 @@ class HCQCompiled(Compiled, Generic[SignalType]):
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
if PROFILE: self._prof_setup()
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
@ -367,13 +326,11 @@ class HCQCompiled(Compiled, Generic[SignalType]):
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE:
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
self.sig_prof_records = []
def _ensure_shared_time_base(self):
if not self.gpu2cpu_compute_time_diff.is_nan(): return
def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWQueue]):
def _at_profile_finalize(self):
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
d.timeline_value += 1
st = time.perf_counter_ns()
@ -381,65 +338,10 @@ class HCQCompiled(Compiled, Generic[SignalType]):
et = time.perf_counter_ns()
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
# randomly sample the timing from GPU to CPU
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
for _ in range(100*len(self.devices)):
d,q,l = random.choice(choices)
l.append(_sync_cpu_queue(d,q))
for d,q,l in choices:
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWQueue], q2_t:Type[HWQueue]):
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
d1.timeline_value += 2
d2.timeline_value += 2
d1.timeline_signal.wait(d1.timeline_value - 1)
d2.timeline_signal.wait(d2.timeline_value - 1)
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
# then test it by timing the GPU to GPU times
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
for i1, d1 in enumerate(self.devices):
for i2, d2 in enumerate(self.devices):
if d1 == d2: continue
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
"""
Translates local gpu time (timestamp) into global cpu time.
"""
self._ensure_shared_time_base()
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
def _prof_setup(self):
if hasattr(self, 'profile_logger'): return
atexit.register(self._prof_finalize)
self.profile_logger = ProfileLogger()
def _prof_finalize(self):
qname = ["COMPUTE", "DMA"]
# Sync to be sure all events on the device are recorded.
self.synchronize()
for st, en, name, is_cp, args in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.device, qname[is_cp], args)]
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
self.profile_logger.deps += [(a_tm, b_tm, a_dev.device, qname[a_is_copy], b_dev.device, qname[b_is_copy])]
self.raw_prof_records, self.dep_prof_records = [], []
# Remove the logger, this flushes all data written by the device.
del self.profile_logger
gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
def _wrap_timeline_signal(self):
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1

View file

@ -130,19 +130,46 @@
#metadata-resize-handle {
left: 0;
}
.collapse-btn {
.floating-container {
position: fixed;
top: 10px;
left: 20px;
z-index: 4;
display: flex;
flex-direction: row;
gap: 8px;
}
.nav-btn {
background-color: #1a1b26;
border: 1px solid #4a4b56;
color: #f0f0f5;
width: 32px;
height: 32px;
padding: 6px;
border-radius: 8px;
cursor: pointer;
z-index: 4;
text-decoration: none;
display: flex;
align-items: center;
padding: 0 6px;
font-weight: bold;
}
.collapse-btn {
width: 32px;
padding: 6px;
}
.btn {
height: 32px;
background-color: #1a1b26;
border: 1px solid #4a4b56;
color: #f0f0f5;
border-radius: 8px;
cursor: pointer;
transition-duration: .5s;
}
.btn:hover {
background-color: #2a2b36;
border-color: #5a5b66;
color: #ffffff;
transform: translateY(-1px);
}
.collapsed .kernel-list, .collapsed .metadata {
width: 0;
@ -170,9 +197,12 @@
</head>
<body>
<div class="main-container">
<button class="collapse-btn">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
</button>
<div class="floating-container">
<button class="btn collapse-btn">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
</button>
<a class="btn nav-btn" href="/profiler">Profiler</a>
</div>
<div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
<div class="graph">
<svg id="graph-svg">

178
tinygrad/viz/perfetto.html Normal file
View file

@ -0,0 +1,178 @@
<!DOCTYPE html>
<html lang="en-us">
<head>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body, html {
width: 100%;
height: 100%;
overflow: hidden;
font-family: "Noto Sans", sans-serif;
font-optical-sizing: auto;
font-weight: 400;
font-style: normal;
font-variation-settings: "wdth" 100;
font-size: 14px;
}
header {
width: 100%;
height: 48px;
background: #0f1018;
}
#perfetto_frame {
width: 100vw;
height: calc(100vh - 40px);
border: none;
}
#loading_screen {
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
background: #0f1018;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
color: white;
font-family: Arial, sans-serif;
z-index: 1;
}
.loader {
width: 48px;
height: 48px;
border: 5px solid #FFF;
border-bottom-color: transparent;
border-radius: 50%;
margin-bottom: 20px;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
#loading_text {
font-size: 18px;
}
.floating-container {
position: fixed;
top: 8px;
left: 16px;
z-index: 4;
display: flex;
flex-direction: row;
gap: 8px;
}
.nav-btn {
background-color: #1a1b26;
border: 1px solid #4a4b56;
color: #f0f0f5;
height: 32px;
border-radius: 8px;
cursor: pointer;
text-decoration: none;
display: flex;
align-items: center;
padding: 0 6px;
font-weight: bold;
}
.btn {
height: 32px;
background-color: #1a1b26;
border: 1px solid #4a4b56;
color: #f0f0f5;
border-radius: 8px;
cursor: pointer;
transition-duration: .5s;
}
.btn:hover {
background-color: #2a2b36;
border-color: #5a5b66;
color: #ffffff;
transform: translateY(-1px);
}
</style>
</head>
<body>
<header></header>
<div class="floating-container">
<a class="btn nav-btn" href="/">UOps</a>
</div>
<div id="loading_screen">
<div class="loader" id="spinner"></div>
<div id="loading_text">Loading trace data...</div>
</div>
<iframe id="perfetto_frame" src="https://ui.perfetto.dev" allow="clipboard-write"></iframe>
<script type="text/javascript">
const ORIGIN = 'https://ui.perfetto.dev';
const API_ENDPOINT = '/get_profile';
const iframe = document.getElementById('perfetto_frame');
const loadingScreen = document.getElementById('loading_screen');
async function fetchFromApi() {
try {
const response = await fetch(API_ENDPOINT);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const jsonData = await response.json();
// Convert JSON to string and then to blob
const jsonString = JSON.stringify(jsonData);
const blob = new Blob([jsonString], { type: 'application/json' });
const arrayBuffer = await blob.arrayBuffer();
openTrace(arrayBuffer);
} catch (error) {
document.getElementById('spinner').remove();
document.getElementById('loading_text').innerHTML = 'Error loading trace data.<br>Please ensure PROFILE=1 is set and the VIZ server is running.';
console.error('Error loading trace:', error);
}
}
function openTrace(arrayBuffer) {
const timer = setInterval(() => iframe.contentWindow.postMessage('PING', ORIGIN), 50);
const onMessageHandler = (evt) => {
if (evt.data !== 'PONG') return;
loadingScreen.style.transition = 'opacity 0.5s';
loadingScreen.style.opacity = '0';
setTimeout(() => {
loadingScreen.style.display = 'none';
}, 500);
window.clearInterval(timer);
window.removeEventListener('message', onMessageHandler);
iframe.contentWindow.postMessage({
perfetto: {
buffer: arrayBuffer,
title: 'Profile Viewer',
url: location.href,
}
}, ORIGIN);
};
window.addEventListener('message', onMessageHandler);
}
window.onload = fetchFromApi;
</script>
</body>
</html>

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python3
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
@ -49,7 +50,7 @@ def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
def get_metadata(keys:List[Any], contexts:List[List[TrackedGraphRewrite]]) -> List[List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
kernels: Dict[str, List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
for k,ctxs in zip(keys, contexts):
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
for ctx in ctxs:
if pickle.loads(ctx.sink).op is Ops.CONST: continue
@ -99,6 +100,35 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -
g.graphs.append(sink:=new_sink)
return g
# Profiler API
devices:Dict[str, Tuple[decimal.Decimal, decimal.Decimal, int]] = {}
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
ret = []
for i,e in enumerate(ev.ents):
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
for dep in ev.deps[i]:
d = ev.ents[dep]
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
return ret
def to_perfetto(profile:List[ProfileEvent]):
# Start json with devices.
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
for ev in tqdm(profile, desc="preparing profile"):
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
# ** HTTP server
class Handler(BaseHTTPRequestHandler):
@ -107,6 +137,8 @@ class Handler(BaseHTTPRequestHandler):
if (url:=urlparse(self.path)).path == "/":
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif (url:=urlparse(self.path)).path == "/profiler":
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
elif self.path.startswith("/assets/") and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
@ -120,6 +152,7 @@ class Handler(BaseHTTPRequestHandler):
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
ret, content_type = json.dumps(jret).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
else: status_code = 404
# send response
@ -139,7 +172,16 @@ def reloader():
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
def load_pickle(path:str):
if path is None or not os.path.exists(path): return None
with open(path, "rb") as f: return pickle.load(f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
parser.add_argument('--profile', type=str, help='Path profile', default=None)
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
@ -147,19 +189,23 @@ if __name__ == "__main__":
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")
with open(sys.argv[1], "rb") as f: contexts: Tuple[List[Any], List[List[TrackedGraphRewrite]]] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = get_metadata(*contexts)
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
kernels = get_metadata(*contexts) if contexts is not None else []
if getenv("FUZZ_VIZ"):
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
print(f"fuzzed {len(ret)} rewrite details")
print("*** loaded kernels")
perfetto_profile = to_perfetto(profile) if profile is not None else None
server = HTTPServer(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
if getenv("BROWSER", 0): webbrowser.open(f"{HOST}:{PORT}")
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")