mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz profiler (#8287)
* only hcq * fix get_metadata * linter * oops * tiny * linter * time * print pm * hmm * nits
This commit is contained in:
parent
0794af97db
commit
af87e4b53c
11 changed files with 500 additions and 301 deletions
|
|
@ -300,8 +300,6 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
# Test profile api
|
||||
def test_speed_exec_time(self):
|
||||
TestHCQ.d0._prof_setup()
|
||||
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
|
|
@ -311,7 +309,7 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
|
||||
et = float(sig_en.timestamp - sig_st.timestamp)
|
||||
|
||||
print(f"exec kernel time: {et:.2f} us")
|
||||
assert 0.1 <= et <= (7000 if CI else 100)
|
||||
|
|
@ -319,8 +317,6 @@ class TestHCQ(unittest.TestCase):
|
|||
def test_speed_copy_bandwidth(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
TestHCQ.d0._prof_setup()
|
||||
|
||||
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
||||
SZ = 200_000_000
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
|
|
@ -335,7 +331,7 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
|
||||
et = float(sig_en.timestamp - sig_st.timestamp)
|
||||
et_ms = et / 1e3
|
||||
|
||||
gb_s = ((SZ / 1e9) / et_ms) * 1e3
|
||||
|
|
@ -348,8 +344,6 @@ class TestHCQ(unittest.TestCase):
|
|||
try: _ = Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: self.skipTest("no multidevice, test skipped")
|
||||
|
||||
TestHCQ.d0._prof_setup()
|
||||
|
||||
SZ = 200_000_000
|
||||
b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
|
|
@ -364,7 +358,7 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
|
||||
et = float(sig_en.timestamp - sig_st.timestamp)
|
||||
et_ms = et / 1e3
|
||||
|
||||
gb_s = ((SZ / 1e9) / et_ms) * 1e3
|
||||
|
|
|
|||
|
|
@ -1,73 +1,30 @@
|
|||
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
|
||||
import unittest, struct, contextlib, statistics
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileRangeEvent, ProfileDeviceEvent, ProfileGraphEvent
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def helper_collect_profile(*devs, random_setup_delay=False):
|
||||
ProfileLogger.mjson, ProfileLogger.actors = [], {}
|
||||
def helper_collect_profile(*devs):
|
||||
Compiled.profile_events = []
|
||||
|
||||
if random_setup_delay:
|
||||
devs = list(devs)
|
||||
for dev in devs: dev.synchronize()
|
||||
random.shuffle(devs)
|
||||
for dev in devs:
|
||||
dev._prof_setup()
|
||||
time.sleep(random.randint(1, 1000) / 1000)
|
||||
else:
|
||||
for dev in devs: dev._prof_setup()
|
||||
|
||||
profile_dict = {}
|
||||
_, tmp = tempfile.mkstemp()
|
||||
with Context(PROFILE=1, PROFILEPATH=tmp):
|
||||
try: yield profile_dict
|
||||
profile_list = []
|
||||
with Context(PROFILE=1):
|
||||
try: yield profile_list
|
||||
finally:
|
||||
for dev in devs:
|
||||
dev.synchronize()
|
||||
dev._prof_finalize()
|
||||
atexit.unregister(dev._prof_finalize)
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
for x in Compiled.profile_events: profile_list.append(x)
|
||||
|
||||
for k,v in json.loads(pathlib.Path(tmp).read_text()).items(): profile_dict[k] = v
|
||||
pathlib.Path(tmp).unlink()
|
||||
|
||||
def helper_profile_filter_node(profile, **kwargs):
|
||||
assert len(profile) > 0, "Empty profile"
|
||||
assert 'traceEvents' in profile, "traceEvents should present"
|
||||
return [x for x in profile['traceEvents'] if all(x.get(k, None) == v for k,v in kwargs.items())]
|
||||
|
||||
def helper_profile_parse_pids(profile):
|
||||
pids, tids = {}, {}
|
||||
procs = helper_profile_filter_node(profile, name='process_name')
|
||||
for proc in procs: pids[proc['pid']] = proc['args']['name']
|
||||
threads = helper_profile_filter_node(profile, name='thread_name')
|
||||
for th in threads: tids[th['tid']] = th['args']['name']
|
||||
return pids, tids
|
||||
|
||||
def helper_profile_parse_deps(profile):
|
||||
deps = []
|
||||
for s in helper_profile_filter_node(profile, ph="s"):
|
||||
f = helper_profile_filter_node(profile, ph="f", id=s['id'])[0]
|
||||
|
||||
starts, ends = [], []
|
||||
for x in helper_profile_filter_node(profile, ph="X"):
|
||||
if s['pid'] == x['pid'] and s['tid'] == x['tid'] and x['ts'] <= s['ts'] <= x['ts'] + x['dur']: starts.append(x)
|
||||
if f['pid'] == x['pid'] and f['tid'] == x['tid'] and x['ts'] <= f['ts'] <= x['ts'] + x['dur']: ends.append(x)
|
||||
|
||||
assert len(starts) == 1 and len(ends) == 1, "more than one start and end possible, valid?"
|
||||
deps.append((s, f, starts[0], ends[0]))
|
||||
return deps
|
||||
|
||||
def helper_validate_node(node, duration_s=10, ts_age_s=30, profile=None, pid_name=None, tid_name=None):
|
||||
pids, tids = helper_profile_parse_pids(profile)
|
||||
assert abs(node['ts'] - time.perf_counter_ns() / 1e3) < ts_age_s * 1e6, "timestimp is not in 30s range"
|
||||
assert 0 < node['dur'] < duration_s * 1e6, "duration is not in 10s range"
|
||||
assert pid_name is None or pids[node['pid']] == pid_name
|
||||
assert tid_name is None or tids[node['tid']] == tid_name
|
||||
def helper_profile_filter_device(profile, device:str):
|
||||
assert any(getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent) for x in profile), f"device {device} is not registred"
|
||||
dev_events = [x for x in profile if getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent)]
|
||||
assert len(dev_events) == 1, "only one device registration event is expected"
|
||||
return [x for x in profile if getattr(x, "device", None) == device], dev_events[0]
|
||||
|
||||
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run")
|
||||
class TestProfiler(unittest.TestCase):
|
||||
|
|
@ -90,8 +47,11 @@ class TestProfiler(unittest.TestCase):
|
|||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={})
|
||||
|
||||
kernel_node = helper_profile_filter_node(profile, name=runner_name)[0]
|
||||
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
|
||||
assert len(kernel_runs) == 1, "one kernel run is expected"
|
||||
assert kernel_runs[0].name == runner_name, "kernel name is not correct"
|
||||
assert not kernel_runs[0].is_copy, "kernel should not be copy"
|
||||
|
||||
def test_profile_copyin(self):
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
|
@ -99,8 +59,10 @@ class TestProfiler(unittest.TestCase):
|
|||
with helper_collect_profile(TestProfiler.d0) as profile:
|
||||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
|
||||
copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
|
||||
helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
|
||||
assert len(kernel_runs) == 1, "one kernel run is expected"
|
||||
assert kernel_runs[0].is_copy, "kernel should not be copy"
|
||||
|
||||
def test_profile_multiops(self):
|
||||
runner_name = TestProfiler.runner._prg.name
|
||||
|
|
@ -111,19 +73,19 @@ class TestProfiler(unittest.TestCase):
|
|||
TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={})
|
||||
buf1.as_buffer()
|
||||
|
||||
copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
|
||||
helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
evs = [x for x in profile if isinstance(x, ProfileRangeEvent)]
|
||||
|
||||
kernel_node = helper_profile_filter_node(profile, name=runner_name)[0]
|
||||
helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE")
|
||||
assert len(evs) == 3, "two kernel runs are expected"
|
||||
assert evs[0].is_copy, "kernel should be copy"
|
||||
assert evs[1].name == runner_name, "kernel name is not correct"
|
||||
assert not evs[1].is_copy, "kernel should not be copy"
|
||||
assert evs[2].is_copy, "kernel should be copy"
|
||||
|
||||
copyout_node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> CPU")[0]
|
||||
helper_validate_node(copyout_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
for i in range(1, 3):
|
||||
assert evs[i].st > evs[i-1].en, "timestamp not aranged"
|
||||
|
||||
assert copyin_node['ts'] + copyin_node['dur'] < kernel_node['ts'], "timestamp not aranged"
|
||||
assert kernel_node['ts'] + kernel_node['dur'] < copyout_node['ts'], "timestamp not aranged"
|
||||
|
||||
def test_profile_multidev_copyin(self):
|
||||
def test_profile_multidev(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
|
|
@ -132,25 +94,16 @@ class TestProfiler(unittest.TestCase):
|
|||
buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
buf2.copyin(memoryview(bytearray(struct.pack("ff", 0, 1))))
|
||||
|
||||
copyin_node_1 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0]
|
||||
helper_validate_node(copyin_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
profile0, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
profile1, _ = helper_profile_filter_device(profile, d1.device)
|
||||
|
||||
copyin_node_2 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}:1")[0]
|
||||
helper_validate_node(copyin_node_2, profile=profile, pid_name=f"{Device.DEFAULT}:1", tid_name="DMA")
|
||||
|
||||
def test_profile_multidev_transfer(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
a = Tensor.randn(1 << 20, device=Device.DEFAULT).realize()
|
||||
with helper_collect_profile(TestProfiler.d0, d1) as profile:
|
||||
y = a.to(f"{Device.DEFAULT}:1")
|
||||
y.realize()
|
||||
|
||||
transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0]
|
||||
helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
|
||||
for p in [profile0, profile1]:
|
||||
evs = [x for x in p if isinstance(x, ProfileRangeEvent)]
|
||||
assert len(evs) == 1, "one kernel runs are expected"
|
||||
assert evs[0].is_copy, "kernel should be copy"
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_deps(self):
|
||||
def test_profile_graph(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
|
||||
def f(a):
|
||||
|
|
@ -163,59 +116,40 @@ class TestProfiler(unittest.TestCase):
|
|||
for _ in range(3): jf(a)
|
||||
del jf
|
||||
|
||||
deps = helper_profile_parse_deps(profile)
|
||||
assert len(deps) == 1, "one dep is expected, one launch"
|
||||
graph_evs = [x for x in profile if isinstance(x, ProfileGraphEvent)]
|
||||
|
||||
_, _, l, r = deps[0]
|
||||
assert l['name'].find("->") == -1, "should be kernel"
|
||||
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
|
||||
_, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
_, _ = helper_profile_filter_device(profile, d1.device)
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_copy_args(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
|
||||
def f(a):
|
||||
x = (a + 1).realize()
|
||||
return x, x.to(d1.device).realize()
|
||||
|
||||
a = Tensor.randn(10, 10, device=TestProfiler.d0.device).realize()
|
||||
with helper_collect_profile(TestProfiler.d0, d1) as profile:
|
||||
jf = TinyJit(f)
|
||||
for _ in range(3):
|
||||
TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs
|
||||
d1.raw_prof_records, d1.sig_prof_records = [], []
|
||||
jf(a)
|
||||
del jf
|
||||
|
||||
node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1]
|
||||
assert node['args']['Size'] == "400.00 B"
|
||||
assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01
|
||||
assert len(graph_evs) == 1, "one graph event is expected"
|
||||
assert len(graph_evs[0].ents) == 2, "two entities are expected"
|
||||
|
||||
@unittest.skipIf(CI, "skip CI")
|
||||
def test_profile_sync(self):
|
||||
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))
|
||||
expected_diff = 100000 # sleep in us
|
||||
def test_dev_jitter_matrix(self):
|
||||
dev_cnt = 6
|
||||
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(dev_cnt)]
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
|
||||
devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(6)]
|
||||
bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() for i in range(6)]
|
||||
def _sync_d2d(d1:HCQCompiled, d2:HCQCompiled):
|
||||
d1.hw_compute_queue_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
||||
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
||||
d2.hw_compute_queue_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
||||
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
||||
d1.timeline_value += 2
|
||||
d2.timeline_value += 2
|
||||
d1.timeline_signal.wait(d1.timeline_value - 1)
|
||||
d2.timeline_signal.wait(d2.timeline_value - 1)
|
||||
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
||||
|
||||
# enqueue ops on different queues to check the timer sync
|
||||
cpu_time = []
|
||||
with helper_collect_profile(*devs, random_setup_delay=True) as profile:
|
||||
for i in range(6):
|
||||
x = time.perf_counter_ns()
|
||||
time.sleep(expected_diff / 1e6)
|
||||
bufs[i].copyin(mv)
|
||||
cpu_time.append(((time.perf_counter_ns() - x) / 1000) - expected_diff)
|
||||
|
||||
nodes = [helper_profile_filter_node(profile, name=f"CPU -> {Device.canonicalize(f'{Device.DEFAULT}:{i}')}")[-1] for i in range(6)]
|
||||
avg_diff = []
|
||||
for i in range(1, 6):
|
||||
diff = nodes[i]['ts'] - nodes[i-1]['ts'] - cpu_time[i]
|
||||
avg_diff.append(diff - expected_diff)
|
||||
assert expected_diff * 0.998 < diff < expected_diff * 1.002, "more that 0.2% diff"
|
||||
|
||||
print(f"total avg delay is {sum(avg_diff) / len(avg_diff)} us")
|
||||
# then test it by timing the GPU to GPU times
|
||||
jitter_matrix = [[float('nan')] * len(devs) for _ in range(len(devs))]
|
||||
pairs = [(p1, p2) for p1 in enumerate(devs) for p2 in enumerate(devs) if p1 != p2]
|
||||
for (i1, d1), (i2, d2) in pairs:
|
||||
cpu_diff = d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff
|
||||
jitter_matrix[i1][i2] = statistics.median(_sync_d2d(d1, d2) - _sync_d2d(d2, d1) for _ in range(20)) / 2 - cpu_diff
|
||||
assert abs(jitter_matrix[i1][i2]) < 0.5, "jitter should be less than 0.5ms"
|
||||
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
|
||||
|
|
@ -112,5 +113,77 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(len(ret), 1)
|
||||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
|
||||
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
|
||||
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][1]['tid'], 0)
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
|
||||
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][2]['tid'], 1)
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 0)
|
||||
|
||||
def test_perfetto_copy_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 1)
|
||||
|
||||
def test_perfetto_graph(self):
|
||||
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
|
||||
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
|
||||
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
|
||||
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
|
||||
deps=[[], [0]],
|
||||
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
|
||||
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
|
||||
|
||||
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
|
||||
self.assertEqual(j['traceEvents'][6]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][6]['dur'], 2)
|
||||
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
|
||||
|
||||
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
|
||||
self.assertEqual(j['traceEvents'][7]['ts'], 954)
|
||||
self.assertEqual(j['traceEvents'][7]['dur'], 4)
|
||||
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Dict, Tuple, Any, Iterator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re
|
||||
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
||||
from typing import Optional, Dict, Tuple, Any, Iterator, List, Set
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp, buffers
|
||||
|
|
@ -13,6 +13,7 @@ from tinygrad.ops import UOp, buffers
|
|||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
self._opened_devices:Set[str] = set()
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
||||
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
||||
|
|
@ -26,6 +27,7 @@ class _Device:
|
|||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
|
||||
if (cname.lower() == x.lower() + "device")][0](ix)
|
||||
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
||||
self._opened_devices.add(ix)
|
||||
return ret
|
||||
@property
|
||||
def default(self) -> Compiled: return self[self.DEFAULT]
|
||||
|
|
@ -42,6 +44,23 @@ class _Device:
|
|||
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
|
||||
Device = _Device()
|
||||
|
||||
# **************** Profile ****************
|
||||
|
||||
class ProfileEvent: pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileDeviceEvent(ProfileEvent):
|
||||
device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProfileGraphEvent(ProfileEvent): ents:List[ProfileGraphEntry]; deps:List[List[int]]; sigs:List[decimal.Decimal] # noqa: E702
|
||||
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
|
||||
|
|
@ -202,6 +221,8 @@ class Compiler:
|
|||
def disassemble(self, lib:bytes): pass
|
||||
|
||||
class Compiled:
|
||||
profile_events:List[ProfileEvent] = []
|
||||
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
|
||||
self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
|
||||
self.renderer = renderer or Renderer()
|
||||
|
|
@ -212,6 +233,11 @@ class Compiled:
|
|||
This method ensures that all previously queued operations on the device have been completed before proceeding.
|
||||
"""
|
||||
# override this in your device implementation
|
||||
def _at_profile_finalize(self):
|
||||
"""
|
||||
Called at the end of profiling to allow the device to finalize any profiling.
|
||||
"""
|
||||
# override this in your device implementation
|
||||
|
||||
# TODO: move this to each Device
|
||||
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
||||
|
|
@ -232,3 +258,15 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
|
|||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
if PROFILE:
|
||||
@atexit.register
|
||||
def finlize_profile():
|
||||
devs = [Device[d] for d in Device._opened_devices]
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
|
||||
with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f)
|
||||
|
||||
from tinygrad.ops import launch_viz
|
||||
launch_viz("PROFILE", temp("profile.pkl"))
|
||||
|
|
|
|||
|
|
@ -97,8 +97,7 @@ class ContextVar:
|
|||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
||||
WINO, CAPTURING, TRACEMETA, PROFILE = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("PROFILE", 0)
|
||||
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
|
|
|
|||
|
|
@ -891,9 +891,7 @@ if TRACK_MATCH_STATS:
|
|||
with open(fn:=temp("rewrites.pkl"), "wb") as f:
|
||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||
pickle.dump((tracked_keys, tracked_ctxs), f)
|
||||
if getenv("VIZ"):
|
||||
os.environ["VIZ"] = "0"
|
||||
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")])
|
||||
launch_viz("VIZ", temp("rewrites.pkl"))
|
||||
if getenv("PRINT_MATCH_STATS", 1):
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
||||
|
|
@ -902,6 +900,14 @@ if TRACK_MATCH_STATS:
|
|||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
|
||||
|
||||
def launch_viz(env_str:str, data:str):
|
||||
os.environ[env_str] = "0"
|
||||
os.environ[f"{env_str}_DATA"] = data
|
||||
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
|
||||
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
|
||||
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
|
||||
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
class RewriteContext:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.helpers import round_up, PROFILE
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
|
|
@ -51,8 +51,11 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.kickoff_value: int = 0
|
||||
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
||||
|
||||
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
||||
# TODO: This logic might allocate a few extra signals...
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
||||
self.prog_graph_deps: List[List[int]] = []
|
||||
self.prof_graph_entries: List[ProfileGraphEntry] = []
|
||||
|
||||
last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
||||
queue_access: Dict[HWQueue, Dict[HWQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
||||
|
|
@ -102,18 +105,20 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
# Collect profile information if profiling is enabled.
|
||||
if PROFILE:
|
||||
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
|
||||
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
|
||||
|
||||
# Description based on the command.
|
||||
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
|
||||
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
||||
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
||||
|
||||
if is_exec_prg: prof_args = None
|
||||
else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
|
||||
|
||||
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
|
||||
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
||||
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Check which signals are used in the profile graph.
|
||||
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
|
||||
|
||||
# Build hardware queues.
|
||||
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
|
|
@ -132,7 +137,7 @@ class HCQGraph(MultiGraphRunner):
|
|||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
|
||||
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
|
|
@ -145,7 +150,7 @@ class HCQGraph(MultiGraphRunner):
|
|||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
|
||||
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
||||
|
||||
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
||||
|
||||
|
|
@ -189,14 +194,8 @@ class HCQGraph(MultiGraphRunner):
|
|||
return None
|
||||
|
||||
def collect_timestamps(self):
|
||||
timestamps = [s.timestamp for s in self.prof_signals]
|
||||
|
||||
for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
|
||||
|
||||
for x in deps:
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
||||
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
||||
# NOTE: Append to any device is fine...
|
||||
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, Tuple, cast, Type, Union, TypeVar, Generic, Any
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv, round_up
|
||||
from typing import List, Optional, Dict, Tuple, cast, Type, TypeVar, Generic, Any
|
||||
import contextlib, decimal, statistics, time, ctypes, array
|
||||
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
|
||||
from tinygrad.ops import sym_infer, sint, Variable
|
||||
|
||||
# **************** for HCQ Compatible Devices ****************
|
||||
|
|
@ -294,51 +294,11 @@ class HCQProgram(Generic[DeviceType]):
|
|||
if wait: self.dev.synchronize()
|
||||
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
||||
|
||||
class ProfileLogger:
|
||||
writers: int = 0
|
||||
mjson: List[Dict] = []
|
||||
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
||||
|
||||
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
||||
|
||||
def _ensure_actor(self, actor_name, subactor_name):
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
||||
self.actors[subactor_key] = (tid:=len(self.actors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
||||
|
||||
def __del__(self):
|
||||
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
||||
for name, st, et, actor_name, subactor_name, args in self.events:
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
||||
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
||||
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
||||
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
||||
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
||||
|
||||
class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
"""
|
||||
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
||||
"""
|
||||
devices: List[HCQCompiled] = []
|
||||
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
|
||||
|
|
@ -350,7 +310,6 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|||
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
|
||||
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
|
@ -367,13 +326,11 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|||
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE:
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
||||
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
|
||||
def _ensure_shared_time_base(self):
|
||||
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
||||
|
||||
def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWQueue]):
|
||||
def _at_profile_finalize(self):
|
||||
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
|
||||
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
||||
d.timeline_value += 1
|
||||
st = time.perf_counter_ns()
|
||||
|
|
@ -381,65 +338,10 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|||
et = time.perf_counter_ns()
|
||||
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
||||
|
||||
# randomly sample the timing from GPU to CPU
|
||||
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
|
||||
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
|
||||
for _ in range(100*len(self.devices)):
|
||||
d,q,l = random.choice(choices)
|
||||
l.append(_sync_cpu_queue(d,q))
|
||||
for d,q,l in choices:
|
||||
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
||||
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
||||
|
||||
def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWQueue], q2_t:Type[HWQueue]):
|
||||
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
||||
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
||||
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
||||
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
||||
d1.timeline_value += 2
|
||||
d2.timeline_value += 2
|
||||
d1.timeline_signal.wait(d1.timeline_value - 1)
|
||||
d2.timeline_signal.wait(d2.timeline_value - 1)
|
||||
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
||||
|
||||
# then test it by timing the GPU to GPU times
|
||||
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
||||
for i1, d1 in enumerate(self.devices):
|
||||
for i2, d2 in enumerate(self.devices):
|
||||
if d1 == d2: continue
|
||||
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
||||
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
||||
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
||||
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
||||
|
||||
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
||||
"""
|
||||
Translates local gpu time (timestamp) into global cpu time.
|
||||
"""
|
||||
self._ensure_shared_time_base()
|
||||
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
||||
|
||||
def _prof_setup(self):
|
||||
if hasattr(self, 'profile_logger'): return
|
||||
atexit.register(self._prof_finalize)
|
||||
self.profile_logger = ProfileLogger()
|
||||
|
||||
def _prof_finalize(self):
|
||||
qname = ["COMPUTE", "DMA"]
|
||||
|
||||
# Sync to be sure all events on the device are recorded.
|
||||
self.synchronize()
|
||||
|
||||
for st, en, name, is_cp, args in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.device, qname[is_cp], args)]
|
||||
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
||||
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
||||
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
||||
self.profile_logger.deps += [(a_tm, b_tm, a_dev.device, qname[a_is_copy], b_dev.device, qname[b_is_copy])]
|
||||
self.raw_prof_records, self.dep_prof_records = [], []
|
||||
|
||||
# Remove the logger, this flushes all data written by the device.
|
||||
del self.profile_logger
|
||||
gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)])
|
||||
if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0)
|
||||
else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)])
|
||||
Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)]
|
||||
|
||||
def _wrap_timeline_signal(self):
|
||||
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
||||
|
|
|
|||
|
|
@ -130,19 +130,46 @@
|
|||
#metadata-resize-handle {
|
||||
left: 0;
|
||||
}
|
||||
.collapse-btn {
|
||||
.floating-container {
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
left: 20px;
|
||||
z-index: 4;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 8px;
|
||||
}
|
||||
.nav-btn {
|
||||
background-color: #1a1b26;
|
||||
border: 1px solid #4a4b56;
|
||||
color: #f0f0f5;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
padding: 6px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
z-index: 4;
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 6px;
|
||||
font-weight: bold;
|
||||
}
|
||||
.collapse-btn {
|
||||
width: 32px;
|
||||
padding: 6px;
|
||||
}
|
||||
.btn {
|
||||
height: 32px;
|
||||
background-color: #1a1b26;
|
||||
border: 1px solid #4a4b56;
|
||||
color: #f0f0f5;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition-duration: .5s;
|
||||
}
|
||||
.btn:hover {
|
||||
background-color: #2a2b36;
|
||||
border-color: #5a5b66;
|
||||
color: #ffffff;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
.collapsed .kernel-list, .collapsed .metadata {
|
||||
width: 0;
|
||||
|
|
@ -170,9 +197,12 @@
|
|||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<button class="collapse-btn">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
|
||||
</button>
|
||||
<div class="floating-container">
|
||||
<button class="btn collapse-btn">
|
||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M15 19l-7-7 7-7"/></svg>
|
||||
</button>
|
||||
<a class="btn nav-btn" href="/profiler">Profiler</a>
|
||||
</div>
|
||||
<div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
|
||||
<div class="graph">
|
||||
<svg id="graph-svg">
|
||||
|
|
|
|||
178
tinygrad/viz/perfetto.html
Normal file
178
tinygrad/viz/perfetto.html
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en-us">
|
||||
<head>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body, html {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
overflow: hidden;
|
||||
font-family: "Noto Sans", sans-serif;
|
||||
font-optical-sizing: auto;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
font-variation-settings: "wdth" 100;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
header {
|
||||
width: 100%;
|
||||
height: 48px;
|
||||
background: #0f1018;
|
||||
}
|
||||
|
||||
#perfetto_frame {
|
||||
width: 100vw;
|
||||
height: calc(100vh - 40px);
|
||||
border: none;
|
||||
}
|
||||
|
||||
#loading_screen {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
background: #0f1018;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
color: white;
|
||||
font-family: Arial, sans-serif;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.loader {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
border: 5px solid #FFF;
|
||||
border-bottom-color: transparent;
|
||||
border-radius: 50%;
|
||||
margin-bottom: 20px;
|
||||
animation: rotation 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes rotation {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
#loading_text {
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.floating-container {
|
||||
position: fixed;
|
||||
top: 8px;
|
||||
left: 16px;
|
||||
z-index: 4;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 8px;
|
||||
}
|
||||
.nav-btn {
|
||||
background-color: #1a1b26;
|
||||
border: 1px solid #4a4b56;
|
||||
color: #f0f0f5;
|
||||
height: 32px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0 6px;
|
||||
font-weight: bold;
|
||||
}
|
||||
.btn {
|
||||
height: 32px;
|
||||
background-color: #1a1b26;
|
||||
border: 1px solid #4a4b56;
|
||||
color: #f0f0f5;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition-duration: .5s;
|
||||
}
|
||||
.btn:hover {
|
||||
background-color: #2a2b36;
|
||||
border-color: #5a5b66;
|
||||
color: #ffffff;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header></header>
|
||||
<div class="floating-container">
|
||||
<a class="btn nav-btn" href="/">UOps</a>
|
||||
</div>
|
||||
|
||||
<div id="loading_screen">
|
||||
<div class="loader" id="spinner"></div>
|
||||
<div id="loading_text">Loading trace data...</div>
|
||||
</div>
|
||||
|
||||
<iframe id="perfetto_frame" src="https://ui.perfetto.dev" allow="clipboard-write"></iframe>
|
||||
|
||||
<script type="text/javascript">
|
||||
const ORIGIN = 'https://ui.perfetto.dev';
|
||||
const API_ENDPOINT = '/get_profile';
|
||||
const iframe = document.getElementById('perfetto_frame');
|
||||
const loadingScreen = document.getElementById('loading_screen');
|
||||
|
||||
async function fetchFromApi() {
|
||||
try {
|
||||
const response = await fetch(API_ENDPOINT);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const jsonData = await response.json();
|
||||
|
||||
// Convert JSON to string and then to blob
|
||||
const jsonString = JSON.stringify(jsonData);
|
||||
const blob = new Blob([jsonString], { type: 'application/json' });
|
||||
const arrayBuffer = await blob.arrayBuffer();
|
||||
|
||||
openTrace(arrayBuffer);
|
||||
} catch (error) {
|
||||
document.getElementById('spinner').remove();
|
||||
document.getElementById('loading_text').innerHTML = 'Error loading trace data.<br>Please ensure PROFILE=1 is set and the VIZ server is running.';
|
||||
console.error('Error loading trace:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function openTrace(arrayBuffer) {
|
||||
const timer = setInterval(() => iframe.contentWindow.postMessage('PING', ORIGIN), 50);
|
||||
|
||||
const onMessageHandler = (evt) => {
|
||||
if (evt.data !== 'PONG') return;
|
||||
loadingScreen.style.transition = 'opacity 0.5s';
|
||||
loadingScreen.style.opacity = '0';
|
||||
setTimeout(() => {
|
||||
loadingScreen.style.display = 'none';
|
||||
}, 500);
|
||||
|
||||
window.clearInterval(timer);
|
||||
window.removeEventListener('message', onMessageHandler);
|
||||
|
||||
iframe.contentWindow.postMessage({
|
||||
perfetto: {
|
||||
buffer: arrayBuffer,
|
||||
title: 'Profile Viewer',
|
||||
url: location.href,
|
||||
}
|
||||
}, ORIGIN);
|
||||
};
|
||||
|
||||
window.addEventListener('message', onMessageHandler);
|
||||
}
|
||||
|
||||
window.onload = fetchFromApi;
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket
|
||||
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
|
|
@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple, Optional
|
|||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
||||
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
|
|
@ -49,7 +50,7 @@ def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
|||
|
||||
def get_metadata(keys:List[Any], contexts:List[List[TrackedGraphRewrite]]) -> List[List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
|
||||
kernels: Dict[str, List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
|
||||
for k,ctxs in zip(keys, contexts):
|
||||
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
|
||||
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
|
||||
for ctx in ctxs:
|
||||
if pickle.loads(ctx.sink).op is Ops.CONST: continue
|
||||
|
|
@ -99,6 +100,35 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -
|
|||
g.graphs.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
# Profiler API
|
||||
devices:Dict[str, Tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
||||
def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy])
|
||||
def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)}
|
||||
def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent):
|
||||
devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices))
|
||||
return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}},
|
||||
{"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}]
|
||||
def range_ev_to_perfetto_json(ev:ProfileRangeEvent):
|
||||
return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}]
|
||||
def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt):
|
||||
ret = []
|
||||
for i,e in enumerate(ev.ents):
|
||||
st, en = ev.sigs[e.st_id], ev.sigs[e.en_id]
|
||||
ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}]
|
||||
for dep in ev.deps[i]:
|
||||
d = ev.ents[dep]
|
||||
ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}]
|
||||
ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}]
|
||||
return ret
|
||||
def to_perfetto(profile:List[ProfileEvent]):
|
||||
# Start json with devices.
|
||||
prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)]
|
||||
for ev in tqdm(profile, desc="preparing profile"):
|
||||
if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev)
|
||||
elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json))
|
||||
return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None
|
||||
|
||||
# ** HTTP server
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
|
|
@ -107,6 +137,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
if (url:=urlparse(self.path)).path == "/":
|
||||
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
|
||||
elif (url:=urlparse(self.path)).path == "/profiler":
|
||||
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
|
||||
elif self.path.startswith("/assets/") and '/..' not in self.path:
|
||||
try:
|
||||
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
|
||||
|
|
@ -120,6 +152,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]}
|
||||
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
|
||||
ret, content_type = json.dumps(jret).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
||||
else: status_code = 404
|
||||
|
||||
# send response
|
||||
|
|
@ -139,7 +172,16 @@ def reloader():
|
|||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
||||
time.sleep(0.1)
|
||||
|
||||
def load_pickle(path:str):
|
||||
if path is None or not os.path.exists(path): return None
|
||||
with open(path, "rb") as f: return pickle.load(f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--kernels', type=str, help='Path to kernels', default=None)
|
||||
parser.add_argument('--profile', type=str, help='Path profile', default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
|
||||
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
|
||||
|
|
@ -147,19 +189,23 @@ if __name__ == "__main__":
|
|||
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
|
||||
st = time.perf_counter()
|
||||
print("*** viz is starting")
|
||||
with open(sys.argv[1], "rb") as f: contexts: Tuple[List[Any], List[List[TrackedGraphRewrite]]] = pickle.load(f)
|
||||
print("*** unpickled saved rewrites")
|
||||
kernels = get_metadata(*contexts)
|
||||
|
||||
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
|
||||
|
||||
kernels = get_metadata(*contexts) if contexts is not None else []
|
||||
|
||||
if getenv("FUZZ_VIZ"):
|
||||
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
|
||||
print(f"fuzzed {len(ret)} rewrite details")
|
||||
print("*** loaded kernels")
|
||||
|
||||
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
||||
|
||||
server = HTTPServer(('', PORT), Handler)
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
reloader_thread.start()
|
||||
print(f"*** started viz on {HOST}:{PORT}")
|
||||
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
|
||||
if getenv("BROWSER", 0): webbrowser.open(f"{HOST}:{PORT}")
|
||||
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}")
|
||||
try: server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("*** viz is shutting down...")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue