Merge branch 'master' into dsp_search

This commit is contained in:
George Hotz 2025-03-26 10:50:00 +08:00 committed by GitHub
commit 60cbfe4222
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 39604 additions and 39308 deletions

View file

@ -574,7 +574,7 @@ jobs:
run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py --durations=20
- name: Run pytest (amd with llvm backend)
if: matrix.backend=='amd'
run: python -m pytest -n=auto test/test_amd_llvm.py --durations=20
run: AMD_LLVM=1 python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py test/test_amd_llvm.py --durations=20
- name: Run TRANSCENDENTAL math
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
- name: Run process replay tests
@ -630,8 +630,10 @@ jobs:
env:
MOCKGPU: 1
AMD: 1
AMD_LLVM: 1
FORWARD_ONLY: 1
run: |
python -m pytest -n=auto test/test_amd_llvm.py --durations=20
python -m pytest -n=auto test/test_hcq.py test/test_tiny.py test/test_amd_llvm.py --durations=20
- name: Run pytest (ptx)
env:
MOCKGPU: 1

View file

@ -297,14 +297,21 @@ generate_am() {
extra/amdpci/headers/amdgpu_vm.h \
extra/amdpci/headers/discovery.h \
extra/amdpci/headers/amdgpu_ucode.h \
extra/amdpci/headers/soc21_enum.h \
extra/amdpci/headers/psp_gfx_if.h \
extra/amdpci/headers/amdgpu_psp.h \
extra/amdpci/headers/amdgpu_irq.h \
extra/amdpci/headers/amdgpu_doorbell.h \
extra/amdpci/headers/soc15_ih_clientid.h \
--clang-args="-include stdint.h" \
-o $BASE/am/am.py
fixup $BASE/am/am.py
sed -i "s\(int64_t)\ \g" $BASE/am/am.py
sed -i "s\AMDGPU_PTE_MTYPE_VG10(2)\AMDGPU_PTE_MTYPE_VG10(0, 2)\g" $BASE/am/am.py # incorrect parsing (TODO: remove when clang2py is gone).
clang2py -k cdefstum \
extra/amdpci/headers/soc21_enum.h \
-o $BASE/am/soc21.py
fixup $BASE/am/soc21.py
clang2py -k cdefstum \
extra/amdpci/headers/mp_13_0_0_offset.h \

View file

@ -79,6 +79,8 @@ class AMSMI(AMDev):
self.psp:AM_PSP = AM_PSP(self)
self.smu:AM_SMU = AM_SMU(self)
for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu]: ip.init_sw()
def read_pci_state(self):
with open(f"/sys/bus/pci/devices/{self.pcibus}/power_state", "r") as f: return f.read().strip().rstrip()

View file

@ -33,21 +33,46 @@ torch._register_device_module("tiny", TinyBackend())
torch.utils.generate_methods_for_privateuse1_backend()
aten = torch.ops.aten
# track view relationships for in place operations
def is_view(tensor: Tensor): return hasattr(tensor, "_view_base")
def canonical_base(view: Tensor): return getattr(view, "_view_base", view)
def derived_views(base: Tensor): return [t for tref in getattr(base, "_views", set()) if (t:=tref()) is not None]
def wrap_view_op(fn):
def _wrap(*args,**kwargs):
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
ret = fn(*args,**kwargs)
ret._view_base = base = canonical_base(args[0])
if not hasattr(base, "_views"): base._views = set()
base._views.add(weakref.ref(ret))
return wrap(ret)
return _wrap
view_ops = {
"aten.view": Tensor.reshape,
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
"aten.view.dtype": lambda self,dtype: self.bitcast(_from_torch_dtype(dtype)),
"aten.expand": Tensor.expand,
"aten.t": Tensor.transpose,
"aten.squeeze.dim": Tensor.squeeze,
"aten.unsqueeze": Tensor.unsqueeze,
"aten.repeat": Tensor.repeat,
"aten.detach": Tensor.detach,
}
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
# in place operations with views
def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None
def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]):
assert self.is_tiny
self = unwrap(self)
def realize_with_views(self: Tensor, views: Tensor):
if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support?
self.replace(self.clone().realize())
for v in views:
v = unwrap(v)
ret = self
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
v.replace(ret)
def maybe_realize_storage(self: torch.Tensor) -> bool:
if realize:=is_view(self): realize_with_views(self._base, [self]) # TODO: other views could exist
def maybe_realize_storage(self: Tensor) -> bool:
if realize:=is_view(self): realize_with_views((base:=canonical_base(self)), derived_views(base))
return realize
def inplace_fn(outvars: str|list[str]):
if type(outvars) is str: outvars = [outvars]
@ -56,9 +81,10 @@ def inplace_fn(outvars: str|list[str]):
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
outs = [kwargs.get(v, bound.arguments.get(v)) for v in outvars]
outs = [unwrap(o) if isinstance(o, torch.Tensor) else o for o in outs]
realize = any(maybe_realize_storage(o) for o in outs)
ret = fn(*args, **kwargs)
if realize: Tensor.realize(*(unwrap(o) for o in outs))
if realize: Tensor.realize(*(o for o in outs))
return ret
return wrapper
return decorator
@ -137,14 +163,15 @@ def cached_to_movement_ops(shape, st) -> list:
from tinygrad.shape.shapetracker import ShapeTracker, View
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
@wrap_view_op
def as_strided(tensor:Tensor, size, stride, storage_offset=None):
# TODO: this is heavyweight
st = ShapeTracker((View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)))
ret = unwrap(tensor)
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
ret = tensor
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
for mo in cached_to_movement_ops(tuple(tensor.shape), st): ret = apply_mop(ret, mo)
return wrap(ret)
return ret
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
@ -236,6 +263,7 @@ for i,pre in enumerate(["", "bi", "tri"]):
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact"))
@torch.library.impl("aten::scatter_add.out", "privateuseone")
@inplace_fn("out")
def scatter_add(self, dim, index, src, out):
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
if self.shape == (): return wrap(out.assign(src))
@ -243,7 +271,7 @@ def scatter_add(self, dim, index, src, out):
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
realize = dest.is_tiny and maybe_realize_storage(dest)
realize = dest.is_tiny and maybe_realize_storage(unwrap(dest))
cast_dtype = _from_torch_dtype(dest.dtype)
if src.is_tiny and dest.is_tiny:
to_device = _from_torch_device(dest.device)
@ -413,6 +441,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
# we add the "out" here
def wrap_out(f):
@inplace_fn("out")
def _wrap_out(*args, **kwargs):
out = kwargs.pop('out')
assigned = f(*args, **kwargs)
@ -425,8 +454,6 @@ def wrap_out(f):
return _wrap_out
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.view": Tensor.reshape,
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
"aten.floor_divide": lambda x,y: x//y,
"aten.floor_divide_.Tensor": inplace_fn("x")(lambda x,y: x.assign(x//y)),
@ -486,25 +513,19 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.fill_.Tensor": Tensor.full, # TODO: looks wrong
"aten.flip": Tensor.flip,
"aten.scatter_reduce.two": Tensor.scatter_reduce,
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True),
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
"aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)),
"aten.topk": Tensor.topk,
"aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value),
"aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim),
"aten.squeeze.dim": Tensor.squeeze,
"aten.unsqueeze": Tensor.unsqueeze,
"aten.roll": Tensor.roll,
"aten.logcumsumexp": Tensor.logcumsumexp,
"aten.repeat": Tensor.repeat,
"aten.lerp.Tensor": Tensor.lerp,
"aten.expand": Tensor.expand,
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
"aten.t": Tensor.transpose,
"aten.detach": Tensor.detach,
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
}}
@ -519,9 +540,7 @@ def wrap_fxn(k,f):
if isinstance(out, Tensor): return wrap(out)
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
else: raise RuntimeError(f"unknown output type {type(out)}")
def nf2(*args, **kwargs):
return inplace_fn("out")(nf)(*args, **kwargs) if "out" in kwargs else nf(*args, **kwargs)
return nf2
return nf
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))

View file

@ -61,5 +61,11 @@ class TestTorchBackendInplace(unittest.TestCase):
b[1:3,:] += 1
np.testing.assert_equal(a.cpu().numpy(), [[0]*4,[1]*4,[1]*4,[0]*4])
def test_detach(self):
a = torch.zeros(4)
d = a.detach()
d += torch.arange(4)
np.testing.assert_array_equal(a.cpu(), torch.arange(4).cpu())
if __name__ == "__main__":
unittest.main()

View file

@ -23,7 +23,7 @@ dtypes_bool = (dtypes.bool,)
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
# TODO: LLVM comparing with nan is incorrect
if Device.DEFAULT == "LLVM":
if Device.DEFAULT == "LLVM" or getenv("AMD_LLVM", 0):
binary_operations.remove(operator.lt)
integer_binary_operations = binary_operations + [(Tensor.bitwise_xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),

View file

@ -79,5 +79,35 @@ class TestGC(unittest.TestCase):
print(inspect.getclosurevars(UOp.toposort.fget))
raise AssertionError(f"never gced {[x for x in gc.get_objects() if isinstance(x, Buffer)]}")
def test_buffer_refcount(self):
init = bufs_allocated()
a = Tensor.empty(10)
self.assertEqual(bufs_allocated()-init, 0)
a.realize()
real_buf = a.lazydata.buffer
# after the Tensor UOp is deleted there shouldn't be any references on the Buffer
self.assertEqual(real_buf.lb_refcount, 1)
self.assertEqual(bufs_allocated()-init, 1)
del a.lazydata
self.assertEqual(real_buf.lb_refcount, 0)
self.assertEqual(bufs_allocated()-init, 1) # keep the buffer alive
del real_buf
self.assertEqual(bufs_allocated()-init, 0)
def test_assign_refcount(self):
init = bufs_allocated()
a = Tensor.full((4,), 1.).contiguous()
a.realize()
real_buf = a.lazydata.buffer
self.assertEqual(real_buf.lb_refcount, 1)
a.assign(Tensor.full((4,), 2.))
self.assertIs(a.lazydata.src[0].buffer, real_buf)
# NOTE: this is still 1, we don't count the ASSIGN
self.assertEqual(real_buf.lb_refcount, 1)
a.realize()
del a
self.assertEqual(real_buf.lb_refcount, 0) # no UOps for this Buffer
self.assertEqual(bufs_allocated()-init, 1) # Buffer is alive
if __name__ == '__main__':
unittest.main()

View file

@ -277,6 +277,38 @@ class TestJit(unittest.TestCase):
assert len(res3) == 5, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # TODO: fix
def test_jit_v_nojit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
rn = rn * a
rn2 = Tensor.randn(*a.shape)
rn2 = rn2 * b
rn = rn + rn2
rn2 = rn2 + Tensor.randn(*a.shape)
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
Tensor.manual_seed(0)
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
without_jit = set()
for _ in range(5):
o1, o2 = f(a, b)
without_jit.add(o1.numpy()[0][0])
without_jit.add(o2.numpy()[0][0])
assert len(without_jit) == 10, "All values should be different."
Tensor.manual_seed(1234)
jf = TinyJit(f)
with_jit = set()
for _ in range(5):
o1, o2 = jf(a, b)
with_jit.add(o1.numpy()[0][0])
with_jit.add(o2.numpy()[0][0])
assert len(with_jit) == 10, "All values should be different."
assert with_jit == without_jit, "Jit rand produced different values from no jit."
def test_jit_multiple_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)

View file

@ -534,6 +534,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
helper_test_op([(), ()], lambda x,y: x/y)
@unittest.skipIf(getenv("AMD_LLVM", 0), "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors")
def test_div_rounding_mode(self):
for denominator in [-10, -5, -3, -2, -1, 1, 2, 3, 5, 10]:
# int numerator

View file

@ -120,9 +120,8 @@ class TestRandomness(unittest.TestCase):
0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749,
0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055,
0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32)
r = Tensor.rand(20).numpy()
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
# next 20, np.arange(20, 40, dtype=np.uint32)
jr = np.array([0.7444133758544922, 0.7713677883148193, 0.8233780860900879, 0.43871235847473145, 0.517757773399353,
@ -130,7 +129,7 @@ class TestRandomness(unittest.TestCase):
0.28920769691467285, 0.017063498497009277, 0.2627382278442383, 0.9525482654571533, 0.9351049661636353,
0.43904995918273926, 0.043945908546447754, 0.6616791486740112, 0.6667773723602295, 0.5228077173233032], dtype=np.float32)
r = Tensor.rand(20).numpy()
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
# next 10, np.arange(40, 50, dtype=np.uint32)
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
@ -138,7 +137,7 @@ class TestRandomness(unittest.TestCase):
r = Tensor.rand(10).numpy()
# TODO: this failed because increment happened before _threefry_random_bits
with self.assertRaises(AssertionError):
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefry_tensors_cnt(self):

View file

@ -329,11 +329,11 @@ def simplify_valid(valid:UOp) -> UOp|None:
# ***** threefry *****
def threefry2x32(x: UOp, key: UOp):
# split x into two uint32, since x in a uint64
# split x and key from uint64 to two uint32
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
xr = [x0 + ks[-1], x1 + ks[0]]
for i in range(5):

View file

@ -453,6 +453,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# final toposort (bfs)
children: dict[UOp, list[UOp]] = {}
@ -473,8 +474,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
# increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here
k.src[0].buffer.ref(1)
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View file

@ -536,6 +536,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
from tinygrad.device import Buffer
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
ret.ref(1)
return ret
@property
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None

View file

