mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into dsp_search
This commit is contained in:
commit
60cbfe4222
27 changed files with 39604 additions and 39308 deletions
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
|
@ -574,7 +574,7 @@ jobs:
|
|||
run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py --durations=20
|
||||
- name: Run pytest (amd with llvm backend)
|
||||
if: matrix.backend=='amd'
|
||||
run: python -m pytest -n=auto test/test_amd_llvm.py --durations=20
|
||||
run: AMD_LLVM=1 python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py test/test_amd_llvm.py --durations=20
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
|
|
@ -630,8 +630,10 @@ jobs:
|
|||
env:
|
||||
MOCKGPU: 1
|
||||
AMD: 1
|
||||
AMD_LLVM: 1
|
||||
FORWARD_ONLY: 1
|
||||
run: |
|
||||
python -m pytest -n=auto test/test_amd_llvm.py --durations=20
|
||||
python -m pytest -n=auto test/test_hcq.py test/test_tiny.py test/test_amd_llvm.py --durations=20
|
||||
- name: Run pytest (ptx)
|
||||
env:
|
||||
MOCKGPU: 1
|
||||
|
|
|
|||
|
|
@ -297,14 +297,21 @@ generate_am() {
|
|||
extra/amdpci/headers/amdgpu_vm.h \
|
||||
extra/amdpci/headers/discovery.h \
|
||||
extra/amdpci/headers/amdgpu_ucode.h \
|
||||
extra/amdpci/headers/soc21_enum.h \
|
||||
extra/amdpci/headers/psp_gfx_if.h \
|
||||
extra/amdpci/headers/amdgpu_psp.h \
|
||||
extra/amdpci/headers/amdgpu_irq.h \
|
||||
extra/amdpci/headers/amdgpu_doorbell.h \
|
||||
extra/amdpci/headers/soc15_ih_clientid.h \
|
||||
--clang-args="-include stdint.h" \
|
||||
-o $BASE/am/am.py
|
||||
fixup $BASE/am/am.py
|
||||
sed -i "s\(int64_t)\ \g" $BASE/am/am.py
|
||||
sed -i "s\AMDGPU_PTE_MTYPE_VG10(2)\AMDGPU_PTE_MTYPE_VG10(0, 2)\g" $BASE/am/am.py # incorrect parsing (TODO: remove when clang2py is gone).
|
||||
|
||||
clang2py -k cdefstum \
|
||||
extra/amdpci/headers/soc21_enum.h \
|
||||
-o $BASE/am/soc21.py
|
||||
fixup $BASE/am/soc21.py
|
||||
|
||||
clang2py -k cdefstum \
|
||||
extra/amdpci/headers/mp_13_0_0_offset.h \
|
||||
|
|
|
|||
|
|
@ -79,6 +79,8 @@ class AMSMI(AMDev):
|
|||
self.psp:AM_PSP = AM_PSP(self)
|
||||
self.smu:AM_SMU = AM_SMU(self)
|
||||
|
||||
for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu]: ip.init_sw()
|
||||
|
||||
def read_pci_state(self):
|
||||
with open(f"/sys/bus/pci/devices/{self.pcibus}/power_state", "r") as f: return f.read().strip().rstrip()
|
||||
|
||||
|
|
|
|||
|
|
@ -33,21 +33,46 @@ torch._register_device_module("tiny", TinyBackend())
|
|||
torch.utils.generate_methods_for_privateuse1_backend()
|
||||
aten = torch.ops.aten
|
||||
|
||||
# track view relationships for in place operations
|
||||
def is_view(tensor: Tensor): return hasattr(tensor, "_view_base")
|
||||
def canonical_base(view: Tensor): return getattr(view, "_view_base", view)
|
||||
def derived_views(base: Tensor): return [t for tref in getattr(base, "_views", set()) if (t:=tref()) is not None]
|
||||
def wrap_view_op(fn):
|
||||
def _wrap(*args,**kwargs):
|
||||
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
|
||||
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
|
||||
ret = fn(*args,**kwargs)
|
||||
ret._view_base = base = canonical_base(args[0])
|
||||
if not hasattr(base, "_views"): base._views = set()
|
||||
base._views.add(weakref.ref(ret))
|
||||
return wrap(ret)
|
||||
return _wrap
|
||||
|
||||
view_ops = {
|
||||
"aten.view": Tensor.reshape,
|
||||
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
|
||||
"aten.view.dtype": lambda self,dtype: self.bitcast(_from_torch_dtype(dtype)),
|
||||
"aten.expand": Tensor.expand,
|
||||
"aten.t": Tensor.transpose,
|
||||
"aten.squeeze.dim": Tensor.squeeze,
|
||||
"aten.unsqueeze": Tensor.unsqueeze,
|
||||
"aten.repeat": Tensor.repeat,
|
||||
"aten.detach": Tensor.detach,
|
||||
}
|
||||
|
||||
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
|
||||
|
||||
# in place operations with views
|
||||
def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None
|
||||
def realize_with_views(self: torch.Tensor, views: list[torch.Tensor]):
|
||||
assert self.is_tiny
|
||||
self = unwrap(self)
|
||||
def realize_with_views(self: Tensor, views: Tensor):
|
||||
if not self.lazydata.st.contiguous: raise ValueError("base of view must be contiguous") # TODO: support?
|
||||
self.replace(self.clone().realize())
|
||||
for v in views:
|
||||
v = unwrap(v)
|
||||
ret = self
|
||||
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
|
||||
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
|
||||
v.replace(ret)
|
||||
def maybe_realize_storage(self: torch.Tensor) -> bool:
|
||||
if realize:=is_view(self): realize_with_views(self._base, [self]) # TODO: other views could exist
|
||||
def maybe_realize_storage(self: Tensor) -> bool:
|
||||
if realize:=is_view(self): realize_with_views((base:=canonical_base(self)), derived_views(base))
|
||||
return realize
|
||||
def inplace_fn(outvars: str|list[str]):
|
||||
if type(outvars) is str: outvars = [outvars]
|
||||
|
|
@ -56,9 +81,10 @@ def inplace_fn(outvars: str|list[str]):
|
|||
def wrapper(*args, **kwargs):
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
outs = [kwargs.get(v, bound.arguments.get(v)) for v in outvars]
|
||||
outs = [unwrap(o) if isinstance(o, torch.Tensor) else o for o in outs]
|
||||
realize = any(maybe_realize_storage(o) for o in outs)
|
||||
ret = fn(*args, **kwargs)
|
||||
if realize: Tensor.realize(*(unwrap(o) for o in outs))
|
||||
if realize: Tensor.realize(*(o for o in outs))
|
||||
return ret
|
||||
return wrapper
|
||||
return decorator
|
||||
|
|
@ -137,14 +163,15 @@ def cached_to_movement_ops(shape, st) -> list:
|
|||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
|
||||
@torch.library.impl("aten::as_strided", "privateuseone")
|
||||
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
@wrap_view_op
|
||||
def as_strided(tensor:Tensor, size, stride, storage_offset=None):
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker((View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)))
|
||||
ret = unwrap(tensor)
|
||||
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
|
||||
ret = tensor
|
||||
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
for mo in cached_to_movement_ops(tuple(tensor.shape), st): ret = apply_mop(ret, mo)
|
||||
return wrap(ret)
|
||||
return ret
|
||||
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
|
|
@ -236,6 +263,7 @@ for i,pre in enumerate(["", "bi", "tri"]):
|
|||
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact"))
|
||||
|
||||
@torch.library.impl("aten::scatter_add.out", "privateuseone")
|
||||
@inplace_fn("out")
|
||||
def scatter_add(self, dim, index, src, out):
|
||||
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
|
||||
if self.shape == (): return wrap(out.assign(src))
|
||||
|
|
@ -243,7 +271,7 @@ def scatter_add(self, dim, index, src, out):
|
|||
|
||||
@torch.library.impl("aten::_copy_from", "privateuseone")
|
||||
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
|
||||
realize = dest.is_tiny and maybe_realize_storage(dest)
|
||||
realize = dest.is_tiny and maybe_realize_storage(unwrap(dest))
|
||||
cast_dtype = _from_torch_dtype(dest.dtype)
|
||||
if src.is_tiny and dest.is_tiny:
|
||||
to_device = _from_torch_device(dest.device)
|
||||
|
|
@ -413,6 +441,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
|||
|
||||
# we add the "out" here
|
||||
def wrap_out(f):
|
||||
@inplace_fn("out")
|
||||
def _wrap_out(*args, **kwargs):
|
||||
out = kwargs.pop('out')
|
||||
assigned = f(*args, **kwargs)
|
||||
|
|
@ -425,8 +454,6 @@ def wrap_out(f):
|
|||
return _wrap_out
|
||||
|
||||
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.view": Tensor.reshape,
|
||||
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
|
||||
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
|
||||
"aten.floor_divide": lambda x,y: x//y,
|
||||
"aten.floor_divide_.Tensor": inplace_fn("x")(lambda x,y: x.assign(x//y)),
|
||||
|
|
@ -486,25 +513,19 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
|||
"aten.fill_.Tensor": Tensor.full, # TODO: looks wrong
|
||||
"aten.flip": Tensor.flip,
|
||||
"aten.scatter_reduce.two": Tensor.scatter_reduce,
|
||||
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True),
|
||||
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
|
||||
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
|
||||
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
|
||||
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
|
||||
"aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)),
|
||||
"aten.topk": Tensor.topk,
|
||||
"aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value),
|
||||
"aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim),
|
||||
"aten.squeeze.dim": Tensor.squeeze,
|
||||
"aten.unsqueeze": Tensor.unsqueeze,
|
||||
"aten.roll": Tensor.roll,
|
||||
"aten.logcumsumexp": Tensor.logcumsumexp,
|
||||
"aten.repeat": Tensor.repeat,
|
||||
"aten.lerp.Tensor": Tensor.lerp,
|
||||
"aten.expand": Tensor.expand,
|
||||
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
|
||||
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
|
||||
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
|
||||
"aten.t": Tensor.transpose,
|
||||
"aten.detach": Tensor.detach,
|
||||
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
|
||||
}}
|
||||
|
||||
|
|
@ -519,9 +540,7 @@ def wrap_fxn(k,f):
|
|||
if isinstance(out, Tensor): return wrap(out)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
|
||||
else: raise RuntimeError(f"unknown output type {type(out)}")
|
||||
def nf2(*args, **kwargs):
|
||||
return inplace_fn("out")(nf)(*args, **kwargs) if "out" in kwargs else nf(*args, **kwargs)
|
||||
return nf2
|
||||
return nf
|
||||
|
||||
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
|
||||
|
||||
|
|
|
|||
|
|
@ -61,5 +61,11 @@ class TestTorchBackendInplace(unittest.TestCase):
|
|||
b[1:3,:] += 1
|
||||
np.testing.assert_equal(a.cpu().numpy(), [[0]*4,[1]*4,[1]*4,[0]*4])
|
||||
|
||||
def test_detach(self):
|
||||
a = torch.zeros(4)
|
||||
d = a.detach()
|
||||
d += torch.arange(4)
|
||||
np.testing.assert_array_equal(a.cpu(), torch.arange(4).cpu())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ dtypes_bool = (dtypes.bool,)
|
|||
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
|
||||
|
||||
# TODO: LLVM comparing with nan is incorrect
|
||||
if Device.DEFAULT == "LLVM":
|
||||
if Device.DEFAULT == "LLVM" or getenv("AMD_LLVM", 0):
|
||||
binary_operations.remove(operator.lt)
|
||||
|
||||
integer_binary_operations = binary_operations + [(Tensor.bitwise_xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
|
||||
|
|
|
|||
|
|
@ -79,5 +79,35 @@ class TestGC(unittest.TestCase):
|
|||
print(inspect.getclosurevars(UOp.toposort.fget))
|
||||
raise AssertionError(f"never gced {[x for x in gc.get_objects() if isinstance(x, Buffer)]}")
|
||||
|
||||
def test_buffer_refcount(self):
|
||||
init = bufs_allocated()
|
||||
a = Tensor.empty(10)
|
||||
self.assertEqual(bufs_allocated()-init, 0)
|
||||
a.realize()
|
||||
real_buf = a.lazydata.buffer
|
||||
# after the Tensor UOp is deleted there shouldn't be any references on the Buffer
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
self.assertEqual(bufs_allocated()-init, 1)
|
||||
del a.lazydata
|
||||
self.assertEqual(real_buf.lb_refcount, 0)
|
||||
self.assertEqual(bufs_allocated()-init, 1) # keep the buffer alive
|
||||
del real_buf
|
||||
self.assertEqual(bufs_allocated()-init, 0)
|
||||
|
||||
def test_assign_refcount(self):
|
||||
init = bufs_allocated()
|
||||
a = Tensor.full((4,), 1.).contiguous()
|
||||
a.realize()
|
||||
real_buf = a.lazydata.buffer
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
a.assign(Tensor.full((4,), 2.))
|
||||
self.assertIs(a.lazydata.src[0].buffer, real_buf)
|
||||
# NOTE: this is still 1, we don't count the ASSIGN
|
||||
self.assertEqual(real_buf.lb_refcount, 1)
|
||||
a.realize()
|
||||
del a
|
||||
self.assertEqual(real_buf.lb_refcount, 0) # no UOps for this Buffer
|
||||
self.assertEqual(bufs_allocated()-init, 1) # Buffer is alive
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -277,6 +277,38 @@ class TestJit(unittest.TestCase):
|
|||
assert len(res3) == 5, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_jit_v_nojit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
without_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = f(a, b)
|
||||
without_jit.add(o1.numpy()[0][0])
|
||||
without_jit.add(o2.numpy()[0][0])
|
||||
assert len(without_jit) == 10, "All values should be different."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
with_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
with_jit.add(o1.numpy()[0][0])
|
||||
with_jit.add(o2.numpy()[0][0])
|
||||
assert len(with_jit) == 10, "All values should be different."
|
||||
assert with_jit == without_jit, "Jit rand produced different values from no jit."
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
|
|
|
|||
|
|
@ -534,6 +534,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
|
||||
helper_test_op([(), ()], lambda x,y: x/y)
|
||||
|
||||
@unittest.skipIf(getenv("AMD_LLVM", 0), "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors")
|
||||
def test_div_rounding_mode(self):
|
||||
for denominator in [-10, -5, -3, -2, -1, 1, 2, 3, 5, 10]:
|
||||
# int numerator
|
||||
|
|
|
|||
|
|
@ -120,9 +120,8 @@ class TestRandomness(unittest.TestCase):
|
|||
0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749,
|
||||
0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055,
|
||||
0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32)
|
||||
|
||||
r = Tensor.rand(20).numpy()
|
||||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# next 20, np.arange(20, 40, dtype=np.uint32)
|
||||
jr = np.array([0.7444133758544922, 0.7713677883148193, 0.8233780860900879, 0.43871235847473145, 0.517757773399353,
|
||||
|
|
@ -130,7 +129,7 @@ class TestRandomness(unittest.TestCase):
|
|||
0.28920769691467285, 0.017063498497009277, 0.2627382278442383, 0.9525482654571533, 0.9351049661636353,
|
||||
0.43904995918273926, 0.043945908546447754, 0.6616791486740112, 0.6667773723602295, 0.5228077173233032], dtype=np.float32)
|
||||
r = Tensor.rand(20).numpy()
|
||||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# next 10, np.arange(40, 50, dtype=np.uint32)
|
||||
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
|
||||
|
|
@ -138,7 +137,7 @@ class TestRandomness(unittest.TestCase):
|
|||
r = Tensor.rand(10).numpy()
|
||||
# TODO: this failed because increment happened before _threefry_random_bits
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
|
||||
def test_threefry_tensors_cnt(self):
|
||||
|
|
|
|||
|
|
@ -329,11 +329,11 @@ def simplify_valid(valid:UOp) -> UOp|None:
|
|||
# ***** threefry *****
|
||||
|
||||
def threefry2x32(x: UOp, key: UOp):
|
||||
# split x into two uint32, since x in a uint64
|
||||
# split x and key from uint64 to two uint32
|
||||
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
||||
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
||||
|
||||
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
||||
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
||||
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
|
||||
xr = [x0 + ks[-1], x1 + ks[0]]
|
||||
for i in range(5):
|
||||
|
|
|
|||
|
|
@ -453,6 +453,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|||
|
||||
# display the final graph
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
|
||||
|
||||
# final toposort (bfs)
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
|
|
@ -473,8 +474,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|||
# TODO: move this to create_kernels
|
||||
k = fix_kernel_ast(u.src[1], var_vals)
|
||||
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here
|
||||
k.src[0].buffer.ref(1)
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
|
|
|||
|
|
@ -536,6 +536,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
from tinygrad.device import Buffer
|
||||
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
||||
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
||||
ret.ref(1)
|
||||
return ret
|
||||
@property
|
||||
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
from typing import cast
|
||||
import math, struct, sys
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.renderer.cstyle import ClangRenderer, AMDRenderer
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, AMX
|
||||
|
||||
def ldt(dt:DType):
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + (" addrspace(3)*" if dt.local else "*")
|
||||
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
||||
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
|
||||
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
|
||||
|
||||
def lconst(x, dtype:DType):
|
||||
if dtype in dtypes.floats:
|
||||
|
|
@ -63,7 +63,8 @@ base_rewrite = PatternMatcher([
|
|||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
||||
|
||||
# GEP/VECTORIZE/CAST for float4 support
|
||||
|
|
@ -113,7 +114,7 @@ class LLVMRenderer(Renderer):
|
|||
supports_float4 = True
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
global_max: tuple[int, ...] | None = None
|
||||
string_rewrite = base_rewrite
|
||||
if AMX: tensor_cores = ClangRenderer.amx_tc
|
||||
|
||||
|
|
@ -126,6 +127,12 @@ class LLVMRenderer(Renderer):
|
|||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
|
||||
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
|
||||
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
|
||||
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
||||
# copied from cstyle.py, add float intermediate casting
|
||||
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
])
|
||||
|
||||
def render(self, uops: list[UOp]) -> str:
|
||||
|
|
@ -135,6 +142,7 @@ class LLVMRenderer(Renderer):
|
|||
end_lines: dict[str, None] = {}
|
||||
vc = -1
|
||||
|
||||
local_args: list[str] = []
|
||||
acc_to_assign: dict[UOp, UOp] = {}
|
||||
for u in uops:
|
||||
if u.op is Ops.ASSIGN: # prealloc all assigns
|
||||
|
|
@ -158,6 +166,10 @@ class LLVMRenderer(Renderer):
|
|||
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
||||
# NOTE: MallocAllocator promises 0x20 alignment
|
||||
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
||||
elif u.op == Ops.DEFINE_LOCAL:
|
||||
r[u] = f"@local_{u.arg}"
|
||||
assert isinstance(u.dtype, PtrDType)
|
||||
local_args.append(f"{r[u]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
|
||||
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
||||
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
|
|
@ -182,7 +194,7 @@ class LLVMRenderer(Renderer):
|
|||
r[x] = f"%acc{vc}"
|
||||
|
||||
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
|
||||
return f'''\
|
||||
prg = f'''\
|
||||
define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
|
||||
{chr(10).join(kernel)}
|
||||
ret void
|
||||
|
|
@ -190,3 +202,19 @@ define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(ar
|
|||
{chr(10).join(end_lines.keys())}
|
||||
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
|
||||
'''
|
||||
return prg if len(local_args) == 0 else "\n".join(local_args)+f"\n{prg}"
|
||||
|
||||
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
||||
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
|
||||
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
|
||||
class AMDLLVMRenderer(LLVMRenderer):
|
||||
device = "AMD"
|
||||
has_local = True
|
||||
has_shared = True
|
||||
shared_max = AMDRenderer.shared_max
|
||||
global_max = AMDRenderer.global_max
|
||||
abi = "amdgpu_kernel"
|
||||
string_rewrite = base_rewrite + PatternMatcher([
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
|
||||
(UPat(Ops.BARRIER), lambda ctx: barrier),
|
||||
])
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
39083
tinygrad/runtime/autogen/am/soc21.py
Normal file
39083
tinygrad/runtime/autogen/am/soc21.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -8,9 +8,10 @@ from tinygrad.ops import sint
|
|||
from tinygrad.device import Compiled, ProfileEvent, BufferSpec, CPUProgram, PROFILE
|
||||
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address, DEBUG, OSX
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc, pci, vfio, sqtt
|
||||
from tinygrad.runtime.autogen.am import am, gc_11_0_0
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.support.am.amdev import AMDev, AMMapping
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
|
@ -706,7 +707,8 @@ class AMDDevice(HCQCompiled):
|
|||
|
||||
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000)
|
||||
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), HIPCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
super().__init__(device, AMDAllocator(self), AMDLLVMRenderer() if getenv("AMD_LLVM", 0) else AMDRenderer(self.arch),
|
||||
AMDLLVMCompiler(self.arch) if getenv("AMD_LLVM", 0) else HIPCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
AMDSignal, AMDComputeQueue, AMDCopyQueue)
|
||||
|
||||
# Scratch setup
|
||||
|
|
|
|||
17
tinygrad/runtime/ops_null.py
Normal file
17
tinygrad/runtime/ops_null.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from tinygrad.device import Compiled, Compiler, Renderer, Allocator
|
||||
|
||||
class NullRenderer(Renderer):
|
||||
def render(self, uops:list) -> str: return ""
|
||||
|
||||
class NullProgram:
|
||||
def __init__(self, name:str, lib:bytes): pass
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
return 1e-4
|
||||
|
||||
class NullAllocator(Allocator):
|
||||
def _alloc(self, size, options): return "null"
|
||||
def _copyin(self, dest, src:memoryview): pass
|
||||
def _copyout(self, dest:memoryview, src): pass
|
||||
|
||||
class NullDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, NullAllocator(), NullRenderer(), Compiler(), NullProgram)
|
||||
|
|
@ -1,35 +1,26 @@
|
|||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
|
||||
import ctypes, collections, time, dataclasses, functools, pathlib, fcntl, os, importlib
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0
|
||||
from tinygrad.runtime.support.amd import AMDRegBase, collect_registers
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
|
||||
AM_DEBUG = getenv("AM_DEBUG", 0)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AMRegister:
|
||||
adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
|
||||
class AMRegister(AMDRegBase):
|
||||
adev:AMDev; hwip:int # noqa: E702
|
||||
|
||||
def _parse_kwargs(self, **kwargs):
|
||||
mask, values = 0xffffffff, 0
|
||||
for k, v in kwargs.items():
|
||||
if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
|
||||
m, s = self.reg_fields[k]
|
||||
if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
|
||||
mask &= ~m
|
||||
values |= v << s
|
||||
return mask, values
|
||||
@property
|
||||
def addr(self): return self.adev.regs_offset[self.hwip][0][self.segment] + self.offset
|
||||
|
||||
def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
|
||||
def read(self): return self.adev.rreg(self.addr)
|
||||
def read_bitfields(self) -> dict[str, int]: return self.decode(self.read())
|
||||
|
||||
def update(self, **kwargs): self.write(value=self.read(), **kwargs)
|
||||
def write(self, _am_val:int=0, **kwargs): self.adev.wreg(self.addr, _am_val | self.encode(**kwargs))
|
||||
|
||||
def write(self, value=0, **kwargs):
|
||||
mask, values = self._parse_kwargs(**kwargs)
|
||||
self.adev.wreg(self.reg_off, (value & mask) | values)
|
||||
|
||||
def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
|
||||
def update(self, **kwargs): self.write(self.encode(**{**self.read_bitfields(), **kwargs}))
|
||||
|
||||
class AMFirmware:
|
||||
def __init__(self, adev):
|
||||
|
|
@ -327,8 +318,6 @@ class AMDev:
|
|||
def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
|
||||
def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
|
||||
|
||||
def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
|
||||
|
||||
def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
|
||||
|
||||
def rreg(self, reg:int) -> int:
|
||||
|
|
@ -358,7 +347,7 @@ class AMDev:
|
|||
for _ in range(timeout):
|
||||
if ((rval:=reg.read()) & mask) == value: return rval
|
||||
time.sleep(0.001)
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
raise RuntimeError(f'wait_reg timeout reg=0x{reg.addr:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
|
||||
|
||||
def _run_discovery(self):
|
||||
# NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
|
||||
|
|
@ -403,14 +392,5 @@ class AMDev:
|
|||
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
|
||||
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
|
||||
|
||||
for base, module in mods:
|
||||
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
|
||||
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
|
||||
reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
|
||||
for k, val in module.__dict__.items():
|
||||
if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
|
||||
reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
|
||||
|
||||
for k, regval in module.__dict__.items():
|
||||
if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
|
||||
setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
|
||||
for ip, module in mods:
|
||||
self.__dict__.update(collect_registers(module, cls=functools.partial(AMRegister, adev=self, hwip=getattr(am, f"{ip}_HWIP"))))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import ctypes, time, contextlib
|
||||
import ctypes, time, contextlib, importlib
|
||||
from typing import Literal
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
|
||||
|
|
@ -11,6 +11,8 @@ class AM_IP:
|
|||
def set_clockgating_state(self): pass # Set clockgating state for this IP
|
||||
|
||||
class AM_SOC(AM_IP):
|
||||
def init_sw(self): self.module = importlib.import_module("tinygrad.runtime.autogen.am.soc21")
|
||||
|
||||
def init_hw(self):
|
||||
self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
|
||||
self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
|
||||
|
|
@ -89,7 +91,7 @@ class AM_GMC(AM_IP):
|
|||
|
||||
# Init TLB and cache
|
||||
self.adev.reg(f"reg{ip}MC_VM_MX_L1_TLB_CNTL").update(enable_l1_tlb=1, system_access_mode=3, enable_advanced_driver_model=1,
|
||||
system_aperture_unmapped_access=0, eco_bits=0, mtype=am.MTYPE_UC)
|
||||
system_aperture_unmapped_access=0, eco_bits=0, mtype=self.adev.soc.module.MTYPE_UC)
|
||||
|
||||
self.adev.reg(f"reg{ip}VM_L2_CNTL").update(enable_l2_cache=1, enable_l2_fragment_processing=0, enable_default_page_out_to_system_memory=1,
|
||||
l2_pde0_cache_tag_generation_mode=0, pde_fault_classification=0, context1_identity_access_mode=1, identity_mode_fragment_size=0)
|
||||
|
|
@ -110,7 +112,7 @@ class AM_GMC(AM_IP):
|
|||
|
||||
def get_pte_flags(self, pte_lv, is_table, frag, uncached, system, snooped, valid, extra=0):
|
||||
extra |= (am.AMDGPU_PTE_SYSTEM * system) | (am.AMDGPU_PTE_SNOOPED * snooped) | (am.AMDGPU_PTE_VALID * valid) | am.AMDGPU_PTE_FRAG(frag)
|
||||
extra |= am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC if uncached else 0)
|
||||
extra |= am.AMDGPU_PTE_MTYPE_NV10(0, self.adev.soc.module.MTYPE_UC if uncached else 0)
|
||||
if not is_table: extra |= (am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE)
|
||||
return extra | (am.AMDGPU_PDE_PTE if not is_table and pte_lv != am.AMDGPU_VM_PTB else 0)
|
||||
def is_pte_huge_page(self, pte): return pte & am.AMDGPU_PDE_PTE
|
||||
|
|
@ -172,7 +174,7 @@ class AM_SMU(AM_IP):
|
|||
class AM_GFX(AM_IP):
|
||||
def init_hw(self):
|
||||
# Wait for RLC autoload to complete
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read(bootload_complete=1) != 0: pass
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
|
||||
|
||||
self._config_gfx_rs64()
|
||||
self.adev.gmc.init_hub("GC")
|
||||
|
|
@ -187,8 +189,8 @@ class AM_GFX(AM_IP):
|
|||
self.adev.regGRBM_CNTL.update(read_timeout=0xff)
|
||||
for i in range(0, 16):
|
||||
self._grbm_select(vmid=i)
|
||||
self.adev.regSH_MEM_CONFIG.write(address_mode=am.SH_MEM_ADDRESS_MODE_64, alignment_mode=am.SH_MEM_ALIGNMENT_MODE_UNALIGNED,
|
||||
initial_inst_prefetch=3)
|
||||
self.adev.regSH_MEM_CONFIG.write(address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64,
|
||||
alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, initial_inst_prefetch=3)
|
||||
|
||||
# Configure apertures:
|
||||
# LDS: 0x10000000'00000000 - 0x10000001'00000000 (4GB)
|
||||
|
|
@ -218,17 +220,17 @@ class AM_GFX(AM_IP):
|
|||
mqd = self.adev.mm.valloc(0x1000, uncached=True, contigous=True)
|
||||
|
||||
mqd_struct = am.struct_v11_compute_mqd(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
|
||||
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.build(preload_size=0x55, preload_req=1),
|
||||
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
|
||||
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
||||
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
|
||||
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
|
||||
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0,
|
||||
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
|
||||
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.build(eop_size=(eop_size//4).bit_length()-2))
|
||||
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2))
|
||||
|
||||
# Copy mqd into memory
|
||||
ctypes.memmove(self.adev.paddr2cpu(mqd.paddrs[0][0]), ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct))
|
||||
|
|
@ -237,7 +239,7 @@ class AM_GFX(AM_IP):
|
|||
self._grbm_select(me=1, pipe=pipe, queue=queue)
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.reg_off, self.adev.regCP_HQD_PQ_WPTR_HI.reg_off + 1)):
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr, self.adev.regCP_HQD_PQ_WPTR_HI.addr + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1)
|
||||
|
||||
|
|
@ -313,7 +315,7 @@ class AM_IH(AM_IP):
|
|||
_, rwptr_vm, suf, _ = self.rings[0]
|
||||
wptr = to_mv(self.adev.paddr2cpu(rwptr_vm), 8).cast('Q')[0]
|
||||
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read(rb_overflow=1):
|
||||
if self.adev.reg(f"regIH_RB_WPTR{suf}").read_bitfields()['rb_overflow']:
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").update(rb_overflow=0)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=1)
|
||||
self.adev.reg(f"regIH_RB_CNTL{suf}").update(wptr_overflow_clear=0)
|
||||
|
|
|
|||
26
tinygrad/runtime/support/amd.py
Normal file
26
tinygrad/runtime/support/amd.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import functools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from math import log2
|
||||
from tinygrad.helpers import getbits
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AMDRegBase:
|
||||
name: str
|
||||
offset: int
|
||||
segment: int
|
||||
fields: dict[str, tuple[int, int]]
|
||||
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
||||
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
||||
|
||||
def collect_registers(module, cls=AMDRegBase) -> dict[str, AMDRegBase]:
|
||||
def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
|
||||
offsets = {k:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
|
||||
bases = {k[:-len('_BASE_IDX')]:v for k,v in module.__dict__.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
|
||||
fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
||||
for field_name,field_mask in module.__dict__.items():
|
||||
if not ('__' in field_name and field_name.endswith('_MASK')): continue
|
||||
reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
|
||||
fields[reg_name][reg_field_name.lower()] = (int(log2(field_mask & -field_mask)), int(log2(field_mask)))
|
||||
# NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
|
||||
return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}
|
||||
|
|
@ -501,15 +501,12 @@ class Tensor(SimpleMathTrait):
|
|||
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
_device = device = Device.canonicalize(device)
|
||||
device = Device.canonicalize(device)
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
num = ceildiv(numel * dtype.itemsize, 4)
|
||||
|
||||
# when using MOCKGPU and NV generate rand on CPU
|
||||
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
|
|
@ -532,12 +529,7 @@ class Tensor(SimpleMathTrait):
|
|||
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
||||
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
||||
# bitcast back to the original dtype and reshape
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
||||
|
||||
# move back to the original device if we were using MOCKGPU
|
||||
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
||||
|
||||
out.requires_grad = kwargs.get("requires_grad")
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
|
||||
return out.contiguous() if contiguous else out
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ GRAPH=1
|
|||
JITGRAPH=1 (this restricts the graph...no need if we can select the schedules)
|
||||
GRAPHUOPS=1
|
||||
most uses of DEBUG >= 3
|
||||
https://tiny-tools-client.vercel.app/
|
||||
tiny-tools
|
||||
|
||||
and a viewer for:
|
||||
SAVE_SCHEDULE=1
|
||||
|
|
|
|||
|
|
@ -228,6 +228,7 @@
|
|||
<g id="render">
|
||||
<g id="edges"></g>
|
||||
<g id="nodes"></g>
|
||||
<g id="bars"></g>
|
||||
</g>
|
||||
</svg>
|
||||
<button class="btn" id="zoom-to-fit-btn">Zoom to fit</button>
|
||||
|
|
@ -377,7 +378,7 @@
|
|||
};
|
||||
}
|
||||
if (ret.length === 0) return;
|
||||
renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || []);
|
||||
renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], kernel.name);
|
||||
// ***** RHS metadata
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
metadata.innerHTML = "";
|
||||
|
|
|
|||
|
|
@ -9,13 +9,18 @@ function intersectRect(r1, r2) {
|
|||
}
|
||||
|
||||
const allWorkers = [];
|
||||
window.renderGraph = function(graph, additions) {
|
||||
window.renderGraph = function(graph, additions, name) {
|
||||
while (allWorkers.length) {
|
||||
const { worker, timeout } = allWorkers.pop();
|
||||
worker.terminate();
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
|
||||
if (name === "View Memory Graph") {
|
||||
return renderMemoryGraph(graph);
|
||||
}
|
||||
d3.select("#bars").html("");
|
||||
|
||||
// ** start calculating the new layout (non-blocking)
|
||||
worker = new Worker("/lib/worker.js");
|
||||
const progressMessage = document.querySelector(".progress-message");
|
||||
|
|
@ -58,3 +63,130 @@ window.renderGraph = function(graph, additions) {
|
|||
.attr("markerWidth", 6).attr("markerHeight", 6).attr("orient", "auto").append("path").attr("d", "M0,-5L10,0L0,5").attr("fill", "#4a4b57");
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4,
|
||||
"long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8}
|
||||
function getBuffer(e) {
|
||||
const [_, size, dtype, device, num] = e.label.split("\n");
|
||||
return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])};
|
||||
}
|
||||
function pluralize(num, name, alt=null) {
|
||||
return num === 1 ? `${num} ${name}` : `${num} ${alt ?? name+'s'}`
|
||||
}
|
||||
|
||||
function renderMemoryGraph(graph) {
|
||||
// ** construct alloc/free traces
|
||||
// we can map reads/writes from the kernel graph
|
||||
const actions = [];
|
||||
const children = new Map(); // {buffer: [...assign]}
|
||||
for (const [k,v] of Object.entries(graph)) {
|
||||
if (!v.label.startsWith("ASSIGN")) continue;
|
||||
actions.push({ op: "write", buffer: v.src[0] });
|
||||
for (const ks of graph[v.src[1]].src) {
|
||||
const node = graph[ks];
|
||||
const s = node.label.startsWith("ASSIGN") ? node.src[0] : ks;
|
||||
if (!children.has(s)) children.set(s, []);
|
||||
children.get(s).push(v);
|
||||
if (s !== v.src[0]) actions.push({ op: "read", buffer: s });
|
||||
}
|
||||
}
|
||||
const prealloc = new Set();
|
||||
const traces = [];
|
||||
for (const a of actions) {
|
||||
// a buffer is allocated immediately before the first write
|
||||
// TODO: we don't know the buffer is preallocated if there's only an assign in the graph
|
||||
if (a.op === "write") {
|
||||
traces.push({ type: "alloc", buffer: a.buffer });
|
||||
}
|
||||
else {
|
||||
if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) {
|
||||
prealloc.add(a.buffer);
|
||||
}
|
||||
else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) {
|
||||
traces.push({type: "free", buffer: a.buffer });
|
||||
}
|
||||
}
|
||||
}
|
||||
// ** get coordinates and layout for each buffer
|
||||
const ret = {};
|
||||
let timestep = 0; // x
|
||||
let memUsed = 0; // y
|
||||
for (const id of prealloc) {
|
||||
const buf = getBuffer(graph[id]);
|
||||
ret[id] = { x: [timestep], y: [memUsed], buf, id };
|
||||
memUsed += buf.nbytes;
|
||||
}
|
||||
let peak = memUsed;
|
||||
const liveBufs = [...prealloc];
|
||||
for (const t of traces) {
|
||||
const buf = getBuffer(graph[t.buffer]);
|
||||
const idx = liveBufs.findLastIndex(b => t.buffer === b);
|
||||
// alloc
|
||||
if (idx === -1) {
|
||||
liveBufs.push(t.buffer);
|
||||
ret[t.buffer] = { x: [timestep], y: [memUsed], buf, id: t.buffer };
|
||||
memUsed += buf.nbytes;
|
||||
peak = Math.max(memUsed, peak);
|
||||
timestep += 1;
|
||||
} // free
|
||||
else {
|
||||
memUsed -= buf.nbytes;
|
||||
timestep += 1;
|
||||
const removed = ret[liveBufs.splice(idx, 1)[0]];
|
||||
removed.x.push(timestep);
|
||||
removed.y.push(removed.y.at(-1));
|
||||
if (idx < liveBufs.length) {
|
||||
for (let j=idx; j<liveBufs.length; j++) {
|
||||
const b = ret[liveBufs[j]];
|
||||
b.x.push(timestep, timestep);
|
||||
b.y.push(b.y.at(-1), b.y.at(-1)-buf.nbytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const id of liveBufs) {
|
||||
const b = ret[id];
|
||||
b.x.push(timestep);
|
||||
b.y.push(b.y.at(-1));
|
||||
}
|
||||
// ** render traces
|
||||
const render = d3.select("#bars");
|
||||
const yscale = d3.scaleLinear().domain([0, peak]).range([576, 0]);
|
||||
const xscale = d3.scaleLinear().domain([0, timestep]).range([0, 1024]);
|
||||
const xaxis = d3.axisBottom(xscale);
|
||||
const axesGroup = render.append("g").attr("id", "axes");
|
||||
const nbytes_format = (d) => d3.format(".3~s")(d)+"B";
|
||||
axesGroup.append("g").call(d3.axisLeft(yscale).tickFormat(nbytes_format));
|
||||
axesGroup.append("g").attr("transform", `translate(0, ${yscale.range()[0]})`).call(d3.axisBottom(xscale).tickFormat(() => ""));
|
||||
const polygonGroup = render.append("g").attr("id", "polygons");
|
||||
const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"];
|
||||
const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => {
|
||||
const xs = d.x.map(t => xscale(t));
|
||||
const y1 = d.y.map(t => yscale(t));
|
||||
const y2 = d.y.map(t => yscale(t+d.buf.nbytes));
|
||||
const p0 = xs.map((x, i) => `${x},${y1[i]}`);
|
||||
const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse();
|
||||
return `${p0.join(' ')} ${p1.join(' ')}`;
|
||||
}).attr("fill", d => `#${colors[d.buf.num % colors.length]}`).on("mouseover", (e, { id, buf, x }) => {
|
||||
d3.select(e.currentTarget).attr("stroke", "rgba(26, 27, 38, 0.8)").attr("stroke-width", 0.8);
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
document.getElementById("current-buf")?.remove();
|
||||
const { num, dtype, nbytes, ...rest } = buf;
|
||||
let label = `<BUFFER n${num} ${dtype} ${nbytes_format(nbytes)}>\nalive for ${pluralize(x[x.length-1]-x[0], 'timestep')}`;
|
||||
label += '\n'+Object.entries(rest).map(([k, v]) => `${k}=${v}`).join('\n');
|
||||
const buf_children = children.get(id);
|
||||
if (buf_children) {
|
||||
label += `\n${pluralize(buf_children.length, 'child', 'children')}\n`;
|
||||
label += buf_children.map((c,i) => `[${i+1}] `+graph[c.src[1]].label.split("\n")[1]).join("\n");
|
||||
}
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { innerText: label, id: "current-buf", className: "wrap" }));
|
||||
}).on("mouseout", (e, _) => {
|
||||
d3.select(e.currentTarget).attr("stroke", null).attr("stroke-width", null);
|
||||
document.getElementById("current-buf")?.remove()
|
||||
});
|
||||
// TODO: add the toposort graph here
|
||||
document.querySelector(".progress-message").style.display = "none";
|
||||
d3.select("#nodes").html("");
|
||||
d3.select("#edges").html("");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ onmessage = (e) => {
|
|||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
|
||||
for (const [k, [label, src, color]] of Object.entries(graph)) {
|
||||
for (const [k, {label, src, color}] of Object.entries(graph)) {
|
||||
// adjust node dims by label size + add padding
|
||||
const [labelWidth, labelHeight] = getTextDims(label);
|
||||
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING});
|
||||
|
|
|
|||
|
|
@ -49,10 +49,9 @@ class GraphRewriteDetails(TypedDict):
|
|||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
# NOTE: this is [id, [label, src_ids, color]]
|
||||
graph: dict[int, tuple[str, list[int], str]] = {}
|
||||
graph: dict[int, dict] = {}
|
||||
excluded: set[UOp] = set()
|
||||
for u in (toposort:=x.toposort):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
|
|
@ -72,7 +71,7 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
|
|||
if x in excluded:
|
||||
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
||||
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
||||
graph[id(u)] = (label, [id(x) for x in u.src if x not in excluded], uops_colors.get(u.op, "#ffffff"))
|
||||
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}
|
||||
return graph
|
||||
|
||||
def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue