Merge branch 'master' into new_x86_backend

This commit is contained in:
ttomsa 2026-02-08 19:15:14 +00:00 committed by GitHub
commit 6ff67781f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 2672 additions and 293 deletions

View file

@ -525,6 +525,8 @@ jobs:
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Test full tinyfs load
run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View file

@ -1,12 +1,18 @@
#!/usr/bin/env python3
import os
from tinygrad.helpers import Context
from tinygrad.runtime.support.system import System, PCIDevice, PCIDevImplBase
from tinygrad.runtime.support.hcq import FileIOInterface
from tinygrad.runtime.support.am.amdev import AMDev
if __name__ == "__main__":
gpus = System.pci_scan_bus(0x1002, [(0xffff, [0x74a1, 0x75a0])])
pcidevs = [PCIDevice(f"reset:{gpu}", gpu, bars=[0, 2, 5]) for gpu in gpus]
for gpu in gpus:
drv_path = f"/sys/bus/pci/devices/{gpu}/driver"
if FileIOInterface.exists(drv_path) and os.path.basename(os.readlink(drv_path)) == "amdgpu":
raise RuntimeError(f"amdgpu is bound to {gpu}. Stopping...")
pcidevs = [PCIDevice("AM", gpu, bars=[0, 2, 5]) for gpu in gpus]
amdevs = []
with Context(DEBUG=2):
for pcidev in pcidevs:

View file

@ -28,7 +28,7 @@ def custom_matmul(output: UOp, inp: UOp, weight: UOp) -> UOp:
return store_op.sink(arg=KernelInfo(name=f"fp8_matmul_{inp.shape}x{weight.shape}"))
def custom_matmul_backward(gradient: UOp, kernel: UOp) -> tuple[UOp, UOp]:
_, input_uop, weight_uop = kernel.src
_, input_uop, weight_uop = kernel.src[1:]
input_tensor = Tensor(input_uop, device=input_uop.device)
grad_tensor = Tensor(gradient, device=gradient.device)
weight_tensor = Tensor(weight_uop, device=weight_uop.device)

View file

@ -62,7 +62,7 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src
out, a, b = kernel.src[1:]
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop

View file

@ -103,9 +103,9 @@ class Attention:
def fa_custom_backward(out_q:UOp, out_k:UOp, out_v:UOp, grad:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
return UOp.sink(arg=KernelInfo(name="fa_custom_backward"))
def fa_backward(grad:UOp, kernel:UOp) -> tuple[None, UOp, UOp, UOp]:
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[1]))
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[2]))
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[3]))
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[2]))
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[3]))
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[4]))
ck = Tensor.custom_kernel(grad_q, grad_k, grad_v, Tensor(grad), q, k, v, fxn=fa_custom_backward)[:3]
return (None, ck[0].uop, ck[1].uop, ck[2].uop)
attn = Tensor.empty_like(attn).custom_kernel(xq, keys, values, fxn=fa_custom_forward, grad_fxn=fa_backward)[0]

View file

@ -1,11 +1,36 @@
from tinygrad.tensor import Tensor
import argparse
import argparse, math, hashlib
def _python_hash_1mb(data:bytes|bytearray):
chunks = [data[i:i+4096] for i in range(0, len(data), 4096)]
chunk_hashes = [hashlib.shake_128(chunk).digest(16) for chunk in chunks]
return hashlib.shake_128(b''.join(chunk_hashes)).digest(16)
def hash_file(data: bytes|bytearray):
if len(data) % Tensor.CHUNK_SIZE != 0: data += bytes(Tensor.CHUNK_SIZE - len(data) % Tensor.CHUNK_SIZE)
base_chunks = math.ceil(len(data) / Tensor.CHUNK_SIZE)
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
for _ in range(tree_depth + 1):
data_chunks = [data[i:i+Tensor.CHUNK_SIZE] for i in range(0, len(data), Tensor.CHUNK_SIZE)]
data_chunk_hashes = [_python_hash_1mb(chunk) for chunk in data_chunks]
data = b''.join(data_chunk_hashes)
if len(data) % Tensor.CHUNK_SIZE != 0: data += bytes(Tensor.CHUNK_SIZE - len(data) % Tensor.CHUNK_SIZE)
return data[:16]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hash", type=str, required=True, help="file hash to fetch")
parser.add_argument("--len", type=int, required=True, help="file length to fetch")
parser.add_argument("--dest", type=str, required=True, help="destination path to save the file")
parser.add_argument("--check", action="store_true", help="verify the file hash after fetching")
args = parser.parse_args()
Tensor(bytes.fromhex(args.hash), device="CPU").fs_load(args.len).to(f"disk:{args.dest}").realize()
if args.check:
with open(args.dest, "rb") as f:
data = f.read()
assert hash_file(data) == bytes.fromhex(args.hash), "Hash mismatch after fetching file"
print("File hash verified successfully!")

View file

@ -67,7 +67,7 @@ testing_minimal = [
"pytest-timeout",
"pytest-split",
"hypothesis>=6.148.9",
"z3-solver",
"z3-solver<4.15.4", # 4.15.4 has a segfault when creating many z3.Context()
]
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "ggml-python"]
testing = [

View file

@ -1,3 +1,7 @@
# NOTE: z3-solver 4.15.4 segfaults (exit code 139) when creating many z3.Context() with complex expressions.
# Reproduces consistently with seed=74 around iteration 1767. Versions <=4.15.3 are fine.
# Workaround: reuse a single z3.Context, or pin z3-solver<4.15.4 (see pyproject.toml).
# To repro: pip install z3-solver==4.15.4.0 && python test/external/fuzz_symbolic.py 74
import random, operator, sys
import z3
from tinygrad import Variable, dtypes

View file

@ -47,7 +47,7 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)

View file

@ -130,15 +130,18 @@ class NVDriver(VirtDriver):
struct.hObjectNew = self._alloc_handle()
self.object_by_handle[struct.hObjectNew] = NVContextShare(self.object_by_handle[struct.hObjectParent])
elif struct.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A:
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVChannelGroup)
parent = self.object_by_handle.get(struct.hObjectParent)
assert parent is not None and isinstance(parent, (NVChannelGroup, NVGPU))
struct.hObjectNew = self._alloc_handle()
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS.from_address(params_ptr)
gpu = self.object_by_handle[struct.hObjectParent].device
gpu = parent.device if isinstance(parent, NVChannelGroup) else parent
gpfifo_token = gpu.add_gpfifo(params.gpFifoOffset, params.gpFifoEntries)
self.object_by_handle[struct.hObjectNew] = NVGPFIFO(gpu, gpfifo_token)
elif struct.hClass == nv_gpu.AMPERE_DMA_COPY_B or struct.hClass == nv_gpu.ADA_COMPUTE_A:
elif struct.hClass in (nv_gpu.AMPERE_DMA_COPY_B, nv_gpu.ADA_COMPUTE_A, nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER):
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPFIFO)
struct.hObjectNew = self._alloc_handle()
gpfifo = self.object_by_handle[struct.hObjectParent]
gpfifo.device.queues[gpfifo.token].bound_engines.add(struct.hClass)
elif struct.hClass == nv_gpu.GT200_DEBUGGER:
struct.hObjectNew = self._alloc_handle()
elif struct.hClass == nv_gpu.MAXWELL_PROFILER_DEVICE:
@ -194,7 +197,7 @@ class NVDriver(VirtDriver):
params = nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS.from_address(params_ptr)
gpu_fifo = self.object_by_handle[struct.hObject]
params.workSubmitToken = gpu_fifo.token
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
elif struct.cmd in (nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE, nv_gpu.NVA06F_CTRL_CMD_BIND, nv_gpu.NVA06F_CTRL_CMD_GPFIFO_SCHEDULE): pass
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_FB_FLUSH_GPU_CACHE: pass
elif struct.cmd == nv_gpu.NV83DE_CTRL_CMD_DEBUG_READ_ALL_SM_ERROR_STATES:

View file

@ -26,6 +26,7 @@ class GPFIFO:
self.gpfifo = to_mv(self.base, self.entries_cnt * 8).cast("Q")
self.ctrl = nv_gpu.AmpereAControlGPFifo.from_address(self.base + self.entries_cnt * 8)
self.state = {}
self.bound_engines: set[int] = set()
# Buf exec state
self.buf = None
@ -115,9 +116,11 @@ class GPFIFO:
def execute_cmd(self, cmd) -> SchedResult:
if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
elif cmd == nv_gpu.NVC6C0_LAUNCH_DMA: return self._exec_nvc6c0_dma()
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: return self._exec_nvc6b5_dma()
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: # NOTE: NVC6B5_LAUNCH_DMA == NVC9B0_EXECUTE == 0x300
return self._exec_vid_decode() if self.bound_engines & {nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER} else self._exec_nvc6b5_dma()
elif cmd == nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B: return self._exec_pcas2()
elif cmd == 0x0320: return self._exec_load_inline_qmd() # NVC6C0_LOAD_INLINE_QMD_DATA
elif cmd == nv_gpu.NVC9B0_SEMAPHORE_D: return self._exec_vid_semaphore()
else: self.state[cmd] = self._next_dword() # just state update
return SchedResult.CONT
@ -136,6 +139,27 @@ class GPFIFO:
else: raise RuntimeError(f"Unsupported type={typ} in exec wait/signal")
return SchedResult.CONT
def _exec_vid_decode(self) -> SchedResult:
self._next_dword() # consume execute flags
# validate that all required decode state was set up correctly
assert self._state(nv_gpu.NVC9B0_SET_APPLICATION_ID) == nv_gpu.NVC9B0_SET_APPLICATION_ID_ID_HEVC
pic_desc_addr = self._state(nv_gpu.NVC9B0_SET_DRV_PIC_SETUP_OFFSET) << 8
pic = nv_gpu.nvdec_hevc_pic_s.from_address(pic_desc_addr)
assert pic.stream_len > 0 and pic.pic_width_in_luma_samples > 0 and pic.pic_height_in_luma_samples > 0
assert self._state(nv_gpu.NVC9B0_SET_IN_BUF_BASE_OFFSET) != 0
assert self._state(nv_gpu.NVC9B0_SET_COLOC_DATA_OFFSET) != 0
assert self._state(nv_gpu.NVC9B0_SET_NVDEC_STATUS_OFFSET) != 0
assert self._state(nv_gpu.NVC9B0_HEVC_SET_FILTER_BUFFER_OFFSET) != 0
return SchedResult.CONT
def _exec_vid_semaphore(self) -> SchedResult:
signal = self._state64(nv_gpu.NVC9B0_SEMAPHORE_A)
val = self._state(nv_gpu.NVC9B0_SEMAPHORE_C)
self._next_dword() # flags
to_mv(signal, 8).cast('Q')[0] = val
to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
return SchedResult.CONT
def _exec_load_inline_qmd(self):
qmd_addr = self._state64(nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A) << 8
assert qmd_addr != 0x0, f"invalid qmd address {qmd_addr}"

View file

@ -102,8 +102,8 @@ class TestCompiler(unittest.TestCase):
class TestRunAsModule(unittest.TestCase):
def test_module_runs(self):
out = '\n'.join(enumerate_devices_str())
self.assertIn("CPU", out) # for sanity check
cpu_line = [l for l in enumerate_devices_str() if "CPU" in l][0]
self.assertIn("PASS", cpu_line, f"expected CPU to PASS, got: {cpu_line}")
if __name__ == "__main__":
unittest.main()

View file

@ -88,13 +88,13 @@ def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp:
# **** backward callbacks ****
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
out, a, b = kernel.src[1:]
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
return (None, grad_a, grad_b)
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
out, a, b = kernel.src
out, a, b = kernel.src[1:]
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
return (None, grad_a, grad_b)
@ -128,6 +128,14 @@ class TestCustomKernel(unittest.TestCase):
out = c.flatten().tolist()
assert all(x == 2 for x in out), "all 2"
def test_sharded_add_one(self):
# PYTHON backend explicitly checks for OOB access for wrong multi shape regression
devs = ("PYTHON:0", "PYTHON:1")
a = Tensor.ones(4, 4).contiguous().shard(devs, axis=0)
c = Tensor(Tensor.empty(2, 4, device=devs).uop.multi(0), device=devs)
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
assert (c == 2).all().item()
def test_multioutput(self):
a = Tensor.full((16, 16), 3.).contiguous()
b = Tensor.full((16, 16), 3.).contiguous()

View file

@ -380,9 +380,40 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
class TestEmulatedBFloat16Type(TestBFloat16Type):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="bfloat16"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
class TestEmulatedFp8e4m3(TestFp8e4m3):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
class TestEmulatedFp8e5m2(TestFp8e5m2):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)

View file

@ -1,6 +1,6 @@
import unittest, operator, math
from tinygrad import Context, Tensor, dtypes, Device
from tinygrad.dtype import DType, truncate
from tinygrad.dtype import DType, truncate, fp8_to_float
from tinygrad.helpers import CI, EMULATED_DTYPES, getenv
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
@ -59,9 +59,10 @@ def universal_test(a, b, dtype, op):
# lt and max with nan is undefined in tinygrad
if op[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
tensor_value = (op[0](ta, tb)).numpy()
numpy_value = op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
if dtype in dtypes.fp8s and op[0] not in (operator.lt, operator.eq):
tensor_value = fp8_to_float((op[0](ta.realize(), tb.realize())).bitcast(dtypes.uint8).item(), dtype)
numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item())
else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.floats:
if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
fe, fm = dtypes.finfo(dtype)
@ -76,13 +77,14 @@ def universal_test_unary(a, dtype, op):
# TODO: cos does not match for large input
if op[0] == Tensor.cos and abs(a) > 30: return
if op[0] == Tensor.log and a <= 0: return
out: Tensor = op[0](ta)
tensor_value = out.numpy()
numpy_value = op[1](ta.numpy())
if dtype in dtypes.fp8s:
# normals are zero
if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
if math.isinf(numpy_value.item()): return
numpy_value = truncate[dtype](numpy_value.item())
if math.isinf(v): return
else: tensor_value, numpy_value = op[0](ta).numpy(), op[1](ta.numpy())
if dtype in dtypes.floats:
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
@ -128,16 +130,31 @@ class TestDTypeALU(unittest.TestCase):
def test_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="bfloat16")
def test_emulated_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
def test_fp8e4m3(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="fp8e4m3")
def test_emulated_fp8e4m3(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
def test_fp8e5m2(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="fp8e5m2")
def test_emulated_fp8e5m2(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.float32, strat.sampled_from(unary_operations))
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
@ -153,18 +170,34 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@given(ht.bfloat16, strat.sampled_from(unary_operations))
@Context(EMULATED_DTYPES="bfloat16")
def test_emulated_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
def test_fp8e4m3_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
@Context(EMULATED_DTYPES="fp8e4m3")
def test_emulated_fp8e4m3_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
def test_fp8e5m2_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
@Context(EMULATED_DTYPES="fp8e5m2")
def test_emulated_fp8e5m2_unary(self, a, op):
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)

View file

@ -212,6 +212,7 @@ class TestProfiler(unittest.TestCase):
for ge in graphs:
self.assertEqual(len(ge.ents), len(graphs))
@unittest.skip("this test is flaky")
def test_trace_metadata(self):
with Context(TRACEMETA=1):
a = Tensor.empty(1)+2

View file

@ -10,7 +10,7 @@ from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import CompiledRunner, run_schedule
@ -677,17 +677,6 @@ class TestSchedule(unittest.TestCase):
c = (a.sum(2).contiguous() + b).contiguous()
check_schedule(c, 2)
# TODO: this requires supporting multiple stores in the AST
@unittest.expectedFailure
def test_multioutput_ast(self):
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
c = Tensor.arange(4).realize().uop
kernel = UOp(Ops.KERNEL, src=(a.base, b.base, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
self.assertEqual(a.buffer.numpy(), [7])
self.assertEqual(b.buffer.numpy(), [12])
@unittest.skip("no longer supported")
def test_double_from(self):
x = Tensor([1,2,3,4])

View file

@ -1,10 +1,14 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, CI
from tinygrad.helpers import getenv
from extra.gemm.asm.cdna.gemm import asm_gemm
from test.helpers import needs_second_gpu
# On non CDNA4 it will only validate the Tensor.custom_kernel integration
# Use NULL=1 EMULATE=AMD_CDNA4 to also test the assembly
def is_cdna4(): return getattr(Device[Device.DEFAULT].renderer, "arch", "").startswith("gfx950")
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None:
Tensor.manual_seed(0)
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
@ -16,36 +20,47 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
tst = asm_gemm(a, b)
tst.sum().backward()
with Context(ASM_GEMM=1):
tst = asm_gemm(a, b)
tst.sum().backward()
Tensor.realize(tst, a.grad, b.grad)
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
ref.sum().backward()
with Context(ASM_GEMM=0):
ref = asm_gemm(a_ref, b_ref)
ref.sum().backward()
Tensor.realize(ref, a_ref.grad, b_ref.grad)
# no validation on the NULL device
if a_rand.device.startswith("NULL"): return None
with Context(DEBUG=0):
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
SCALE = 128 if CI else 1
# 128x smaller than usual
# uses the UOp GEMM, runs on non CDNA4 and CI
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE)
def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE)
def setUp(self):
if is_cdna4(): self.skipTest("shapes are too small for the assembly GEMM")
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 32), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 64, 32, 112)
def test_gemm_batched(self): verify_asm_gemm(2, 64, 32, 32)
@needs_second_gpu
def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2)
def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2)
# uses the Asm GEMM on CDNA4 only for speed reasons
class TestGemmLarge(unittest.TestCase):
def setUp(self):
if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950":
if not is_cdna4():
self.skipTest("very slow on non mi350x")
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096)
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
@unittest.skip("disabled, asm in this shape is slower than tinygrad")
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, dtype=dtypes.bfloat16, gpus=8)

View file

@ -1,8 +1,9 @@
import unittest
from tinygrad import Device
from tinygrad.helpers import fetch
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import fetch, round_up
from extra.hevc.hevc import parse_hevc_file_headers, nv_gpu
from extra.hevc.decode import hevc_decode
class TestHevc(unittest.TestCase):
def test_hevc_parser(self):
@ -61,5 +62,26 @@ class TestHevc(unittest.TestCase):
self.assertEqual(list(frame3.initreflistidxl0), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
self.assertEqual(list(frame3.initreflistidxl1), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
self.assertEqual(list(frame3.RefDiffPicOrderCnts), [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
@unittest.skipUnless(Device.DEFAULT == "NV", "NV only")
def test_hevc_decode(self):
url = "https://github.com/haraschax/filedump/raw/09a497959f7fa6fd8dba501a25f2cdb3a41ecb12/comma_video.hevc"
dat = fetch(url, headers={"Range": f"bytes=0-{512<<10}"}).read_bytes()
opaque, frame_info, w, h, luma_w, luma_h, chroma_off = parse_hevc_file_headers(dat)
frame_info = frame_info[:4]
out_image_size = luma_h + (luma_h + 1) // 2, round_up(luma_w, 64)
hevc_tensor = Tensor(dat, device="NV")
opaque_nv = opaque.to("NV").contiguous().realize()
frames = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w))
Device.default.synchronize()
self.assertEqual(len(frames), 4)
for f in frames:
self.assertEqual(f.shape, out_image_size)
self.assertEqual(f.dtype, dtypes.uint8)
self.assertEqual(f.device, "NV")
if __name__ == "__main__":
unittest.main()

View file

@ -1,18 +1,19 @@
from typing import cast
from dataclasses import replace
import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, TracingKey, Context
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, promo_lattice
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import panic
from tinygrad.codegen.opt import Opt
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns, get_transcendental_patterns
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads
@ -24,7 +25,7 @@ from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_c
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(sink))
if SPEC: type_verify(sink, kernel_spec)
@ -88,10 +89,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
pm_unsupported = get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes)))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True)
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
# final rules for the renderer (without sym)
@ -108,7 +112,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# inject IF/ENDIF. only needed if device doesn't support gated stores
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))

View file

@ -406,9 +406,9 @@ def enumerate_devices_str() -> Generator[str, None, None]:
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else ''
default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text
compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}")
compilers_results.append(f"{colored('+', 'green')} {d._compiler_name(r, c)} {default_text}")
any_works = True
except Exception as e: compilers_results.append(f"{colored('-', 'yellow')} {unwrap_class_type(c).__name__}: {e}")
except Exception as e: compilers_results.append(f"{colored('-', 'yellow')} {d._compiler_name(r, c)}: {e}")
finally:
# put the defaults back!
d.comp_sets, d.comps_ctrl_var = default_comp_pairs, cc_ctrl_var

View file

@ -1,7 +1,7 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
@ -22,14 +22,14 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
for u in sched_sink.toposort(gate_kernel_sink):
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
if u.op is not Ops.AFTER: continue
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
for s in k.src[0].src if k.op is Ops.END else k.src:
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
@ -54,13 +54,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL:
ast = (kernel:=cast(Kernel, k.arg)).ast
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges)
elif k.op is Ops.CALL:
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
schedule.append(k)
if rk.op is Ops.END: schedule.append(rk)
for x in children.get(rk, []):
@ -86,7 +86,7 @@ def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem])
sched_ptr = range_ptrs[si.src[1]]
continue
else:
assert si.op is Ops.KERNEL, f"unexpected op in schedule: {si.op}"
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
ast, buf_uops, metadata, bound_ranges = sched_item[si]
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))

View file

@ -52,7 +52,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# gradient on CALL: use provided grad_fxn or auto-differentiate
(UPat(Ops.CALL, name="k"), call_gradient),
# there's no gradient for bitcast

View file

@ -86,7 +86,9 @@ def word_wrap(x, wrap=80):
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
return x[:i] + "\n" + word_wrap(x[i:], wrap)
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
# NOTE: you must create the exception inside the function where it's raised or you will get a GC cycle!
def panic(e:type[Exception]|None=None, *arg): raise e(*arg) if e is not None else RuntimeError("PANIC!")
@functools.cache
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:

View file

@ -57,7 +57,7 @@ def __getattr__(nm):
], args=[
"-include", "{}/src/common/sdk/nvidia/inc/nvtypes.h", "-I{}/src/common/inc", "-I{}/kernel-open/nvidia-uvm", "-I{}/kernel-open/common/inc",
"-I{}/src/common/sdk/nvidia/inc", "-I{}/src/nvidia/arch/nvalloc/unix/include", "-I{}/src/common/sdk/nvidia/inc/ctrl"
], rules=[(r'MW\(([^:]+):(.+)\)',r'(\1, \2)')], tarball=nv_src[nm], anon_names={"{}/kernel-open/common/inc/nvstatus.h:37":"nv_status_codes"})
], rules=[(r'MW\(([^:]+):(.+)\)',r'(\1, \2)'), (r'(\d+):(\d+)', r'(\1, \2)')], tarball=nv_src[nm], anon_names={"{}/kernel-open/common/inc/nvstatus.h:37":"nv_status_codes"})
case "nv": return load("nv", None, [
*[f"{{}}/src/nvidia/inc/kernel/gpu/{s}.h" for s in ["fsp/kern_fsp_cot_payload", "gsp/gsp_init_args"]],
*[f"{{}}/src/nvidia/arch/nvalloc/common/inc/{s}.h" for s in ["gsp/gspifpub", "gsp/gsp_fw_wpr_meta", "gsp/gsp_fw_sr_meta", "rmRiscvUcode",

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -703,13 +703,15 @@ class KFDIface:
# Event to wait for queues completion
self.dev.queue_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_SIGNAL, auto_reset=1)
self.dev.queue_event_mailbox_ptr = KFDIface.event_page.va_addr + self.dev.queue_event.event_slot_index * 8
self.queue_event_arr = (kfd.struct_kfd_event_data)(event_id=self.dev.queue_event.event_id)
self.queue_event_arr_ptr = ctypes.addressof(self.queue_event_arr)
# OS events to collect memory and hardware faults
self.mem_fault_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_MEMORY)
self.hw_fault_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_HW_EXCEPTION)
self.queue_event_arr = (kfd.struct_kfd_event_data * 3)(kfd.struct_kfd_event_data(event_id=self.dev.queue_event.event_id),
kfd.struct_kfd_event_data(event_id=self.mem_fault_event.event_id), kfd.struct_kfd_event_data(event_id=self.hw_fault_event.event_id))
self.queue_event_arr_ptr = ctypes.addressof(self.queue_event_arr)
def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, cpu_addr=None) -> HCQBuffer:
flags = kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
@ -778,18 +780,20 @@ class KFDIface:
doorbell=MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q'))
def sleep(self, tm:int):
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=3, wait_for_all=0, timeout=tm)
if self.queue_event_arr[1].memory_exception_data.gpu_id or self.queue_event_arr[2].hw_exception_data.gpu_id: raise RuntimeError("Device fault")
def on_device_hang(self):
def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._real_fields_)
def _str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._real_fields_)
# try to collect fault info if not already set from sleep().
if not self.queue_event_arr[1].memory_exception_data.gpu_id and not self.queue_event_arr[2].hw_exception_data.gpu_id:
with contextlib.suppress(RuntimeError): self.sleep(tm=1)
report = []
for evnt in [self.mem_fault_event, self.hw_fault_event]:
ev = (kfd.struct_kfd_event_data)(event_id=evnt.event_id)
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=ctypes.addressof(ev), num_events=1, wait_for_all=1)
if evnt == self.mem_fault_event and ev.memory_exception_data.gpu_id:
report += [f"MMU fault: 0x{ev.memory_exception_data.va:X} | {_collect_str(ev.memory_exception_data.failure)}"]
if evnt == self.hw_fault_event and ev.hw_exception_data.gpu_id: report += [f"HW fault: {_collect_str(ev.hw_exception_data)}"]
if self.queue_event_arr[1].memory_exception_data.gpu_id:
report += [f"MMU fault: 0x{self.queue_event_arr[1].memory_exception_data.va:X} | {_str(self.queue_event_arr[1].memory_exception_data.failure)}"]
if self.queue_event_arr[2].hw_exception_data.gpu_id: report += [f"HW fault: {_str(self.queue_event_arr[2].hw_exception_data)}"]
raise RuntimeError("\n".join(report))

View file

@ -36,6 +36,9 @@ def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status
NV_PFAULT_FAULT_TYPE = {dt:name for name,dt in nv_gpu.__dict__.items() if name.startswith("NV_PFAULT_FAULT_TYPE_")}
NV_PFAULT_ACCESS_TYPE = {dt:name.split("_")[-1] for name,dt in nv_gpu.__dict__.items() if name.startswith("NV_PFAULT_ACCESS_TYPE_")}
def nv_flags(reg, **kwargs): return functools.reduce(int.__or__, ((getattr(nv_gpu, f"{reg}_{k}_{v}".upper()) if isinstance(v, str) else v) <<
getattr(nv_gpu, f"{reg}_{k}".upper())[1] for k, v in kwargs.items()), 0)
def nv_iowr(fd:FileIOInterface, nr, args, cmd=None):
ret = fd.ioctl(cmd or ((3 << 30) | (ctypes.sizeof(args) & 0x1FFF) << 16 | (ord('F') & 0xFF) << 8 | (nr & 0xFF)), args)
if ret != 0: raise RuntimeError(f"ioctl returned {ret}")
@ -94,7 +97,8 @@ class NVCommandQueue(HWQueue[HCQSignal, 'NVDevice', 'NVProgram', 'NVArgsState'])
return self
def wait(self, signal:HCQSignal, value:sint=0):
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value), (3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
nv_flags("NVC56F_SEM_EXECUTE", operation="acq_circ_geq", payload_size="64bit"))
self.active_qmd = None
return self
@ -125,7 +129,8 @@ class NVCommandQueue(HWQueue[HCQSignal, 'NVDevice', 'NVProgram', 'NVArgsState'])
class NVComputeQueue(NVCommandQueue):
def memory_barrier(self):
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, (1 << 12) | (1 << 4) | (1 << 0))
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI,
nv_flags("NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI", instruction="true", global_data="true", constant="true"))
self.active_qmd:QMD|None = None
return self
@ -169,7 +174,7 @@ class NVComputeQueue(NVCommandQueue):
return self
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
nv_flags("NVC56F_SEM_EXECUTE", operation="release", release_wfi="en", payload_size="64bit", release_timestamp="en"))
self.nvm(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 0x0)
self.active_qmd = None
return self
@ -185,12 +190,13 @@ class NVCopyQueue(NVCommandQueue):
for off in range(0, copy_size, step:=(1 << 31)):
self.nvm(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, *data64(src+off), *data64(dest+off))
self.nvm(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, min(copy_size-off, step))
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA,
nv_flags("NVC6B5_LAUNCH_DMA", data_transfer_type="non_pipelined", src_memory_layout="pitch", dst_memory_layout="pitch"))
return self
def signal(self, signal:HCQSignal, value:sint=0):
self.nvm(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, *data64(signal.value_addr), value)
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x14)
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, nv_flags("NVC6B5_LAUNCH_DMA", flush_enable="true", semaphore_type="release_four_word_semaphore"))
return self
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
@ -199,7 +205,8 @@ class NVVideoQueue(NVCommandQueue):
def decode_hevc_chunk(self, pic_desc:HCQBuffer, in_buf:HCQBuffer, out_buf:HCQBuffer, out_buf_pos:int, hist_bufs:list[HCQBuffer], hist_pos:list[int],
chroma_off:int, coloc_buf:HCQBuffer, filter_buf:HCQBuffer, intra_top_off:int, intra_unk_off:int|None, status_buf:HCQBuffer):
self.nvm(4, nv_gpu.NVC9B0_SET_APPLICATION_ID, nv_gpu.NVC9B0_SET_APPLICATION_ID_ID_HEVC)
self.nvm(4, nv_gpu.NVC9B0_SET_CONTROL_PARAMS, 0x52057)
self.nvm(4, nv_gpu.NVC9B0_SET_CONTROL_PARAMS, nv_flags("NVC9B0_SET_CONTROL_PARAMS", codec_type="hevc", testrun_env="prod_run", gptimer_on=1,
err_conceal_on=1, mbtimer_on=1, event_trace_logging_on=1))
self.nvm(4, nv_gpu.NVC9B0_SET_DRV_PIC_SETUP_OFFSET, pic_desc.va_addr >> 8)
self.nvm(4, nv_gpu.NVC9B0_SET_IN_BUF_BASE_OFFSET, in_buf.va_addr >> 8)
for pos, buf in zip(hist_pos + [out_buf_pos], hist_bufs + [out_buf]):
@ -216,7 +223,7 @@ class NVVideoQueue(NVCommandQueue):
def signal(self, signal:HCQSignal, value:sint=0):
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_A, *data64(signal.value_addr), value)
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_D, (1 << 24) | (1 << 0))
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_D, nv_flags("NVC9B0_SEMAPHORE_D", structure_size="four", payload_size="64bit"))
return self
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.vid_gpfifo)

View file

@ -264,7 +264,8 @@ class AM_GFX(AM_IP):
self.adev.regGRBM_CNTL.update(read_timeout=0xff, inst=xcc)
for i in range(0, 16):
self._grbm_select(vmid=i, inst=xcc)
self.adev.regSH_MEM_CONFIG.write(**({'initial_inst_prefetch':3} if self.adev.ip_ver[am.GC_HWIP][0]>=10 else {'retry_disable':1, 'f8_mode':1}),
self.adev.regSH_MEM_CONFIG.write(**({'initial_inst_prefetch':3} if self.adev.ip_ver[am.GC_HWIP][0]>=10 else {'retry_disable':1}),
**({'f8_mode':1} if self.adev.ip_ver[am.GC_HWIP][:2]==(9,4) else {}),
address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64, alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, inst=xcc)
# Configure apertures:

View file

@ -257,11 +257,10 @@ class HCQSignal(Generic[HCQDeviceType]):
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
"""
start_time = last_sleep_time = int(time.perf_counter() * 1000)
start_time = int(time.perf_counter() * 1000)
while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout:
self._sleep(cur_time - last_sleep_time)
last_sleep_time = int(time.perf_counter() * 1000)
if self.value != prev_value: start_time = last_sleep_time # progress was made, reset timer
self._sleep(cur_time - start_time)
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager

View file

@ -3,13 +3,13 @@ import functools, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
@ -17,7 +17,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
pm_generate_realize_map = PatternMatcher([
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize
@ -91,20 +91,25 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
if assign.src[1].op is Ops.KERNEL: return None
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
ret = assign.replace(src=assign.src+(to_mop,))
ctx.range_map[ret] = ctx.range_map[assign]
return ret
def handle_assign_mops(ctx:IndexingContext, assign:UOp, target:UOp, src:UOp):
if target.op in GroupOp.Movement and src.op is not Ops.CALL:
mops = []
while target.op in GroupOp.Movement:
mops.append((target.op, target.marg))
target = target.src[0]
if mops and assign in ctx.range_map:
ret = assign.replace(arg=tuple(mops))
ctx.range_map[ret] = ctx.range_map[assign]
return ret
return None
pm_apply_rangeify = PatternMatcher([
# REDUCE_AXIS -> REDUCE
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
# PAD -> WHERE
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
# add third op to assign
(UPat(Ops.ASSIGN, src=(UPat(), UPat()), name="assign"), add_third_op_to_assign_to_track_shape),
# store movement ops in ASSIGN arg
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(name="src")), name="assign"), handle_assign_mops),
# finally, apply_rangeify
(UPat(GroupOp.All, name="x"), create_bufferize_and_index_based_on_ranges),
# remove movement op
@ -159,7 +164,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# get the consumer map
with cpu_profile("consumer map in rangeify", "TINY"):
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort())
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort(gate_kernel_sink))
# explicit rangeify
ending_ranges: dict[UOp, list[UOp]] = {}
@ -167,7 +172,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
# no ranges on kernels, they are internal
if x.op is Ops.KERNEL: continue
if x.op is Ops.CALL: continue
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])

View file

@ -1,7 +1,8 @@
from typing import cast
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.dtype import dtypes
from tinygrad.device import Device
# *** allreduce implementation ***
@ -214,17 +215,18 @@ multi_pm = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
# after CALL
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
])+replace_allreduce
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
if getenv("VIZ"): graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
return ret

View file

@ -1,10 +1,10 @@
from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
@ -33,13 +33,17 @@ pm_mops = PatternMatcher([
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp):
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
if (t := target.base).op is Ops.BUFFER or (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src)): return None
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if dest.op_in_backward_slice_with_self(Ops.SHRINK) else set())
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return
for h in hazards:
if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
return assign.replace(src=(dest, src.contiguous()))
if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
return assign.replace(src=(target, src.contiguous()))
def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None
@ -70,10 +74,6 @@ mop_cleanup = PatternMatcher([
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
@ -95,9 +95,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
@ -137,15 +134,13 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda target, src, assign: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None),
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
# make source contiguous if it has hazardous movement ops on the dest buffer
(UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard),
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
])
# *****************
@ -334,19 +329,13 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if x.src[0].op is Ops.ASSIGN:
assign_target, assign_src, assign_mops = x.src[0].src
if (assign := x.src[0]).op is Ops.ASSIGN:
assign_target, assign_src = assign.src[0], assign.src[1]
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
# in assign, this is the buffer size, not the bufferize size
# TODO: assign_mops here
do_store = assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*rngs)
ret = assign_target.src[0].after(do_store)
mops = []
walk = assign_mops
while walk is not assign_mops.base:
mops.append((walk.op, walk.marg))
walk = walk.src[0]
for m in mops[::-1]: ret = ret._mop(*m)
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
return ret
# lower outerworld reduce here
@ -393,7 +382,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
# remove any RESHAPEs on KERNEL
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
@ -525,10 +514,9 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel_arg = Kernel(ret, metadata)
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
split_kernels = PatternMatcher([
@ -543,9 +531,9 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
if x.dtype.scalar() == dtypes.index: return None
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([
add_tags = pm_gate_kernel_sink+PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])
@ -566,7 +554,7 @@ replace_contiguous = PatternMatcher([
])
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
@ -584,12 +572,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER, Ops.AFTER} and \
x.tag is not None and len(x.tag)])
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
name="bufferize to store")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
@ -609,7 +598,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
becomes_map: dict[UOp, UOp] = {}
for tag, s in zip(sink_tags, tsink.src):

View file

@ -234,8 +234,7 @@ class Tensor(OpMixin):
def as_param(self, slot:int):
if self.uop.axis is not None:
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
@ -756,8 +755,7 @@ class Tensor(OpMixin):
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
@ -3786,7 +3784,7 @@ class Tensor(OpMixin):
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices).realize()
V = V.gather(-1, new_indices)
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
padded_u[..., 0:num, 0:num] = U

View file

@ -84,8 +84,7 @@ class Ops(FastEnum):
# ** 6 -- ops that don't exist in programs **
# tensor graph ops
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto()
CUSTOM_KERNEL = auto()
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
# local unique
LUNIQUE = auto()

View file

@ -14,8 +14,8 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
# *** helper functions for bit manipulation ***
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1
def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1
# **** utils ****
def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
@ -378,33 +378,46 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
# ***** floats *****
f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint }
f2f_dt = { f:getattr(dtypes, f"uint{f.bitsize}") for f in dtypes.floats }
def rne(v: UOp, s) -> UOp: return shr(v, s) + ((shr(v, s - 1) & 1) & ((v & ((1 << (s - 1)) - 1)).ne(0).cast(v.dtype) | (shr(v, s) & 1)))
def f2f(v, fr:DType, to:DType):
fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to)
# NB: denormals are zero!
if fe < te and fm < tm:
if fe <= te and fm < tm:
sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to])
exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm)
inf_or_nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
return (sign | exp.eq(0).where(0, exp.eq(shl(1, fe) - 1).where(inf_or_nan, norm))).bitcast(to)
elif fe > te and fm > tm:
sign, nosign, exp = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1), shr(v, fm) & (shl(1, fe) - 1)
nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
# fp8e4m3 has only one nan
is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1))
return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to)
elif fe >= te and fm > tm:
v = f2f_clamp(v.bitcast(fr), to).bitcast(f2f_dt[fr])
sign, nosign = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1)
norm = (rne(nosign, fm - tm) - shl(fb - tb, tm)).cast(f2f_dt[to])
infnan = (sign | (shr(nosign, fm - tm) & (shl(1, tm) - 1)) | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
underflow, overflow = exp < (1 + fb - tb), exp > (shl(1, te) - 2 + (fb - tb))
return exp.eq(shl(1, fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | underflow.where(0, overflow.where(shl(shl(1, te) - 1, tm), norm)))
underflow = (shr(v, fm) & (shl(1, fe) - 1)) < (1 + fb - tb)
nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1))
nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1)
return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm))
else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}")
def f2f_load(x: UOp) -> UOp:
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=dtypes.ushort), dtypes.half, dtypes.float)
return UOp.vectorize(*(f2f(x.replace(dtype=dtypes.ushort, src=(reindex(x.src[0].src[0], i, 1),)), dtypes.half, dtypes.float) for i in range(n)))
def f2f_clamp(val:UOp, dt:DType) -> UOp:
e, m = dtypes.finfo(dt)
max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m)))
sat = mx if dt in dtypes.fp8s else val.const_like(float('inf'))
# FIXME: CMPLT of nan is undefined
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
def f2f_store(st, idx, val):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half)))
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n)))
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
def f2f_store(st, idx, val, fr:DType, to:DType):
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n)))
# ***** decomposition patterns *****
@ -463,40 +476,44 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, disable_fast_idiv
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
return PatternMatcher(pat)
@functools.cache
def get_unsupported_dtypes_patterns(device:str, emulated_dtypes:tuple[DType, ...]) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if not is_dtype_supported(dtypes.long, device) or dtypes.long in emulated_dtypes:
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None)]
pat += [(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype]))]
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None)]
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None)]
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag])]
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None)]
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)))]
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))]
if dtypes.half in emulated_dtypes:
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)]
pat += [(UPat(Ops.LOAD, dtypes.half, name="x"), f2f_load)]
pat += [(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, dtypes.half, name="ld"),), name="bc"), lambda bc,ld:
ld.replace(dtype=dtypes.ushort).bitcast(bc.dtype))]
pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x:
bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))]
pat += [(UPat(GroupOp.All, dtypes.half, name="x"), lambda x:
x.replace(dtype=dtypes.float.vec(x.dtype.count), src=tuple(s.cast(dtypes.float) if s.dtype == dtypes.half else s for s in x.src)))]
pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val:
f2f_store(st, idx, val) if (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == dtypes.half else None)]
return PatternMatcher(pat)
pm_long_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))),
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None),
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag]),
(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None),
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]),
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
])
# float decomposition patterns - ctx is (fr, to) tuple
pm_float_decomp = PatternMatcher([
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld:
ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None),
(UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x:
bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None),
(UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val:
f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None),
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src))
if x.dtype.scalar() == ctx[0] else None),
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val:
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
])

View file

@ -207,7 +207,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None
@ -236,9 +236,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
return self.src[0]._shape
# ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
ps = self.src[0]._shape
@ -288,10 +285,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
raise ValueError(f"invalid type for axis: {axis_arg}")
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
if self.op is Ops.ASSIGN: return self.src[1]._shape
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
# TODO: remove this hack for 3 op assign
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return input_shapes[0]
@ -371,9 +369,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def trace_num(self):
num = next(ucount)
# KERNEL also has a UOp in the arg
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), self.arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num
# *** uop syntactic sugar ***
@ -806,6 +802,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop high level syntactic sugar ***
@property
def shard_shape(self):
if self.axis is None: return self.shape
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
@ -814,7 +815,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.shape, self.dtype, slot)
return UOp.placeholder(self.shard_shape, self.dtype, slot)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
@ -827,11 +828,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
# TODO: reenable this after ENCDEC is fixed
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
return [s.after(kernel) for s in contig_srcs]
@dataclass(frozen=True)
@ -845,14 +848,6 @@ class KernelInfo:
@property
def function_name(self): return to_function_name(self.name)
@dataclass(frozen=True)
class CustomKernel:
fxn: Callable
grad_fxn: Callable|None = None
# sadly CustomKernel can't be pickled or reconstructed as a str
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class CallInfo:
grad_fxn: Callable|None = None
@ -861,12 +856,6 @@ class CallInfo:
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
grad_fxn: Callable|None = None
# ******** ops in python ********
def safe_exp2(x):
@ -1326,6 +1315,9 @@ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lowe
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()
ctx[v] = i
@ -1423,7 +1415,6 @@ pm_pyrender_extra = PatternMatcher([
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
])
@ -1433,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast}
for u in lst:
@ -1441,7 +1432,7 @@ def pyrender(ast:UOp) -> str:
for s in u.src: to_render.add(s)
if u.op is Ops.STORE: to_render.add(u.src[1])
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
if u.op in not_rendered: continue
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
@ -1456,11 +1447,6 @@ def pyrender(ast:UOp) -> str:
op_depth = 1 + max([depth[s] for s in u.src], default=0)
if op_depth > 100: to_render.add(u)
depth[u] = 0 if u in to_render else op_depth
# do the rendering
if u.op is Ops.KERNEL:
if u.arg.ast not in kernels:
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
r[u.arg.ast] = kernels[u.arg.ast][0]
ren = cast(str, pm_pyrender.rewrite(u, ctx=r))
assert isinstance(ren, str)
if u.tag is not None: ren += f".rtag({repr(u.tag)})"

View file

@ -1,6 +1,6 @@
import math
from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
@ -77,9 +77,6 @@ movement_ops = PatternMatcher([
# AFTER on Movement Op
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
# custom kernels allowed here
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
])
_tensor_spec = PatternMatcher([
@ -92,7 +89,7 @@ _tensor_spec = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -143,12 +140,19 @@ _tensor_spec = PatternMatcher([
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),
])+movement_ops+shared_spec
tensor_spec = PatternMatcher([
# no tags allowed in tensor graph
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
])+_tensor_spec
# ** for custom kernels **
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
])+movement_ops+shared_spec
# ***** UOp spec in codegen shared between kernel and program *****
@ -208,6 +212,11 @@ kernel_spec = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
])+movement_ops+shared_codegen_spec+shared_spec
tensor_spec = PatternMatcher([
# no tags allowed in tensor graph
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
])+_tensor_spec+kernel_spec
# ***** UOp spec in linearized programs *****
program_spec = PatternMatcher([
@ -239,8 +248,6 @@ full_spec = PatternMatcher([
# rangeify: buffer view with index or load is okay
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
# assign on index. the third op is the shape
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
# expander: unroll/contract/gep/ptrcat/cat
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
@ -254,7 +261,7 @@ full_spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
@ -270,16 +277,6 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),
@ -307,8 +304,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
# late imports to avoid circular import
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.schedule.rangeify import BufferizeOpts
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
"ConstFloat": ConstFloat}
def eval_pyrender(code:str) -> UOp:

View file

@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# VECTORIZE/CONST

View file

@ -25,9 +25,10 @@ z3_renderer = PatternMatcher([
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
# loads are variables bounded by the min/max of the dtype
(UPat(Ops.LOAD, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
(UPat(Ops.LOAD, dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
# constants
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
(UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),

View file

@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
@ -106,9 +106,6 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
if u.op is Ops.KERNEL:
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
@ -130,9 +127,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+str(u.metadata)
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph
@ -140,12 +137,10 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
@functools.cache
def _reconstruct(a:int):
op, dtype, src, arg, *rest = trace.uop_fields[a]
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):