@ -1,17 +1,17 @@
from typing import cast
import math, struct, sys
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.cstyle import ClangRenderer, AMDRenderer
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
from tinygrad.helpers import prod, AMX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
if isinstance(dt, PtrDType): return ldt(dt.base) + (" addrspace(3)*" if dt.local else "*")
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
def lconst(x, dtype:DType):
if dtype in dtypes.floats:
@ -63,7 +63,8 @@ base_rewrite = PatternMatcher([
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
# GEP/VECTORIZE/CAST for float4 support
@ -113,7 +114,7 @@ class LLVMRenderer(Renderer):
supports_float4 = True
has_local = False
has_shared = False
global_max = None
global_max: tuple[int, ...] | None = None
string_rewrite = base_rewrite
if AMX: tensor_cores = ClangRenderer.amx_tc
@ -126,6 +127,12 @@ class LLVMRenderer(Renderer):
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
# copied from cstyle.py, add float intermediate casting
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
])
def render(self, uops: list[UOp]) -> str:
@ -135,6 +142,7 @@ class LLVMRenderer(Renderer):
end_lines: dict[str, None] = {}
vc = -1
local_args: list[str] = []
acc_to_assign: dict[UOp, UOp] = {}
for u in uops:
if u.op is Ops.ASSIGN: # prealloc all assigns
@ -158,6 +166,10 @@ class LLVMRenderer(Renderer):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
# NOTE: MallocAllocator promises 0x20 alignment
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
elif u.op == Ops.DEFINE_LOCAL:
r[u] = f"@local_{u.arg}"
assert isinstance(u.dtype, PtrDType)
local_args.append(f"{r[u]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
@ -182,7 +194,7 @@ class LLVMRenderer(Renderer):
r[x] = f"%acc{vc}"
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
return f'''\
prg = f'''\
define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
{chr(10).join(kernel)}
ret void
@ -190,3 +202,19 @@ define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(ar
{chr(10).join(end_lines.keys())}
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
'''
return prg if len(local_args) == 0 else "\n".join(local_args)+f"\n{prg}"
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
class AMDLLVMRenderer(LLVMRenderer):
device = "AMD"
has_local = True
has_shared = True
shared_max = AMDRenderer.shared_max
global_max = AMDRenderer.global_max
abi = "amdgpu_kernel"
string_rewrite = base_rewrite + PatternMatcher([
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
(UPat(Ops.BARRIER), lambda ctx: barrier),
])

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -8,9 +8,10 @@ from tinygrad.ops import sint
from tinygrad.device import Compiled, ProfileEvent, BufferSpec, CPUProgram, PROFILE
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address, DEBUG, OSX
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc, pci, vfio, sqtt
from tinygrad.runtime.autogen.am import am, gc_11_0_0
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMapping
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
@ -706,7 +707,8 @@ class AMDDevice(HCQCompiled):
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000)
super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), HIPCompiler(self.arch), functools.partial(AMDProgram, self),
super().__init__(device, AMDAllocator(self), AMDLLVMRenderer() if getenv("AMD_LLVM", 0) else AMDRenderer(self.arch),
AMDLLVMCompiler(self.arch) if getenv("AMD_LLVM", 0) else HIPCompiler(self.arch), functools.partial(AMDProgram, self),
AMDSignal, AMDComputeQueue, AMDCopyQueue)
# Scratch setup

View file

@ -0,0 +1,17 @@
from tinygrad.device import Compiled, Compiler, Renderer, Allocator
class NullRenderer(Renderer):
def render(self, uops:list) -> str: return ""
class NullProgram:
def __init__(self, name:str, lib:bytes): pass
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
return 1e-4
class NullAllocator(Allocator):
def _alloc(self, size, options): return "null"
def _copyin(self, dest, src:memoryview): pass
def _copyout(self, dest:memoryview, src): pass
class NullDevice(Compiled):
def __init__(self, device:str): super().__init__(device, NullAllocator(), NullRenderer(), Compiler(), NullProgram)

View file

@ -1,35 +1,26 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
import ctypes, collections, time, dataclasses, functools, pathlib, fcntl, os, importlib
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0
from tinygrad.runtime.support.amd import AMDRegBase, collect_registers
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
AM_DEBUG = getenv("AM_DEBUG", 0)
@dataclasses.dataclass(frozen=True)
class AMRegister:
adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
class AMRegister(AMDRegBase):
adev:AMDev; hwip:int # noqa: E702
def _parse_kwargs(self, **kwargs):
mask, values = 0xffffffff, 0
for k, v in kwargs.items():
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
m, s = self.reg_fields[k]
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
mask &= ~m
values |= v << s
return mask, values
@property
def addr(self): return self.adev.regs_offset[self.hwip][0][self.segment] + self.offset
def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
def read(self): return self.adev.rreg(self.addr)
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
def update(self, **kwargs): self.write(value=self.read(), **kwargs)
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
def write(self, value=0, **kwargs):
mask, values = self._parse_kwargs(**kwargs)
self.adev.wreg(self.reg_off, (value & mask) | values)
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs}))
class AMFirmware:
def __init__(self, adev):
@ -327,8 +318,6 @@ class AMDev:
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
def rreg(self, reg:int) -> int:
@ -358,7 +347,7 @@ class AMDev:
for _ in range(timeout):
if ((rval:=reg.read()) & mask) == value: return rval
time.sleep(0.001)
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
def _run_discovery(self):
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
@ -403,14 +392,5 @@ class AMDev:
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
for base, module in mods:
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
for k, val in module.__dict__.items():
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
for k, regval in module.__dict__.items():
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
for ip, module in mods:
self.__dict__.update(collect_registers(module, cls=functools.partial(AMRegister, adev=self, hwip=getattr(am, f"{ip}_HWIP"))))

View file

@ -1,4 +1,4 @@
import ctypes, time, contextlib
import ctypes, time, contextlib, importlib
from typing import Literal
from tinygrad.runtime.autogen.am import am
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
@ -11,6 +11,8 @@ class AM_IP:
def set_clockgating_state(self): pass # Set clockgating state for this IP
class AM_SOC(AM_IP):
def init_sw(self): self.module = importlib.import_module("tinygrad.runtime.autogen.am.soc21")
def init_hw(self):
self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
@ -89,7 +91,7 @@ class AM_GMC(AM_IP):
# Init TLB and cache
self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1,
system_aperture_unmapped_access=0, eco_bits=0, mtype=am.MTYPE_UC)
system_aperture_unmapped_access=0, eco_bits=0, mtype=self.adev.soc.module.MTYPE_UC)
self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1,
l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0)
@ -110,7 +112,7 @@ class AM_GMC(AM_IP):
def get_pte_flags(self, pte_lv, is_table, frag, uncached, system, snooped, valid, extra=0):
extra |= (am.AMDGPU_PTE_SYSTEM * system) | (am.AMDGPU_PTE_SNOOPED * snooped) | (am.AMDGPU_PTE_VALID * valid) | am.AMDGPU_PTE_FRAG(frag)
extra |= am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC if uncached else 0)
extra |= am.AMDGPU_PTE_MTYPE_NV10(0, self.adev.soc.module.MTYPE_UC if uncached else 0)
if not is_table: extra |= (am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE)
return extra | (am.AMDGPU_PDE_PTE if not is_table and pte_lv != am.AMDGPU_VM_PTB else 0)
def is_pte_huge_page(self, pte): return pte & am.AMDGPU_PDE_PTE
@ -172,7 +174,7 @@ class AM_SMU(AM_IP):
class AM_GFX(AM_IP):
def init_hw(self):
# Wait for RLC autoload to complete
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
self._config_gfx_rs64()
self.adev.gmc.init_hub("GC")
@ -187,8 +189,8 @@ class AM_GFX(AM_IP):
self.adev.regGRBM_CNTL.update(read_timeout=0xff)
for i in range(0, 16):
self._grbm_select(vmid=i)
self.adev.regSH_MEM_CONFIG.write(address_mode=am.SH_MEM_ADDRESS_MODE_64, alignment_mode=am.SH_MEM_ALIGNMENT_MODE_UNALIGNED,
initial_inst_prefetch=3)
self.adev.regSH_MEM_CONFIG.write(address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64,
alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, initial_inst_prefetch=3)
# Configure apertures:
# LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB)
@ -218,17 +220,17 @@ class AM_GFX(AM_IP):
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0,
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2))
# Copy mqd into memory
ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
@ -237,7 +239,7 @@ class AM_GFX(AM_IP):
self._grbm_select(me=1, pipe=pipe, queue=queue)
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
self.adev.regCP_HQD_ACTIVE.write(0x1)
@ -313,7 +315,7 @@ class AM_IH(AM_IP):
_, rwptr_vm, suf, _ = self.rings[0]
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']:
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)

View file

@ -0,0 +1,26 @@
import functools
from collections import defaultdict
from dataclasses import dataclass
from math import log2
from tinygrad.helpers import getbits
@dataclass(frozen=True)
class AMDRegBase:
name: str
offset: int
segment: int
fields: dict[str, tuple[int, int]]
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
offsets = {k:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
bases = {k[:-len('_BASE_IDX')]:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
for field_name,field_mask in module.__dict__.items():
if not ('__' in field_name and field_name.endswith('_MASK')): continue
reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
fields[reg_name][reg_field_name.lower()] = (int(log2(field_mask & -field_mask)), int(log2(field_mask)))
# NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}

View file

@ -501,15 +501,12 @@ class Tensor(SimpleMathTrait):
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
_device = device = Device.canonicalize(device)
device = Device.canonicalize(device)
# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
num = ceildiv(numel * dtype.itemsize, 4)
# when using MOCKGPU and NV generate rand on CPU
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = Tensor(
@ -532,12 +529,7 @@ class Tensor(SimpleMathTrait):
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
# bitcast back to the original dtype and reshape
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
# move back to the original device if we were using MOCKGPU
if getenv("MOCKGPU") and _device: out = out.to(_device)
out.requires_grad = kwargs.get("requires_grad")
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
return out.contiguous() if contiguous else out
# ***** creation helper functions *****

View file

@ -3,7 +3,7 @@ GRAPH=1
JITGRAPH=1 (this restricts the graph...no need if we can select the schedules)
GRAPHUOPS=1
most uses of DEBUG >= 3
https://tiny-tools-client.vercel.app/
tiny-tools
and a viewer for:
SAVE_SCHEDULE=1

View file

@ -228,6 +228,7 @@
<g id="render">
<g id="edges"></g>
<g id="nodes"></g>
<g id="bars"></g>
</g>
</svg>
<button class="btn" id="zoom-to-fit-btn">Zoom to fit</button>
@ -377,7 +378,7 @@
};
}
if (ret.length === 0) return;
renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || []);
renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], kernel.name);
// ***** RHS metadata
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";

View file

@ -9,13 +9,18 @@ function intersectRect(r1, r2) {
}
const allWorkers = [];
window.renderGraph = function(graph, additions) {
window.renderGraph = function(graph, additions, name) {
while (allWorkers.length) {
const { worker, timeout } = allWorkers.pop();
worker.terminate();
clearTimeout(timeout);
}
if (name === "View Memory Graph") {
return renderMemoryGraph(graph);
}
d3.select("#bars").html("");
// ** start calculating the new layout (non-blocking)
worker = new Worker("/lib/worker.js");
const progressMessage = document.querySelector(".progress-message");
@ -58,3 +63,130 @@ window.renderGraph = function(graph, additions) {
.attr("markerWidth", 6).attr("markerHeight", 6).attr("orient", "auto").append("path").attr("d", "M0,-5L10,0L0,5").attr("fill", "#4a4b57");
};
}
DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4,
"long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8}
function getBuffer(e) {
const [_, size, dtype, device, num] = e.label.split("\n");
return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])};
}
function pluralize(num, name, alt=null) {
return num === 1 ? `${num} ${name}` : `${num} ${alt ?? name+'s'}`
}
function renderMemoryGraph(graph) {
// ** construct alloc/free traces
// we can map reads/writes from the kernel graph
const actions = [];
const children = new Map(); // {buffer: [...assign]}
for (const [k,v] of Object.entries(graph)) {
if (!v.label.startsWith("ASSIGN")) continue;
actions.push({ op: "write", buffer: v.src[0] });
for (const ks of graph[v.src[1]].src) {
const node = graph[ks];
const s = node.label.startsWith("ASSIGN") ? node.src[0] : ks;
if (!children.has(s)) children.set(s, []);
children.get(s).push(v);
if (s !== v.src[0]) actions.push({ op: "read", buffer: s });
}
}
const prealloc = new Set();
const traces = [];
for (const a of actions) {
// a buffer is allocated immediately before the first write
// TODO: we don't know the buffer is preallocated if there's only an assign in the graph
if (a.op === "write") {
traces.push({ type: "alloc", buffer: a.buffer });
}
else {
if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) {
prealloc.add(a.buffer);
}
else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) {
traces.push({type: "free", buffer: a.buffer });
}
}
}
// ** get coordinates and layout for each buffer
const ret = {};
let timestep = 0; // x
let memUsed = 0; // y
for (const id of prealloc) {
const buf = getBuffer(graph[id]);
ret[id] = { x: [timestep], y: [memUsed], buf, id };
memUsed += buf.nbytes;
}
let peak = memUsed;
const liveBufs = [...prealloc];
for (const t of traces) {
const buf = getBuffer(graph[t.buffer]);
const idx = liveBufs.findLastIndex(b => t.buffer === b);
// alloc
if (idx === -1) {
liveBufs.push(t.buffer);
ret[t.buffer] = { x: [timestep], y: [memUsed], buf, id: t.buffer };
memUsed += buf.nbytes;
peak = Math.max(memUsed, peak);
timestep += 1;
} // free
else {
memUsed -= buf.nbytes;
timestep += 1;
const removed = ret[liveBufs.splice(idx, 1)[0]];
removed.x.push(timestep);
removed.y.push(removed.y.at(-1));
if (idx < liveBufs.length) {
for (let j=idx; j<liveBufs.length; j++) {
const b = ret[liveBufs[j]];
b.x.push(timestep, timestep);
b.y.push(b.y.at(-1), b.y.at(-1)-buf.nbytes);
}
}
}
}
for (const id of liveBufs) {
const b = ret[id];
b.x.push(timestep);
b.y.push(b.y.at(-1));
}
// ** render traces
const render = d3.select("#bars");
const yscale = d3.scaleLinear().domain([0, peak]).range([576, 0]);
const xscale = d3.scaleLinear().domain([0, timestep]).range([0, 1024]);
const xaxis = d3.axisBottom(xscale);
const axesGroup = render.append("g").attr("id", "axes");
const nbytes_format = (d) => d3.format(".3~s")(d)+"B";
axesGroup.append("g").call(d3.axisLeft(yscale).tickFormat(nbytes_format));
axesGroup.append("g").attr("transform", `translate(0, ${yscale.range()[0]})`).call(d3.axisBottom(xscale).tickFormat(() => ""));
const polygonGroup = render.append("g").attr("id", "polygons");
const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"];
const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => {
const xs = d.x.map(t => xscale(t));
const y1 = d.y.map(t => yscale(t));
const y2 = d.y.map(t => yscale(t+d.buf.nbytes));
const p0 = xs.map((x, i) => `${x},${y1[i]}`);
const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse();
return `${p0.join(' ')} ${p1.join(' ')}`;
}).attr("fill", d => `#${colors[d.buf.num % colors.length]}`).on("mouseover", (e, { id, buf, x }) => {
d3.select(e.currentTarget).attr("stroke", "rgba(26, 27, 38, 0.8)").attr("stroke-width", 0.8);
const metadata = document.querySelector(".container.metadata");
document.getElementById("current-buf")?.remove();
const { num, dtype, nbytes, ...rest } = buf;
let label = `<BUFFER n${num} ${dtype} ${nbytes_format(nbytes)}>\nalive for ${pluralize(x[x.length-1]-x[0], 'timestep')}`;
label += '\n'+Object.entries(rest).map(([k, v]) => `${k}=${v}`).join('\n');
const buf_children = children.get(id);
if (buf_children) {
label += `\n${pluralize(buf_children.length, 'child', 'children')}\n`;
label += buf_children.map((c,i) => `[${i+1}] `+graph[c.src[1]].label.split("\n")[1]).join("\n");
}
metadata.appendChild(Object.assign(document.createElement("pre"), { innerText: label, id: "current-buf", className: "wrap" }));
}).on("mouseout", (e, _) => {
d3.select(e.currentTarget).attr("stroke", null).attr("stroke-width", null);
document.getElementById("current-buf")?.remove()
});
// TODO: add the toposort graph here
document.querySelector(".progress-message").style.display = "none";
d3.select("#nodes").html("");
d3.select("#edges").html("");
}

View file

@ -21,7 +21,7 @@ onmessage = (e) => {
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
for (const [k, [label, src, color]] of Object.entries(graph)) {
for (const [k, {label, src, color}] of Object.entries(graph)) {
// adjust node dims by label size + add padding
const [labelWidth, labelHeight] = getTextDims(label);
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING});

View file

@ -49,10 +49,9 @@ class GraphRewriteDetails(TypedDict):
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
# NOTE: this is [id, [label, src_ids, color]]
graph: dict[int, tuple[str, list[int], str]] = {}
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
for u in (toposort:=x.toposort):
# always exclude DEVICE/CONST/UNIQUE
@ -72,7 +71,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
if x in excluded:
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
else: label += f"\n{x.op.name}{idx} {x.arg}"
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
return graph
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: