mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into new_x86_backend
This commit is contained in:
commit
6ff67781f1
42 changed files with 2672 additions and 293 deletions
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
|
|
@ -525,6 +525,8 @@ jobs:
|
|||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training steps w 6 GPUS
|
||||
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Test full tinyfs load
|
||||
run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.runtime.support.system import System, PCIDevice, PCIDevImplBase
|
||||
from tinygrad.runtime.support.hcq import FileIOInterface
|
||||
from tinygrad.runtime.support.am.amdev import AMDev
|
||||
|
||||
if __name__ == "__main__":
|
||||
gpus = System.pci_scan_bus(0x1002, [(0xffff, [0x74a1, 0x75a0])])
|
||||
pcidevs = [PCIDevice(f"reset:{gpu}", gpu, bars=[0, 2, 5]) for gpu in gpus]
|
||||
for gpu in gpus:
|
||||
drv_path = f"/sys/bus/pci/devices/{gpu}/driver"
|
||||
if FileIOInterface.exists(drv_path) and os.path.basename(os.readlink(drv_path)) == "amdgpu":
|
||||
raise RuntimeError(f"amdgpu is bound to {gpu}. Stopping...")
|
||||
pcidevs = [PCIDevice("AM", gpu, bars=[0, 2, 5]) for gpu in gpus]
|
||||
amdevs = []
|
||||
with Context(DEBUG=2):
|
||||
for pcidev in pcidevs:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def custom_matmul(output: UOp, inp: UOp, weight: UOp) -> UOp:
|
|||
return store_op.sink(arg=KernelInfo(name=f"fp8_matmul_{inp.shape}x{weight.shape}"))
|
||||
|
||||
def custom_matmul_backward(gradient: UOp, kernel: UOp) -> tuple[UOp, UOp]:
|
||||
_, input_uop, weight_uop = kernel.src
|
||||
_, input_uop, weight_uop = kernel.src[1:]
|
||||
input_tensor = Tensor(input_uop, device=input_uop.device)
|
||||
grad_tensor = Tensor(gradient, device=gradient.device)
|
||||
weight_tensor = Tensor(weight_uop, device=weight_uop.device)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
|||
# ** backward gemm, might use the asm gemm
|
||||
|
||||
def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
assert all_same([gradient.device, a.device, b.device, out.device])
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
grad_a = (g_t @ b_t.T).uop
|
||||
|
|
|
|||
|
|
@ -103,9 +103,9 @@ class Attention:
|
|||
def fa_custom_backward(out_q:UOp, out_k:UOp, out_v:UOp, grad:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
||||
return UOp.sink(arg=KernelInfo(name="fa_custom_backward"))
|
||||
def fa_backward(grad:UOp, kernel:UOp) -> tuple[None, UOp, UOp, UOp]:
|
||||
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[1]))
|
||||
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[2]))
|
||||
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[3]))
|
||||
grad_q = Tensor.empty_like(q:=Tensor(kernel.src[2]))
|
||||
grad_k = Tensor.empty_like(k:=Tensor(kernel.src[3]))
|
||||
grad_v = Tensor.empty_like(v:=Tensor(kernel.src[4]))
|
||||
ck = Tensor.custom_kernel(grad_q, grad_k, grad_v, Tensor(grad), q, k, v, fxn=fa_custom_backward)[:3]
|
||||
return (None, ck[0].uop, ck[1].uop, ck[2].uop)
|
||||
attn = Tensor.empty_like(attn).custom_kernel(xq, keys, values, fxn=fa_custom_forward, grad_fxn=fa_backward)[0]
|
||||
|
|
|
|||
|
|
@ -1,11 +1,36 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
import argparse
|
||||
import argparse, math, hashlib
|
||||
|
||||
def _python_hash_1mb(data:bytes|bytearray):
|
||||
chunks = [data[i:i+4096] for i in range(0, len(data), 4096)]
|
||||
chunk_hashes = [hashlib.shake_128(chunk).digest(16) for chunk in chunks]
|
||||
return hashlib.shake_128(b''.join(chunk_hashes)).digest(16)
|
||||
|
||||
def hash_file(data: bytes|bytearray):
|
||||
if len(data) % Tensor.CHUNK_SIZE != 0: data += bytes(Tensor.CHUNK_SIZE - len(data) % Tensor.CHUNK_SIZE)
|
||||
base_chunks = math.ceil(len(data) / Tensor.CHUNK_SIZE)
|
||||
tree_depth = math.ceil(math.log(base_chunks, Tensor.CHUNK_SIZE // 16))
|
||||
|
||||
for _ in range(tree_depth + 1):
|
||||
data_chunks = [data[i:i+Tensor.CHUNK_SIZE] for i in range(0, len(data), Tensor.CHUNK_SIZE)]
|
||||
data_chunk_hashes = [_python_hash_1mb(chunk) for chunk in data_chunks]
|
||||
data = b''.join(data_chunk_hashes)
|
||||
if len(data) % Tensor.CHUNK_SIZE != 0: data += bytes(Tensor.CHUNK_SIZE - len(data) % Tensor.CHUNK_SIZE)
|
||||
|
||||
return data[:16]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hash", type=str, required=True, help="file hash to fetch")
|
||||
parser.add_argument("--len", type=int, required=True, help="file length to fetch")
|
||||
parser.add_argument("--dest", type=str, required=True, help="destination path to save the file")
|
||||
parser.add_argument("--check", action="store_true", help="verify the file hash after fetching")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor(bytes.fromhex(args.hash), device="CPU").fs_load(args.len).to(f"disk:{args.dest}").realize()
|
||||
|
||||
if args.check:
|
||||
with open(args.dest, "rb") as f:
|
||||
data = f.read()
|
||||
assert hash_file(data) == bytes.fromhex(args.hash), "Hash mismatch after fetching file"
|
||||
print("File hash verified successfully!")
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ testing_minimal = [
|
|||
"pytest-timeout",
|
||||
"pytest-split",
|
||||
"hypothesis>=6.148.9",
|
||||
"z3-solver",
|
||||
"z3-solver<4.15.4", # 4.15.4 has a segfault when creating many z3.Context()
|
||||
]
|
||||
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "ggml-python"]
|
||||
testing = [
|
||||
|
|
|
|||
4
test/external/fuzz_symbolic.py
vendored
4
test/external/fuzz_symbolic.py
vendored
|
|
@ -1,3 +1,7 @@
|
|||
# NOTE: z3-solver 4.15.4 segfaults (exit code 139) when creating many z3.Context() with complex expressions.
|
||||
# Reproduces consistently with seed=74 around iteration 1767. Versions <=4.15.3 are fine.
|
||||
# Workaround: reuse a single z3.Context, or pin z3-solver<4.15.4 (see pyproject.toml).
|
||||
# To repro: pip install z3-solver==4.15.4.0 && python test/external/fuzz_symbolic.py 74
|
||||
import random, operator, sys
|
||||
import z3
|
||||
from tinygrad import Variable, dtypes
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
|
|||
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
|
||||
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
|
||||
def to_str(ret:UOp) -> str:
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
|
|
|
|||
|
|
@ -130,15 +130,18 @@ class NVDriver(VirtDriver):
|
|||
struct.hObjectNew = self._alloc_handle()
|
||||
self.object_by_handle[struct.hObjectNew] = NVContextShare(self.object_by_handle[struct.hObjectParent])
|
||||
elif struct.hClass == nv_gpu.AMPERE_CHANNEL_GPFIFO_A:
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVChannelGroup)
|
||||
parent = self.object_by_handle.get(struct.hObjectParent)
|
||||
assert parent is not None and isinstance(parent, (NVChannelGroup, NVGPU))
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS.from_address(params_ptr)
|
||||
gpu = self.object_by_handle[struct.hObjectParent].device
|
||||
gpu = parent.device if isinstance(parent, NVChannelGroup) else parent
|
||||
gpfifo_token = gpu.add_gpfifo(params.gpFifoOffset, params.gpFifoEntries)
|
||||
self.object_by_handle[struct.hObjectNew] = NVGPFIFO(gpu, gpfifo_token)
|
||||
elif struct.hClass == nv_gpu.AMPERE_DMA_COPY_B or struct.hClass == nv_gpu.ADA_COMPUTE_A:
|
||||
elif struct.hClass in (nv_gpu.AMPERE_DMA_COPY_B, nv_gpu.ADA_COMPUTE_A, nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER):
|
||||
assert struct.hObjectParent in self.object_by_handle and isinstance(self.object_by_handle[struct.hObjectParent], NVGPFIFO)
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
gpfifo = self.object_by_handle[struct.hObjectParent]
|
||||
gpfifo.device.queues[gpfifo.token].bound_engines.add(struct.hClass)
|
||||
elif struct.hClass == nv_gpu.GT200_DEBUGGER:
|
||||
struct.hObjectNew = self._alloc_handle()
|
||||
elif struct.hClass == nv_gpu.MAXWELL_PROFILER_DEVICE:
|
||||
|
|
@ -194,7 +197,7 @@ class NVDriver(VirtDriver):
|
|||
params = nv_gpu.NVC36F_CTRL_CMD_GPFIFO_GET_WORK_SUBMIT_TOKEN_PARAMS.from_address(params_ptr)
|
||||
gpu_fifo = self.object_by_handle[struct.hObject]
|
||||
params.workSubmitToken = gpu_fifo.token
|
||||
elif struct.cmd == nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE: pass
|
||||
elif struct.cmd in (nv_gpu.NVA06C_CTRL_CMD_GPFIFO_SCHEDULE, nv_gpu.NVA06F_CTRL_CMD_BIND, nv_gpu.NVA06F_CTRL_CMD_GPFIFO_SCHEDULE): pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_PERF_BOOST: pass
|
||||
elif struct.cmd == nv_gpu.NV2080_CTRL_CMD_FB_FLUSH_GPU_CACHE: pass
|
||||
elif struct.cmd == nv_gpu.NV83DE_CTRL_CMD_DEBUG_READ_ALL_SM_ERROR_STATES:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class GPFIFO:
|
|||
self.gpfifo = to_mv(self.base, self.entries_cnt * 8).cast("Q")
|
||||
self.ctrl = nv_gpu.AmpereAControlGPFifo.from_address(self.base + self.entries_cnt * 8)
|
||||
self.state = {}
|
||||
self.bound_engines: set[int] = set()
|
||||
|
||||
# Buf exec state
|
||||
self.buf = None
|
||||
|
|
@ -115,9 +116,11 @@ class GPFIFO:
|
|||
def execute_cmd(self, cmd) -> SchedResult:
|
||||
if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
|
||||
elif cmd == nv_gpu.NVC6C0_LAUNCH_DMA: return self._exec_nvc6c0_dma()
|
||||
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: return self._exec_nvc6b5_dma()
|
||||
elif cmd == nv_gpu.NVC6B5_LAUNCH_DMA: # NOTE: NVC6B5_LAUNCH_DMA == NVC9B0_EXECUTE == 0x300
|
||||
return self._exec_vid_decode() if self.bound_engines & {nv_gpu.NVC9B0_VIDEO_DECODER, nv_gpu.NVCFB0_VIDEO_DECODER} else self._exec_nvc6b5_dma()
|
||||
elif cmd == nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B: return self._exec_pcas2()
|
||||
elif cmd == 0x0320: return self._exec_load_inline_qmd() # NVC6C0_LOAD_INLINE_QMD_DATA
|
||||
elif cmd == nv_gpu.NVC9B0_SEMAPHORE_D: return self._exec_vid_semaphore()
|
||||
else: self.state[cmd] = self._next_dword() # just state update
|
||||
return SchedResult.CONT
|
||||
|
||||
|
|
@ -136,6 +139,27 @@ class GPFIFO:
|
|||
else: raise RuntimeError(f"Unsupported type={typ} in exec wait/signal")
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_vid_decode(self) -> SchedResult:
|
||||
self._next_dword() # consume execute flags
|
||||
# validate that all required decode state was set up correctly
|
||||
assert self._state(nv_gpu.NVC9B0_SET_APPLICATION_ID) == nv_gpu.NVC9B0_SET_APPLICATION_ID_ID_HEVC
|
||||
pic_desc_addr = self._state(nv_gpu.NVC9B0_SET_DRV_PIC_SETUP_OFFSET) << 8
|
||||
pic = nv_gpu.nvdec_hevc_pic_s.from_address(pic_desc_addr)
|
||||
assert pic.stream_len > 0 and pic.pic_width_in_luma_samples > 0 and pic.pic_height_in_luma_samples > 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_IN_BUF_BASE_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_COLOC_DATA_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_SET_NVDEC_STATUS_OFFSET) != 0
|
||||
assert self._state(nv_gpu.NVC9B0_HEVC_SET_FILTER_BUFFER_OFFSET) != 0
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_vid_semaphore(self) -> SchedResult:
|
||||
signal = self._state64(nv_gpu.NVC9B0_SEMAPHORE_A)
|
||||
val = self._state(nv_gpu.NVC9B0_SEMAPHORE_C)
|
||||
self._next_dword() # flags
|
||||
to_mv(signal, 8).cast('Q')[0] = val
|
||||
to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
|
||||
return SchedResult.CONT
|
||||
|
||||
def _exec_load_inline_qmd(self):
|
||||
qmd_addr = self._state64(nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A) << 8
|
||||
assert qmd_addr != 0x0, f"invalid qmd address {qmd_addr}"
|
||||
|
|
|
|||
|
|
@ -102,8 +102,8 @@ class TestCompiler(unittest.TestCase):
|
|||
|
||||
class TestRunAsModule(unittest.TestCase):
|
||||
def test_module_runs(self):
|
||||
out = '\n'.join(enumerate_devices_str())
|
||||
self.assertIn("CPU", out) # for sanity check
|
||||
cpu_line = [l for l in enumerate_devices_str() if "CPU" in l][0]
|
||||
self.assertIn("PASS", cpu_line, f"expected CPU to PASS, got: {cpu_line}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -88,13 +88,13 @@ def simple_qkv_kernel(O:UOp, Q:UOp, K:UOp, V:UOp) -> UOp:
|
|||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
|
||||
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
out, a, b = kernel.src[1:]
|
||||
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
|
||||
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
|
@ -128,6 +128,14 @@ class TestCustomKernel(unittest.TestCase):
|
|||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_sharded_add_one(self):
|
||||
# PYTHON backend explicitly checks for OOB access for wrong multi shape regression
|
||||
devs = ("PYTHON:0", "PYTHON:1")
|
||||
a = Tensor.ones(4, 4).contiguous().shard(devs, axis=0)
|
||||
c = Tensor(Tensor.empty(2, 4, device=devs).uop.multi(0), device=devs)
|
||||
c = Tensor.custom_kernel(c, a, fxn=custom_add_one_kernel)[0]
|
||||
assert (c == 2).all().item()
|
||||
|
||||
def test_multioutput(self):
|
||||
a = Tensor.full((16, 16), 3.).contiguous()
|
||||
b = Tensor.full((16, 16), 3.).contiguous()
|
||||
|
|
|
|||
|
|
@ -380,9 +380,40 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
|||
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
||||
class TestEmulatedBFloat16Type(TestBFloat16Type):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="bfloat16"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
|
||||
|
||||
class TestEmulatedFp8e4m3(TestFp8e4m3):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
|
||||
|
||||
class TestEmulatedFp8e5m2(TestFp8e5m2):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import unittest, operator, math
|
||||
from tinygrad import Context, Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, truncate
|
||||
from tinygrad.dtype import DType, truncate, fp8_to_float
|
||||
from tinygrad.helpers import CI, EMULATED_DTYPES, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
|
@ -59,9 +59,10 @@ def universal_test(a, b, dtype, op):
|
|||
# lt and max with nan is undefined in tinygrad
|
||||
if op[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return
|
||||
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
|
||||
tensor_value = (op[0](ta, tb)).numpy()
|
||||
numpy_value = op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.fp8s and op[0] not in (operator.lt, operator.eq):
|
||||
tensor_value = fp8_to_float((op[0](ta.realize(), tb.realize())).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item())
|
||||
else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
|
||||
fe, fm = dtypes.finfo(dtype)
|
||||
|
|
@ -76,13 +77,14 @@ def universal_test_unary(a, dtype, op):
|
|||
# TODO: cos does not match for large input
|
||||
if op[0] == Tensor.cos and abs(a) > 30: return
|
||||
if op[0] == Tensor.log and a <= 0: return
|
||||
out: Tensor = op[0](ta)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s:
|
||||
# normals are zero
|
||||
if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return
|
||||
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
if math.isinf(numpy_value.item()): return
|
||||
numpy_value = truncate[dtype](numpy_value.item())
|
||||
if math.isinf(v): return
|
||||
else: tensor_value, numpy_value = op[0](ta).numpy(), op[1](ta.numpy())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
|
|
@ -128,16 +130,31 @@ class TestDTypeALU(unittest.TestCase):
|
|||
def test_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="bfloat16")
|
||||
def test_emulated_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
def test_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3")
|
||||
def test_emulated_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2")
|
||||
def test_emulated_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
|
|
@ -153,18 +170,34 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="bfloat16")
|
||||
def test_emulated_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
def test_fp8e4m3_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3")
|
||||
def test_emulated_fp8e4m3_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2")
|
||||
def test_emulated_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ class TestProfiler(unittest.TestCase):
|
|||
for ge in graphs:
|
||||
self.assertEqual(len(ge.ents), len(graphs))
|
||||
|
||||
@unittest.skip("this test is flaky")
|
||||
def test_trace_metadata(self):
|
||||
with Context(TRACEMETA=1):
|
||||
a = Tensor.empty(1)+2
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from hypothesis import assume, given, settings, strategies as strat
|
|||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
|
|
@ -677,17 +677,6 @@ class TestSchedule(unittest.TestCase):
|
|||
c = (a.sum(2).contiguous() + b).contiguous()
|
||||
check_schedule(c, 2)
|
||||
|
||||
# TODO: this requires supporting multiple stores in the AST
|
||||
@unittest.expectedFailure
|
||||
def test_multioutput_ast(self):
|
||||
a = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().uop
|
||||
c = Tensor.arange(4).realize().uop
|
||||
kernel = UOp(Ops.KERNEL, src=(a.base, b.base, c.base), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
|
||||
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
|
||||
self.assertEqual(a.buffer.numpy(), [7])
|
||||
self.assertEqual(b.buffer.numpy(), [12])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.gemm.asm.cdna.gemm import asm_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
# On non CDNA4 it will only validate the Tensor.custom_kernel integration
|
||||
# Use NULL=1 EMULATE=AMD_CDNA4 to also test the assembly
|
||||
def is_cdna4(): return getattr(Device[Device.DEFAULT].renderer, "arch", "").startswith("gfx950")
|
||||
|
||||
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
|
||||
|
|
@ -16,36 +20,47 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i
|
|||
|
||||
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
|
||||
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
with Context(ASM_GEMM=1):
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
Tensor.realize(tst, a.grad, b.grad)
|
||||
|
||||
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
|
||||
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
|
||||
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
|
||||
ref.sum().backward()
|
||||
with Context(ASM_GEMM=0):
|
||||
ref = asm_gemm(a_ref, b_ref)
|
||||
ref.sum().backward()
|
||||
Tensor.realize(ref, a_ref.grad, b_ref.grad)
|
||||
|
||||
# no validation on the NULL device
|
||||
if a_rand.device.startswith("NULL"): return None
|
||||
with Context(DEBUG=0):
|
||||
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
|
||||
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
|
||||
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
|
||||
|
||||
SCALE = 128 if CI else 1
|
||||
|
||||
# 128x smaller than usual
|
||||
# uses the UOp GEMM, runs on non CDNA4 and CI
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
class TestGemm(unittest.TestCase):
|
||||
def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE)
|
||||
def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE)
|
||||
def setUp(self):
|
||||
if is_cdna4(): self.skipTest("shapes are too small for the assembly GEMM")
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 32), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 64, 32, 112)
|
||||
def test_gemm_batched(self): verify_asm_gemm(2, 64, 32, 32)
|
||||
@needs_second_gpu
|
||||
def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2)
|
||||
def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2)
|
||||
|
||||
# uses the Asm GEMM on CDNA4 only for speed reasons
|
||||
class TestGemmLarge(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950":
|
||||
if not is_cdna4():
|
||||
self.skipTest("very slow on non mi350x")
|
||||
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
|
||||
def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096)
|
||||
|
||||
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
|
||||
@unittest.skip("disabled, asm in this shape is slower than tinygrad")
|
||||
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import fetch, round_up
|
||||
from extra.hevc.hevc import parse_hevc_file_headers, nv_gpu
|
||||
from extra.hevc.decode import hevc_decode
|
||||
|
||||
class TestHevc(unittest.TestCase):
|
||||
def test_hevc_parser(self):
|
||||
|
|
@ -61,5 +62,26 @@ class TestHevc(unittest.TestCase):
|
|||
self.assertEqual(list(frame3.initreflistidxl0), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
self.assertEqual(list(frame3.initreflistidxl1), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
self.assertEqual(list(frame3.RefDiffPicOrderCnts), [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "NV only")
|
||||
def test_hevc_decode(self):
|
||||
url = "https://github.com/haraschax/filedump/raw/09a497959f7fa6fd8dba501a25f2cdb3a41ecb12/comma_video.hevc"
|
||||
dat = fetch(url, headers={"Range": f"bytes=0-{512<<10}"}).read_bytes()
|
||||
|
||||
opaque, frame_info, w, h, luma_w, luma_h, chroma_off = parse_hevc_file_headers(dat)
|
||||
frame_info = frame_info[:4]
|
||||
out_image_size = luma_h + (luma_h + 1) // 2, round_up(luma_w, 64)
|
||||
|
||||
hevc_tensor = Tensor(dat, device="NV")
|
||||
opaque_nv = opaque.to("NV").contiguous().realize()
|
||||
|
||||
frames = list(hevc_decode(hevc_tensor, opaque_nv, frame_info, luma_h, luma_w))
|
||||
Device.default.synchronize()
|
||||
self.assertEqual(len(frames), 4)
|
||||
for f in frames:
|
||||
self.assertEqual(f.shape, out_image_size)
|
||||
self.assertEqual(f.dtype, dtypes.uint8)
|
||||
self.assertEqual(f.device, "NV")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
from typing import cast
|
||||
from dataclasses import replace
|
||||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, TracingKey, Context
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, promo_lattice
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import panic
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns, get_transcendental_patterns
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
|
|
@ -24,7 +25,7 @@ from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_c
|
|||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(sink))
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
|
|
@ -88,10 +89,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
# decompositions
|
||||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
|
||||
pm_unsupported = get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes)))
|
||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
|
||||
sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True)
|
||||
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
||||
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
||||
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
|
||||
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
|
||||
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
|
||||
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
|
|
@ -108,7 +112,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
# inject IF/ENDIF. only needed if device doesn't support gated stores
|
||||
pm_linearize_cleanups = PatternMatcher([
|
||||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
|
|
|
|||
|
|
@ -406,9 +406,9 @@ def enumerate_devices_str() -> Generator[str, None, None]:
|
|||
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
|
||||
set_text = f'({cc_ctrl_var.key}={d._compiler_name(r, c)} to make default)' if cc_ctrl_var is not None else ''
|
||||
default_text = '(default)' if type(default_compiler) is type(d.compiler) else set_text
|
||||
compilers_results.append(f"{colored('+', 'green')} {unwrap_class_type(c).__name__} {default_text}")
|
||||
compilers_results.append(f"{colored('+', 'green')} {d._compiler_name(r, c)} {default_text}")
|
||||
any_works = True
|
||||
except Exception as e: compilers_results.append(f"{colored('-', 'yellow')} {unwrap_class_type(c).__name__}: {e}")
|
||||
except Exception as e: compilers_results.append(f"{colored('-', 'yellow')} {d._compiler_name(r, c)}: {e}")
|
||||
finally:
|
||||
# put the defaults back!
|
||||
d.comp_sets, d.comps_ctrl_var = default_comp_pairs, cc_ctrl_var
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
|
|
@ -22,14 +22,14 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
for u in sched_sink.toposort(gate_kernel_sink):
|
||||
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
|
||||
if u.op is not Ops.AFTER: continue
|
||||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]:
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
|
|
@ -54,13 +54,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
||||
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges)
|
||||
elif k.op is Ops.CALL:
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
|
||||
schedule.append(k)
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
for x in children.get(rk, []):
|
||||
|
|
@ -86,7 +86,7 @@ def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem])
|
|||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
assert si.op is Ops.KERNEL, f"unexpected op in schedule: {si.op}"
|
||||
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
|
||||
ast, buf_uops, metadata, bound_ranges = sched_item[si]
|
||||
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# gradient on CALL: use provided grad_fxn or auto-differentiate
|
||||
(UPat(Ops.CALL, name="k"), call_gradient),
|
||||
# there's no gradient for bitcast
|
||||
|
|
|
|||
|
|
@ -86,7 +86,9 @@ def word_wrap(x, wrap=80):
|
|||
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
|
||||
return x[:i] + "\n" + word_wrap(x[i:], wrap)
|
||||
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
|
||||
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
|
||||
|
||||
# NOTE: you must create the exception inside the function where it's raised or you will get a GC cycle!
|
||||
def panic(e:type[Exception]|None=None, *arg): raise e(*arg) if e is not None else RuntimeError("PANIC!")
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def __getattr__(nm):
|
|||
], args=[
|
||||
"-include", "{}/src/common/sdk/nvidia/inc/nvtypes.h", "-I{}/src/common/inc", "-I{}/kernel-open/nvidia-uvm", "-I{}/kernel-open/common/inc",
|
||||
"-I{}/src/common/sdk/nvidia/inc", "-I{}/src/nvidia/arch/nvalloc/unix/include", "-I{}/src/common/sdk/nvidia/inc/ctrl"
|
||||
], rules=[(r'MW\(([^:]+):(.+)\)',r'(\1, \2)')], tarball=nv_src[nm], anon_names={"{}/kernel-open/common/inc/nvstatus.h:37":"nv_status_codes"})
|
||||
], rules=[(r'MW\(([^:]+):(.+)\)',r'(\1, \2)'), (r'(\d+):(\d+)', r'(\1, \2)')], tarball=nv_src[nm], anon_names={"{}/kernel-open/common/inc/nvstatus.h:37":"nv_status_codes"})
|
||||
case "nv": return load("nv", None, [
|
||||
*[f"{{}}/src/nvidia/inc/kernel/gpu/{s}.h" for s in ["fsp/kern_fsp_cot_payload", "gsp/gsp_init_args"]],
|
||||
*[f"{{}}/src/nvidia/arch/nvalloc/common/inc/{s}.h" for s in ["gsp/gspifpub", "gsp/gsp_fw_wpr_meta", "gsp/gsp_fw_sr_meta", "rmRiscvUcode",
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -703,13 +703,15 @@ class KFDIface:
|
|||
# Event to wait for queues completion
|
||||
self.dev.queue_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_SIGNAL, auto_reset=1)
|
||||
self.dev.queue_event_mailbox_ptr = KFDIface.event_page.va_addr + self.dev.queue_event.event_slot_index * 8
|
||||
self.queue_event_arr = (kfd.struct_kfd_event_data)(event_id=self.dev.queue_event.event_id)
|
||||
self.queue_event_arr_ptr = ctypes.addressof(self.queue_event_arr)
|
||||
|
||||
# OS events to collect memory and hardware faults
|
||||
self.mem_fault_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_MEMORY)
|
||||
self.hw_fault_event = kfd.AMDKFD_IOC_CREATE_EVENT(KFDIface.kfd, event_type=kfd.KFD_IOC_EVENT_HW_EXCEPTION)
|
||||
|
||||
self.queue_event_arr = (kfd.struct_kfd_event_data * 3)(kfd.struct_kfd_event_data(event_id=self.dev.queue_event.event_id),
|
||||
kfd.struct_kfd_event_data(event_id=self.mem_fault_event.event_id), kfd.struct_kfd_event_data(event_id=self.hw_fault_event.event_id))
|
||||
self.queue_event_arr_ptr = ctypes.addressof(self.queue_event_arr)
|
||||
|
||||
def alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, cpu_addr=None) -> HCQBuffer:
|
||||
flags = kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
|
||||
|
||||
|
|
@ -778,18 +780,20 @@ class KFDIface:
|
|||
doorbell=MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q'))
|
||||
|
||||
def sleep(self, tm:int):
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=3, wait_for_all=0, timeout=tm)
|
||||
if self.queue_event_arr[1].memory_exception_data.gpu_id or self.queue_event_arr[2].hw_exception_data.gpu_id: raise RuntimeError("Device fault")
|
||||
|
||||
def on_device_hang(self):
|
||||
def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._real_fields_)
|
||||
def _str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._real_fields_)
|
||||
|
||||
# try to collect fault info if not already set from sleep().
|
||||
if not self.queue_event_arr[1].memory_exception_data.gpu_id and not self.queue_event_arr[2].hw_exception_data.gpu_id:
|
||||
with contextlib.suppress(RuntimeError): self.sleep(tm=1)
|
||||
|
||||
report = []
|
||||
for evnt in [self.mem_fault_event, self.hw_fault_event]:
|
||||
ev = (kfd.struct_kfd_event_data)(event_id=evnt.event_id)
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=ctypes.addressof(ev), num_events=1, wait_for_all=1)
|
||||
if evnt == self.mem_fault_event and ev.memory_exception_data.gpu_id:
|
||||
report += [f"MMU fault: 0x{ev.memory_exception_data.va:X} | {_collect_str(ev.memory_exception_data.failure)}"]
|
||||
if evnt == self.hw_fault_event and ev.hw_exception_data.gpu_id: report += [f"HW fault: {_collect_str(ev.hw_exception_data)}"]
|
||||
if self.queue_event_arr[1].memory_exception_data.gpu_id:
|
||||
report += [f"MMU fault: 0x{self.queue_event_arr[1].memory_exception_data.va:X} | {_str(self.queue_event_arr[1].memory_exception_data.failure)}"]
|
||||
if self.queue_event_arr[2].hw_exception_data.gpu_id: report += [f"HW fault: {_str(self.queue_event_arr[2].hw_exception_data)}"]
|
||||
|
||||
raise RuntimeError("\n".join(report))
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,9 @@ def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status
|
|||
NV_PFAULT_FAULT_TYPE = {dt:name for name,dt in nv_gpu.__dict__.items() if name.startswith("NV_PFAULT_FAULT_TYPE_")}
|
||||
NV_PFAULT_ACCESS_TYPE = {dt:name.split("_")[-1] for name,dt in nv_gpu.__dict__.items() if name.startswith("NV_PFAULT_ACCESS_TYPE_")}
|
||||
|
||||
def nv_flags(reg, **kwargs): return functools.reduce(int.__or__, ((getattr(nv_gpu, f"{reg}_{k}_{v}".upper()) if isinstance(v, str) else v) <<
|
||||
getattr(nv_gpu, f"{reg}_{k}".upper())[1] for k, v in kwargs.items()), 0)
|
||||
|
||||
def nv_iowr(fd:FileIOInterface, nr, args, cmd=None):
|
||||
ret = fd.ioctl(cmd or ((3 << 30) | (ctypes.sizeof(args) & 0x1FFF) << 16 | (ord('F') & 0xFF) << 8 | (nr & 0xFF)), args)
|
||||
if ret != 0: raise RuntimeError(f"ioctl returned {ret}")
|
||||
|
|
@ -94,7 +97,8 @@ class NVCommandQueue(HWQueue[HCQSignal, 'NVDevice', 'NVProgram', 'NVArgsState'])
|
|||
return self
|
||||
|
||||
def wait(self, signal:HCQSignal, value:sint=0):
|
||||
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value), (3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
|
||||
nv_flags("NVC56F_SEM_EXECUTE", operation="acq_circ_geq", payload_size="64bit"))
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
|
|
@ -125,7 +129,8 @@ class NVCommandQueue(HWQueue[HCQSignal, 'NVDevice', 'NVProgram', 'NVArgsState'])
|
|||
|
||||
class NVComputeQueue(NVCommandQueue):
|
||||
def memory_barrier(self):
|
||||
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, (1 << 12) | (1 << 4) | (1 << 0))
|
||||
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI,
|
||||
nv_flags("NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI", instruction="true", global_data="true", constant="true"))
|
||||
self.active_qmd:QMD|None = None
|
||||
return self
|
||||
|
||||
|
|
@ -169,7 +174,7 @@ class NVComputeQueue(NVCommandQueue):
|
|||
return self
|
||||
|
||||
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
nv_flags("NVC56F_SEM_EXECUTE", operation="release", release_wfi="en", payload_size="64bit", release_timestamp="en"))
|
||||
self.nvm(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 0x0)
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
|
@ -185,12 +190,13 @@ class NVCopyQueue(NVCommandQueue):
|
|||
for off in range(0, copy_size, step:=(1 << 31)):
|
||||
self.nvm(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, *data64(src+off), *data64(dest+off))
|
||||
self.nvm(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, min(copy_size-off, step))
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA,
|
||||
nv_flags("NVC6B5_LAUNCH_DMA", data_transfer_type="non_pipelined", src_memory_layout="pitch", dst_memory_layout="pitch"))
|
||||
return self
|
||||
|
||||
def signal(self, signal:HCQSignal, value:sint=0):
|
||||
self.nvm(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, *data64(signal.value_addr), value)
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x14)
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, nv_flags("NVC6B5_LAUNCH_DMA", flush_enable="true", semaphore_type="release_four_word_semaphore"))
|
||||
return self
|
||||
|
||||
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
|
||||
|
|
@ -199,7 +205,8 @@ class NVVideoQueue(NVCommandQueue):
|
|||
def decode_hevc_chunk(self, pic_desc:HCQBuffer, in_buf:HCQBuffer, out_buf:HCQBuffer, out_buf_pos:int, hist_bufs:list[HCQBuffer], hist_pos:list[int],
|
||||
chroma_off:int, coloc_buf:HCQBuffer, filter_buf:HCQBuffer, intra_top_off:int, intra_unk_off:int|None, status_buf:HCQBuffer):
|
||||
self.nvm(4, nv_gpu.NVC9B0_SET_APPLICATION_ID, nv_gpu.NVC9B0_SET_APPLICATION_ID_ID_HEVC)
|
||||
self.nvm(4, nv_gpu.NVC9B0_SET_CONTROL_PARAMS, 0x52057)
|
||||
self.nvm(4, nv_gpu.NVC9B0_SET_CONTROL_PARAMS, nv_flags("NVC9B0_SET_CONTROL_PARAMS", codec_type="hevc", testrun_env="prod_run", gptimer_on=1,
|
||||
err_conceal_on=1, mbtimer_on=1, event_trace_logging_on=1))
|
||||
self.nvm(4, nv_gpu.NVC9B0_SET_DRV_PIC_SETUP_OFFSET, pic_desc.va_addr >> 8)
|
||||
self.nvm(4, nv_gpu.NVC9B0_SET_IN_BUF_BASE_OFFSET, in_buf.va_addr >> 8)
|
||||
for pos, buf in zip(hist_pos + [out_buf_pos], hist_bufs + [out_buf]):
|
||||
|
|
@ -216,7 +223,7 @@ class NVVideoQueue(NVCommandQueue):
|
|||
|
||||
def signal(self, signal:HCQSignal, value:sint=0):
|
||||
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_A, *data64(signal.value_addr), value)
|
||||
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_D, (1 << 24) | (1 << 0))
|
||||
self.nvm(4, nv_gpu.NVC9B0_SEMAPHORE_D, nv_flags("NVC9B0_SEMAPHORE_D", structure_size="four", payload_size="64bit"))
|
||||
return self
|
||||
|
||||
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.vid_gpfifo)
|
||||
|
|
|
|||
|
|
@ -264,7 +264,8 @@ class AM_GFX(AM_IP):
|
|||
self.adev.regGRBM_CNTL.update(read_timeout=0xff, inst=xcc)
|
||||
for i in range(0, 16):
|
||||
self._grbm_select(vmid=i, inst=xcc)
|
||||
self.adev.regSH_MEM_CONFIG.write(**({'initial_inst_prefetch':3} if self.adev.ip_ver[am.GC_HWIP][0]>=10 else {'retry_disable':1, 'f8_mode':1}),
|
||||
self.adev.regSH_MEM_CONFIG.write(**({'initial_inst_prefetch':3} if self.adev.ip_ver[am.GC_HWIP][0]>=10 else {'retry_disable':1}),
|
||||
**({'f8_mode':1} if self.adev.ip_ver[am.GC_HWIP][:2]==(9,4) else {}),
|
||||
address_mode=self.adev.soc.module.SH_MEM_ADDRESS_MODE_64, alignment_mode=self.adev.soc.module.SH_MEM_ALIGNMENT_MODE_UNALIGNED, inst=xcc)
|
||||
|
||||
# Configure apertures:
|
||||
|
|
|
|||
|
|
@ -257,11 +257,10 @@ class HCQSignal(Generic[HCQDeviceType]):
|
|||
value: The value to wait for.
|
||||
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
|
||||
"""
|
||||
start_time = last_sleep_time = int(time.perf_counter() * 1000)
|
||||
start_time = int(time.perf_counter() * 1000)
|
||||
while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout:
|
||||
self._sleep(cur_time - last_sleep_time)
|
||||
last_sleep_time = int(time.perf_counter() * 1000)
|
||||
if self.value != prev_value: start_time = last_sleep_time # progress was made, reset timer
|
||||
self._sleep(cur_time - start_time)
|
||||
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
|
||||
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ import functools, itertools
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
for s in rb.src:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
|
|
@ -91,20 +91,25 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
|||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||
if assign.src[1].op is Ops.KERNEL: return None
|
||||
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
||||
ret = assign.replace(src=assign.src+(to_mop,))
|
||||
ctx.range_map[ret] = ctx.range_map[assign]
|
||||
return ret
|
||||
def handle_assign_mops(ctx:IndexingContext, assign:UOp, target:UOp, src:UOp):
|
||||
if target.op in GroupOp.Movement and src.op is not Ops.CALL:
|
||||
mops = []
|
||||
while target.op in GroupOp.Movement:
|
||||
mops.append((target.op, target.marg))
|
||||
target = target.src[0]
|
||||
if mops and assign in ctx.range_map:
|
||||
ret = assign.replace(arg=tuple(mops))
|
||||
ctx.range_map[ret] = ctx.range_map[assign]
|
||||
return ret
|
||||
return None
|
||||
|
||||
pm_apply_rangeify = PatternMatcher([
|
||||
# REDUCE_AXIS -> REDUCE
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
||||
# PAD -> WHERE
|
||||
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
||||
# add third op to assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat()), name="assign"), add_third_op_to_assign_to_track_shape),
|
||||
# store movement ops in ASSIGN arg
|
||||
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(name="src")), name="assign"), handle_assign_mops),
|
||||
# finally, apply_rangeify
|
||||
(UPat(GroupOp.All, name="x"), create_bufferize_and_index_based_on_ranges),
|
||||
# remove movement op
|
||||
|
|
@ -159,7 +164,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
|
||||
# get the consumer map
|
||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort())
|
||||
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort(gate_kernel_sink))
|
||||
|
||||
# explicit rangeify
|
||||
ending_ranges: dict[UOp, list[UOp]] = {}
|
||||
|
|
@ -167,7 +172,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op is Ops.KERNEL: continue
|
||||
if x.op is Ops.CALL: continue
|
||||
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
|
|
@ -214,17 +215,18 @@ multi_pm = PatternMatcher([
|
|||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
|
||||
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
||||
# after CALL
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
|
||||
if getenv("VIZ"): graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
|
@ -33,13 +33,17 @@ pm_mops = PatternMatcher([
|
|||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp):
|
||||
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
|
||||
if (t := target.base).op is Ops.BUFFER or (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src)): return None
|
||||
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
|
||||
|
||||
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if dest.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return
|
||||
for h in hazards:
|
||||
if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
|
||||
return assign.replace(src=(dest, src.contiguous()))
|
||||
if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
|
||||
return assign.replace(src=(target, src.contiguous()))
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
if prod(reduce.shape) == 0: return None
|
||||
|
|
@ -70,10 +74,6 @@ mop_cleanup = PatternMatcher([
|
|||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
])
|
||||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
|
|
@ -95,9 +95,6 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
|
|
@ -137,15 +134,13 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
|
||||
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
|
||||
lambda target, src, assign: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
|
||||
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
|
||||
|
||||
# assign only to buffer, otherwise make it a CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
|
||||
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
|
||||
not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
|
@ -334,19 +329,13 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
|
||||
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
||||
if x.src[0].op is Ops.ASSIGN:
|
||||
assign_target, assign_src, assign_mops = x.src[0].src
|
||||
if (assign := x.src[0]).op is Ops.ASSIGN:
|
||||
assign_target, assign_src = assign.src[0], assign.src[1]
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
# TODO: assign_mops here
|
||||
do_store = assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*rngs)
|
||||
ret = assign_target.src[0].after(do_store)
|
||||
mops = []
|
||||
walk = assign_mops
|
||||
while walk is not assign_mops.base:
|
||||
mops.append((walk.op, walk.marg))
|
||||
walk = walk.src[0]
|
||||
for m in mops[::-1]: ret = ret._mop(*m)
|
||||
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
|
||||
return ret
|
||||
|
||||
# lower outerworld reduce here
|
||||
|
|
@ -393,7 +382,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
||||
|
||||
# remove any RESHAPEs on KERNEL
|
||||
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
|
@ -525,10 +514,9 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel_arg = Kernel(ret, metadata)
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||
return kernel
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
|
|
@ -543,9 +531,9 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
|||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = PatternMatcher([
|
||||
add_tags = pm_gate_kernel_sink+PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
|
@ -566,7 +554,7 @@ replace_contiguous = PatternMatcher([
|
|||
])
|
||||
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
|
|
@ -584,12 +572,13 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.BUFFER, Ops.AFTER} and \
|
||||
x.tag is not None and len(x.tag)])
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
|
||||
name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
|
|
@ -609,7 +598,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for tag, s in zip(sink_tags, tsink.src):
|
||||
|
|
|
|||
|
|
@ -234,8 +234,7 @@ class Tensor(OpMixin):
|
|||
|
||||
def as_param(self, slot:int):
|
||||
if self.uop.axis is not None:
|
||||
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
|
||||
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device).multi(self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
|
|
@ -756,8 +755,7 @@ class Tensor(OpMixin):
|
|||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
sharded_shape = tuple(s//len(self.device) if a==self.uop.axis else s for a,s in enumerate(self.shape))
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
|
||||
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
|
||||
|
||||
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
|
||||
|
|
@ -3786,7 +3784,7 @@ class Tensor(OpMixin):
|
|||
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
||||
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
|
||||
V = V.gather(-1, new_indices).realize()
|
||||
V = V.gather(-1, new_indices)
|
||||
|
||||
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
||||
padded_u[..., 0:num, 0:num] = U
|
||||
|
|
|
|||
|
|
@ -84,8 +84,7 @@ class Ops(FastEnum):
|
|||
# ** 6 -- ops that don't exist in programs **
|
||||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto()
|
||||
CUSTOM_KERNEL = auto()
|
||||
UNIQUE = auto(); DEVICE = auto(); ASSIGN = auto()
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
|||
|
||||
# *** helper functions for bit manipulation ***
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
||||
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1
|
||||
def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1
|
||||
|
||||
# **** utils ****
|
||||
def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
|
||||
|
|
@ -378,33 +378,46 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
|||
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
|
||||
|
||||
# ***** floats *****
|
||||
f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint }
|
||||
f2f_dt = { f:getattr(dtypes, f"uint{f.bitsize}") for f in dtypes.floats }
|
||||
|
||||
def rne(v: UOp, s) -> UOp: return shr(v, s) + ((shr(v, s - 1) & 1) & ((v & ((1 << (s - 1)) - 1)).ne(0).cast(v.dtype) | (shr(v, s) & 1)))
|
||||
|
||||
def f2f(v, fr:DType, to:DType):
|
||||
fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to)
|
||||
# NB: denormals are zero!
|
||||
if fe < te and fm < tm:
|
||||
if fe <= te and fm < tm:
|
||||
sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to])
|
||||
exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm)
|
||||
inf_or_nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
|
||||
return (sign | exp.eq(0).where(0, exp.eq(shl(1, fe) - 1).where(inf_or_nan, norm))).bitcast(to)
|
||||
elif fe > te and fm > tm:
|
||||
sign, nosign, exp = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1), shr(v, fm) & (shl(1, fe) - 1)
|
||||
nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
|
||||
# fp8e4m3 has only one nan
|
||||
is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1))
|
||||
return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to)
|
||||
elif fe >= te and fm > tm:
|
||||
v = f2f_clamp(v.bitcast(fr), to).bitcast(f2f_dt[fr])
|
||||
sign, nosign = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1)
|
||||
norm = (rne(nosign, fm - tm) - shl(fb - tb, tm)).cast(f2f_dt[to])
|
||||
infnan = (sign | (shr(nosign, fm - tm) & (shl(1, tm) - 1)) | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
|
||||
underflow, overflow = exp < (1 + fb - tb), exp > (shl(1, te) - 2 + (fb - tb))
|
||||
return exp.eq(shl(1, fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | underflow.where(0, overflow.where(shl(shl(1, te) - 1, tm), norm)))
|
||||
underflow = (shr(v, fm) & (shl(1, fe) - 1)) < (1 + fb - tb)
|
||||
nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1))
|
||||
nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
|
||||
is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1)
|
||||
return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm))
|
||||
else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}")
|
||||
|
||||
def f2f_load(x: UOp) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=dtypes.ushort), dtypes.half, dtypes.float)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=dtypes.ushort, src=(reindex(x.src[0].src[0], i, 1),)), dtypes.half, dtypes.float) for i in range(n)))
|
||||
def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
||||
e, m = dtypes.finfo(dt)
|
||||
max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
|
||||
mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m)))
|
||||
sat = mx if dt in dtypes.fp8s else val.const_like(float('inf'))
|
||||
# FIXME: CMPLT of nan is undefined
|
||||
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
|
||||
|
||||
def f2f_store(st, idx, val):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half)))
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n)))
|
||||
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
|
||||
def f2f_store(st, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n)))
|
||||
|
||||
# ***** decomposition patterns *****
|
||||
|
||||
|
|
@ -463,40 +476,44 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, disable_fast_idiv
|
|||
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
@functools.cache
|
||||
def get_unsupported_dtypes_patterns(device:str, emulated_dtypes:tuple[DType, ...]) -> PatternMatcher:
|
||||
pat: list[tuple[UPat, Callable]] = []
|
||||
if not is_dtype_supported(dtypes.long, device) or dtypes.long in emulated_dtypes:
|
||||
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None)]
|
||||
pat += [(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype]))]
|
||||
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None)]
|
||||
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
|
||||
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None)]
|
||||
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag])]
|
||||
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None)]
|
||||
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
|
||||
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
|
||||
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)))]
|
||||
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))]
|
||||
if dtypes.half in emulated_dtypes:
|
||||
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)]
|
||||
pat += [(UPat(Ops.LOAD, dtypes.half, name="x"), f2f_load)]
|
||||
pat += [(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, dtypes.half, name="ld"),), name="bc"), lambda bc,ld:
|
||||
ld.replace(dtype=dtypes.ushort).bitcast(bc.dtype))]
|
||||
pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x:
|
||||
bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))]
|
||||
pat += [(UPat(GroupOp.All, dtypes.half, name="x"), lambda x:
|
||||
x.replace(dtype=dtypes.float.vec(x.dtype.count), src=tuple(s.cast(dtypes.float) if s.dtype == dtypes.half else s for s in x.src)))]
|
||||
pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val:
|
||||
f2f_store(st, idx, val) if (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == dtypes.half else None)]
|
||||
return PatternMatcher(pat)
|
||||
pm_long_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
|
||||
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))),
|
||||
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None),
|
||||
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag]),
|
||||
(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None),
|
||||
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
|
||||
])
|
||||
|
||||
# float decomposition patterns - ctx is (fr, to) tuple
|
||||
pm_float_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
|
||||
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld:
|
||||
ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x:
|
||||
bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None),
|
||||
(UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val:
|
||||
f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
|
||||
x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src))
|
||||
if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val:
|
||||
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
|
||||
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
return None
|
||||
|
||||
|
|
@ -236,9 +236,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.KERNEL: return self.arg.ast._shape
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
ps = self.src[0]._shape
|
||||
|
|
@ -288,10 +285,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
raise ValueError(f"invalid type for axis: {axis_arg}")
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
if self.op is Ops.ASSIGN: return self.src[1]._shape
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
# TODO: remove this hack for 3 op assign
|
||||
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
return input_shapes[0]
|
||||
|
|
@ -371,9 +369,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@recursive_property
|
||||
def trace_num(self):
|
||||
num = next(ucount)
|
||||
# KERNEL also has a UOp in the arg
|
||||
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), self.arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||
return num
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
|
@ -806,6 +802,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@property
|
||||
def shard_shape(self):
|
||||
if self.axis is None: return self.shape
|
||||
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
|
||||
|
||||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
|
|
@ -814,7 +815,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
return UOp.placeholder(self.shape, self.dtype, slot)
|
||||
return UOp.placeholder(self.shard_shape, self.dtype, slot)
|
||||
|
||||
# set is store+end+after
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||
|
|
@ -827,11 +828,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
|
||||
# TODO: reenable this after ENCDEC is fixed
|
||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
kernel = fxn(*placeholders).call(*contig_srcs, grad_fxn=grad_fxn)
|
||||
return [s.after(kernel) for s in contig_srcs]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -845,14 +848,6 @@ class KernelInfo:
|
|||
@property
|
||||
def function_name(self): return to_function_name(self.name)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomKernel:
|
||||
fxn: Callable
|
||||
grad_fxn: Callable|None = None
|
||||
# sadly CustomKernel can't be pickled or reconstructed as a str
|
||||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
|
|
@ -861,12 +856,6 @@ class CallInfo:
|
|||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
grad_fxn: Callable|None = None
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
def safe_exp2(x):
|
||||
|
|
@ -1326,6 +1315,9 @@ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lowe
|
|||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
|
||||
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
ctx[v] = i
|
||||
|
|
@ -1423,7 +1415,6 @@ pm_pyrender_extra = PatternMatcher([
|
|||
|
||||
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
||||
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
|
||||
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
|
||||
])
|
||||
|
||||
|
|
@ -1433,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
|
|||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
|
|
@ -1441,7 +1432,7 @@ def pyrender(ast:UOp) -> str:
|
|||
for s in u.src: to_render.add(s)
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
|
||||
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
|
||||
if u.op is Ops.CALL: raise NotImplementedError("call can't be pyrendered")
|
||||
if u.op in not_rendered: continue
|
||||
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
|
||||
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
|
||||
|
|
@ -1456,11 +1447,6 @@ def pyrender(ast:UOp) -> str:
|
|||
op_depth = 1 + max([depth[s] for s in u.src], default=0)
|
||||
if op_depth > 100: to_render.add(u)
|
||||
depth[u] = 0 if u in to_render else op_depth
|
||||
# do the rendering
|
||||
if u.op is Ops.KERNEL:
|
||||
if u.arg.ast not in kernels:
|
||||
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
|
||||
r[u.arg.ast] = kernels[u.arg.ast][0]
|
||||
ren = cast(str, pm_pyrender.rewrite(u, ctx=r))
|
||||
assert isinstance(ren, str)
|
||||
if u.tag is not None: ren += f".rtag({repr(u.tag)})"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from typing import cast, Any
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
|
||||
|
|
@ -77,9 +77,6 @@ movement_ops = PatternMatcher([
|
|||
|
||||
# AFTER on Movement Op
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
|
||||
|
||||
# custom kernels allowed here
|
||||
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
|
|
@ -92,7 +89,7 @@ _tensor_spec = PatternMatcher([
|
|||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
(UPat(Ops.CALL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
|
@ -143,12 +140,19 @@ _tensor_spec = PatternMatcher([
|
|||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
# no tags allowed in tensor graph
|
||||
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
|
||||
])+_tensor_spec
|
||||
# ** for custom kernels **
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
# ***** UOp spec in codegen shared between kernel and program *****
|
||||
|
||||
|
|
@ -208,6 +212,11 @@ kernel_spec = PatternMatcher([
|
|||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
|
||||
])+movement_ops+shared_codegen_spec+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
# no tags allowed in tensor graph
|
||||
(UPat(GroupOp.All, name="x"), lambda x: None if x.tag is None else False),
|
||||
])+_tensor_spec+kernel_spec
|
||||
|
||||
# ***** UOp spec in linearized programs *****
|
||||
|
||||
program_spec = PatternMatcher([
|
||||
|
|
@ -239,8 +248,6 @@ full_spec = PatternMatcher([
|
|||
|
||||
# rangeify: buffer view with index or load is okay
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),)), lambda: True),
|
||||
# assign on index. the third op is the shape
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
|
||||
# expander: unroll/contract/gep/ptrcat/cat
|
||||
(UPat((Ops.UNROLL, Ops.CONTRACT), src=(UPat(),)), lambda: True),
|
||||
|
|
@ -254,7 +261,7 @@ full_spec = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
|
|
@ -270,16 +277,6 @@ full_spec = PatternMatcher([
|
|||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
|
@ -307,8 +304,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
|
|||
# late imports to avoid circular import
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
|
||||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
|
||||
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic,
|
||||
"ConstFloat": ConstFloat}
|
||||
def eval_pyrender(code:str) -> UOp:
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# VECTORIZE/CONST
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ z3_renderer = PatternMatcher([
|
|||
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
# loads are variables bounded by the min/max of the dtype
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
(UPat(Ops.LOAD, dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx:
|
||||
create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0]), None)),
|
||||
# constants
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x,ctx: (z3.Int("Invalid", ctx=ctx[0]), None)),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0]), None)),
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
|||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
|
||||
|
|
@ -106,9 +106,6 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
if u.op is Ops.KERNEL:
|
||||
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
|
||||
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
|
||||
if u.op is Ops.BINARY: argst = f"<{len(u.arg)} bytes>"
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
|
|
@ -130,9 +127,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+str(u.metadata)
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
||||
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
|
||||
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
|
||||
return graph
|
||||
|
|
@ -140,12 +137,10 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
@functools.cache
|
||||
def _reconstruct(a:int):
|
||||
op, dtype, src, arg, *rest = trace.uop_fields[a]
|
||||
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
|
||||
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
|
||||
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue