mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge remote-tracking branch 'upstream/master' into new_x86_backend
This commit is contained in:
commit
b5db91bfdf
95 changed files with 1649 additions and 1764 deletions
26
.github/workflows/benchmark.yml
vendored
26
.github/workflows/benchmark.yml
vendored
|
|
@ -21,6 +21,9 @@ jobs:
|
|||
# the 3 minute timeout should not be raised
|
||||
testmacpytest:
|
||||
name: Mac pytest
|
||||
env:
|
||||
CI: ""
|
||||
CAPTURE_PROCESS_REPLAY: "0"
|
||||
runs-on: [self-hosted, macOS]
|
||||
timeout-minutes: 3
|
||||
defaults:
|
||||
|
|
@ -41,22 +44,14 @@ jobs:
|
|||
run: |
|
||||
echo "CACHEDB=/tmp/pytest-db-ci.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/pytest-db-ci*
|
||||
# TODO: remove this step once all old caches are migrated
|
||||
- name: Migrate old huggingface cache (symlinks break onnxruntime 1.24+)
|
||||
run: |
|
||||
cd ~/Library/Caches/tinygrad/downloads/models 2>/dev/null || exit 0
|
||||
for old_dir in models--*; do
|
||||
[ -d "$old_dir" ] || continue
|
||||
repo_id=$(echo "$old_dir" | sed 's/models--//; s/--/\//g')
|
||||
snapshot=$(ls -1 "$old_dir/snapshots" 2>/dev/null | head -1)
|
||||
[ -n "$snapshot" ] || continue
|
||||
mkdir -p "$repo_id"
|
||||
cp -RLn "$old_dir/snapshots/$snapshot/"* "$repo_id/" 2>/dev/null || true
|
||||
done
|
||||
- name: Run pytest -nauto
|
||||
run: |
|
||||
source /tmp/tinygrad_pytest_ci/bin/activate
|
||||
pytest -nauto --durations=20
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
run: FLOAT16=1 CL=1 IMAGE=2 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: IMAGE=1 openpilot compile3 0.10.1 driving_vision
|
||||
run: FLOAT16=1 CL=1 IMAGE=1 python3.11 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
|
||||
testmacbenchmark:
|
||||
name: Mac Benchmark
|
||||
|
|
@ -515,7 +510,7 @@ jobs:
|
|||
- name: Run 10 CIFAR training steps
|
||||
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py
|
||||
- name: Run 10 CIFAR training steps w HALF
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=230 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
# - name: Run 10 CIFAR training steps w BF16
|
||||
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
|
||||
# TODO: too slow
|
||||
|
|
@ -525,8 +520,9 @@ jobs:
|
|||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training steps w 6 GPUS
|
||||
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Test full tinyfs load
|
||||
run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
|
||||
# this needs to be mocked and testable on a local machine
|
||||
#- name: Test full tinyfs load
|
||||
# run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
|
|
|
|||
24
.github/workflows/test.yml
vendored
24
.github/workflows/test.yml
vendored
|
|
@ -1,7 +1,7 @@
|
|||
name: Unit Tests
|
||||
env:
|
||||
# increment this when downloads substantially change to avoid the internet
|
||||
CACHE_VERSION: '16'
|
||||
CACHE_VERSION: '17'
|
||||
CAPTURE_PROCESS_REPLAY: 1
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
|
|
@ -664,6 +664,28 @@ jobs:
|
|||
- name: Run LLVM test
|
||||
run: AMD_LLVM=1 python test/device/test_amd_llvm.py
|
||||
|
||||
testmockam:
|
||||
name: Linux (am)
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 15
|
||||
env:
|
||||
AMD: 1
|
||||
MOCKGPU: 1
|
||||
AMD_IFACE: PCI
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: mockam
|
||||
deps: testing_unit
|
||||
amd: 'true'
|
||||
- name: Run test_tiny on MOCKAM
|
||||
run: python test/test_tiny.py
|
||||
- name: Run test_hcq on MOCKAM
|
||||
run: python -m pytest test/device/test_hcq.py
|
||||
|
||||
testamd:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
|
|||
|
||||
Group UOps into kernels.
|
||||
|
||||
::: tinygrad.schedule.rangeify.get_rangeify_map
|
||||
::: tinygrad.schedule.rangeify.get_rangeify
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
|||
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
|
||||
EVAL_BS = getenv("EVAL_BS", BS)
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
|
||||
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
|
||||
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
|
||||
assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}"
|
||||
|
||||
class UnsyncedBatchNorm:
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
|
||||
|
|
|
|||
|
|
@ -65,17 +65,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
|||
else:
|
||||
# pad data with training mean
|
||||
img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
|
||||
|
||||
# broken out
|
||||
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
||||
#storage_tensor = X[idx].contiguous().realize().lazydata.base.realized
|
||||
#storage_tensor._copyin(img_tensor.numpy())
|
||||
|
||||
# faster
|
||||
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
# ideal
|
||||
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
||||
X[idx].flatten().assign(img.tobytes())
|
||||
q_out.put(idx)
|
||||
q_out.put(None)
|
||||
|
||||
|
|
@ -264,8 +254,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
|
|||
x = random_brightness_augmentation(x)
|
||||
x = gaussian_noise(x)
|
||||
|
||||
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = x.tobytes()
|
||||
Y[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = y.tobytes()
|
||||
X[idx].flatten().assign(x.tobytes())
|
||||
Y[idx].flatten().assign(y.tobytes())
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
|
|
@ -379,12 +369,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
|
|||
clipped_match_idxs = np.clip(match_idxs, 0, None)
|
||||
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
|
||||
|
||||
boxes[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_boxes.tobytes()
|
||||
labels[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_labels.tobytes()
|
||||
matches[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = match_idxs.tobytes()
|
||||
anchors[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = anchor.tobytes()
|
||||
boxes[idx].flatten().assign(clipped_boxes.tobytes())
|
||||
labels[idx].flatten().assign(clipped_labels.tobytes())
|
||||
matches[idx].flatten().assign(match_idxs.tobytes())
|
||||
anchors[idx].flatten().assign(anchor.tobytes())
|
||||
|
||||
imgs[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
|
||||
imgs[idx].flatten().assign(img.tobytes())
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
|
|
|
|||
|
|
@ -1285,6 +1285,7 @@ def train_llama3():
|
|||
from extra.models.llama import Transformer
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.optim import GradAccClipAdamW
|
||||
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
|
|
@ -1370,13 +1371,13 @@ def train_llama3():
|
|||
# prevents memory spike on device 0
|
||||
v.realize()
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
optim = GradAccClipAdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay, grad_acc=grad_acc)
|
||||
|
||||
# init grads
|
||||
for p in optim.params:
|
||||
p.grad = p.zeros_like().contiguous().realize()
|
||||
grads = [p.grad for p in optim.params]
|
||||
grads: list[Tensor] = [p.grad for p in optim.params]
|
||||
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
|
|
@ -1407,25 +1408,11 @@ def train_llama3():
|
|||
|
||||
@TinyJit
|
||||
def optim_step():
|
||||
for p in optim.params:
|
||||
p.grad.assign(p.grad / grad_acc)
|
||||
|
||||
# L2 norm grad clip
|
||||
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
||||
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
||||
if not getenv("DISABLE_GRAD_CLIP_NORM"):
|
||||
total_norm = Tensor(0.0, dtype=dtypes.float32, device=optim.params[0].device)
|
||||
for g in grads:
|
||||
total_norm += g.float().square().sum()
|
||||
total_norm = total_norm.sqrt().contiguous().realize()
|
||||
for g in grads:
|
||||
g.assign((g * (opt_gradient_clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(g.dtype)).realize()
|
||||
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
|
||||
for g in grads:
|
||||
g.assign(g.zeros_like().contiguous()).realize()
|
||||
g.assign(g.zeros_like())
|
||||
|
||||
lr = optim.lr
|
||||
Tensor.realize(lr, *grads)
|
||||
|
|
|
|||
24
examples/mlperf/optim.py
Normal file
24
examples/mlperf/optim.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.nn.optim import LAMB
|
||||
from tinygrad.helpers import FUSE_OPTIM
|
||||
|
||||
class GradAccClipAdamW(LAMB):
|
||||
def __init__(self, params:list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, grad_acc=1, clip_norm=1.0, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, b1, b2, eps, weight_decay, adam=True, fused=FUSE_OPTIM)
|
||||
self.grad_acc, self.clip_norm = grad_acc, clip_norm
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
if self.fused:
|
||||
grads[0] = grads[0] / self.grad_acc
|
||||
total_norm = grads[0].float().square().sum().sqrt()
|
||||
grads[0] = (grads[0] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[0].dtype)
|
||||
else:
|
||||
total_norm = Tensor.zeros((), dtype=dtypes.float32, device=self.device)
|
||||
for g in grads:
|
||||
total_norm += g.float().square().sum()
|
||||
total_norm = total_norm.sqrt()
|
||||
for i in range(len(grads)):
|
||||
grads[i] = grads[i] / self.grad_acc
|
||||
grads[i] = (grads[i] * (self.clip_norm / (total_norm + 1e-6)).clamp(max_=1.0)).cast(grads[i].dtype)
|
||||
return super()._step(params, grads)
|
||||
|
|
@ -11,9 +11,10 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
|||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1}
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
|
@ -30,8 +31,11 @@ export SEED=${SEED:-5760}
|
|||
export DATA_SEED=${DATA_SEED:-5760}
|
||||
|
||||
export JITBEAM=${JITBEAM:-3}
|
||||
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=1
|
||||
|
||||
export FAKEDATA=1 BENCHMARK=10 LLAMA_LAYERS=2
|
||||
export FAKEDATA=1 BENCHMARK=10
|
||||
if [ -z "$FULL_LAYERS" ]; then
|
||||
export LLAMA_LAYERS=2
|
||||
fi
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
|
|
|||
|
|
@ -11,9 +11,10 @@ export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
|||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-0}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1}
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
|
@ -30,6 +31,6 @@ export SEED=${SEED:-$RANDOM}
|
|||
export DATA_SEED=${DATA_SEED:-5760}
|
||||
|
||||
export JITBEAM=${JITBEAM:-3}
|
||||
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
|
||||
export BEAM_UOPS_MAX=6000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=1
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
|
|
|||
|
|
@ -236,8 +236,6 @@ class SMICtx:
|
|||
case _: return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX
|
||||
|
||||
def get_mem_usage(self, dev):
|
||||
return 0
|
||||
|
||||
usage = 0
|
||||
pt_stack = [dev.mm.root_page_table]
|
||||
while len(pt_stack) > 0:
|
||||
|
|
@ -246,8 +244,8 @@ class SMICtx:
|
|||
entry = pt.entries[i]
|
||||
|
||||
if (entry & am.AMDGPU_PTE_VALID) == 0: continue
|
||||
if pt.lv!=am.AMDGPU_VM_PTB and not dev.gmc.is_pte_huge_page(pt.lv, entry):
|
||||
pt_stack.append(AMPageTableEntry(dev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1))
|
||||
if pt.lv < am.AMDGPU_VM_PDB0 and not dev.gmc.is_pte_huge_page(pt.lv, entry):
|
||||
pt_stack.append(AMPageTableEntry(dev, dev.xgmi2paddr(entry & 0x0000FFFFFFFFF000), lv=pt.lv+1))
|
||||
continue
|
||||
if (entry & am.AMDGPU_PTE_SYSTEM) != 0: continue
|
||||
usage += (1 << ((9 * (3-pt.lv)) + 12))
|
||||
|
|
|
|||
|
|
@ -41,9 +41,13 @@ class Attention:
|
|||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
if getenv("WQKV"):
|
||||
self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False)
|
||||
else:
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = nn.RMSNorm(dim, qk_norm) if qk_norm is not None else None
|
||||
|
|
@ -51,9 +55,8 @@ class Attention:
|
|||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]=None) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
xqkv = x @ self.wqkv.T
|
||||
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
|
||||
xqkv = self.wqkv(x)
|
||||
xq, xk, xv = xqkv.split([self.n_heads * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim], dim=2)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x.contiguous_backward()), self.wv(x)
|
||||
|
||||
|
|
@ -200,14 +203,14 @@ class Transformer:
|
|||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
h = self.tok_embeddings(tokens).contiguous()
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, start_pos:start_pos+seqlen, :, :, :]
|
||||
|
||||
if self.max_context != 0 and seqlen > 1:
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1)
|
||||
else: mask = None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
logits = self.output(self.norm(h).contiguous().contiguous_backward()).contiguous_backward()
|
||||
if math.isnan(temperature): return logits
|
||||
|
||||
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from typing import Generator
|
|||
from tinygrad.helpers import temp, unwrap, DEBUG
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
from test.amd.disasm import disasm
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InstExec:
|
||||
|
|
@ -44,8 +46,8 @@ class OccEvent(WaveSlot):
|
|||
RunKey = tuple[str, int]
|
||||
|
||||
class _ROCParseCtx:
|
||||
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]):
|
||||
self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms
|
||||
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]]):
|
||||
self.sqtt_evs, self.disasms = iter(sqtt_evs), {k:{k2:(disasm(v2), v2.size()) for k2,v2 in v.items()} for k,v in disasms.items()}
|
||||
self.inst_execs:dict[RunKey, list[WaveExec]] = {}
|
||||
self.occ_events:dict[RunKey, list[OccEvent]] = {}
|
||||
|
||||
|
|
@ -71,7 +73,7 @@ class _ROCParseCtx:
|
|||
self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
|
||||
ev.end_time, insts_blob))
|
||||
|
||||
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx:
|
||||
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, Inst]]) -> _ROCParseCtx:
|
||||
ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms)
|
||||
|
||||
@rocprof.rocprof_trace_decoder_se_data_callback_t
|
||||
|
|
|
|||
|
|
@ -47,24 +47,21 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
|
||||
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
|
||||
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch))[:2]
|
||||
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2]
|
||||
|
||||
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch))[:3]
|
||||
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3]
|
||||
|
||||
# unshuffle dq
|
||||
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch))[0]
|
||||
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0]
|
||||
|
||||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch), grad_fxn=grad)[:2]
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2]
|
||||
|
||||
return attn.transpose(1, 2)
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str):
|
||||
B, N, H, D = q.shape
|
||||
H_KV = k.shape[2]
|
||||
|
||||
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_fwd_causal.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
|
||||
|
|
@ -95,9 +92,7 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
|
|||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str):
|
||||
B, N, H, D = o.shape
|
||||
|
||||
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_pre.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
|
||||
|
|
@ -128,10 +123,7 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
|
|||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str):
|
||||
B, N, H, D = q.shape
|
||||
H_KV = k.shape[2]
|
||||
|
||||
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_causal.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
|
||||
|
|
@ -162,9 +154,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
|
|||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str):
|
||||
B, N, H, D = dq_out.shape
|
||||
|
||||
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_post.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ if __name__ == "__main__":
|
|||
|
||||
kernel_count = GlobalCounters.kernel_count
|
||||
assert kernel_count > 0, "No kernels, test failed"
|
||||
expected_kernels = 228
|
||||
# NOTE: this is 124 on torch 2.10.0
|
||||
expected_kernels = 332
|
||||
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
|
||||
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
assert kernel_count <= expected_kernels, f"{expectation}"
|
||||
|
|
@ -26,7 +26,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
def fn():
|
||||
x = torch.randn(128, 128, device=device)
|
||||
return (x + 1.0) * 2.0 - 0.5
|
||||
self._check_kernel_count(fn, 6)
|
||||
self._check_kernel_count(fn, 5)
|
||||
|
||||
def test_relu_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -50,14 +50,14 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
def fn():
|
||||
x = torch.randn(64, 64, device=device)
|
||||
return (x * 2.0).sum()
|
||||
self._check_kernel_count(fn, 7)
|
||||
self._check_kernel_count(fn, 5)
|
||||
|
||||
def test_matmul_elementwise_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(32, 32, device=device)
|
||||
w = torch.randn(32, 32, device=device)
|
||||
return torch.nn.functional.relu(x @ w + 1.0)
|
||||
self._check_kernel_count(fn, 6)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_pooling_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -71,7 +71,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
identity = torch.randn(1, 8, 16, 16, device=device)
|
||||
out = x + identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 6)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_inplace_add_relu_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -79,7 +79,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
y = torch.randn(1, 16, 32, 32, device=device)
|
||||
x += y
|
||||
return torch.nn.functional.relu(x)
|
||||
self._check_kernel_count(fn, 6)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_conv_bn_add_relu_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -92,7 +92,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
out = bn(conv(x))
|
||||
out += identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 16)
|
||||
self._check_kernel_count(fn, 17)
|
||||
|
||||
def test_multiple_inplace_ops_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -138,7 +138,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
self._check_kernel_count(fn, 33)
|
||||
self._check_kernel_count(fn, 28)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ testing_minimal = [
|
|||
"hypothesis>=6.148.9",
|
||||
"z3-solver<4.15.4", # 4.15.4 has a segfault when creating many z3.Context()
|
||||
]
|
||||
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "ggml-python"]
|
||||
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai", "gguf"]
|
||||
testing = [
|
||||
"tinygrad[testing_unit]",
|
||||
"pillow",
|
||||
|
|
|
|||
|
|
@ -1,19 +1,9 @@
|
|||
"""Shared test helpers for AMD tests."""
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.autogen import llvm
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
@dataclass
|
||||
class KernelInfo:
|
||||
code: bytes
|
||||
src: str
|
||||
global_size: tuple[int, int, int]
|
||||
local_size: tuple[int, int, int]
|
||||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
ARCH_TO_TARGET:dict[str, list[str]] = {
|
||||
"rdna3":["gfx1100"],
|
||||
"rdna4":["gfx1200"],
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from tinygrad import Device
|
|||
|
||||
from test.mockgpu.amd.emu import WaveState, _decode_at, WAVE_SIZE, VCC_LO, EXEC_LO, SCC
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
from test.amd.helpers import KernelInfo
|
||||
import tinygrad
|
||||
REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.so"
|
||||
if not REMU_PATH.exists(): REMU_PATH = Path(tinygrad.__file__).parent.parent / "extra/remu/target/release/libremu.dylib"
|
||||
|
|
@ -22,6 +21,15 @@ def _vals_equal(a: int, b: int) -> bool:
|
|||
if a == b: return True
|
||||
return _is_f32_nan(a) and _is_f32_nan(b)
|
||||
|
||||
@dataclass
|
||||
class KernelSnapshot:
|
||||
code: bytes
|
||||
src: str
|
||||
global_size: tuple[int, int, int]
|
||||
local_size: tuple[int, int, int]
|
||||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
@dataclass
|
||||
class StateSnapshot:
|
||||
pc: int
|
||||
|
|
@ -285,7 +293,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
|||
|
||||
return True, f"Completed {gx*gy*gz} workgroups", total_steps
|
||||
|
||||
def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int, int], max_steps: int = 1000,
|
||||
def compare_emulators_multi_kernel(kernels: list[KernelSnapshot], buf_pool: dict[int, int], max_steps: int = 1000,
|
||||
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] | None = None) -> tuple[bool, str]:
|
||||
"""Run all kernels through both emulators with shared buffer pool."""
|
||||
if buf_data is None: buf_data = {}
|
||||
|
|
@ -349,7 +357,7 @@ def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list,
|
|||
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, (n_lanes, 1, 1), max_steps, debug, trace_len)
|
||||
return ok, msg
|
||||
|
||||
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:
|
||||
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelSnapshot], dict[int, int], dict[int, bytes]]:
|
||||
"""Compile a tinygrad operation and extract all kernels with their buffer mappings."""
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
|
@ -387,7 +395,7 @@ def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int],
|
|||
buf_pool[buf_id] = b.nbytes
|
||||
buf_idxs.append(buf_id)
|
||||
buf_sizes.append(b.nbytes)
|
||||
kernels.append(KernelInfo(
|
||||
kernels.append(KernelSnapshot(
|
||||
code=bytes(sec.content),
|
||||
src=lowered.prg.p.src,
|
||||
global_size=tuple(lowered.prg.p.global_size),
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
|||
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
|
||||
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
|
||||
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.UNK_60}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_STORE}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
|
|||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
disasm_map = {addr+prg.base:(disasm(inst), inst.size()) for addr,inst in addr_table.items()}
|
||||
disasm_map = {addr+prg.base:inst for addr,inst in addr_table.items()}
|
||||
rctx = roc_decode([sqtt], {prg.tag:disasm_map})
|
||||
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
|
||||
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
|
||||
|
|
@ -30,7 +30,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
|
|||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc][0]} != {info.pc}:{disasm(info.inst)}"
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
if info.inst == s_endpgm():
|
||||
completed_wave = list(rwaves_iter[info.wave].pop(0))
|
||||
|
|
@ -72,7 +72,6 @@ class TestSQTTMapBase(unittest.TestCase):
|
|||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
@unittest.skip("this doesn't work")
|
||||
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ class TestGemmLarge(unittest.TestCase):
|
|||
if not is_cdna4():
|
||||
self.skipTest("very slow on non mi350x")
|
||||
|
||||
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
|
||||
def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ def _check_ast_count(desired_count:int, t:Tensor):
|
|||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
schedule = t.schedule()
|
||||
asts = [s for s in schedule if s.ast.op is Ops.SINK]
|
||||
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
|
||||
len(asts)
|
||||
# NOT SUPPORTED ANYMORE
|
||||
#assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
|
||||
|
||||
class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_add_shrunk_zero(self):
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class TestCustomKernel(unittest.TestCase):
|
|||
The custom_addmul kernel should be at index 3.
|
||||
"""
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
|
||||
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
A2 = (A + 1).contiguous() # kernel 0: depends on A
|
||||
|
|
@ -277,8 +277,7 @@ class TestCustomKernel(unittest.TestCase):
|
|||
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
|
||||
|
||||
big_sink = result.uop.sink()
|
||||
tensor_map = get_rangeify_map(big_sink)
|
||||
sched_sink = big_sink.substitute(tensor_map)
|
||||
sched_sink = get_rangeify(big_sink)
|
||||
schedule, _ = create_schedule(sched_sink)
|
||||
|
||||
# Find the custom_addmul kernel position
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ class TestImageDType(unittest.TestCase):
|
|||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
# the underlying UOp is identical
|
||||
self.assertIs(it.uop.base.realized, data.uop.base.realized)
|
||||
#self.assertIs(it.uop.base.realized, data.uop.base.realized)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_image_and_back_wrong_shape(self):
|
||||
|
|
|
|||
|
|
@ -332,7 +332,6 @@ class TestJit(unittest.TestCase):
|
|||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
#@unittest.expectedFailure # requires contiguous folding
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
|
|
@ -476,7 +475,7 @@ class TestJit(unittest.TestCase):
|
|||
b = f(Tensor([2.0]))
|
||||
assert abs((a - b).item()) > 0.5
|
||||
|
||||
def test_jit_init_with_empty_different_size(self):
|
||||
def test_jit_init_empty(self):
|
||||
@TinyJit
|
||||
def f(x:Tensor) -> Tensor: return (x + 1).realize()
|
||||
|
||||
|
|
@ -485,9 +484,16 @@ class TestJit(unittest.TestCase):
|
|||
# scalar const input is not allowed
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(2.0)).item()
|
||||
# list input has different view structure than empty(1)
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor([2.0])).item()
|
||||
# self.assertEqual(f(Tensor([2.0])).item(), 1.0) # TODO: wrong output, should be 3.0. currently depends on empty value
|
||||
|
||||
def test_jit_init_empty_alt(self):
|
||||
@TinyJit
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return b.assign(a+1)
|
||||
for i in range(4):
|
||||
a = Tensor([i])
|
||||
b = Tensor.empty_like(a)
|
||||
c = f(a, b)
|
||||
self.assertEqual(c.item(), i+1)
|
||||
|
||||
@unittest.skip("Pending multioutput implementation #3607")
|
||||
class TestMultioutputJit(unittest.TestCase):
|
||||
|
|
@ -645,8 +651,8 @@ class TestJitFree(unittest.TestCase):
|
|||
def test_replan_buffers_memory_layout(self):
|
||||
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
|
||||
|
||||
ext_tensor = Tensor([1,24,23,45,1])
|
||||
ext_tensor_2 = Tensor([2,2,2,2,2])
|
||||
ext_tensor = Tensor([1,24,23,45,1]).contiguous()
|
||||
ext_tensor_2 = Tensor([2,2,2,2,2]).contiguous()
|
||||
@TinyJit
|
||||
def fxn(x:Tensor):
|
||||
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
|
||||
|
|
@ -654,9 +660,9 @@ class TestJitFree(unittest.TestCase):
|
|||
for i in range(5):
|
||||
out = fxn(Tensor([i,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 11400+200*i)
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
|
||||
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 4)
|
||||
fxn.captured.replan_buffers_memory_layout()
|
||||
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
|
||||
self.assertEqual(len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])), 2)
|
||||
|
||||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ import unittest
|
|||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
|
||||
|
|
@ -254,100 +253,6 @@ class TestLinearizer(unittest.TestCase):
|
|||
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
|
||||
assert end_range < uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
|
||||
sizes = [x.src[0].arg for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
# TODO: add these back after uop symbolic
|
||||
# for i in range(len(dims)):
|
||||
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
|
||||
# for i in range(len(loop_idxs)):
|
||||
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
|
||||
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
|
||||
|
||||
# no-op
|
||||
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
|
||||
|
||||
# check reverse dims
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 3
|
||||
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
# 2 -> 2
|
||||
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
|
||||
# We should check if the returned indices are correct, for all cases.
|
||||
# (65536, 2) -> (32768, 4)
|
||||
dims, expected_limited_dims = (65536,2), (32768, 4)
|
||||
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
|
||||
def match_div(): raise RuntimeError("match_div")
|
||||
def match_mod(): raise RuntimeError("match_mod")
|
||||
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
|
||||
pm = PatternMatcher([
|
||||
(flat_idx_pattern//dims[1], match_div),
|
||||
(flat_idx_pattern%dims[1], match_mod)
|
||||
])
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[0], pm)
|
||||
self.assertIn("match_div", str(error.exception))
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[1], pm)
|
||||
self.assertIn("match_mod", str(error.exception))
|
||||
|
||||
# # variable too large
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_default_global_reversed(self):
|
||||
# shrink so that the dims do not collapse
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
def _test_shard_op(self, op, out, n=4):
|
||||
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
|
||||
r = op(t).realize()
|
||||
assert t.uop.is_realized, "shard didn't realize"
|
||||
#assert t.uop.is_realized, "shard didn't realize"
|
||||
self.assertEqual(r.tolist(), out)
|
||||
def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]])
|
||||
def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]])
|
||||
|
|
@ -654,54 +654,6 @@ class TestMultiTensor(unittest.TestCase):
|
|||
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard(self):
|
||||
for N in range(1, 6):
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
n = X.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
X.shard_(devices, 2)
|
||||
np.testing.assert_equal(X.numpy(), n)
|
||||
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
|
||||
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
|
||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_multiple_zeros(self):
|
||||
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
|
||||
for N in (1, 2, 3, 4):
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
# make sure something is computed on each device
|
||||
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
|
||||
np.testing.assert_equal(X.numpy(), data)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard_with_empty(self):
|
||||
N = 4
|
||||
X = Tensor.rand(16, 1, 3).contiguous().realize()
|
||||
np_x = X.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
|
||||
# test empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
|
||||
|
||||
# test reshape with empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_multiple_uneven_shard(self):
|
||||
N = 4
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
Y = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
np_x, np_y = X.numpy(), Y.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
X.shard_(devices, 2)
|
||||
Y.shard_(devices, 2)
|
||||
np.testing.assert_equal(X.numpy(), np_x)
|
||||
np.testing.assert_equal(Y.numpy(), np_y)
|
||||
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
|
||||
|
||||
def test_bn_ast_on_devices(self):
|
||||
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
|
||||
bn = nn.BatchNorm2d(64)
|
||||
|
|
@ -752,34 +704,7 @@ class TestMultiTensor(unittest.TestCase):
|
|||
|
||||
# test no left join
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7)).schedule()
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
assert t.uop.axis == t_axis
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
|
||||
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
|
||||
|
||||
# ok to reshape as long as elements remain on same device
|
||||
reshape_helper(t0, t0.reshape(2, 2, 42, 3, 5), 2)
|
||||
# split to the right
|
||||
reshape_helper(t0, t0.reshape(2, 2, 6, 7, 15), 2)
|
||||
# split off and merge to the right
|
||||
reshape_helper(t0, t0.reshape(4, 6, 105), 1)
|
||||
# really blend the axes together
|
||||
reshape_helper(t0, t0.reshape(4, 30, 21), 1)
|
||||
# split off 1-shape
|
||||
reshape_helper(t0, t0.reshape(4, 1, 42, 15), 2)
|
||||
reshape_helper(t0, t0.reshape(4, 6, 1, 7, 15), 1)
|
||||
|
||||
# assert if cannot maintain shard axis without moving items between devices
|
||||
with self.assertRaises(AssertionError): t0.reshape(4, 7, 6, 15)
|
||||
# assert for degenerate reshape
|
||||
with self.assertRaises(AssertionError): t0.reshape(4, 5, 7, 15)
|
||||
# assert for cannot maintain axis
|
||||
with self.assertRaises(AssertionError): t0.reshape(4, 3, 2, 7, 15)
|
||||
t0.reshape((26*15,7)).contiguous().schedule()
|
||||
|
||||
# it doesn't work like this anymore
|
||||
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
|
||||
|
|
@ -849,16 +774,6 @@ class TestMultiTensor(unittest.TestCase):
|
|||
self.assertEqual(rab.device, devices_4)
|
||||
self.assertEqual(rab.uop.axis, 0)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
t2 = Tensor.rand_like(t)
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.uop.axis, t2.uop.axis)
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.uop.src, t2.uop.src))
|
||||
|
||||
def test_rand_like_none_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
t2 = Tensor.rand_like(t)
|
||||
|
|
@ -894,6 +809,14 @@ class TestMultiTensor(unittest.TestCase):
|
|||
t2.realize()
|
||||
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
|
||||
|
||||
def test_full_like_shrink_on_shard_axis(self):
|
||||
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
|
||||
out = Tensor.full_like(t, 2)[:, :8]
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink
|
||||
run_schedule(sched)
|
||||
self.assertEqual(out.tolist(), [[2]*8]*16)
|
||||
|
||||
def test_dropout_on_shard(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).to(devices_2)
|
||||
|
|
@ -910,15 +833,6 @@ class TestMultiTensor(unittest.TestCase):
|
|||
assert set(unique) == {0, 2}, unique
|
||||
assert 200 < counts[0] < 312, counts[0]
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0)
|
||||
output = X.dropout(0.5).numpy()
|
||||
unique, counts = np.unique(output, return_counts=True)
|
||||
assert set(unique) == {0, 2}, unique
|
||||
assert 100 < counts[0] < 156, counts[0]
|
||||
|
||||
@unittest.skip("TODO: this requires forced_realize to be deleted.")
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
|
|
@ -926,13 +840,15 @@ class TestMultiTensor(unittest.TestCase):
|
|||
t.shard_(devices, axis=0).realize()
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
|
||||
|
||||
@unittest.skip("this is unreliable on OSX")
|
||||
def test_clone(self):
|
||||
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
|
||||
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
|
||||
|
||||
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
|
||||
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
|
||||
for axis in (None, 0):
|
||||
t = Tensor.arange(16).reshape(4, 4).shard(devices_2, axis=axis).contiguous().realize()
|
||||
t_clone = t.clone().realize()
|
||||
self.assertEqual(t_clone.device, t.device)
|
||||
self.assertEqual(t_clone.uop.axis, axis)
|
||||
self.assertEqual(t_clone.tolist(), t.tolist())
|
||||
t_clone += 1
|
||||
self.assertNotEqual(t_clone.tolist(), t.tolist())
|
||||
|
||||
@unittest.skip("RANGEIFY doesn't support multi const folding")
|
||||
def test_multi_const_folding(self):
|
||||
|
|
@ -981,18 +897,18 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|||
|
||||
with self.assertRaises(AssertionError):
|
||||
# sharded axis shrink on non-device boundry is not allowed
|
||||
a = t.shrink(((0, 3), (0, 8)))
|
||||
a.schedule()
|
||||
with self.assertRaises(AssertionError):
|
||||
# cannot shrink sharded and non-sharded axis at the same time
|
||||
a = t.shrink(((0, 2), (2, 4)))
|
||||
a = t.shrink(((0, 3), (0, 8))).contiguous()
|
||||
a.schedule()
|
||||
a = t.shrink(((0, 2), (2, 4)))
|
||||
assert a.shape == (2, 2)
|
||||
ref = Tensor.arange(64).reshape(8, 8).shrink(((0, 2), (2, 4)))
|
||||
np.testing.assert_equal(a.numpy(), ref.numpy())
|
||||
|
||||
a = t.shrink(((0, 2), (0, 8)))
|
||||
a = t.shrink(((0, 2), (0, 8))).contiguous()
|
||||
a.schedule()
|
||||
assert a.shape == (2, 8)
|
||||
|
||||
p = a.pad(((0, 6), (0, 0)))
|
||||
p = a.pad(((0, 6), (0, 0))).contiguous()
|
||||
p.schedule()
|
||||
assert p.shape == (8, 8)
|
||||
|
||||
|
|
@ -1042,24 +958,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
|||
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
|
||||
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven(self):
|
||||
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
|
||||
|
||||
a = t.shrink(((0, 2), None))
|
||||
b = t.shrink(((2, 3), None))
|
||||
na = t.numpy()[0:2]
|
||||
nb = t.numpy()[2:3]
|
||||
np.testing.assert_equal(a.numpy(), na)
|
||||
np.testing.assert_equal(b.numpy(), nb)
|
||||
np.testing.assert_equal((a+1).numpy(), na+1)
|
||||
np.testing.assert_equal((b+1).numpy(), nb+1)
|
||||
np.testing.assert_equal((1+a).numpy(), 1+na)
|
||||
np.testing.assert_equal((1+b).numpy(), 1+nb)
|
||||
np.testing.assert_equal((a+a).numpy(), na+na)
|
||||
np.testing.assert_equal((b+b).numpy(), nb+nb)
|
||||
|
||||
def test_add_two_partitions(self):
|
||||
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from tinygrad.tensor import _to_np_dtype
|
|||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
if getenv("TINY_BACKEND"):
|
||||
TINY_BACKEND = getenv("TINY_BACKEND")
|
||||
if TINY_BACKEND:
|
||||
import tinygrad.nn.torch # noqa: F401 # pylint: disable=unused-import
|
||||
torch.set_default_device("tiny")
|
||||
|
||||
|
|
@ -760,6 +761,7 @@ class TestOps(unittest.TestCase):
|
|||
data = [[1,-8,1],[32,1,6]]
|
||||
tor = torch.tensor(data, dtype=torch.int)
|
||||
ten = Tensor(data, dtype=dtypes.int32)
|
||||
# NOTE: this breaks assigns because it's folded to 0!
|
||||
helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True)
|
||||
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
|
||||
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
|
||||
|
|
|
|||
|
|
@ -1,230 +0,0 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, UOp, nn
|
||||
from tinygrad.uop.ops import AxisType, Ops
|
||||
|
||||
class TestOuterworldReduce(unittest.TestCase):
|
||||
def test_reduce(self):
|
||||
x = Tensor.ones(5, 5).contiguous()
|
||||
a = UOp.range(5, -1, AxisType.REDUCE)
|
||||
out = x[a]
|
||||
# TODO: syntax for this
|
||||
t = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
self.assertListEqual(t.tolist(), [5.,5.,5.,5.,5.])
|
||||
|
||||
# TODO: delete test_outerworld_range?
|
||||
class TestOuterRange(unittest.TestCase):
|
||||
def test_simple_range(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
acc = Tensor.zeros().contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[vi].uop).end(i)))
|
||||
out.realize()
|
||||
assert out.item() == 10.0
|
||||
|
||||
def test_inner_range(self):
|
||||
a = Tensor.ones(10, 10).contiguous()
|
||||
acc = Tensor.zeros(10).contiguous()
|
||||
Tensor.realize(a, acc)
|
||||
|
||||
# this is fold
|
||||
i = UOp.range(10, -100, AxisType.OUTER)
|
||||
acc_i = acc.uop.after(i)
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(acc.uop.after(acc_i.store(acc_i + a[:, vi].uop).end(i)))
|
||||
out.realize()
|
||||
self.assertEqual(out.tolist(), [10.0]*10)
|
||||
|
||||
def test_range_matmul(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
|
||||
ref.realize()
|
||||
|
||||
# 3 matmuls with outer world range
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
vec_i = Tensor(vec.uop.after(i))
|
||||
comp = vec_i.contiguous() @ mats[i]
|
||||
store = vec_i.uop.store(comp.uop).end(i)
|
||||
out = Tensor(vec.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
|
||||
|
||||
class TestOuterScan(unittest.TestCase):
|
||||
def _test_scan(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
vec1 = vec @ mats[0]
|
||||
vec2 = vec1 @ mats[1]
|
||||
vec3 = vec2 @ mats[2]
|
||||
ref = Tensor.stack(vec1, vec2, vec3)
|
||||
ref.realize()
|
||||
return vec, mats, ref
|
||||
|
||||
def test_uop_scan_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-5), f"max diff {(ref-out).abs().max().item()}"
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a] + 1
|
||||
assert sel.shape == (10,)
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t+1==cpy).all().item())
|
||||
|
||||
def test_range_plus_1_transpose(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a] + 1
|
||||
assert sel.shape == (10,)
|
||||
cpy = sel.reshape(10, 1).expand(10, a).contiguous().realize()
|
||||
|
||||
self.assertTrue(((t+1).T==cpy).all().item())
|
||||
|
||||
def test_flip_range(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t.flip(0)==cpy).all().item())
|
||||
|
||||
def test_vmap(self):
|
||||
def f(x): return x.sum(axis=0)*2
|
||||
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.reshape(1, 2).expand(a, 2).contiguous()
|
||||
|
||||
# 3x2 grid of 20
|
||||
out.realize()
|
||||
self.assertTrue((out==20).all().item())
|
||||
|
||||
def test_fancy_vmap(self):
|
||||
def f(x,y): return x+y
|
||||
|
||||
x = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
y = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[:,a], y[a,:])
|
||||
# TODO: this should support flatten
|
||||
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
|
||||
|
||||
ref = x @ mats
|
||||
if fuse: ref = ref * 2
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, axis_type)
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
if fuse: out = out * 2
|
||||
if grad:
|
||||
out.mean().backward()
|
||||
np.testing.assert_allclose(mats.grad.numpy(), (2./30) if fuse else (1./30))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"max diff {(ref-out).abs().max().item()}"
|
||||
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
|
||||
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
|
||||
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
|
||||
|
||||
def test_vmap_inner_grad(self): self.test_vmap_inner(grad=True)
|
||||
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
||||
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
||||
|
||||
def test_vmap_convs(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 8, 3), Tensor.relu,
|
||||
nn.Conv2d(8, 8, 3), Tensor.relu]
|
||||
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None, None, None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
def test_vmap_gemm(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
|
||||
def test_vmap_gemm_grad(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
layer_tensors = nn.state.get_parameters(layers)
|
||||
img = Tensor.randn(4, 16).realize(*layer_tensors)
|
||||
for l in layer_tensors: l.requires_grad_()
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
out_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compute reference grads
|
||||
for l in layer_tensors: l.grad = None
|
||||
img.sequential(layers).mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
ref_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compare
|
||||
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestOuterCall(unittest.TestCase):
|
||||
def test_outer_call_assign(self):
|
||||
a = Tensor.zeros(10,10).contiguous()
|
||||
b = Tensor.ones(10,10).contiguous()
|
||||
Tensor.realize(a,b)
|
||||
|
||||
pa = a.as_param(0)
|
||||
pb = b.as_param(1)
|
||||
out = Tensor.call(a, b, fxn=pa.assign(pa+pb))
|
||||
out.realize()
|
||||
|
||||
print(a.numpy())
|
||||
assert (a == 1).all().item()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, nn, Variable, UOp
|
||||
|
||||
# outerworld range should support three things
|
||||
# 1. full optimizer steps (test_model_bound_range)
|
||||
# 2. gradient accumulation (you want to end the range before running the optimizer)
|
||||
# 3. stacked linear layers
|
||||
|
||||
class Model:
|
||||
def __init__(self): self.w = nn.Linear(64, 8, bias=False)
|
||||
def __call__(self, x:Tensor) -> Tensor: return self.w(x)
|
||||
|
||||
def get_model_and_opt():
|
||||
Tensor.manual_seed(1337)
|
||||
m = Model()
|
||||
opt = nn.optim.SGD(nn.state.get_parameters(m), lr=0.1, weight_decay=0)
|
||||
return m, opt
|
||||
|
||||
class TestOuterworldRange(unittest.TestCase):
|
||||
STEPS = 5
|
||||
BS = 20
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
Tensor.manual_seed(1338)
|
||||
# it learns to compute mean
|
||||
cls.X = Tensor.randn(cls.STEPS, cls.BS, 64).contiguous().realize()
|
||||
cls.Y = cls.X.reshape(cls.STEPS, cls.BS, 8, 8).mean(axis=-1).contiguous().realize()
|
||||
cls.losses = cls._get_model_baseline()
|
||||
|
||||
def _compare(self, losses):
|
||||
for i,(x,y) in enumerate(zip(self.losses, losses)):
|
||||
self.assertAlmostEqual(x, y, places=5, msg=f"mismatch at {i} in {self.losses} vs {losses}")
|
||||
|
||||
@classmethod
|
||||
@Tensor.train()
|
||||
def _get_model_baseline(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
return losses
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_grad_acc(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
sub_batch_size = self.BS//2
|
||||
loss = 0
|
||||
scaling_factor = self.BS//sub_batch_size
|
||||
for j in range(0, self.BS, sub_batch_size):
|
||||
sub_loss = (m(self.X[i][j:j+sub_batch_size]) - self.Y[i][j:j+sub_batch_size]).square().mean() / scaling_factor
|
||||
sub_loss.backward()
|
||||
loss += sub_loss
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
self._compare(losses)
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_variable(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
loss.realize(*opt.schedule_step())
|
||||
losses.append(loss.item())
|
||||
self._compare(losses)
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_scheduled(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses.append(loss)
|
||||
self._compare(Tensor.stack(*losses).tolist())
|
||||
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_setitem(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
for i in range(self.STEPS):
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[i]) - self.Y[i]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
# TODO: this shouldn't realize
|
||||
losses[i] = loss.requires_grad_(False)
|
||||
self._compare(losses.tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_variable(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = []
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses.append(loss)
|
||||
self._compare(Tensor.stack(*losses).tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_scheduled_variable_setitem(self):
|
||||
m, opt = get_model_and_opt()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
vi = Variable('i', 0, self.STEPS-1)
|
||||
for i in range(self.STEPS):
|
||||
vib = vi.bind(i)
|
||||
opt.zero_grad()
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
opt.schedule_step()
|
||||
losses[vib] = loss.requires_grad_(False)
|
||||
self._compare(losses.tolist())
|
||||
|
||||
@unittest.expectedFailure
|
||||
@Tensor.train()
|
||||
def test_model_bound_range(self):
|
||||
m, opt = get_model_and_opt()
|
||||
# TODO: should ranges be unique so you don't have to pass in the -1?
|
||||
rng = UOp.range(self.STEPS, -1)
|
||||
vib = Variable('i', 0, self.STEPS-1).bind(rng)
|
||||
loss = (m(self.X[vib]) - self.Y[vib]).square().mean()
|
||||
loss.backward()
|
||||
losses = Tensor.empty(self.STEPS)
|
||||
losses[vib] = loss
|
||||
losses.realize(*opt.schedule_step())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, struct, contextlib, statistics, time, gc
|
||||
import unittest, struct, contextlib, statistics, gc
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events, ProfilePointEvent, dedup
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
|
||||
|
|
@ -20,7 +20,7 @@ def helper_collect_profile(*devs):
|
|||
cpu_events.clear()
|
||||
|
||||
profile_list = []
|
||||
with Context(VIZ=1, PROFILE=1):
|
||||
with Context(PROFILE=1):
|
||||
yield profile_list
|
||||
for dev in devs: dev.synchronize()
|
||||
for dev in devs: dev._at_profile_finalize()
|
||||
|
|
@ -170,30 +170,19 @@ class TestProfiler(unittest.TestCase):
|
|||
for (i1, d1), (i2, d2) in pairs:
|
||||
assert abs(jitter_matrix[i1][i2]) < 0.5, "jitter should be less than 0.5us"
|
||||
|
||||
@unittest.skip("this test is flaky")
|
||||
def test_cpu_profile(self):
|
||||
def test_fxn(err=False):
|
||||
time.sleep(0.1)
|
||||
if err: raise Exception()
|
||||
time.sleep(0.1)
|
||||
|
||||
with helper_collect_profile(dev:=TestProfiler.d0) as profile:
|
||||
with cpu_profile("test_1", dev.device):
|
||||
with cpu_profile("test_1", dev):
|
||||
test_fxn(err=False)
|
||||
with self.assertRaises(Exception):
|
||||
with cpu_profile("test_2", dev.device):
|
||||
with cpu_profile("test_2", dev):
|
||||
test_fxn(err=True)
|
||||
|
||||
range_events = [p for p in profile if isinstance(p, ProfileRangeEvent)]
|
||||
range_events = [p for p in profile if isinstance(p, ProfileRangeEvent) and p.device == dev]
|
||||
self.assertEqual(len(range_events), 2)
|
||||
# record start/end time up to exit (error or success)
|
||||
for e in range_events:
|
||||
self.assertGreater(e.en, e.st)
|
||||
e1, e2 = range_events
|
||||
self.assertEqual([e1.name, e2.name], ["test_1", "test_2"])
|
||||
# TODO: this is flaky
|
||||
#self.assertLess(e1.st, e2.st)
|
||||
#self.assertGreater(e1.en-e1.st, e2.en-e2.st)
|
||||
|
||||
@unittest.skip("this test is flaky")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# schedule confirms the right things are capable of fusing
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest, functools
|
||||
import gc, unittest, functools
|
||||
import numpy as np
|
||||
from typing import cast
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
|
|
@ -168,13 +168,13 @@ class TestSchedule(unittest.TestCase):
|
|||
a = Tensor.full((4,), 4.0).contiguous().realize()
|
||||
b = Tensor.full((4,), 2.0).contiguous().realize()
|
||||
expr = (a*b)/b
|
||||
check_schedule(expr, 0)
|
||||
run_schedule(check_schedule(expr, 1))
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
|
||||
|
||||
def test_div_collapse_const(self):
|
||||
a = Tensor.full((4,), 4.0).contiguous().realize()
|
||||
expr = a/a
|
||||
check_schedule(expr, 0)
|
||||
run_schedule(check_schedule(expr, 1))
|
||||
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
|
||||
|
||||
def test_div_collapse(self):
|
||||
|
|
@ -747,7 +747,7 @@ class TestSchedule(unittest.TestCase):
|
|||
p = P[0]
|
||||
p = p.pad(((1, 0), ))
|
||||
p = p.repeat([2])
|
||||
run_schedule(check_schedule(p, 3))
|
||||
run_schedule(check_schedule(p, 4)) # TODO: this is high
|
||||
tiny_ret = p.numpy()
|
||||
|
||||
P = np.ones((3, 3), dtype=np.float32)
|
||||
|
|
@ -775,11 +775,12 @@ class TestSchedule(unittest.TestCase):
|
|||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
|
||||
def test_conv2d_fused_half(self): _test_conv2d(4, dtype=dtypes.half)
|
||||
|
||||
@unittest.skip("TODO: this is consistently creating non reproducible failures")
|
||||
def test_schedule_mem_used_with_inputs(self):
|
||||
gc.collect()
|
||||
base = GlobalCounters.mem_used
|
||||
x = Tensor.ones(256).contiguous().realize()
|
||||
(x+Tensor.ones(256).contiguous()).schedule()
|
||||
gc.collect()
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 1024)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
|
|
@ -840,10 +841,9 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
run_schedule(check_schedule(casted_view, 0))
|
||||
self.assertIsNone(casted_view.uop.base.realized)
|
||||
run_schedule(check_schedule(casted_view, 1))
|
||||
realized_const_view = casted_view.contiguous()
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
run_schedule(check_schedule(realized_const_view, 0))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
|
||||
|
|
@ -1036,7 +1036,7 @@ class TestSchedule(unittest.TestCase):
|
|||
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
|
||||
flat_base[idx] = Tensor([99,99,99,99])
|
||||
base.assign(flat_base.reshape(4, 4))
|
||||
sched = check_schedule(base, 2)
|
||||
sched = check_schedule(base, 6) # TODO: this is high
|
||||
run_schedule(sched)
|
||||
expected = list(range(16))
|
||||
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
|
||||
|
|
@ -1235,11 +1235,11 @@ class TestView(unittest.TestCase):
|
|||
bv = b.pad(((0, 2),))[-2:]
|
||||
# this becomes a late a*0
|
||||
late_mul = a*bv
|
||||
check_schedule(late_mul, 0)
|
||||
run_schedule(check_schedule(late_mul, 2))
|
||||
# the arange doesn't realize
|
||||
self.assertIsNone(b.uop.base.realized)
|
||||
#self.assertIsNone(b.uop.base.realized)
|
||||
# mul doesn't realize
|
||||
self.assertIsNone(late_mul.uop.base.realized)
|
||||
#self.assertIsNone(late_mul.uop.base.realized)
|
||||
self.assertEqual(late_mul.tolist(), [0, 0])
|
||||
|
||||
# SINK has two branches:
|
||||
|
|
@ -1252,20 +1252,21 @@ class TestView(unittest.TestCase):
|
|||
bv = b.pad(((0, 2),))[-2:]
|
||||
late_mul = a*bv
|
||||
other_child = b+2
|
||||
s = check_schedule([late_mul, other_child], 2)
|
||||
s = check_schedule([late_mul, other_child], 3)
|
||||
# the arange becomes a BUFFER
|
||||
self.assertIs(b.uop.base.op, Ops.BUFFER)
|
||||
# NOTE: no longer checked
|
||||
# mul still collapses
|
||||
self.assertIs(late_mul.uop.base.op, Ops.CONST)
|
||||
#self.assertIs(late_mul.uop.base.op, Ops.CONST)
|
||||
run_schedule(s)
|
||||
self.assertEqual(other_child.tolist(), [2, 3, 4])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
|
||||
class TestCopyFolding(unittest.TestCase):
|
||||
def test_const_copy_is_free(self):
|
||||
b = Tensor(1).to("CPU")
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
assert b.item() == 1
|
||||
b = Tensor(1).to("CPU") * 4
|
||||
run_schedule(check_schedule(b, 1, filter_sink=False))
|
||||
assert b.item() == 4
|
||||
|
||||
def test_one_hot_with_copy(self):
|
||||
y = Tensor([1, 2, 3]).to("CPU")
|
||||
|
|
@ -1273,16 +1274,16 @@ class TestCopyFolding(unittest.TestCase):
|
|||
check_schedule(x, 3, filter_sink=False)
|
||||
|
||||
def test_const_copy_multi(self):
|
||||
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
|
||||
check_schedule(x, 0, filter_sink=False)
|
||||
self.assertEqual(x.item(), 1)
|
||||
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) * 2
|
||||
run_schedule(check_schedule(x, 2, filter_sink=False))
|
||||
self.assertEqual(x.item(), 2.0)
|
||||
|
||||
def test_late_const_copy_folding(self):
|
||||
a = Tensor.arange(3).realize()
|
||||
zeros = Tensor.zeros(3).realize()
|
||||
b = (a*zeros).to("CPU")
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
b = (a*zeros).to("CPU") + 1
|
||||
run_schedule(check_schedule(b, 1, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [1, 1, 1])
|
||||
self.assertEqual(b.device, "CPU")
|
||||
|
||||
def test_alu_after_copy(self):
|
||||
|
|
@ -1321,7 +1322,7 @@ class TestCopyFolding(unittest.TestCase):
|
|||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
|
||||
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
|
||||
run_schedule(check_schedule(a, 0, filter_sink=False))
|
||||
run_schedule(check_schedule(a, 2, filter_sink=False))
|
||||
self.assertListEqual(a.tolist(), [[1.]*4]*4)
|
||||
|
||||
def test_clone(self):
|
||||
|
|
|
|||
|
|
@ -113,6 +113,12 @@ class TestFloatUOps(TestUOps):
|
|||
def test_max(self): self._test_bop_fxn(Ops.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(Ops.CMPLT, lambda a,b: a<b)
|
||||
def test_cmpne(self): self._test_bop_fxn(Ops.CMPNE, lambda a,b: a!=b)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support NaN comparison correctly")
|
||||
def test_cmpne_nan(self): # NaN != x for any x (IEEE 754), fixes #14095
|
||||
for a, b in [(math.nan, 1.0), (1.0, math.nan), (math.nan, math.nan)]:
|
||||
self.assertTrue(_test_single_value(
|
||||
[dtypes.as_const(a, dtypes.float32), dtypes.as_const(b, dtypes.float32)],
|
||||
Ops.CMPNE, (dtypes.float32, dtypes.float32)))
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_where(self):
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class TestHCQ(unittest.TestCase):
|
|||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "Can't handle async update on CPU device")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"} or getenv("AMD_IFACE", "") == "PCI", "Can't handle async update on CPU/MOCKAM device")
|
||||
def test_wait_late_set(self):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
|
|
@ -538,7 +538,7 @@ class TestHCQ(unittest.TestCase):
|
|||
|
||||
np.testing.assert_equal(cpu_buffer.numpy(), local_buf.numpy(), "failed")
|
||||
|
||||
@unittest.skipUnless(MOCKGPU, "Emulate this on MOCKGPU to check the path in CI")
|
||||
@unittest.skipUnless(MOCKGPU and getenv("AMD_IFACE", "") != "PCI", "Emulate this on MOCKGPU to check the path in CI")
|
||||
def test_on_device_hang(self):
|
||||
if not hasattr(self.d0, 'on_device_hang'): self.skipTest("device does not have on_device_hang")
|
||||
|
||||
|
|
|
|||
2
test/external/external_test_onnx_runner.py
vendored
2
test/external/external_test_onnx_runner.py
vendored
|
|
@ -56,10 +56,12 @@ class TestOnnxRunner(unittest.TestCase):
|
|||
output = runner({'inp': Tensor([1, 2, 3, 4])})['output']
|
||||
_check_ast_count(0, output)
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_const_fold_from_disk(self):
|
||||
self._test_const_fold_unary_op(True)
|
||||
self._test_const_fold_binary_op(True)
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_const_fold_from_memory(self):
|
||||
self._test_const_fold_unary_op(False)
|
||||
# TODO: understand this and fix this, bitcast related
|
||||
|
|
|
|||
13
test/external/process_replay/process_replay.py
vendored
13
test/external/process_replay/process_replay.py
vendored
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, functools, base64, codecs
|
||||
from dataclasses import replace
|
||||
from typing import Callable, Any
|
||||
|
||||
|
|
@ -8,7 +8,6 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
|
|||
if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
|
||||
|
||||
try:
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
|
@ -43,14 +42,6 @@ class ProcessReplayWarning(Warning): pass
|
|||
|
||||
# *** replay the function and convert return values to string
|
||||
|
||||
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
|
||||
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
|
||||
def to_str(ret:UOp) -> str:
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# the ast.arg is non None if we are inside of search.py
|
||||
sink_arg = ast.arg or KernelInfo()
|
||||
|
|
@ -68,8 +59,6 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]
|
|||
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {}
|
||||
replayers["get_program"] = replay_get_program
|
||||
# disable this for speed, does it ever find things?
|
||||
#replayers["get_rangeify_map"] = replay_get_rangeify_map
|
||||
|
||||
# *** run replayers on captured rows and print diffs
|
||||
|
||||
|
|
|
|||
0
test/mockgpu/am/__init__.py
Normal file
0
test/mockgpu/am/__init__.py
Normal file
122
test/mockgpu/am/amdriver.py
Normal file
122
test/mockgpu/am/amdriver.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import ctypes, ctypes.util, mmap, functools
|
||||
from test.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile
|
||||
from test.mockgpu.am.amgpu import MockAMGPU, VRAM_SIZE
|
||||
|
||||
DOORBELL_SIZE = 0x2000
|
||||
BAR5_SIZE = (512 << 20)
|
||||
PCIBUS = "mock:am:0"
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
|
||||
_empty_bar = "0x0000000000000000 0x0000000000000000 0x0000000000000000"
|
||||
_resource_lines = [
|
||||
f"0x0000000000000000 0x{VRAM_SIZE-1:016x} 0x0000000000000000", _empty_bar,
|
||||
f"0x0000000000000000 0x{DOORBELL_SIZE-1:016x} 0x0000000000000000", _empty_bar, _empty_bar,
|
||||
f"0x0000000000000000 0x{BAR5_SIZE-1:016x} 0x0000000000000000", _empty_bar,
|
||||
]
|
||||
|
||||
class PagemapFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, gpu):
|
||||
super().__init__(fd)
|
||||
self.gpu = gpu
|
||||
def seek(self, offset): self.off = offset
|
||||
def read_contents(self, size=None):
|
||||
entries = bytearray()
|
||||
for i in range((size or 8) // 8):
|
||||
vaddr = ((self.off // 8) + i) * 0x1000
|
||||
paddr = self.gpu._next_sysmem_paddr
|
||||
self.gpu._next_sysmem_paddr += 0x1000
|
||||
self.gpu._sysmem_map[paddr] = vaddr
|
||||
entries += ((1 << 63) | (paddr // 0x1000)).to_bytes(8, 'little')
|
||||
self.off += len(entries)
|
||||
return bytes(entries)
|
||||
|
||||
class PCIBarFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, memfd, driver=None):
|
||||
super().__init__(fd)
|
||||
self.memfd, self.driver = memfd, driver
|
||||
def mmap(self, start, sz, prot, flags, fd, off):
|
||||
addr = libc.mmap(start, sz, prot, flags, self.memfd, off)
|
||||
if self.driver is not None:
|
||||
self.driver.track_address(addr, addr + sz, lambda mv, idx: None, lambda mv, idx: self.driver._emulate_execute())
|
||||
return addr
|
||||
|
||||
class PCIMMIOBarFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd, bar5_addr):
|
||||
super().__init__(fd)
|
||||
self.bar5_addr = bar5_addr
|
||||
def mmap(self, start, sz, prot, flags, fd, off): return self.bar5_addr + off
|
||||
|
||||
class PCIConfigFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd):
|
||||
super().__init__(fd)
|
||||
self.data = bytearray(256)
|
||||
def read_contents(self, size=None): return bytes(self.data[self.off:self.off + (size or len(self.data) - self.off)])
|
||||
def write_contents(self, content): self.data[self.off:self.off + len(content)] = content
|
||||
def seek(self, offset): self.off = offset
|
||||
|
||||
class PCIEnableFileDesc(VirtFileDesc):
|
||||
def __init__(self, fd): super().__init__(fd)
|
||||
def read_contents(self, size=None): return "1\n"
|
||||
def write_contents(self, content): pass
|
||||
|
||||
class AMDriver(VirtDriver):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gpus:dict[int, MockAMGPU] = {}
|
||||
self._executing = False
|
||||
self.gpu = MockAMGPU(0)
|
||||
self.gpus[0] = self.gpu
|
||||
self.next_fd = 1 << 30
|
||||
|
||||
self._bar5_addr = libc.mmap(0, BAR5_SIZE, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
|
||||
mmio = self.gpu.mmio
|
||||
self.track_address(self._bar5_addr, self._bar5_addr + BAR5_SIZE,
|
||||
lambda mv, idx: _bar5_sync_read(mv, idx, mmio), lambda mv, idx: _bar5_sync_write(mv, idx, mmio))
|
||||
|
||||
p = f"/sys/bus/pci/devices/{PCIBUS}"
|
||||
self.tracked_files += [
|
||||
VirtFile("/proc/sys/vm/compact_unevictable_allowed", functools.partial(TextFileDesc, text="0\n")),
|
||||
VirtFile("/proc/self/pagemap", functools.partial(PagemapFileDesc, gpu=self.gpu)),
|
||||
VirtFile("/sys/bus/pci/devices", functools.partial(DirFileDesc, child_names=[PCIBUS])),
|
||||
VirtFile(f"{p}/vendor", functools.partial(TextFileDesc, text="0x1002\n")),
|
||||
VirtFile(f"{p}/device", functools.partial(TextFileDesc, text="0x74a1\n")),
|
||||
VirtFile(f"{p}/enable", PCIEnableFileDesc),
|
||||
VirtFile(f"{p}/config", PCIConfigFileDesc),
|
||||
VirtFile(f"{p}/resource", functools.partial(TextFileDesc, text="\n".join(_resource_lines) + "\n")),
|
||||
VirtFile(f"{p}/resource0", functools.partial(PCIBarFileDesc, memfd=self.gpu.vram_fd)),
|
||||
VirtFile(f"{p}/resource2", functools.partial(PCIBarFileDesc, memfd=self.gpu.doorbell_fd, driver=self)),
|
||||
VirtFile(f"{p}/resource5", functools.partial(PCIMMIOBarFileDesc, bar5_addr=self._bar5_addr)),
|
||||
]
|
||||
|
||||
def _alloc_fd(self):
|
||||
fd = self.next_fd
|
||||
self.next_fd += 1
|
||||
return fd
|
||||
|
||||
def open(self, name, flags, mode, virtfile): return virtfile.fdcls(self._alloc_fd())
|
||||
|
||||
def _emulate_execute(self):
|
||||
if self._executing: return
|
||||
self._executing = True
|
||||
try:
|
||||
any_progress = True
|
||||
while any_progress:
|
||||
any_progress = False
|
||||
for gpu in self.gpus.values():
|
||||
for q in gpu.queues:
|
||||
if q.executing: any_progress |= q.execute() > 0
|
||||
finally:
|
||||
self._executing = False
|
||||
|
||||
def _bar5_sync_read(mv, idx, mmio):
|
||||
if isinstance(idx, slice):
|
||||
for i in range(idx.start or 0, idx.stop or len(mv), idx.step or 1): mv[i] = mmio[i]
|
||||
else: mv[idx] = mmio[idx]
|
||||
|
||||
def _bar5_sync_write(mv, idx, mmio):
|
||||
if isinstance(idx, slice):
|
||||
for i in range(idx.start or 0, idx.stop or len(mv), idx.step or 1): mmio[i] = mv[i]
|
||||
else: mmio[idx] = mv[idx]
|
||||
309
test/mockgpu/am/amgpu.py
Normal file
309
test/mockgpu/am/amgpu.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
# mypy: ignore-errors
|
||||
import ctypes, ctypes.util, struct, functools, os, mmap
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.runtime.support.amd import AMDReg, import_asic_regs
|
||||
from test.mockgpu.amd.amdgpu import AMDGPU
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
|
||||
VRAM_SIZE = 512 << 20
|
||||
|
||||
IP_VERSIONS = {
|
||||
am.GC_HWIP: (12, 0, 0), am.SDMA0_HWIP: (7, 0, 0), am.MMHUB_HWIP: (4, 1, 0), am.NBIO_HWIP: (6, 3, 1),
|
||||
am.MP0_HWIP: (14, 0, 2), am.MP1_HWIP: (14, 0, 2), am.HDP_HWIP: (7, 0, 0), am.OSSSYS_HWIP: (7, 0, 0),
|
||||
}
|
||||
|
||||
def _pad(t, n=10): return t + (0,) * (n - len(t))
|
||||
IP_BASES = {
|
||||
am.GC_HWIP: _pad((0x00001260, 0x0000A000, 0x0001C000, 0x02402C00)),
|
||||
am.SDMA0_HWIP: _pad((0x00001260, 0x0000A000, 0x0001C000, 0x02402C00)),
|
||||
am.MMHUB_HWIP: _pad((0x0001A000, 0x02408800)),
|
||||
am.NBIO_HWIP: _pad((0x00000000, 0x00000014, 0x00000D20, 0x00010400, 0x0241B000, 0x04040000)),
|
||||
am.MP0_HWIP: _pad((0x00016000, 0x00DC0000, 0x00E00000, 0x00E40000, 0x0243FC00)),
|
||||
am.MP1_HWIP: _pad((0x00016000, 0x00DC0000, 0x00E00000, 0x00E40000, 0x0243FC00)),
|
||||
am.HDP_HWIP: _pad((0x00000F20, 0x0240A400)),
|
||||
am.OSSSYS_HWIP: _pad((0x000010A0, 0x0240A000)),
|
||||
}
|
||||
|
||||
IP_HWIDS = {hwip: am.hw_id_map[hwip] for hwip in IP_VERSIONS}
|
||||
|
||||
GC_INFO = dict(gc_num_se=2, gc_num_cu_per_sh=8, gc_num_sh_per_se=2, gc_num_rb_per_se=4,
|
||||
gc_num_tccs=8, gc_wave_size=32, gc_max_waves_per_simd=16, gc_max_scratch_slots_per_cu=32, gc_lds_size=64)
|
||||
|
||||
def _build_ip_regs(prefix, hwip) -> dict[str, AMDReg]:
|
||||
try: return import_asic_regs(prefix, IP_VERSIONS[hwip], cls=functools.partial(AMDReg, bases={0: IP_BASES[hwip]}))
|
||||
except Exception: return {}
|
||||
|
||||
class MockMMU:
|
||||
def __init__(self, gpu:'MockAMGPU'):
|
||||
self.gpu = gpu
|
||||
self.tlb: dict[int, tuple[int, int, bool]] = {}
|
||||
|
||||
def invalidate(self, pt_base:int, va_base:int):
|
||||
new_tlb: dict[int, tuple[int, int, bool]] = {}
|
||||
self._walk(pt_base, 0, 0, new_tlb, va_base)
|
||||
for va, (pa, sz, is_sys) in new_tlb.items():
|
||||
if va not in self.tlb:
|
||||
if not is_sys: self.gpu.map_vram_at(va, pa, sz)
|
||||
self.gpu.map_range(va, sz)
|
||||
self.tlb = new_tlb
|
||||
|
||||
def _walk(self, pt_paddr:int, level:int, va_acc:int, out:dict, va_base:int):
|
||||
shift = [39, 30, 21, 12][level]
|
||||
for i in range(512):
|
||||
pte = struct.unpack_from('<Q', self.gpu.vram, pt_paddr + i * 8)[0]
|
||||
if not (pte & am.AMDGPU_PTE_VALID): continue
|
||||
va, pa = va_acc | (i << shift), pte & 0x0000FFFFFFFFF000
|
||||
if level == 3 or (pte & am.AMDGPU_PDE_PTE_GFX12):
|
||||
out[va_base + va] = (pa, 1 << shift, bool(pte & am.AMDGPU_PTE_SYSTEM))
|
||||
else:
|
||||
self._walk(pa, level + 1, va, out, va_base)
|
||||
|
||||
def paddr_to_host(self, paddr:int) -> int:
|
||||
if paddr < VRAM_SIZE: return self.gpu.vram_addr + paddr
|
||||
page, off = paddr & ~0xFFF, paddr & 0xFFF
|
||||
return self.gpu._sysmem_map[page] + off
|
||||
|
||||
def addr_to_host(self, addr:int) -> int:
|
||||
gmc = self.gpu.mmio.gmc
|
||||
sys_lo = self.gpu.mmio.regs.get(gmc.reg('regMMMC_VM_SYSTEM_APERTURE_LOW_ADDR') or 0, 0) << 18
|
||||
sys_hi = self.gpu.mmio.regs.get(gmc.reg('regMMMC_VM_SYSTEM_APERTURE_HIGH_ADDR') or 0, 0) << 18
|
||||
if sys_lo <= addr < sys_hi: return self.paddr_to_host(addr - self.gpu.mc_base)
|
||||
for tva, (pa, sz, is_sys) in self.tlb.items():
|
||||
if tva <= addr < tva + sz:
|
||||
if not is_sys: return addr
|
||||
return self.paddr_to_host(pa + (addr - tva))
|
||||
raise ValueError(f"addr {addr:#x} not mapped (sys_aperture=[{sys_lo:#x}, {sys_hi:#x}])")
|
||||
|
||||
class MockIPBlock:
|
||||
def __init__(self, gpu:'MockAMGPU', mmio:'MockMMIOInterface', regs:dict[str, AMDReg]):
|
||||
self.gpu, self.mmio, self._regs = gpu, mmio, regs
|
||||
self._n2a = {n: r.addr[0] for n, r in regs.items()}
|
||||
self._a2n = {a: n for n, a in self._n2a.items()}
|
||||
self.addrs = set(self._n2a.values())
|
||||
def reg(self, name) -> int|None: return self._n2a.get(name)
|
||||
def decode(self, name) -> dict: return self._regs[name].decode(self.mmio.regs.get(self._n2a[name], 0))
|
||||
def read(self, reg:int) -> int: return self.mmio.regs.get(reg, 0)
|
||||
def write(self, reg:int, val:int): self.mmio.regs[reg] = val
|
||||
def _read_pair(self, pair) -> int:
|
||||
if pair[0] is None: return 0
|
||||
return self.mmio.regs.get(pair[0], 0) | (self.mmio.regs.get(pair[1], 0) << 32)
|
||||
|
||||
class MockPSP(MockIPBlock):
|
||||
def __init__(self, gpu, mmio):
|
||||
super().__init__(gpu, mmio, _build_ip_regs('mp', am.MP0_HWIP))
|
||||
self._sos_alive, self._ring_wptr = False, 0
|
||||
pref = "regMPASP_SMN_C2PMSG" if IP_VERSIONS[am.MP0_HWIP] >= (14,0,0) else "regMP0_SMN_C2PMSG"
|
||||
def r(n): return self.reg(f"{pref}_{n}")
|
||||
self._c2pmsg_35, self._c2pmsg_64, self._c2pmsg_67 = r(35), r(64), r(67)
|
||||
self._c2pmsg_69, self._c2pmsg_70, self._c2pmsg_81 = r(69), r(70), r(81)
|
||||
|
||||
def read(self, reg:int) -> int:
|
||||
if reg == self._c2pmsg_35: return 0x80000000
|
||||
if reg == self._c2pmsg_81: return 0x1 if self._sos_alive else 0x0
|
||||
if reg == self._c2pmsg_64: return 0x80000000 if self._sos_alive else 0x0
|
||||
if reg == self._c2pmsg_67: return self._ring_wptr
|
||||
return super().read(reg)
|
||||
|
||||
def write(self, reg:int, val:int):
|
||||
super().write(reg, val)
|
||||
if reg == self._c2pmsg_35 and val == am.PSP_BL__LOAD_SOSDRV: self._sos_alive = True
|
||||
if reg == self._c2pmsg_67: self._ring_submit(val)
|
||||
|
||||
def _ring_submit(self, new_wptr:int):
|
||||
old_wptr = self._ring_wptr
|
||||
self._ring_wptr = new_wptr
|
||||
lo, hi = self._c2pmsg_69, self._c2pmsg_70
|
||||
if lo is None or hi is None: return
|
||||
ring_mc = self.mmio.regs.get(lo, 0) | (self.mmio.regs.get(hi, 0) << 32)
|
||||
ring_paddr = ring_mc - self.gpu.mc_base
|
||||
frame_off = ring_paddr + old_wptr * 4
|
||||
frame = am.struct_psp_gfx_rb_frame.from_buffer_copy(bytes(self.gpu.vram[frame_off:frame_off + ctypes.sizeof(am.struct_psp_gfx_rb_frame)]))
|
||||
fence_paddr = ((frame.fence_addr_hi << 32) | frame.fence_addr_lo) - self.gpu.mc_base
|
||||
if 0 <= fence_paddr < len(self.gpu.vram):
|
||||
struct.pack_into('<I', self.gpu.vram, fence_paddr, frame.fence_value)
|
||||
cmd_paddr = ((frame.cmd_buf_addr_hi << 32) | frame.cmd_buf_addr_lo) - self.gpu.mc_base
|
||||
if 0 <= cmd_paddr < len(self.gpu.vram):
|
||||
struct.pack_into('<I', self.gpu.vram, cmd_paddr + 864, 0)
|
||||
|
||||
class MockSMU(MockIPBlock):
|
||||
def __init__(self, gpu, mmio):
|
||||
try: regs = import_asic_regs('mp', (11, 0), cls=functools.partial(AMDReg, bases={0: IP_BASES[am.MP1_HWIP]}))
|
||||
except Exception: regs = {}
|
||||
super().__init__(gpu, mmio, regs)
|
||||
self._msg_pending = False
|
||||
def r(n): return self.reg(f"mmMP1_SMN_C2PMSG_{n}")
|
||||
self._c2pmsg_53, self._c2pmsg_54, self._c2pmsg_66 = r(53), r(54), r(66)
|
||||
self._c2pmsg_75, self._c2pmsg_82, self._c2pmsg_90 = r(75), r(82), r(90)
|
||||
|
||||
def read(self, reg:int) -> int:
|
||||
if reg == self._c2pmsg_90 or reg == self._c2pmsg_54: return 0x1 if self._msg_pending else super().read(reg)
|
||||
if reg == self._c2pmsg_82: return self.mmio.regs.get(reg, 3)
|
||||
return super().read(reg)
|
||||
|
||||
def write(self, reg:int, val:int):
|
||||
super().write(reg, val)
|
||||
if reg == self._c2pmsg_66 or reg == self._c2pmsg_75: self._msg_pending = True
|
||||
if (reg == self._c2pmsg_90 or reg == self._c2pmsg_54) and val == 0: self._msg_pending = False
|
||||
|
||||
class MockSDMA(MockIPBlock):
|
||||
def __init__(self, gpu, mmio):
|
||||
all_gc = _build_ip_regs('gc', am.GC_HWIP)
|
||||
super().__init__(gpu, mmio, {n: r for n, r in all_gc.items() if 'SDMA' in n})
|
||||
|
||||
def write(self, reg:int, val:int):
|
||||
super().write(reg, val)
|
||||
name = self._a2n.get(reg, '')
|
||||
if name.endswith('_RB_CNTL') and self._regs[name].decode(val).get('rb_enable', 0):
|
||||
self._activate_queue(name.rsplit('_RB_CNTL', 1)[0])
|
||||
|
||||
def _activate_queue(self, prefix:str):
|
||||
ring_addr = self._read_pair((self.reg(f'{prefix}_RB_BASE'), self.reg(f'{prefix}_RB_BASE_HI'))) << 8
|
||||
rptr_addr = self._read_pair((self.reg(f'{prefix}_RB_RPTR_ADDR_LO'), self.reg(f'{prefix}_RB_RPTR_ADDR_HI')))
|
||||
wptr_addr = self._read_pair((self.reg(f'{prefix}_RB_WPTR_POLL_ADDR_LO'), self.reg(f'{prefix}_RB_WPTR_POLL_ADDR_HI')))
|
||||
rb_size = self.decode(f'{prefix}_RB_CNTL')['rb_size']
|
||||
self.gpu.add_sdma_queue(self.gpu.mmu.addr_to_host(ring_addr), 4 << rb_size,
|
||||
self.gpu.mmu.addr_to_host(rptr_addr), self.gpu.mmu.addr_to_host(wptr_addr))
|
||||
|
||||
class MockGFX(MockIPBlock):
|
||||
def __init__(self, gpu, mmio):
|
||||
super().__init__(gpu, mmio, _build_ip_regs('gc', am.GC_HWIP))
|
||||
self._pt_base = (self.reg('regGCVM_CONTEXT0_PAGE_TABLE_BASE_ADDR_LO32'), self.reg('regGCVM_CONTEXT0_PAGE_TABLE_BASE_ADDR_HI32'))
|
||||
self._pt_start = (self.reg('regGCVM_CONTEXT0_PAGE_TABLE_START_ADDR_LO32'), self.reg('regGCVM_CONTEXT0_PAGE_TABLE_START_ADDR_HI32'))
|
||||
self._gc_inv_ack = self.reg('regGCVM_INVALIDATE_ENG17_ACK')
|
||||
self._gc_inv_req = self.reg('regGCVM_INVALIDATE_ENG17_REQ')
|
||||
self._hqd_active = self.reg('regCP_HQD_ACTIVE')
|
||||
|
||||
def read(self, reg:int) -> int:
|
||||
if reg == self.reg('regCP_STAT') or reg == self.reg('regRLC_SAFE_MODE'): return 0
|
||||
if reg == self.reg('regRLC_RLCS_BOOTLOAD_STATUS'): return 0x2
|
||||
if reg == self._gc_inv_ack: return 0x1
|
||||
return super().read(reg)
|
||||
|
||||
def write(self, reg:int, val:int):
|
||||
super().write(reg, val)
|
||||
if reg == self.reg('regCP_HQD_DEQUEUE_REQUEST'):
|
||||
if self._hqd_active is not None: self.mmio.regs[self._hqd_active] = 0
|
||||
if reg == self._hqd_active and val == 1: self._activate_pm4_queue()
|
||||
if reg == self._gc_inv_req: self.gpu.mmu.invalidate(self.get_pt_base(), self.get_va_base())
|
||||
|
||||
def _activate_pm4_queue(self):
|
||||
ring_addr = self._read_pair((self.reg('regCP_HQD_PQ_BASE'), self.reg('regCP_HQD_PQ_BASE_HI'))) << 8
|
||||
rptr_addr = self._read_pair((self.reg('regCP_HQD_PQ_RPTR_REPORT_ADDR'), self.reg('regCP_HQD_PQ_RPTR_REPORT_ADDR_HI')))
|
||||
wptr_addr = self._read_pair((self.reg('regCP_HQD_PQ_WPTR_POLL_ADDR'), self.reg('regCP_HQD_PQ_WPTR_POLL_ADDR_HI')))
|
||||
queue_size = self.decode('regCP_HQD_PQ_CONTROL')['queue_size']
|
||||
self.gpu.add_pm4_queue(self.gpu.mmu.addr_to_host(ring_addr), 4 << (queue_size + 1),
|
||||
self.gpu.mmu.addr_to_host(rptr_addr), self.gpu.mmu.addr_to_host(wptr_addr))
|
||||
|
||||
def get_pt_base(self) -> int: return self._read_pair(self._pt_base) & 0x0000FFFFFFFFF000
|
||||
def get_va_base(self) -> int: return self._read_pair(self._pt_start) << 12
|
||||
|
||||
class MockGMC(MockIPBlock):
|
||||
def __init__(self, gpu, mmio, gfx:MockGFX):
|
||||
super().__init__(gpu, mmio, _build_ip_regs('mmhub', am.MMHUB_HWIP))
|
||||
self._gfx = gfx
|
||||
self._inv_ack = self.reg('regMMVM_INVALIDATE_ENG17_ACK')
|
||||
self._inv_sem = self.reg('regMMVM_INVALIDATE_ENG17_SEM')
|
||||
self._inv_req = self.reg('regMMVM_INVALIDATE_ENG17_REQ')
|
||||
self._fb_loc_top = self.reg('regMMMC_VM_FB_LOCATION_TOP')
|
||||
|
||||
def read(self, reg:int) -> int:
|
||||
if reg == self._inv_ack or reg == self._inv_sem: return 0x1
|
||||
if reg == self._fb_loc_top: return VRAM_SIZE >> 24
|
||||
return super().read(reg)
|
||||
|
||||
def write(self, reg:int, val:int):
|
||||
super().write(reg, val)
|
||||
if reg == self._inv_req: self.gpu.mmu.invalidate(self._gfx.get_pt_base(), self._gfx.get_va_base())
|
||||
|
||||
class MockNBIO(MockIPBlock):
|
||||
def __init__(self, gpu, mmio):
|
||||
regs = _build_ip_regs('nbif', am.NBIO_HWIP)
|
||||
regs.update(_build_ip_regs('hdp', am.HDP_HWIP))
|
||||
super().__init__(gpu, mmio, regs)
|
||||
self._remap_hdp = self.reg('regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL')
|
||||
self._hdp_flush = self.reg('regHDP_MEM_FLUSH_CNTL')
|
||||
|
||||
def read(self, reg:int) -> int:
|
||||
if reg == self._remap_hdp and self._hdp_flush is not None: return self._hdp_flush * 4
|
||||
return super().read(reg)
|
||||
|
||||
class MockMMIOInterface:
|
||||
def __init__(self, gpu:'MockAMGPU'):
|
||||
self.gpu = gpu
|
||||
self.regs: dict[int, int] = {}
|
||||
gfx = MockGFX(gpu, self)
|
||||
self.gmc = MockGMC(gpu, self, gfx)
|
||||
self.blocks = [MockPSP(gpu, self), MockSMU(gpu, self), MockSDMA(gpu, self), gfx, self.gmc, MockNBIO(gpu, self)]
|
||||
self._addr_block: dict[int, MockIPBlock] = {}
|
||||
for block in self.blocks:
|
||||
for addr in block.addrs: self._addr_block.setdefault(addr, block)
|
||||
|
||||
def __getitem__(self, index:int|slice) -> int|list[int]:
|
||||
if isinstance(index, slice): return [self[i] for i in range(index.start or 0, index.stop or 0, index.step or 1)] # type: ignore[misc]
|
||||
if index == 0xde3: return VRAM_SIZE >> 20
|
||||
if block := self._addr_block.get(index): return block.read(index)
|
||||
return self.regs.get(index, 0)
|
||||
|
||||
def __setitem__(self, index:int|slice, val:int|list[int]|tuple[int, ...]):
|
||||
if isinstance(index, slice):
|
||||
vals = val if isinstance(val, (list, tuple)) else [val] * ((index.stop - index.start) // (index.step or 1)) # type: ignore[operator]
|
||||
for i, v in zip(range(index.start or 0, index.stop or 0, index.step or 1), vals): self[i] = v
|
||||
return
|
||||
assert isinstance(val, int)
|
||||
self.regs[index] = val
|
||||
if block := self._addr_block.get(index): block.write(index, val)
|
||||
|
||||
def __len__(self): return 0x10000000
|
||||
|
||||
class MockAMGPU(AMDGPU):
|
||||
def __init__(self, gpuid:int=0):
|
||||
super().__init__(gpuid)
|
||||
self.vram_fd = os.memfd_create("vram")
|
||||
os.ftruncate(self.vram_fd, VRAM_SIZE)
|
||||
self.vram_addr = libc.mmap(0, VRAM_SIZE, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, self.vram_fd, 0)
|
||||
self.vram = (ctypes.c_ubyte * VRAM_SIZE).from_address(self.vram_addr)
|
||||
self.doorbell_fd = os.memfd_create("doorbell")
|
||||
os.ftruncate(self.doorbell_fd, 0x2000)
|
||||
self.arch = "rdna4"
|
||||
self._sysmem_map:dict[int,int] = {}
|
||||
self._next_sysmem_paddr = 0x100000000
|
||||
self.mmu = MockMMU(self)
|
||||
self.mmio = MockMMIOInterface(self)
|
||||
self._preboot()
|
||||
|
||||
def map_vram_at(self, va:int, paddr:int, size:int):
|
||||
libc.mmap(va, size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | 0x10, self.vram_fd, paddr)
|
||||
|
||||
def _preboot(self):
|
||||
ip_data = bytearray()
|
||||
for hwip, (major, minor, rev) in IP_VERSIONS.items():
|
||||
ip = am.struct_ip_v4(hw_id=IP_HWIDS[hwip], num_base_address=len(IP_BASES[hwip]), major=major, minor=minor, revision=rev)
|
||||
ip_data += bytes(ip) + b'\x00'
|
||||
for b in IP_BASES[hwip]: ip_data += struct.pack('<I', b)
|
||||
|
||||
dhdr = am.struct_die_header(num_ips=len(IP_VERSIONS))
|
||||
ihdr = am.struct_ip_discovery_header(signature=am.DISCOVERY_TABLE_SIGNATURE, version=4, num_dies=1)
|
||||
ip_disc_off = ctypes.sizeof(am.struct_binary_header)
|
||||
ihdr.die_info[0].die_offset = ip_disc_off + ctypes.sizeof(am.struct_ip_discovery_header)
|
||||
|
||||
gc = am.struct_gc_info_v2_1()
|
||||
gc.header.table_id, gc.header.version_major, gc.header.version_minor = am.GC, 2, 1
|
||||
gc.header.size = ctypes.sizeof(am.struct_gc_info_v2_1)
|
||||
for field, val in GC_INFO.items(): setattr(gc, field, val)
|
||||
|
||||
gc_off = ip_disc_off + ctypes.sizeof(am.struct_ip_discovery_header) + ctypes.sizeof(am.struct_die_header) + len(ip_data)
|
||||
bhdr = am.struct_binary_header(binary_signature=am.BINARY_SIGNATURE)
|
||||
bhdr.table_list[am.IP_DISCOVERY].offset = ip_disc_off
|
||||
bhdr.table_list[am.GC].offset = gc_off
|
||||
|
||||
tbl = bytes(bhdr) + bytes(ihdr) + bytes(dhdr) + ip_data + bytes(gc)
|
||||
tbl_offset = VRAM_SIZE - (64 << 10)
|
||||
self.vram[tbl_offset:tbl_offset + len(tbl)] = list(tbl)
|
||||
|
||||
@property
|
||||
def mc_base(self) -> int:
|
||||
fb_loc_base = self.mmio.gmc.reg('regMMMC_VM_FB_LOCATION_BASE') or 0
|
||||
return (self.mmio.regs.get(fb_loc_base, 0) & 0xFFFFFF) << 24
|
||||
|
|
@ -2,6 +2,7 @@ import ctypes, ctypes.util, time, os, builtins, fcntl
|
|||
from tinygrad.runtime.support.hcq import FileIOInterface
|
||||
from test.mockgpu.nv.nvdriver import NVDriver
|
||||
from test.mockgpu.amd.amddriver import AMDDriver
|
||||
from test.mockgpu.am.amdriver import AMDriver
|
||||
start = time.perf_counter()
|
||||
|
||||
# *** ioctl lib ***
|
||||
|
|
@ -9,7 +10,7 @@ libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
|||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
libc.mmap.restype = ctypes.c_void_p
|
||||
|
||||
drivers = [AMDDriver(), NVDriver()]
|
||||
drivers = [NVDriver(), AMDriver() if os.environ.get("AMD_IFACE") == "PCI" else AMDDriver()]
|
||||
tracked_fds = {}
|
||||
|
||||
original_memoryview = builtins.memoryview
|
||||
|
|
@ -77,9 +78,10 @@ class MockFileIOInterface(FileIOInterface):
|
|||
return libc.mmap(start, sz, prot, flags, self.fd, offset)
|
||||
|
||||
def read(self, size=None, binary=False, offset=None):
|
||||
if binary: raise NotImplementedError()
|
||||
if self.fd in tracked_fds:
|
||||
if offset is not None: tracked_fds[self.fd].seek(offset)
|
||||
return tracked_fds[self.fd].read_contents(size)
|
||||
if binary: raise NotImplementedError()
|
||||
with open(self.fd, "rb" if binary else "r", closefd=False) as file:
|
||||
if file.tell() >= os.fstat(self.fd).st_size: file.seek(0)
|
||||
return file.read(size)
|
||||
|
|
@ -89,13 +91,20 @@ class MockFileIOInterface(FileIOInterface):
|
|||
return tracked_fds[self.fd].list_contents()
|
||||
return os.listdir(self.path)
|
||||
|
||||
def write(self, content, binary=False, offset=None): raise NotImplementedError()
|
||||
def write(self, content, binary=False, offset=None):
|
||||
if self.fd in tracked_fds:
|
||||
if offset is not None: tracked_fds[self.fd].seek(offset)
|
||||
return tracked_fds[self.fd].write_contents(content)
|
||||
raise NotImplementedError()
|
||||
def seek(self, offset):
|
||||
if self.fd in tracked_fds:
|
||||
tracked_fds[self.fd].seek(offset)
|
||||
else:
|
||||
os.lseek(self.fd, offset, os.SEEK_CUR)
|
||||
@staticmethod
|
||||
def anon_mmap(start, sz, prot, flags, offset):
|
||||
return FileIOInterface._mmap(start, sz, prot, flags & ~0x4a000, -1, offset) # strip MAP_LOCKED|MAP_POPULATE|MAP_HUGETLB
|
||||
@staticmethod
|
||||
def exists(path): return _open(path, os.O_RDONLY) is not None
|
||||
@staticmethod
|
||||
def readlink(path): raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ def _check_ast_count(desired_count:int, t:Tensor):
|
|||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
schedule = t.schedule()
|
||||
asts = [s for s in schedule if s.ast.op is Ops.SINK]
|
||||
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
|
||||
len(asts)
|
||||
# NOT SUPPORTED ANYMORE
|
||||
#assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
|
||||
|
||||
class TestUnaryOpsConstFolding(unittest.TestCase):
|
||||
def test_all_consts_ops(self):
|
||||
|
|
|
|||
101
test/null/test_gpudims.py
Normal file
101
test/null/test_gpudims.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import unittest, math
|
||||
import z3
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import flatten, dedup
|
||||
|
||||
class TestGroupedDims(unittest.TestCase):
|
||||
def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
|
||||
sizes = [x.src[0].arg for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
self._verify_indices_z3(idxs, dims)
|
||||
|
||||
def _verify_indices_z3(self, idxs, dims):
|
||||
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
|
||||
total = math.prod(dims)
|
||||
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
|
||||
# build flat index and primed flat (same expression with renamed SPECIALs)
|
||||
flat = UOp.const(dtypes.index, 0)
|
||||
for i, idx in enumerate(idxs):
|
||||
flat = flat + idx * int(math.prod(dims[i+1:]))
|
||||
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
|
||||
solver = z3.Solver()
|
||||
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
|
||||
# bounds
|
||||
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
|
||||
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
|
||||
# injectivity: flat == flat' but inputs differ => unsat
|
||||
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
|
||||
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
|
||||
|
||||
def test_grouped_dims(self):
|
||||
# no-op
|
||||
self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2])
|
||||
self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
|
||||
|
||||
# check reverse dims
|
||||
self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 2
|
||||
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
# 2 -> 3
|
||||
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
def test_grouped_direct_dims_are_special(self):
|
||||
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
|
||||
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
|
||||
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
|
||||
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
|
||||
|
||||
def test_max_sizes_none(self):
|
||||
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
|
||||
self._check_grouped_dims("gidx", (100,), None, False, [100])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
import unittest
|
||||
import gc, unittest
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
|
||||
class TestMultiRamUsage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
self.baseline = GlobalCounters.mem_used
|
||||
self.N = 100
|
||||
def assertUsed(self, amt, strict=True):
|
||||
gc.collect()
|
||||
used = GlobalCounters.mem_used - self.baseline
|
||||
print(f"used {used} bytes")
|
||||
if strict: self.assertEqual(used, amt)
|
||||
|
|
@ -20,20 +22,17 @@ class TestMultiRamUsage(unittest.TestCase):
|
|||
del _
|
||||
self.assertUsed(0)
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_zeros_copy(self):
|
||||
devices_2 = ("NULL:1", "NULL:2")
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
||||
# NOTE: the first one on the DEFAULT device should be freed
|
||||
self.assertUsed(self.N*self.N*4*2)
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_zeros_shard(self, devices=("NULL:1", "NULL:2")):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices, axis=0).realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
def test_zeros_shard_self(self): self.test_zeros_shard(("NULL:0", "NULL:1"))
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_zeros_contiguous_shard(self):
|
||||
devices_2 = ("NULL:1", "NULL:2")
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
|
|
@ -54,5 +53,26 @@ class TestMultiRamUsage(unittest.TestCase):
|
|||
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
||||
|
||||
class TestMultiAxis(unittest.TestCase):
|
||||
def test_reshape_shard_invalid(self):
|
||||
devices = ("NULL:0", "NULL:1")
|
||||
t = Tensor.ones(4, 3).shard(devices, axis=0)
|
||||
with self.assertRaises(RuntimeError, msg="reshape cannot move items between shards"):
|
||||
t.reshape(3, 4).uop.axis
|
||||
|
||||
def test_reshape_shard_valid(self):
|
||||
devices = ("NULL:0", "NULL:1")
|
||||
t = Tensor.ones(4, 8).shard(devices, axis=0)
|
||||
self.assertEqual(t.reshape(2, 16).uop.axis, 0)
|
||||
self.assertEqual(t.reshape(2, 2, 8).uop.axis, 0)
|
||||
|
||||
def test_empty_like_sharded(self):
|
||||
t = Tensor.ones(4, 8).shard(("NULL:0", "NULL:1"), axis=0)
|
||||
e = t.empty_like()
|
||||
self.assertEqual(e.shape, t.shape)
|
||||
self.assertEqual(e.device, t.device)
|
||||
self.assertEqual(e.uop.axis, 0)
|
||||
self.assertTrue(e.uop.has_buffer_identity())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import CPU_LLVM, CPU_LVP, CPU_X86
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
class TestOpts(unittest.TestCase):
|
||||
def test_opt_upcast(self):
|
||||
opts = (Opt(OptOps.UPCAST, 0, 4),)
|
||||
a = Tensor.empty(16)
|
||||
b = Tensor.empty(16)
|
||||
out = (a+b).contiguous(arg=opts)
|
||||
s = out.schedule()
|
||||
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
|
||||
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP and not CPU_X86:
|
||||
prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
self.assertIn('float4', prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
@ -1,202 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
class TestRewriteMap(unittest.TestCase):
|
||||
def test_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
ret = (a+b)*c
|
||||
sub = {a+b: e}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], e*c)
|
||||
|
||||
def test_substitute_depth_2(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
f = UOp.variable('f', 0, 10)
|
||||
ret = (a+b)*c+d
|
||||
sub = {a+b: e, (a+b)*c: f}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], f)
|
||||
|
||||
def test_multistage_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
sub1 = {a+b:c}
|
||||
start = (a+b)*c
|
||||
# stage 1: (a+b)*c -> c*c
|
||||
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
|
||||
self.assertIs(sub_map1[(a+b)*c], c*c)
|
||||
# stage 2: c*c -> d
|
||||
sub2 = {c*c:d}
|
||||
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
|
||||
# (a+b)*c -> c*c -> d
|
||||
self.assertIs(sub_map2[(a+b)*c], d)
|
||||
|
||||
def test_add_zero(self):
|
||||
# Build a small graph: add(0, add(const=0, const=5))
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
inner_add = zero_node + five_node
|
||||
root_add = zero_node + inner_add
|
||||
|
||||
# Perform top-down rewrite
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
|
||||
# Check the mapping
|
||||
assert node_map[root_add] == five_node
|
||||
assert node_map[inner_add] == five_node
|
||||
# zero_node and five_node map to themselves
|
||||
assert node_map[zero_node] == zero_node
|
||||
assert node_map[five_node] == five_node
|
||||
|
||||
def test_double_neg(self):
|
||||
"""
|
||||
Test rewriting neg(neg(5)) => 5 using symbolic.
|
||||
"""
|
||||
# In some versions of TinyGrad, you might do: (-(-five_node))
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
# If your code allows UOp(...), do that; else you might do something like:
|
||||
# double_neg_five = -(-five_node)
|
||||
# But let's be explicit:
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
|
||||
node_map = graph_rewrite_map(double_neg_five, symbolic)
|
||||
|
||||
# node_map should map double_neg_five -> five_node
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# five_node maps to itself
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_add_zero_and_double_neg(self):
|
||||
"""
|
||||
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
|
||||
"""
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
five_node = UOp.const(dtypes.index, 5)
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
root_add = zero_node + double_neg_five
|
||||
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# node_map: root_add -> five_node, double_neg_five -> five_node
|
||||
self.assertEqual(node_map[root_add], five_node)
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# zero_node, five_node map to themselves
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_multi_var_rewrites(self):
|
||||
x_var = UOp.variable('x', 0, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
|
||||
sum_with_zero = y_var + zero_node # (y + 0)
|
||||
combined = x_var + sum_with_zero # x + (y + 0)
|
||||
double_neg = -(-combined) # neg(neg(x + y))
|
||||
final_expr = zero_node + double_neg # 0 + (x + y)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# The final root should be (x_var + y_var).
|
||||
expected = x_var + y_var
|
||||
|
||||
# Each sub-expression has its own "final" result.
|
||||
# (y + 0) -> y_var
|
||||
self.assertEqual(node_map[sum_with_zero], y_var)
|
||||
# (x + (y+0)) -> (x + y)
|
||||
self.assertEqual(node_map[combined], expected)
|
||||
# neg(neg(x+y)) -> (x + y)
|
||||
self.assertEqual(node_map[double_neg], expected)
|
||||
# 0 + (x+y) -> (x + y)
|
||||
self.assertEqual(node_map[final_expr], expected)
|
||||
|
||||
# x_var, y_var, zero_node remain unchanged
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
|
||||
def test_complex_multi_var_edges(self):
|
||||
"""
|
||||
Build a multi-variable expression with multiple intermediates:
|
||||
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
one_node = UOp.const(dtypes.int, 1)
|
||||
|
||||
yz_sum = y_var + z_var
|
||||
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
|
||||
yz_neg = -yz_sum_zero -> -(y+z)
|
||||
yz_dneg = -yz_neg -> y+z (double neg gone)
|
||||
x_plus_yz = x_var + yz_dneg -> x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) -> x + (y+z)
|
||||
final_expr = double_neg_x * one_node -> x + (y+z)
|
||||
|
||||
We expect the final result to be (x + (y+z)).
|
||||
Each original node should map to the final node that replaces it,
|
||||
which might be structurally equivalent but not the same reference.
|
||||
"""
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.index, 0)
|
||||
one_node = UOp.const(dtypes.index, 1)
|
||||
|
||||
# Build sub-expressions
|
||||
yz_sum = y_var + z_var # (y + z)
|
||||
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
|
||||
yz_neg = -yz_sum_zero # -(y+z)
|
||||
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
|
||||
x_plus_yz = x_var + yz_dneg # x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
|
||||
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# (y + z) is unchanged
|
||||
self.assertEqual(node_map[yz_sum], yz_sum)
|
||||
|
||||
# (y+z) + 0 => (y+z)
|
||||
self.assertEqual(node_map[yz_sum_zero], yz_sum)
|
||||
|
||||
# -(y+z) remains -(y+z), but might be a new UOp with updated children
|
||||
# Compare structurally to -(y_var + z_var).
|
||||
self.assertEqual(node_map[yz_neg], -yz_sum)
|
||||
|
||||
# -(-(y+z)) => (y+z)
|
||||
self.assertEqual(node_map[yz_dneg], yz_sum)
|
||||
|
||||
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
|
||||
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
|
||||
|
||||
# -(-(x+(y+z))) => x + (y+z)
|
||||
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
|
||||
|
||||
# (x+(y+z)) * 1 => x+(y+z)
|
||||
self.assertEqual(node_map[final_expr], x_var + yz_sum)
|
||||
|
||||
# Unchanged atomic nodes map to themselves
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[z_var], z_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[one_node], one_node)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
# schedule tests that pass on NULL backend (no copyout needed)
|
||||
import unittest, time
|
||||
import gc, unittest, time
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
|
|
@ -59,7 +59,7 @@ class TestBufferUOp(unittest.TestCase):
|
|||
|
||||
def test_buffer_view_not_allowed(self):
|
||||
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
|
||||
with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"):
|
||||
with self.assertRaises(RuntimeError):
|
||||
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
|
||||
|
||||
def test_buffer_only_after_realize(self):
|
||||
|
|
@ -74,7 +74,7 @@ class TestBufferUOp(unittest.TestCase):
|
|||
self.assertIsNotNone(a.uop.buffer)
|
||||
|
||||
def test_const_does_not_realize(self):
|
||||
a = Tensor(1)+Tensor(2)
|
||||
a = Tensor(1)
|
||||
run_schedule(check_schedule(a, 0))
|
||||
self.assertIsNone(a.uop.base.realized)
|
||||
|
||||
|
|
@ -191,6 +191,13 @@ class TestSchedule(unittest.TestCase):
|
|||
a, _ = Tensor.empty(1022).cummax(axis=0)
|
||||
check_schedule(a, 3)
|
||||
|
||||
@unittest.skip("should this pass?")
|
||||
def test_contiguous_assign(self):
|
||||
a = Tensor.ones(10) * 2
|
||||
b = Tensor.empty(10)
|
||||
c = b.assign(a.contiguous())
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
|
@ -198,6 +205,14 @@ class TestSchedule(unittest.TestCase):
|
|||
d = a+b+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_basic_binop_fusion_assign(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = Tensor.empty(10)
|
||||
d = a+b+c
|
||||
e = Tensor.empty(10).assign(d)
|
||||
check_schedule(e, 1)
|
||||
|
||||
def test_basic_binop_fusion_deep(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
|
@ -212,6 +227,13 @@ class TestSchedule(unittest.TestCase):
|
|||
c = (a*b).sum()
|
||||
check_schedule(c, 1)
|
||||
|
||||
def test_mulacc_fusion_assign(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = (a*b).sum()
|
||||
d = Tensor.empty(1).assign(c)
|
||||
check_schedule(d, 1)
|
||||
|
||||
def test_mulacc_relu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
|
@ -390,20 +412,20 @@ class TestSchedule(unittest.TestCase):
|
|||
out = bn(c1(img)).relu()
|
||||
check_schedule(out, 4, [c1.weight, c1.bias])
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 27), (nn.optim.SGD, 7)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
_realize_weights([c1, bn])
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
def test_fold_conv_batchnorm_optim(self, adam=False):
|
||||
# 2 is too low?
|
||||
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
_realize_weights([c1, bn])
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
|
||||
|
||||
def test_fold_batchnorm_backward(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -627,6 +649,7 @@ class TestSchedule(unittest.TestCase):
|
|||
t = Tensor([1.0, 2.0, 3.0]) ** 8
|
||||
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_pow_const_tensor_to_zero(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
out = x ** Tensor(0.0)
|
||||
|
|
@ -751,7 +774,7 @@ class TestSchedule(unittest.TestCase):
|
|||
_realize_weights(layer)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 19)
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -761,7 +784,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 19)
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -772,7 +795,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 21)
|
||||
check_schedule(opt.schedule_step(), 15)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -804,7 +827,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
|
|
@ -880,9 +903,11 @@ class TestSchedule(unittest.TestCase):
|
|||
check_schedule(out, 2)
|
||||
|
||||
def test_schedule_mem_used(self):
|
||||
gc.collect()
|
||||
base = GlobalCounters.mem_used
|
||||
Tensor.ones(256).contiguous().realize()
|
||||
Tensor.ones(5, 5).contiguous().schedule()
|
||||
gc.collect()
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 0)
|
||||
|
||||
def test_const_schedule(self):
|
||||
|
|
@ -986,6 +1011,7 @@ class TestUOpBecome(unittest.TestCase):
|
|||
|
||||
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_reorder_expand(self):
|
||||
a = Tensor.empty(4, 1)
|
||||
b = a.expand(4, 4).reciprocal()
|
||||
|
|
@ -1021,6 +1047,7 @@ class TestUOpBecome(unittest.TestCase):
|
|||
late_add = noop+2
|
||||
late_add.realize()
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_become_const_in_base(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a*0
|
||||
|
|
@ -1028,6 +1055,7 @@ class TestUOpBecome(unittest.TestCase):
|
|||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_become_const_from_const(self):
|
||||
const_add = Tensor(1)+Tensor(2)
|
||||
assert UPat(Ops.ADD).match(const_add.uop, {})
|
||||
|
|
@ -1116,7 +1144,7 @@ class TestFusionOp(unittest.TestCase):
|
|||
a = Tensor(val)
|
||||
for _ in range(24): a = Tensor.stack(a, a)[0]
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
self.assertLessEqual(len(sched), 1)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
def test_recursive_reshape(self):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from tinygrad.tensor import _METADATA
|
|||
from tinygrad.engine.realize import capturing
|
||||
from tinygrad.helpers import Context
|
||||
|
||||
@unittest.skip("tensor metadata is no longer supported")
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
_METADATA.set(None)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pa
|
|||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
@unittest.skip("this doesn't mutate anymore")
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
|
|
|
|||
|
|
@ -655,6 +655,24 @@ class TestSymbolic(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
|
||||
|
||||
def test_div_mod_recombine_3level(self):
|
||||
gidx = Variable("gidx", 0, 150527)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx%3 + gidx//672*672, 0, 150527, "gidx")
|
||||
# different shapes
|
||||
x = Variable("x", 0, 5*7*11-1)
|
||||
self.helper_test_variable(x//11%7*11 + x%11 + x//77*77, 0, 5*7*11-1, "x")
|
||||
# result is x//a*c2 not just x
|
||||
x2 = Variable("x2", 0, 5*6*7-1)
|
||||
self.helper_test_variable(x2//7%6*14 + x2//42*84, 0, (5*6*7-1)//7*14, "(x2//7*14)")
|
||||
# negative variable range
|
||||
xn = Variable("x", -1000, 1000)
|
||||
self.helper_test_variable(xn//3%224*3 + xn%3 + xn//672*672, -1000, 1000, "x")
|
||||
self.helper_test_variable(xn//3%7*3 + xn//21*21, -999, 999, "(x//3*3)")
|
||||
# should NOT simplify: a*c1 != b (3*224 != 600)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)")
|
||||
# should NOT simplify: c1*c2 != c3 (224*3 != 700)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx//672*700, 0, 156769, "(gidx//672*700+gidx//3%224*3)")
|
||||
|
||||
def test_div_mod_recombine_with_gcd(self):
|
||||
b = Variable("b", 0, 100)
|
||||
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
|
||||
|
|
@ -835,34 +853,33 @@ class TestSymbolicNumeric(unittest.TestCase):
|
|||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
class TestSymbolicVariables(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = uconst(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == set()
|
||||
print(a.vars())
|
||||
assert a.vars() == a.vars() == {a}
|
||||
assert z.variables() == []
|
||||
assert a.variables() == [a]
|
||||
m = a * 3
|
||||
assert m.vars() == {a}
|
||||
assert m.variables() == [a]
|
||||
s = usum([a, b, c])
|
||||
assert s.vars() == {a, b, c}
|
||||
assert s.variables() == [a, b, c]
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert (a + b * c).vars() == {a, b, c}
|
||||
assert (a % 3 + b // 5).vars() == {a, b}
|
||||
assert (a + b * c).variables() == [a, b, c]
|
||||
assert (a % 3 + b // 5).variables() == [a, b]
|
||||
# TODO: fix me
|
||||
with self.assertRaises(AssertionError):
|
||||
assert (a + b + c - a).vars() == {b, c}
|
||||
assert (a + b + c - a).variables() == [b, c]
|
||||
|
||||
def test_dedup(self):
|
||||
a = Variable("a", 0, 10)
|
||||
assert (a * a).vars() == {a}
|
||||
assert (a//4 + a//6).vars() == {a}
|
||||
assert (a * a).variables() == [a]
|
||||
assert (a//4 + a//6).variables() == [a]
|
||||
|
||||
class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer(self):
|
||||
|
|
|
|||
|
|
@ -94,12 +94,6 @@ class TestExecALU(unittest.TestCase):
|
|||
# test no truncate
|
||||
self.assertEqual(exec_alu(Ops.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
|
||||
|
||||
class TestConstantFolding(unittest.TestCase):
|
||||
def test_cast_const(self):
|
||||
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
|
||||
si = t.schedule()
|
||||
assert len(si) == 0
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(Ops.PARAM, dtypes.float.ptr(), (), 0)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class TestMemoryCount(unittest.TestCase):
|
|||
_, mem = get_stats(a+b)
|
||||
self.assertEqual(mem, 1024*1024*2 + 1024) # 1 full read + 1 lil read + 1 write
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_both_expanded(self):
|
||||
# TODO: this probably should be a full write
|
||||
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
||||
|
|
|
|||
|
|
@ -286,6 +286,20 @@ class TestVizIntegration(BaseTestViz):
|
|||
self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1")
|
||||
self.assertEqual(lst[1]["name"], prg.name)
|
||||
|
||||
# schedule graph CALL nodes have a link to jump to codegen
|
||||
def test_link_sched_codegen(self):
|
||||
c1 = Tensor.empty(4).add(1)
|
||||
c2 = Tensor.empty(8).add(1)
|
||||
sched = Tensor.schedule(c1, c2)
|
||||
prgs = [si.lower().prg.p.name for si in sched]
|
||||
lst = get_viz_list()
|
||||
viz_kernel = next(i for i,s in enumerate(lst[0]["steps"]) if s["name"] == "View Kernel Graph")
|
||||
graph = next(get_viz_details(0, viz_kernel))["graph"]
|
||||
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
|
||||
for i,n in enumerate(call_nodes):
|
||||
assert n["ref"] is not None
|
||||
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
|
||||
|
||||
def test_metadata_tracing(self):
|
||||
with Context(TRACEMETA=2):
|
||||
a = Tensor.empty(1)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class TestWinograd(unittest.TestCase):
|
|||
|
||||
# TODO: what's optimal on this?
|
||||
self.assertLess(ops_ratio, 4.3)
|
||||
self.assertLess(mem_ratio, 3)
|
||||
self.assertLess(mem_ratio, 4)
|
||||
|
||||
def test_dtype(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestCfg(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.arch = Device["AMD"].arch
|
||||
if not any(self.arch.startswith(a) for a in {"gfx11", "gfx12"}):
|
||||
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
||||
self.skipTest(f"tests written for RDNA, got arch {self.arch}")
|
||||
|
||||
def test_simple(self):
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
|
|
|
|||
|
|
@ -35,8 +35,16 @@ class TestAssign(unittest.TestCase):
|
|||
a.realize()
|
||||
np.testing.assert_allclose(b.numpy(), 0)
|
||||
|
||||
def test_assign_copy(self):
|
||||
a = Tensor([1.,2,3], device="PYTHON")
|
||||
c = Tensor.empty(3).assign(a.to(None))
|
||||
# it should copy into the empty buffer
|
||||
GlobalCounters.reset()
|
||||
c.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
|
||||
def test_assign_add(self):
|
||||
for T in (1, 2, 10, 100):
|
||||
for T in (1, 2, 10):#, 100): # this crashes in CI, not sure why
|
||||
x = Tensor([0]).realize()
|
||||
buf = x.uop.base.realized
|
||||
for _ in range(T):
|
||||
|
|
@ -120,6 +128,7 @@ class TestAssign(unittest.TestCase):
|
|||
new = a + old_a
|
||||
np.testing.assert_allclose(new.numpy(), 4)
|
||||
|
||||
@unittest.skip("TODO: this is broken")
|
||||
def test_assign_changes_alt(self, realize=False):
|
||||
a = Tensor(1).contiguous()
|
||||
if realize: a.realize()
|
||||
|
|
@ -629,6 +638,7 @@ class TestAssignOrdering(unittest.TestCase):
|
|||
self.assertEqual(r1.item(), 4)
|
||||
self.assertEqual(r2.item(), 8)
|
||||
|
||||
@unittest.skip("TODO: this is broken")
|
||||
def test_write_read_write_chain(self):
|
||||
"""Write, read, write chain - middle read must complete before second write."""
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
|
|
|
|||
|
|
@ -92,5 +92,13 @@ class TestCall(unittest.TestCase):
|
|||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
||||
|
||||
def test_call_plus_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
b = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
|
||||
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class TestRawDiskBuffer(unittest.TestCase):
|
|||
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
|
||||
# doesn't suport normal cast
|
||||
with self.assertRaises(NotImplementedError):
|
||||
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).realize()
|
||||
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).to(None).realize()
|
||||
|
||||
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
@ -264,18 +264,20 @@ class TestDiskTensor(TempDirTestCase):
|
|||
def test_strided_read(self):
|
||||
# test non-contiguous (strided) read - should read elements at indices 0, 2, 4
|
||||
dt = Tensor([0, 1, 2, 3, 4, 5]).to(f"disk:{self.tmp('dt_strided_read')}")
|
||||
result = dt[::2].tolist()
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
|
||||
# self.assertEqual(result, [0, 2, 4])
|
||||
self.assertEqual(result, [0, 1, 2]) # wrong!
|
||||
with self.assertRaises(RuntimeError):
|
||||
result = dt[::2].tolist()
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
|
||||
# self.assertEqual(result, [0, 2, 4])
|
||||
self.assertEqual(result, [0, 1, 2]) # wrong!
|
||||
|
||||
def test_permuted_read(self):
|
||||
# test non-contiguous (permuted) read - should read transposed
|
||||
dt = Tensor([[0, 1, 2], [3, 4, 5]]).to(f"disk:{self.tmp('dt_permuted_read')}")
|
||||
result = dt.T.tolist()
|
||||
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
|
||||
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
|
||||
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
|
||||
with self.assertRaises(RuntimeError):
|
||||
result = dt.T.tolist()
|
||||
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
|
||||
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
|
||||
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
|
||||
|
||||
def test_write_ones(self):
|
||||
out = Tensor.ones(10, 10, device="CPU").contiguous()
|
||||
|
|
@ -303,10 +305,11 @@ class TestDiskTensor(TempDirTestCase):
|
|||
def test_strided_setitem(self):
|
||||
# test non-contiguous (strided) setitem - should set elements at indices 0, 2, 4
|
||||
dt = Tensor([1, 2, 3, 4, 5, 6]).to(f"disk:{self.tmp('dt_strided_setitem')}")
|
||||
dt[::2] = Tensor([10, 20, 30])
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
|
||||
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
|
||||
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
|
||||
with self.assertRaises(RuntimeError):
|
||||
dt[::2] = Tensor([10, 20, 30])
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
|
||||
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
|
||||
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
|
||||
|
||||
def test_advanced_setitem_not_supported(self):
|
||||
dt = Tensor.arange(12).reshape(3, 4).to(f"disk:{self.tmp('dt_advanced_setitem')}")
|
||||
|
|
|
|||
|
|
@ -1,62 +1,38 @@
|
|||
import os, unittest, ctypes
|
||||
import os, unittest
|
||||
from tinygrad import dtypes, Tensor, fetch, Device
|
||||
from tinygrad.nn.state import ggml_data_to_tensor, gguf_load
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
import ggml
|
||||
from gguf import GGUFReader, GGUFValueType, GGMLQuantizationType, GGML_QUANT_SIZES, dequantize, quantize
|
||||
|
||||
ggml_test_block_count = 4
|
||||
ggml_type_to_np_dtype = {
|
||||
ggml.GGML_TYPE_F16: np.float16, ggml.GGML_TYPE_F32:np.float32, ggml.GGML_TYPE_F64:np.float64,
|
||||
ggml.GGML_TYPE_I8:np.int8, ggml.GGML_TYPE_I16: np.int16, ggml.GGML_TYPE_I32: np.int32, ggml.GGML_TYPE_I64: np.int64,
|
||||
}
|
||||
np_dtype_to_ctype = { np.float16: ctypes.c_uint16 }
|
||||
gguf_val_getters = [
|
||||
ggml.gguf_get_val_u8, ggml.gguf_get_val_i8, ggml.gguf_get_val_u16, ggml.gguf_get_val_i16,
|
||||
ggml.gguf_get_val_u32, ggml.gguf_get_val_i32, ggml.gguf_get_val_f32, ggml.gguf_get_val_bool,
|
||||
lambda *args: ggml.gguf_get_val_str(*args).decode("utf-8"), None,
|
||||
ggml.gguf_get_val_u64, ggml.gguf_get_val_i64, ggml.gguf_get_val_f64,
|
||||
]
|
||||
|
||||
def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
|
||||
ctx: ggml.ggml_context_p | None = None
|
||||
ggml_type, n_dims, n_els = tensor.contents.type, ggml.ggml_n_dims(tensor), ggml.ggml_nelements(tensor)
|
||||
shape = tuple(reversed(tensor.contents.ne[:n_dims]))
|
||||
if ggml_type not in ggml_type_to_np_dtype:
|
||||
ctx = ggml.ggml_init(ggml.ggml_init_params(mem_size=n_els * 5 + 500, mem_buffer=None))
|
||||
ntensor = ggml.ggml_new_tensor(ctx, ggml.GGML_TYPE_F32, n_dims, tensor.contents.ne)
|
||||
type_traits = ggml.ggml_internal_get_type_traits(ggml_type)
|
||||
type_traits.to_float(ggml.ggml_get_data(tensor), ggml.ggml_get_data_f32(ntensor), n_els)
|
||||
tensor, ggml_type = ntensor, ggml.GGML_TYPE_F32
|
||||
|
||||
np_type = ggml_type_to_np_dtype[ggml_type]
|
||||
ctypes_type = np_dtype_to_ctype.get(np_type, None) or np.ctypeslib.as_ctypes_type(np_type)
|
||||
data = ggml.ggml_get_data(tensor)
|
||||
if data is None: raise ValueError("tensor data is None")
|
||||
arr = (ctypes_type * ggml.ggml_nelements(tensor)).from_address(data)
|
||||
strides = tuple(reversed(tensor.contents.nb[:n_dims]))
|
||||
output = np.ctypeslib.as_array(arr)
|
||||
output.dtype = np_type
|
||||
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
|
||||
|
||||
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
|
||||
class TestGGUF(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)
|
||||
self.ctx = ctypes.cast(ggml.ggml_init(params), ctypes.POINTER(ctypes.c_void_p))
|
||||
def tearDown(self) -> None: ggml.ggml_free(self.ctx)
|
||||
|
||||
def test_load_tinyllama_q8_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q8_0.gguf?download=true")
|
||||
def test_load_tinyllama_q4_0(self): self._test_gguf_load("https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf?download=true")
|
||||
def test_load_gpt2_q4_1(self): self._test_gguf_load("https://huggingface.co/PrunaAI/gpt2-GGUF-smashed/resolve/main/gpt2.Q4_1.gguf?download=true")
|
||||
def test_load_sample_q6_k(self): self._test_gguf_load("https://huggingface.co/Isotr0py/test-gguf-sample/resolve/main/Quant_Q6_K_1024.gguf?download=true")
|
||||
def test_load_sample_mxfp4(self): self._test_gguf_load("https://huggingface.co/ngxson/boring-testing-tiny/resolve/main/stories260K-mxfp4.gguf?download=true")
|
||||
|
||||
def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
|
||||
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
|
||||
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
|
||||
def test_dequantization_q4_k(self): self._test_dequantization(ggml.GGML_TYPE_Q4_K)
|
||||
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
|
||||
def test_dequantization_q8_0_hardcoded(self):
|
||||
# Q8_0: 2 bytes float16 scale + 32 bytes int8 values, dequant = scale * values
|
||||
block = np.frombuffer(np.float16(2.0).tobytes() + np.arange(1, 33, dtype=np.int8).tobytes(), dtype=np.uint8).copy()
|
||||
expected = np.arange(1, 33, dtype=np.float32) * 2.0
|
||||
np.testing.assert_equal(ggml_data_to_tensor(Tensor(block), 32, GGMLQuantizationType.Q8_0.value).numpy().flatten(), expected)
|
||||
|
||||
def test_dequantization_mxfp4_hardcoded(self):
|
||||
# MXFP4: 1 byte shared exponent E + 16 packed bytes (32 x 4-bit values)
|
||||
# nibble: bit3=sign, bit2:1=exp, bit0=mant; E=128 gives scale=1.0
|
||||
# codes 0-7 = [0, 1, 2, 3, 4, 6, 8, 12], codes 8-15 are their negatives
|
||||
block = np.array([0x80] + list(range(16)), dtype=np.uint8) # E=128, nibbles 0-15 in low, zeros in high
|
||||
expected = np.array([0., 1., 2., 3., 4., 6., 8., 12., -0., -1., -2., -3., -4., -6., -8., -12.] + [0.]*16, dtype=np.float32)
|
||||
np.testing.assert_equal(ggml_data_to_tensor(Tensor(block), 32, 39).numpy().flatten(), expected)
|
||||
|
||||
def test_dequantization_q4_0(self): self._test_dequantization(GGMLQuantizationType.Q4_0)
|
||||
def test_dequantization_q4_1(self): self._test_dequantization(GGMLQuantizationType.Q4_1)
|
||||
def test_dequantization_q8_0(self): self._test_dequantization(GGMLQuantizationType.Q8_0)
|
||||
def test_dequantization_q4_k(self): self._test_dequantization(GGMLQuantizationType.Q4_K)
|
||||
def test_dequantization_q6_k(self): self._test_dequantization(GGMLQuantizationType.Q6_K)
|
||||
def test_dequantization_mxfp4(self):
|
||||
MXFP4 = 39
|
||||
|
||||
|
|
@ -68,7 +44,7 @@ class TestGGUF(unittest.TestCase):
|
|||
sign = -1.0 if (code & 0b1000) else 1.0
|
||||
exp = (code >> 1) & 0b11
|
||||
mant = code & 0b1
|
||||
val = (1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant
|
||||
val = 2 * ((1.0 + 0.5 * mant) * np.exp2(exp - 1) if exp else 0.5 * mant)
|
||||
scale = np.exp2(E - 128) if E >= 2 else np.exp2(-127 if E == 1 else -128)
|
||||
return sign * val * scale
|
||||
|
||||
|
|
@ -84,24 +60,44 @@ class TestGGUF(unittest.TestCase):
|
|||
# TODO: should this be exact equal? somehow failed on CI
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
|
||||
|
||||
def test_dequantization_mxfp4_block(self):
|
||||
MXFP4 = 39
|
||||
# https://gist.github.com/Ananta-Ranganathan/3317b6ed51a3b033e9c2564fafb4e043
|
||||
# used the above script to download the first block of blk.0.attn_k_b.weight from
|
||||
# https://huggingface.co/unsloth/GLM-4.7-Flash-GGUF/blob/main/GLM-4.7-Flash-MXFP4_MOE.gguf
|
||||
# and compute the canonical expected dequantized output with the GGUF PY implementation
|
||||
block = np.array([0x7a, 0x29, 0xab, 0x61, 0x10, 0x21, 0x02, 0x4a,
|
||||
0x15, 0xca, 0x05, 0x01, 0x9b, 0x39, 0x0b, 0x0b, 0x1c], dtype=np.uint8)
|
||||
expected = np.array([-0.01562500, -0.04687500, 0.01562500, 0.00000000,
|
||||
0.01562500, 0.03125000, -0.03125000, 0.09375000,
|
||||
-0.03125000, 0.09375000, 0.01562500, -0.04687500,
|
||||
-0.01562500, -0.04687500, -0.04687500, -0.06250000,
|
||||
0.03125000, -0.03125000, 0.12500000, 0.01562500,
|
||||
0.03125000, 0.00000000, 0.06250000, 0.01562500,
|
||||
-0.06250000, 0.00000000, 0.00000000, -0.01562500,
|
||||
0.04687500, 0.00000000, 0.00000000, 0.01562500], dtype=np.float32)
|
||||
out = ggml_data_to_tensor(Tensor(block), 32, MXFP4)
|
||||
# TODO: similar to previous test fails on Mac CI with assert_equal for unclear reason
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=0.0, rtol=1e-6)
|
||||
|
||||
def test_expected_failure_unknown_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ggml_data_to_tensor(Tensor.empty(512, dtype=dtypes.uint8), 256, 1337)
|
||||
|
||||
def _test_dequantization(self, ttype: int):
|
||||
type_traits = ggml.ggml_internal_get_type_traits(ttype)
|
||||
n_el, n_bytes = ggml_test_block_count * type_traits.blck_size, ggml_test_block_count * type_traits.type_size
|
||||
def _test_dequantization(self, qtype: GGMLQuantizationType):
|
||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||
n_el, n_bytes = ggml_test_block_count * block_size, ggml_test_block_count * type_size
|
||||
|
||||
data_in = (np.random.random((n_el,)).astype(np.float32) * 100 - 50).ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
try:
|
||||
q_data = quantize((np.random.random((n_el,)).astype(np.float32) * 100 - 50), qtype)
|
||||
except NotImplementedError:
|
||||
q_data = np.random.default_rng(42).integers(0, 256, size=n_bytes, dtype=np.uint8)
|
||||
ref = dequantize(q_data, qtype)
|
||||
|
||||
c_q_data, c_dq_data = (ctypes.c_char * n_bytes)(0), (ctypes.c_float * n_el)(0)
|
||||
type_traits.from_float(data_in, c_q_data, n_el)
|
||||
type_traits.to_float(c_q_data, c_dq_data, n_el)
|
||||
q_tensor = Tensor(q_data)
|
||||
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, qtype.value).reshape(n_el)
|
||||
|
||||
q_tensor = Tensor(np.frombuffer(c_q_data, dtype=np.uint8, count=n_bytes))
|
||||
dq_tensor = ggml_data_to_tensor(q_tensor, n_el, ttype).reshape(n_el)
|
||||
|
||||
np.testing.assert_equal(dq_tensor.numpy(), np.frombuffer(c_dq_data, dtype=np.float32))
|
||||
np.testing.assert_equal(dq_tensor.numpy(), ref)
|
||||
|
||||
def _test_gguf_load(self, url: str):
|
||||
fp = fetch(url)
|
||||
|
|
@ -109,24 +105,20 @@ class TestGGUF(unittest.TestCase):
|
|||
gguf_tensor = Tensor.empty(model_size, dtype=dtypes.uint8, device=f"disk:{fp}").to(Device.DEFAULT)
|
||||
kv_data, tensors = gguf_load(gguf_tensor)
|
||||
|
||||
gguf_params = ggml.gguf_init_params(ctx=self.ctx, no_alloc=False)
|
||||
gguf_ctx = ggml.gguf_init_from_file(str(fp).encode("utf8"), gguf_params)
|
||||
param_ctx = gguf_params.ctx.contents.value
|
||||
reader = GGUFReader(fp)
|
||||
|
||||
for ggml_tensor_idx in range(ggml.gguf_get_n_tensors(gguf_ctx)):
|
||||
tensor_name = ggml.gguf_get_tensor_name(gguf_ctx, ggml_tensor_idx)
|
||||
ggml_tensor = ggml.ggml_get_tensor(param_ctx, tensor_name)
|
||||
ggml_tensor_numpy, temp_ctx = ggml_tensor_to_numpy(ggml_tensor)
|
||||
tensor = tensors.get(tensor_name.decode("utf-8"))
|
||||
np.testing.assert_equal(tensor.numpy(), ggml_tensor_numpy)
|
||||
if temp_ctx is not None: ggml.ggml_free(temp_ctx)
|
||||
for rt in reader.tensors:
|
||||
ref = dequantize(rt.data, rt.tensor_type)
|
||||
np.testing.assert_equal(tensors[rt.name].numpy(), ref.reshape(tensors[rt.name].shape))
|
||||
|
||||
for gguf_key_id in range(ggml.gguf_get_n_kv(gguf_ctx)):
|
||||
v = kv_data[ggml.gguf_get_key(gguf_ctx, gguf_key_id).decode("utf-8")]
|
||||
v_type = ggml.gguf_get_kv_type(gguf_ctx, gguf_key_id)
|
||||
if (get_fn := gguf_val_getters[v_type]) is not None: self.assertEqual(get_fn(gguf_ctx, gguf_key_id), v)
|
||||
|
||||
ggml.gguf_free(gguf_ctx)
|
||||
for k, f in reader.fields.items():
|
||||
if k.startswith("GGUF."): continue # skip file header keys (version, tensor_count, kv_count)
|
||||
def read_val(i, parts=f.parts, is_str=(f.types[-1] == GGUFValueType.STRING)):
|
||||
return bytes(parts[i]).decode("utf-8") if is_str else parts[i][0].item()
|
||||
if f.types[0] == GGUFValueType.ARRAY:
|
||||
self.assertEqual(kv_data[k], [read_val(i) for i in f.data])
|
||||
else:
|
||||
self.assertEqual(kv_data[k], read_val(-1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -28,16 +28,7 @@ class TestRealizeIsRealized(unittest.TestCase):
|
|||
t = Tensor.ones(8).contiguous().shard((d, d), axis=0).realize()
|
||||
assert all(u.is_realized for u in t.uop.src)
|
||||
|
||||
# TODO: these are not realized after .realize() because they stay as consts / don't allocate buffers
|
||||
def test_const_not_realized(self):
|
||||
t = Tensor(3.14).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
def test_ones_not_realized(self):
|
||||
t = Tensor.ones(4, 4).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
def test_empty_not_realized(self):
|
||||
def test_empty(self):
|
||||
t = Tensor.empty(4, 4).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
|
|
@ -48,6 +39,22 @@ class TestRealizeIsRealized(unittest.TestCase):
|
|||
t = Tensor.empty(4, dtype=dtypes.float32, device=f"disk:{f.name}").realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
def test_assign(self):
|
||||
t = Tensor([1, 2, 3])
|
||||
t += 1
|
||||
t.realize()
|
||||
assert t.uop.is_realized
|
||||
|
||||
# TODO: these are not realized after .realize()
|
||||
|
||||
def test_const_not_realized(self):
|
||||
t = Tensor(3.14).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
def test_ones_not_realized(self):
|
||||
t = Tensor.ones(4, 4).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
||||
def test_none_not_realized(self):
|
||||
t = Tensor(None).realize()
|
||||
assert not t.uop.is_realized
|
||||
|
|
|
|||
28
test/unit/test_system_pci_scan_bus.py
Normal file
28
test/unit/test_system_pci_scan_bus.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import sys
|
||||
import pytest
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "linux", reason="uses linux sysfs layout")
|
||||
def test_pci_scan_bus_filters_vendor(monkeypatch):
|
||||
import tinygrad.runtime.support.system as system
|
||||
|
||||
fake = {
|
||||
"/sys/bus/pci/devices/0000:00:01.0/vendor": "0x1234",
|
||||
"/sys/bus/pci/devices/0000:00:01.0/device": "0x1111",
|
||||
"/sys/bus/pci/devices/0000:00:02.0/vendor": "0xabcd",
|
||||
"/sys/bus/pci/devices/0000:00:02.0/device": "0x1111",
|
||||
}
|
||||
|
||||
class FakeFileIOInterface:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
|
||||
def listdir(self):
|
||||
assert self.path == "/sys/bus/pci/devices"
|
||||
return ["0000:00:01.0", "0000:00:02.0"]
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return fake[self.path]
|
||||
|
||||
monkeypatch.setattr(system, "FileIOInterface", FakeFileIOInterface)
|
||||
|
||||
assert system.System.pci_scan_bus(0x1234, devices=[(0xffff, [0x1111])]) == ["0000:00:01.0"]
|
||||
|
|
@ -340,6 +340,10 @@ if __name__ == "__main__":
|
|||
# do benchmark
|
||||
if args.benchmark:
|
||||
param_bytes = sum(x.nbytes() for x in nn.state.get_parameters(model))
|
||||
for b in model.blk:
|
||||
if hasattr(b, 'ffn_gate_exps'):
|
||||
expert_bytes = b.ffn_gate_exps.weight.nbytes() + b.ffn_up_exps.weight.nbytes() + b.ffn_down_exps.weight.nbytes()
|
||||
param_bytes -= int(expert_bytes * (1 - b.num_experts_per_tok / b.ffn_gate_exps.weight.shape[0]))
|
||||
gen = model.generate([0], 0)
|
||||
for _ in range(args.benchmark):
|
||||
GlobalCounters.reset()
|
||||
|
|
|
|||
|
|
@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
|||
elif (a:=len(limited)) > (b:=len(dims)):
|
||||
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
elif limited != dims:
|
||||
if limited != dims:
|
||||
# Convert to 1D
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
# Get back original indices from 1D
|
||||
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
|
||||
return raw_idxs
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, CPU_X86, prod, ceildiv, IMAGE
|
||||
from tinygrad.helpers import getenv, flatten, AMX, CPU_X86, prod, IMAGE
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
|
@ -190,9 +190,9 @@ def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
|
|||
buf = idx.src[0]
|
||||
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
h, w = dt.shape[0], dt.shape[1]
|
||||
if IMAGE == 1 and valid is not None and (tp:=dt.size // 4) // 64:
|
||||
h, w = max(([(1, tp)] * (tp < 16384)) + [(tp//64//k, 64*k) for k in range(ceildiv(tp//64, 16384), min(tp//64, 256)+1) if (tp//64) % k == 0],
|
||||
key=lambda hw: len(_drop_valid_stmts(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1])), *hw)))
|
||||
if IMAGE == 1 and valid is not None:
|
||||
h, w = max(ImageDType.valid_dims(dt),
|
||||
key=lambda hw: len(_drop_valid_stmts(valid, uop_given_valid(valid, UOp.vectorize((x//4)%hw[1], x//(4*hw[1]))), *hw)))
|
||||
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w))))
|
||||
return x, idx.replace(src=(buf, oidx.valid(valid))), w, h
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
|||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import ALLOW_TF32, count, Context, ceildiv
|
||||
from tinygrad.helpers import ALLOW_TF32, count, Context
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -353,28 +353,21 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
|
|||
k = hand_coded_optimizations(k)
|
||||
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
||||
|
||||
# max image width (pixels): 16384. max image size: 4 * 16384 ** 2
|
||||
def _image_shape(dt):
|
||||
if dt.base not in (dtypes.half, dtypes.float) or isinstance(dt, ImageDType) or dt.size > 4*16384*16384 or dt.nbytes()%64 != 0: return None
|
||||
if dt.size <= 4 * 16384: return (1, dt.size // 4, 4)
|
||||
if (pxls:=dt.size // 4) % 64: return None
|
||||
# verify that a valid format exists
|
||||
try: return next((pxls // 64 // k, 64 * k, 4) for k in range(ceildiv(pxls // 64, 16384), min(pxls // 64, 256)+1))
|
||||
except StopIteration: return None
|
||||
|
||||
def make_image(pa, off, idx):
|
||||
if (idx.tag is None or idx.tag) and (shape:=_image_shape(dt:=pa.dtype)):
|
||||
new_idx = idx.replace(src=(pa.replace(dtype=(dtypes.imageh if dt.base==dtypes.half else dtypes.imagef)(shape, shape[1] * 4 * dt.itemsize)), off),
|
||||
dtype=dtypes.float if dt.base == dtypes.half else idx.dtype)
|
||||
if not isinstance(dt:=pa.dtype, ImageDType) and (idx.tag is None or idx.tag) and (shapes:=ImageDType.valid_dims(dt)):
|
||||
new_pa = pa.replace(dtype=(dtypes.imageh if dt.base==dtypes.half else dtypes.imagef)(shapes[0] + (4,), shapes[0][1] * 4 * dt.itemsize))
|
||||
new_idx = idx.replace(src=(new_pa, off), dtype=dtypes.float if dt.base == dtypes.half else idx.dtype)
|
||||
return new_idx if idx.tag or dt.base == dtypes.float else new_idx.cast(dtypes.half)
|
||||
|
||||
pm_make_images = PatternMatcher([
|
||||
# ensure we dont create an unfoldable image store
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"),), allow_any_len=True, name="st"), lambda idx,st:
|
||||
st.replace(src=(idx.rtag(is_image:=any(c.op is Ops.RANGE and (c.vmax+1)%4 == 0 for c in idx.src[1].get_idx().split_uop(Ops.ADD))),
|
||||
st.src[1].cast(dtypes.float if is_image and _image_shape(idx.src[0].dtype) else idx.dtype.base)))),
|
||||
st.src[1].cast(dtypes.float if is_image and ImageDType.valid_dims(idx.src[0].dtype) else idx.dtype.base)))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="pa"), UPat.var("off")), name="idx"), make_image),
|
||||
# remove double cast from image loads
|
||||
# remove double cast from image loads / stores
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="pa"),), allow_any_len=True, name="idx").cast(dtypes.half).cast(dtypes.float), lambda idx,pa:
|
||||
idx if isinstance(pa.dtype, ImageDType) else None),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PARAM, name="pa").index(UPat()), UPat.var("val").cast(dtypes.half).cast(dtypes.float)), name="st"), lambda st,pa,val:
|
||||
st.replace(src=(st.src[0], val)) if isinstance(pa.dtype, ImageDType) else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -283,10 +283,10 @@ class CompilerSet: cset:list[tuple[type[Renderer]|functools.partial, ContextVar|
|
|||
class Compiled:
|
||||
profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
|
||||
|
||||
def __init__(self, device:str, allocator:Allocator, compilers:CompilerSet|None, runtime, graph=None, group_id=None):
|
||||
def __init__(self, device:str, allocator:Allocator, compilers:CompilerSet|None, runtime, graph=None):
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
self.device, self.allocator, self.runtime, self.graph, self.group_id = device, allocator, runtime, graph, group_id
|
||||
self.device, self.allocator, self.runtime, self.graph = device, allocator, runtime, graph
|
||||
|
||||
self.comps_ctrl_var = compilers.ctrl_var if compilers is not None else None
|
||||
self.comp_sets:dict[str, tuple[ContextVar|None, type[Renderer]|functools.partial]] = {}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
from typing import Final, ClassVar, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv, prod, round_up, next_power2, OSX
|
||||
from tinygrad.helpers import ceildiv, getenv, prod, round_up, next_power2, OSX
|
||||
from enum import Enum, auto
|
||||
|
||||
class ConstFloat(float):
|
||||
|
|
@ -121,13 +121,25 @@ class ImageDType(PtrDType):
|
|||
if self._pitch != -1: return self._pitch
|
||||
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
|
||||
if OSX: return round_up(imgw, 256) * 4 * self.itemsize
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
# needs to be IMAGE_PITCH_ALIGN=256 for AMD
|
||||
min_pitchalign = int(math.log2(v)) if (v := getenv("IMAGE_PITCH_ALIGN", 0)) > 0 else 6
|
||||
pitchalign = max(min_pitchalign, 11 - int(math.log2(imgh))) if imgh > 1 else min_pitchalign
|
||||
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
||||
|
||||
granularity = 128 if self.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
|
||||
|
||||
# get list of (height, width) that do not require pitch padding
|
||||
@staticmethod
|
||||
def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]:
|
||||
ALIGN, MAXW = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384
|
||||
if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW or (ptr.size if OSX else ptr.nbytes()) % ALIGN != 0: return []
|
||||
if OSX and (ptr.size // 4) % ALIGN: return [] # OSX has stricter requirements for height=1 images
|
||||
pxls: int = ptr.size // 4
|
||||
return ([(1, pxls)] * (pxls < MAXW) + [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1)
|
||||
if (pxls//ALIGN)%k == 0] if pxls//ALIGN else [])
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
|
|
|
|||
129
tinygrad/engine/allocations.py
Normal file
129
tinygrad/engine/allocations.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], dict[UOp, UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None: return None
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
|
||||
def disk_copy_is_buffer(ctx, u):
|
||||
# copies to disk are replaced with the disk buffer
|
||||
to_disk = isinstance(u._device, str) and u._device.startswith("DISK")
|
||||
if to_disk: ctx[1][u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
|
||||
# all copies from disk/numpy are realized into a real buffer
|
||||
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON"])
|
||||
if from_creation: return tag_uop(ctx, u)
|
||||
|
||||
def apply_after(ctx, u):
|
||||
ctx[1][u] = u.src[0]
|
||||
|
||||
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
|
||||
add_tags = PatternMatcher([
|
||||
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
|
||||
# no tag on copies that are assigned
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
|
||||
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
|
||||
(UPat(Ops.AFTER, name="u"), apply_after),
|
||||
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
|
||||
(UPat(GroupOp.All, name="x"), lambda ctx,x: tag_uop(ctx,x) if x in ctx[2] else None),
|
||||
])
|
||||
|
||||
def replace_contig_with_assign(u:UOp):
|
||||
# if size is 0, remove the contig
|
||||
if u.size == 0: return u.src[0]
|
||||
# no real contig for DISK tensors, they are left alone
|
||||
if isinstance(u._device, str) and u._device.startswith("DISK"): return u.rtag(None)
|
||||
dtype = u.dtype
|
||||
if isinstance(dtype, ImageDType):
|
||||
if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0:
|
||||
if DEBUG >= 1: print(f"demoting Image {dtype} with shape {u.max_shard_shape}")
|
||||
dtype = dtype.base
|
||||
buffer = UOp.new_buffer(u.device, u.shard_size, dtype).reshape(u.max_shard_shape)
|
||||
if isinstance(u.device, tuple) and u.axis is not None: buffer = buffer.multi(u.axis)
|
||||
return buffer.assign(u.src[0]).rtag(u.tag)
|
||||
|
||||
def replace_assign_with_contig(u:UOp):
|
||||
assigned_to = u
|
||||
while assigned_to.op in {Ops.ASSIGN, Ops.BITCAST}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER:
|
||||
return u.src[1].contiguous(tag=u.tag)
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
x = src
|
||||
while x is not src.base:
|
||||
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
|
||||
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
|
||||
else: return None
|
||||
x = x.src[0]
|
||||
ctx[src.base] = contig
|
||||
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# CONTIGUOUS replacement hack for openpilot
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
|
||||
# replace ALU sources with contiguous versions found above
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
|
||||
# remove extra CONTIGUOUS on ASSIGN
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"), lambda a,c: a.replace(tag=a.tag+c.tag)),
|
||||
# replace ASSIGN with CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, name="u"), replace_assign_with_contig),
|
||||
# replace CONTIGUOUS with ASSIGNs
|
||||
(UPat(Ops.CONTIGUOUS, name="u"), replace_contig_with_assign),
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# handle size 0
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
# early fixup const copy (TODO: is this wrong if there's a pad?)
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
])
|
||||
|
||||
def untag_and_append(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
|
||||
if x.tag is None: return None
|
||||
uop_list, buffer_map, assigns = ctx
|
||||
ret = x.replace(tag=None)
|
||||
for t in x.tag:
|
||||
original_uop: UOp = uop_list[t]
|
||||
replace_uop = ret
|
||||
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
|
||||
buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
|
||||
assigns.append(ret)
|
||||
return ret
|
||||
|
||||
def append_after(ctx:tuple[list[UOp], dict[UOp, UOp], list[UOp]], x:UOp):
|
||||
ctx[2].append(x)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
|
||||
(UPat(Ops.AFTER, name="x"), append_after),
|
||||
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
# replace UNIQUE with LUNIQUE for CONST cache key normalization
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
|
||||
])
|
||||
|
||||
def allocate_global_buffers(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
# here we build buffer map
|
||||
uop_list: list[UOp] = []
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
|
||||
bases = set([x.multibase for x in big_sink.src if x.base.op not in dont_realize])
|
||||
|
||||
# this rewrite is "read-only", it adds simple things to buffer_map and may sink things on big_sink, bottom_up
|
||||
# this is the only one where we have to be careful to not break the tensor graph
|
||||
big_sink = graph_rewrite(big_sink, add_tags, ctx=(uop_list, buffer_map, bases), bottom_up=True, name="number the uops")
|
||||
|
||||
# here we can break the tensor graph. this is the only place you need to maintain numbered tags
|
||||
big_sink = graph_rewrite(big_sink, pm_early_transform_tensor_graph, ctx={}, name="early transform tensor graph")
|
||||
|
||||
# here we construct the final buffer_map. this is everything that will go into the tensor map
|
||||
assigns: list[UOp] = []
|
||||
graph_rewrite(big_sink, pm_finalize_call, ctx=(uop_list, buffer_map, assigns), name="finalize call")
|
||||
ret = UOp.sink(*assigns)
|
||||
if VIZ: graph_rewrite(ret, PatternMatcher([]), name="*** Call")
|
||||
return ret, buffer_map
|
||||
|
|
@ -348,6 +348,8 @@ class TinyJit(Generic[ReturnType]):
|
|||
update_depends(depends, jit_cache)
|
||||
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
# sync before re-executing onetime kernels
|
||||
for dev in set(Device[b.device] for ei in onetime for b in ei.bufs if b is not None): dev.synchronize()
|
||||
# run the onetime kernels here
|
||||
for ei in onetime:
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
import time
|
||||
import time, sys
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.allocations import allocate_global_buffers
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
# ScheduleItem = tuple[AST, buffer UOps, metadata, bound_ranges]
|
||||
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], tuple[UOp, ...]]
|
||||
|
||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
|
|
@ -23,9 +21,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort(gate_kernel_sink):
|
||||
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
|
||||
if u.op is not Ops.AFTER: continue
|
||||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||
k = u.src[1]
|
||||
assert k.op in {Ops.CALL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
|
|
@ -50,55 +47,24 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
|
||||
with cpu_profile(TracingKey("linearize schedule")):
|
||||
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
|
||||
|
||||
schedule: list[UOp] = [] # RANGE, KERNEL, or END UOps
|
||||
sched_item: dict[UOp, ScheduleItem] = {}
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.CALL:
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
|
||||
schedule.append(k)
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
rk = queue.popleft()
|
||||
k = rk.src[0] if rk.op is Ops.END else rk
|
||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
with cpu_profile(TracingKey("unroll outer ranges")):
|
||||
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule, sched_item)
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
def unroll_outer_ranges(schedule:list[UOp], sched_item:dict[UOp, ScheduleItem]) -> tuple[list[ExecItem], list[UOp]]:
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
assert si.op is Ops.CALL, f"unexpected op in schedule: {si.op}"
|
||||
ast, buf_uops, metadata, bound_ranges = sched_item[si]
|
||||
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
return pre_schedule, buf_uops_list
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
from tinygrad.schedule.rangeify import get_rangeify
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
||||
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
|
|
@ -107,13 +73,6 @@ def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], li
|
|||
ctx[2][0] += 1
|
||||
return ret
|
||||
|
||||
def replace_input_const(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
# replace UNIQUE with LUNIQUE for CONST cache key normalization
|
||||
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=ctx[3][0]), b.src[1]))
|
||||
ctx[3][0] += 1
|
||||
return ret
|
||||
|
||||
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
|
||||
var, val = b.src[0], b.src[1].arg
|
||||
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
|
||||
|
|
@ -123,8 +82,6 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]],
|
|||
pm_pre_sched_cache = PatternMatcher([
|
||||
# replace BUFFER with PARAM for cache key normalization
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# replace UNIQUE with LUNIQUE for CONST cache key normalization
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_const),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
|
||||
])
|
||||
|
|
@ -136,8 +93,6 @@ def create_new_buffer(ctx:dict[UOp, UOp], b:UOp):
|
|||
pm_post_sched_cache = PatternMatcher([
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
# restore CONST back to original CONST
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
# restore PARAM back to original BUFFER
|
||||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
# restore BIND value stripped in pm_pre_sched_cache
|
||||
|
|
@ -150,6 +105,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
big_sink, buffer_map = allocate_global_buffers(big_sink)
|
||||
|
||||
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
|
|
@ -159,39 +116,17 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
# hack to preserve metadata
|
||||
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}, [0], [0]), name="preserve metadata")
|
||||
|
||||
# tensor map is what we return
|
||||
tensor_map: dict[UOp, UOp] = {}
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
|
||||
tensor_map |= get_multi_map(big_sink_cache)
|
||||
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
|
||||
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
|
||||
|
||||
tensor_map |= get_rangeify_map(big_sink_cache)
|
||||
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
pre_schedule, buf_uops_sink = create_schedule(big_sink)
|
||||
|
||||
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
|
||||
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm", rewrite_into_calls=True)
|
||||
pre_schedule, buf_uops_sink = create_schedule(get_rangeify(big_sink_cache))
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
else:
|
||||
# schedule cache hit
|
||||
del big_sink_cache
|
||||
pre_schedule, combined_sink = sc_ret
|
||||
pre_schedule, buf_uops_sink = sc_ret
|
||||
del big_sink, big_sink_cache
|
||||
|
||||
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
|
||||
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
|
||||
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
tensor_map_sink, buf_uops_sink = combined.src
|
||||
tm_src = tensor_map_sink.src
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
|
||||
# add bufs to pre_schedule
|
||||
schedule: list[ExecItem] = []
|
||||
|
|
@ -205,7 +140,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
ubufs = tuple(b.buffer for b in buf_uops)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
else:
|
||||
|
|
@ -214,9 +149,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
i = 6
|
||||
while (frm:=sys._getframe(i)) and frm.f_code.co_filename.startswith(str(BASEDIR)): i += 1
|
||||
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache | {frm.f_code.co_filename}:{frm.f_lineno}")
|
||||
|
||||
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
|
||||
return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
|
||||
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
|
|
@ -39,7 +39,6 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
||||
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||||
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
||||
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
|
||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
|
|||
OSX, WIN = platform.system() == "Darwin", sys.platform == "win32"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64"))
|
||||
BASEDIR = pathlib.Path(__file__).parent
|
||||
|
||||
# fix colors on Windows, https://stackoverflow.com/questions/12492810/python-how-can-i-make-the-ansi-escape-codes-to-work-also-in-windows
|
||||
if WIN: os.system("")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class Optimizer:
|
|||
"""
|
||||
Base class for all optimizers.
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
||||
def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM):
|
||||
# if requires_grad is None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad_(True)
|
||||
|
|
@ -16,19 +16,19 @@ class Optimizer:
|
|||
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.device = device or self.params[0].device
|
||||
self.fused = fused
|
||||
# store lr in at least float32 precision
|
||||
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
||||
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
||||
|
||||
@property
|
||||
def device(self): return self.params[0].device
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False)]
|
||||
if self.device is not None:
|
||||
return [Tensor.zeros(t.shape, dtype=param_dtype, device=self.device, requires_grad=False) for t in self.params]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False) for t in self.params]
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
|
|
@ -54,13 +54,14 @@ class Optimizer:
|
|||
# NOTE: contiguous is for speed
|
||||
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
||||
[Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)])
|
||||
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
updates = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
else:
|
||||
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
||||
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
|
||||
updates, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
||||
for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i]))
|
||||
return extra+self.params+self.buffers
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
|
||||
def _apply_update(self, t:Tensor, up:Tensor) -> Tensor: return t.detach() - up.to(t.device)
|
||||
|
||||
class OptimizerGroup(Optimizer):
|
||||
"""
|
||||
|
|
@ -74,17 +75,17 @@ class OptimizerGroup(Optimizer):
|
|||
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
|
||||
|
||||
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
|
||||
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
|
||||
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
||||
|
||||
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
||||
"""
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, device=device, fused=fused)
|
||||
|
||||
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
|
||||
def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_steps=5, ns_coefficients=(3.4445, -4.775, 2.0315),
|
||||
nesterov=True, fused=FUSE_OPTIM):
|
||||
nesterov=True, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
SGD with newton-schulz iteration and post momentum weight decay.
|
||||
|
||||
|
|
@ -92,7 +93,8 @@ def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_ste
|
|||
- Paper: https://arxiv.org/pdf/2502.16982
|
||||
"""
|
||||
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov,
|
||||
classic=False, pre_wd=False, tcoef=0.0, device=None, fused=fused)
|
||||
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
|
|
@ -101,8 +103,8 @@ class LARS(Optimizer):
|
|||
- Paper: https://arxiv.org/abs/1708.03888v3
|
||||
"""
|
||||
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_coefficients=None,
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, device=None, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, device, fused)
|
||||
self.momentum, self.wd, self.ns_steps, self.ns_coefficients = momentum, weight_decay, ns_steps, ns_coefficients
|
||||
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
|
||||
self.b = self._new_optim_param() if self.momentum else []
|
||||
|
|
@ -126,24 +128,24 @@ class LARS(Optimizer):
|
|||
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
ret.append((t.detach() - g).cast(t.dtype))
|
||||
ret.append(g.cast(t.dtype))
|
||||
return ret, self.b
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
|
||||
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
|
||||
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
AdamW optimizer with optional weight decay.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1711.05101v3
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
|
||||
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
|
||||
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, device=device, fused=fused)
|
||||
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Adam optimizer.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
|
||||
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused)
|
||||
|
||||
class LAMB(Optimizer):
|
||||
"""
|
||||
|
|
@ -151,10 +153,10 @@ class LAMB(Optimizer):
|
|||
|
||||
- Paper: https://arxiv.org/abs/1904.00962
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, device, fused)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2])
|
||||
self.m = self._new_optim_param()
|
||||
self.v = self._new_optim_param()
|
||||
|
||||
|
|
@ -175,5 +177,5 @@ class LAMB(Optimizer):
|
|||
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
ret.append((self.lr * r * up).cast(t.dtype))
|
||||
return ret, [self.b1_t, self.b2_t] + self.m + self.v
|
||||
|
|
|
|||
|
|
@ -339,8 +339,8 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
|||
codes = q_to_uint8(blocks[:, 1:17], 4)
|
||||
sign = 1.0 - codes.rshift(3).cast(dtypes.float32) * 2.0
|
||||
exp, mant = codes.rshift(1).bitwise_and(0x3).cast(dtypes.float32), codes.bitwise_and(0x1).cast(dtypes.float32)
|
||||
fp4_val = sign * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
|
||||
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
|
||||
fp4_val = sign * 2.0 * ((exp != 0).cast(dtypes.float32) * (1.0 + 0.5 * mant) * (exp - 1.0).exp2() +
|
||||
(exp == 0).cast(dtypes.float32) * 0.5 * mant)
|
||||
return (fp4_val * d).flatten(-2)[:n]
|
||||
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
||||
|
||||
|
|
|
|||
|
|
@ -80,19 +80,16 @@ class ProgramSpec:
|
|||
ins:list[int]=field(default_factory=list)
|
||||
|
||||
@property
|
||||
def estimates(self) -> Estimates: return self.ast.arg.estimates
|
||||
def estimates(self) -> Estimates: return self.ast.arg.estimates if self.ast.arg is not None and self.ast.arg.estimates is not None else Estimates()
|
||||
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
@functools.cached_property
|
||||
def runtimevars(self) -> dict[str, int]: return {v.arg[0]: i for i, v in enumerate(self.vars) if v.arg[0] == 'core_id'}
|
||||
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
|
||||
|
||||
@property
|
||||
def applied_opts(self) -> tuple[Opt, ...]|None:
|
||||
if self.uops is None: return None
|
||||
assert self.uops[-1].op is Ops.SINK, self.uops[-1].op
|
||||
return self.uops[-1].arg.applied_opts
|
||||
def applied_opts(self) -> tuple[Opt, ...]|None: return self.ast.arg.applied_opts if self.ast.arg is not None else None
|
||||
|
||||
def launch_dims(self, var_vals:dict[str, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from dataclasses import dataclass
|
|||
from typing import Iterator
|
||||
from enum import Enum
|
||||
from tinygrad.renderer.amd.dsl import BitField, FixedBitField, Inst, bits
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, s_endpgm
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm # same encoding as RDNA4
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FIELD ENUMS
|
||||
|
|
@ -102,16 +101,17 @@ class InstOpRDNA4(Enum):
|
|||
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
|
||||
# TODO: we need to do discovery of all of these from instructions
|
||||
SALU = 0x0
|
||||
SMEM = 0x1
|
||||
UNK_02 = 0x2
|
||||
JUMP_NO = 0x4
|
||||
UNK_06 = 0x6
|
||||
JUMP = 0x1
|
||||
NEXT = 0x2
|
||||
MESSAGE = 0x4
|
||||
VALU_64 = 0x6
|
||||
VALU_WMMA = 0x46
|
||||
VMEM = 0x10
|
||||
UNK_11 = 0x11
|
||||
VINTERP = 0x12
|
||||
UNK_14 = 0x14
|
||||
VMEM_128 = 0x11
|
||||
VMEM_STORE = 0x12
|
||||
VMEM_STORE_128 = 0x14
|
||||
OTHER_VMEM = 0x5e
|
||||
UNK_60 = 0x60
|
||||
OTHER_VMEM_STORE = 0x60
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE BASE CLASS
|
||||
|
|
@ -343,8 +343,12 @@ class INST_RDNA4(PacketType): # Layout 4: different delta position and InstOp e
|
|||
delta = bits[5:3]
|
||||
flag1 = bits[6:6]
|
||||
flag2 = bits[7:7]
|
||||
wave = bits[12:8]
|
||||
wave_pair = bits[11:8]
|
||||
flag3 = bits[12:12]
|
||||
op = bits[19:13].enum(InstOpRDNA4)
|
||||
# INST_RDNA4 wave_pair field (4 bits) addresses wave pairs, flag2 selects even/odd wave
|
||||
@property
|
||||
def wave(self): return self.wave_pair * 2 + self.flag2
|
||||
|
||||
class UTILCTR(PacketType):
|
||||
encoding = bits[6:0] == 0b0110001
|
||||
|
|
@ -586,7 +590,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
|||
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
|
||||
for p in decode(data):
|
||||
if not simd_select(p): continue
|
||||
if isinstance(p, WAVESTART):
|
||||
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
|
||||
assert p.wave not in wave_pc, "only one inflight wave per unit"
|
||||
wave_pc[p.wave] = next(iter(pc_map))
|
||||
continue
|
||||
|
|
@ -595,33 +599,35 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
|||
yield (p, InstructionInfo(pc, p.wave, s_endpgm()))
|
||||
continue
|
||||
# skip OTHER_ instructions, they don't belong to this unit
|
||||
if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue
|
||||
if isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): continue
|
||||
if isinstance(p, IMMEDIATE_MASK):
|
||||
# immediate mask may yield multiple times per packet
|
||||
for wave in range(16):
|
||||
if p.mask & (1 << wave):
|
||||
inst = pc_map[pc:=wave_pc[wave]]
|
||||
# can this assert be more strict?
|
||||
assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
|
||||
assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
|
||||
wave_pc[wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, wave, inst))
|
||||
continue
|
||||
if isinstance(p, (VALUINST, INST, IMMEDIATE)):
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# s_delay_alu doesn't get a packet?
|
||||
if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}:
|
||||
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
|
||||
wave_pc[p.wave] += inst.size()
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# identify a branch instruction, only used for asserts
|
||||
is_branch = isinstance(inst, SOPP) and "BRANCH" in inst.op_name
|
||||
if is_branch: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by jump packets, got {p}"
|
||||
branch_inst = inst if "BRANCH" in inst_op else None
|
||||
if branch_inst is not None:
|
||||
assert isinstance(p, (INST, INST_RDNA4)) and p.op.name in {"JUMP_NO", "JUMP", "NEXT"}, f"branch can only be folowed by JUMP, got {p}"
|
||||
# JUMP handling
|
||||
if isinstance(p, INST) and p.op is InstOp.JUMP:
|
||||
assert is_branch, f"JUMP packet must map to a branch instruction, got {inst}"
|
||||
x = inst.simm16 & 0xffff
|
||||
wave_pc[p.wave] += inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
|
||||
if (isinstance(p, INST) and p.op is InstOp.JUMP) or (isinstance(p, INST_RDNA4) and branch_inst is not None and p.flag3):
|
||||
simm16 = getattr(branch_inst, 'simm16')
|
||||
assert branch_inst is not None and simm16 is not None, f"JUMP packet must map to a branch instruction, got {inst}"
|
||||
x = simm16 & 0xffff
|
||||
wave_pc[p.wave] += branch_inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
|
||||
else:
|
||||
if is_branch: assert inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}"
|
||||
if branch_inst is not None: assert inst_op != "S_BRANCH", f"S_BRANCH must have a JUMP packet, got {p}"
|
||||
wave_pc[p.wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, p.wave, inst))
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ class LLVMRenderer(Renderer):
|
|||
supports_float4 = True
|
||||
abi: str | None
|
||||
string_rewrite: PatternMatcher
|
||||
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
|
||||
code_for_op = {k:lambda:None for v in lop.values() for k in v.keys()}
|
||||
if AMX: tensor_cores = tc.amx
|
||||
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
|
||||
|
|
@ -169,7 +169,7 @@ class LLVMRenderer(Renderer):
|
|||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.arg[0]}"
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
args.append((r[u], u.dtype))
|
||||
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
||||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ asm_for_op: dict[Ops, Callable] = {
|
|||
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
|
||||
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
|
||||
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
||||
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||
Ops.CMPNE: lambda d,a,b,dt,name: f"setp.{'neu' if dtypes.is_float(dt) else 'ne'}.{name} {d}, {a}, {b};",
|
||||
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
||||
Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
|
||||
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
|
|
@ -134,7 +135,7 @@ string_rewrite = PatternMatcher([
|
|||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.expr}+0];"),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
|
|
@ -220,7 +221,7 @@ class PTXRenderer(Renderer):
|
|||
continue
|
||||
if u.op is Ops.INDEX: continue # other index we can skip
|
||||
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
|
|
|
|||
|
|
@ -61,8 +61,9 @@ class CLProgram:
|
|||
if isinstance(dt, ImageDType):
|
||||
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
|
||||
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], image_row_pitch=dt.pitch, buffer=b)
|
||||
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
img = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(img), ctypes.byref(img)))
|
||||
else: check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,v in enumerate(vals,start=i+1): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
|
||||
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
event = cl.cl_event() if wait else None
|
||||
|
|
|
|||
|
|
@ -315,6 +315,9 @@ class AMDev(PCIDevImplBase):
|
|||
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
|
||||
self.reserved_vram_size = (384 << 20) if self.ip_ver[am.GC_HWIP][:2] in {(9,4), (9,5)} else (64 << 20)
|
||||
|
||||
@functools.cached_property
|
||||
def hwid_names(self) -> dict[int, str]: return {v:k.removesuffix('_HWID') for k,v in vars(am).items() if k.endswith('_HWID') and isinstance(v, int)}
|
||||
|
||||
def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): return import_module(prefix, self.ip_ver[hwip], prever_prefix)
|
||||
|
||||
def _build_regs(self):
|
||||
|
|
|
|||
|
|
@ -217,6 +217,16 @@ class AM_SMU(AM_IP):
|
|||
with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), timeout=20)
|
||||
if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
|
||||
|
||||
def _aca_read_reg(self, bank_idx:int, reg_idx:int, ue=True) -> int:
|
||||
msg = self.smu_mod.PPSMC_MSG_McaBankDumpDW if ue else self.smu_mod.PPSMC_MSG_McaBankCeDumpDW
|
||||
return (self._send_msg(msg, (bank_idx << 16) | (reg_idx * 8 + 4), read_back_arg=True) << 32) | \
|
||||
self._send_msg(msg, (bank_idx << 16) | (reg_idx * 8), read_back_arg=True)
|
||||
|
||||
def _aca_read_banks(self, ue=True) -> list[list[int]]:
|
||||
if not hasattr(self.smu_mod, 'PPSMC_MSG_QueryValidMcaCount'): return []
|
||||
count_msg = self.smu_mod.PPSMC_MSG_QueryValidMcaCount if ue else self.smu_mod.PPSMC_MSG_QueryValidMcaCeCount
|
||||
return [[self._aca_read_reg(idx, reg_idx, ue=ue) for reg_idx in range(16)] for idx in range(self._send_msg(count_msg, 0, read_back_arg=True))]
|
||||
|
||||
def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):
|
||||
(self.adev.mmMP1_SMN_C2PMSG_90 if not debug else self.adev.mmMP1_SMN_C2PMSG_54).write(0) # resp reg
|
||||
(self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).write(param)
|
||||
|
|
@ -462,6 +472,20 @@ class AM_IH(AM_IP):
|
|||
|
||||
self.adev.regIH_RB_RPTR.write(wptr['offset'] % (self.ring_size // 4))
|
||||
|
||||
bif_intr = self.adev.regBIF_BX0_BIF_DOORBELL_INT_CNTL.read_bitfields()
|
||||
athub_err, cntlr_err = bif_intr['ras_athub_err_event_interrupt_status'], bif_intr['ras_cntlr_interrupt_status']
|
||||
if athub_err or cntlr_err:
|
||||
print(f"am {self.adev.devfmt}: fatal hardware error detected: {'RAS_ATHUB_ERR_EVENT ' if athub_err else ''}{'RAS_CNTLR' if cntlr_err else ''}")
|
||||
|
||||
acas = self.adev.smu._aca_read_banks(ue=True) + self.adev.smu._aca_read_banks(ue=False)
|
||||
for regs in acas:
|
||||
acatyp = 'Uncorrectable' if (regs[1] >> 61) & 1 and (regs[1] >> 57) & 1 else 'Correctable'
|
||||
hwname = f'{self.adev.hwid_names.get((regs[5] >> 32) & 0xFFF, "")} ({(regs[5] >> 32) & 0xFFF:#03x})'
|
||||
print(f"am {self.adev.devfmt}: {acatyp} ACA: {hwname} mcatype={(regs[5] >> 48) & 0xFFFF:#06x} regs=[{', '.join(f'{r:#x}' for r in regs)}]")
|
||||
|
||||
self.adev.regBIF_BX0_BIF_DOORBELL_INT_CNTL.write(ras_cntlr_interrupt_clear=cntlr_err, ras_athub_err_event_interrupt_clear=athub_err)
|
||||
self.adev.is_err_state = True
|
||||
|
||||
class AM_SDMA(AM_IP):
|
||||
def init_sw(self): self.sdma_reginst, self.sdma_name = [], "F32" if self.adev.ip_ver[am.SDMA0_HWIP] < (7,0,0) else "MCU"
|
||||
def init_hw(self):
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class TLSFAllocator:
|
|||
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
||||
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
||||
|
||||
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
||||
# Search for the smallest block that can fit the requested size. Start with its bucket and go up until any block is found.
|
||||
for l1 in range(self.lv1(size), len(self.storage)):
|
||||
if self.lv1_entries[l1] == 0: continue
|
||||
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
||||
|
|
@ -105,7 +105,7 @@ class TLSFAllocator:
|
|||
def free(self, start:int):
|
||||
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
||||
|
||||
# Memory Managment
|
||||
# Memory Management
|
||||
|
||||
class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702
|
||||
|
||||
|
|
@ -221,7 +221,7 @@ class MemoryManager:
|
|||
|
||||
@classmethod
|
||||
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
|
||||
assert cls.va_allocator is not None, "must be set it"
|
||||
assert cls.va_allocator is not None, "must be set"
|
||||
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
||||
|
||||
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
|
||||
|
|
@ -248,7 +248,7 @@ class MemoryManager:
|
|||
return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached)
|
||||
|
||||
def vfree(self, vm:VirtMapping):
|
||||
assert self.va_allocator is not None, "must be set it"
|
||||
assert self.va_allocator is not None, "must be set"
|
||||
self.unmap_range(vm.va_addr, vm.size)
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class NVDev(PCIDevImplBase):
|
|||
self._early_ip_init()
|
||||
self._early_mmu_init()
|
||||
|
||||
# Turn the booting early, gsp client is loaded from the clean.
|
||||
# No booting state, gsp client is reinited every run.
|
||||
self.is_booting = False
|
||||
|
||||
for ip in [self.flcn, self.gsp]: ip.init_sw()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuff
|
|||
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace
|
||||
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
|
||||
|
||||
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
||||
MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000
|
||||
MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PCIBarInfo: addr:int; size:int # noqa: E702
|
||||
|
|
@ -69,7 +70,7 @@ class _System:
|
|||
all_devs.append((int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16),
|
||||
int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16), pcibus))
|
||||
|
||||
return sorted([val for vendor, device, val in all_devs if vendor == vendor and any((device & mask) in devlist for mask, devlist in devices)])
|
||||
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
|
||||
|
||||
def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, PCIBarInfo]:
|
||||
for bus in range(gpu_bus):
|
||||
|
|
@ -219,7 +220,7 @@ class LNXPCIIfaceBase:
|
|||
cls.gpus = hcq_filter_visible_devices(System.pci_scan_bus(vendor, devices, base_class))
|
||||
|
||||
# Acquire va range to avoid collisions.
|
||||
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
|
||||
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED_NOREPLACE, 0)
|
||||
self.pci_dev, self.dev, self.vram_bar = PCIDevice(dev.__class__.__name__[:2], cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
|
||||
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar].addr
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import functools, itertools
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
|
|
@ -17,15 +17,19 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
for s in rb.src:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
pm_generate_realize_map = pm_gate_kernel_sink+PatternMatcher([
|
||||
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
if buf.base in x.backward_slice_with_self: ctx[x] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
|
||||
# always realize REDUCE on outer ranges
|
||||
(UPat(Ops.REDUCE, name="r"), lambda ctx,r: realize(ctx, r) if any(tr.arg[-1] == AxisType.OUTER for tr in r.src[1:]) else None),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ASSIGN, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -68,7 +72,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
|
|
@ -84,7 +88,7 @@ def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
|
|||
def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||
# input ranges
|
||||
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
|
||||
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0], tag=x.tag)
|
||||
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=x.arg[0])
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
|
||||
|
|
@ -161,6 +165,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
|
||||
# get ops to realize
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
|
||||
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
for u in tsink.toposort():
|
||||
if u.op is Ops.ASSIGN and u.src[1].op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and u.src[1] in rctx.realize_map \
|
||||
and not u.src[0].op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del rctx.realize_map[u.src[1]]
|
||||
|
||||
# get the consumer map
|
||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||
|
|
|
|||
|
|
@ -1,85 +1,56 @@
|
|||
from typing import cast
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
|
||||
# Group buffers
|
||||
groups: dict[int|None, list[UOp]] = {}
|
||||
for i,dev in enumerate(buf.device):
|
||||
groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
|
||||
|
||||
# Put reduce leader of each group first
|
||||
reduce_leaders = set(getenv("REDUCE_LEADERS", "").split(","))
|
||||
groups = {gid: sorted(bufs, key=lambda x: (x.device not in reduce_leaders, x.device)) for gid,bufs in groups.items()}
|
||||
|
||||
# Skip if only one group or if every group has only one buffer
|
||||
if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
|
||||
|
||||
# Reduce inside each group
|
||||
inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
|
||||
|
||||
# Allreduce across groups
|
||||
outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
|
||||
|
||||
# Broadcast back to all devices in the group
|
||||
gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
|
||||
return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
|
||||
UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
|
||||
|
||||
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
||||
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
|
||||
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
|
||||
|
||||
# contiguous before we copy it
|
||||
buf = buf.contiguous()
|
||||
|
||||
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||
if not use_ring and not use_all2all:
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
|
||||
|
||||
# chunk data into n_lbs pieces
|
||||
# chunk data into ndev pieces
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
|
||||
base, left = divmod(numel // factor, ndev)
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks = []
|
||||
reduced_chunks:list[UOp] = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
|
||||
else:
|
||||
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
|
||||
for step in range(n_lbs-1):
|
||||
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||
for step in range(ndev-1):
|
||||
src, dest = (i+step)%ndev, (i+step+1)%ndev
|
||||
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
|
||||
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||
reduced_chunks.append(reduced)
|
||||
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * n_lbs
|
||||
this_chunk[(i+n_lbs-1)%n_lbs] = rc
|
||||
for step in range(n_lbs-1):
|
||||
this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
|
|
@ -90,7 +61,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
|||
ret:list[UOp] = []
|
||||
def apply_shrink(s:UOp, i:int) -> UOp:
|
||||
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
|
||||
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.marg]
|
||||
(dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg]
|
||||
return s.shrink(tuple(new_arg))
|
||||
for i, x in enumerate(ms.src):
|
||||
if x.op is Ops.COPY:
|
||||
|
|
@ -100,7 +71,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
|||
return ms.replace(src=tuple(ret))
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
||||
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
||||
|
|
@ -125,20 +95,20 @@ def alu_multi(root:UOp):
|
|||
axis = root.axis
|
||||
assert axis is not None
|
||||
|
||||
srcs = []
|
||||
srcs:list[UOp] = []
|
||||
for mlb in msrcs:
|
||||
if mlb.axis == axis:
|
||||
# same axis, just copy through
|
||||
assert mlb.op is Ops.MULTI
|
||||
srcs.append(mlb.src[0])
|
||||
elif mlb.axis is None:
|
||||
if mlb.axis is None:
|
||||
# no axis, shard it
|
||||
assert mlb.op is not Ops.MULTI
|
||||
srcs.append(mlb._shard(axis))
|
||||
else:
|
||||
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
||||
assert mlb.op is Ops.MULTI
|
||||
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
||||
if mlb.axis == axis:
|
||||
# same axis, just copy through
|
||||
srcs.append(mlb.src[0])
|
||||
else:
|
||||
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
||||
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
||||
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
|
||||
|
||||
def reduce_multi(root:UOp, multi:UOp):
|
||||
|
|
@ -149,21 +119,15 @@ def reduce_multi(root:UOp, multi:UOp):
|
|||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return multi.src[0].r(op, axis).multi(axis=multi.axis)
|
||||
|
||||
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
||||
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
||||
|
||||
def reshape_multi(root:UOp, multi:UOp):
|
||||
arg = root.marg
|
||||
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
|
||||
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
||||
assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {arg=}"
|
||||
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
|
||||
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
|
||||
if prod(multi.shape) != prod(new_shape:=root.marg): raise RuntimeError("reshape must maintain prod(shape)")
|
||||
if (new_axis:=root.axis) is not None: new_shape = tuple(s//len(multi.device) if a==new_axis else s for a,s in enumerate(new_shape))
|
||||
return multi.src[0].reshape(new_shape).multi(new_axis)
|
||||
|
||||
def expand_multi(root:UOp, multi:UOp):
|
||||
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
||||
assert multi.axis is None or root.marg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.marg=}"
|
||||
return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.marg, multi.src[0])).multi(multi.axis)
|
||||
if multi.axis is None: new_shape = root.marg
|
||||
else: new_shape = tuple(multi.src[0].shape[multi.axis] if a == multi.axis else s for a,s in enumerate(root.marg))
|
||||
return multi.src[0].expand(new_shape).multi(multi.axis)
|
||||
|
||||
def pad_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.marg[multi.axis] == (0,0), f"padding not supported for {root.marg=}"
|
||||
|
|
@ -177,11 +141,10 @@ def shrink_multi(root:UOp, multi:UOp):
|
|||
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]) or root.marg[multi.axis] in multi.bounds, \
|
||||
f"shrinking not supported for {root.marg=}"
|
||||
if multi.axis is not None and root.marg[multi.axis] in multi.bounds and root.marg[multi.axis] != (0, multi.shape[multi.axis]):
|
||||
assert all(root.marg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
|
||||
"cannot shrink sharded and non-sharded axis at the same time"
|
||||
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
||||
# we just copy it to all the devices, no real. this will be optimized out later
|
||||
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.marg[multi.axis]))
|
||||
non_shard_shrink = tuple((0, multi.src[0].shape[i]) if i == multi.axis else s for i, s in enumerate(root.marg))
|
||||
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.marg[multi.axis])).shrink(non_shard_shrink)
|
||||
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))).multi(multi.axis)
|
||||
|
||||
def flip_multi(root:UOp, multi:UOp):
|
||||
|
|
@ -224,9 +187,3 @@ multi_pm = PatternMatcher([
|
|||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
|
||||
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType, BottomUpGate
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
|
@ -26,27 +26,13 @@ pm_mops = PatternMatcher([
|
|||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
|
||||
# move movement ops after AFTER
|
||||
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:], tag=None),)+r.src[1:], r.arg, tag=a.tag)),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
|
||||
if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None
|
||||
# partial view of unrealized graph: insert CONTIGUOUS at base to realize it
|
||||
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
|
||||
if t.op is Ops.CONTIGUOUS: return None
|
||||
mops: list[UOp] = []
|
||||
while target.op in GroupOp.Movement:
|
||||
mops.append(target)
|
||||
target = target.src[0]
|
||||
new_target = t.f(Ops.CONTIGUOUS, tag=t.tag)
|
||||
for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:])
|
||||
return assign.replace(src=(new_target, src))
|
||||
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
|
||||
|
||||
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
|
|
@ -83,19 +69,21 @@ def split_reduceop(reduce:UOp, x:UOp):
|
|||
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
|
||||
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
|
||||
# reduce original axes, then split
|
||||
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape).replace(tag=reduce.tag)
|
||||
return splitted.r(*reduce.arg).contiguous().r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
|
||||
|
||||
mop_cleanup = PatternMatcher([
|
||||
# merge adjacent RESHAPES, safe because they are not tagged
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"),
|
||||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
# merge adjacent RESHAPES
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
|
|
@ -103,49 +91,31 @@ def resolve_call(c:UOp) -> UOp|None:
|
|||
for i, (p, a) in enumerate(zip(params, args)):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict(zip(params, args))).rtag(c.tag)
|
||||
return c.src[0].substitute(dict(zip(params, args)))
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# remove CONTIGUOUS if the source is already contiguous
|
||||
(UPat(Ops.RESHAPE, src=(UPat((Ops.PARAM, Ops.CONTIGUOUS)), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
|
||||
# preserve tags?
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
|
||||
# handle size 0
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
|
||||
# remove contiguous on movement ops before a copy on disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
# push copy past movement ops to disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:], tag=None),)+x.src[1:], tag=copy.tag) \
|
||||
lambda x,copy: x.replace(src=(copy.replace(src=(x.src[0],)+copy.src[1:]),)+x.src[1:]) \
|
||||
if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
|
||||
# ** copy rules **
|
||||
|
||||
# early fixup const copy
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
|
||||
# COPY and source size need to match
|
||||
# TODO: expand after copy creates issues with tagging
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
|
||||
|
||||
# copy only to different device
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
|
||||
|
||||
# ** assign rules **
|
||||
|
||||
|
|
@ -153,23 +123,20 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
(UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat(Ops.ASSIGN, src=(UPat(name="target"), UPat()), name="src"))), lambda target, src: src),
|
||||
|
||||
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
|
||||
lambda assign, target, src: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src"))),
|
||||
lambda target, src: target.assign(src.bitcast(target.dtype))),
|
||||
|
||||
# if assign target is itself an ASSIGN chain, canonicalize to the original buffer target
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, name="target"), UPat(name="src")), allow_any_len=True, name="assign"), normalize_assign_target_chain),
|
||||
|
||||
# assign only to buffer, otherwise make it a CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.PARAM}, name="target"), UPat(name="src")), name="assign"), assign_to_contiguous),
|
||||
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
# make source contiguous if it has hazardous movement ops on the dest buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("target"), UPat.var("src")), name="assign"), fix_assign_hazard),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC}
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
|
|
@ -190,8 +157,7 @@ def cleanup_dead_axes(b:UOp):
|
|||
reshape.append(s)
|
||||
new_rng.append(rng)
|
||||
if hit:
|
||||
# move the tag to the expand. NOTE: this expand tag might not survive
|
||||
return b.replace(src=b.src[0:1]+tuple(new_rng), tag=None).reshape(tuple(reshape)).expand(b.shape).replace(tag=b.tag)
|
||||
return b.replace(src=b.src[0:1]+tuple(new_rng)).reshape(tuple(reshape)).expand(b.shape)
|
||||
|
||||
def gate_substitute(ctx, b:UOp) -> None:
|
||||
if not any(r in b.ranges for r in ctx.keys()): raise BottomUpGate()
|
||||
|
|
@ -264,8 +230,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
|
||||
def remove_noop_bufferize(idx,b2):
|
||||
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
|
||||
new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None
|
||||
return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag)
|
||||
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
|
||||
|
||||
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||
|
|
@ -275,13 +240,13 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
|
||||
# indexing a const is a const
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
# copy on CONST is CONST
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||
# hack if a noop turned to a const
|
||||
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c.rtag(noop.tag)),
|
||||
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
|
||||
# mstack on CONST is CONST
|
||||
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
|
||||
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
|
||||
|
|
@ -307,7 +272,7 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset)), b.src[1]))
|
||||
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
|
|
@ -346,7 +311,6 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
|
|||
# NOTE: this has been fixed up a bit
|
||||
|
||||
def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
||||
#assert isinstance(x.tag, Flat), "bufferize must be flat"
|
||||
size = prod(x.shape)
|
||||
rngs = sorted(idx.ranges, key=lambda x: x.arg)
|
||||
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
|
||||
|
|
@ -359,26 +323,14 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
# skip self-assign from same-device copy, otherwise create the store
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
if assign_src is assign_target: ret = assign_target.src[0]
|
||||
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*rngs))
|
||||
else: ret = assign_target.src[0].after(assign_target.replace(dtype=sdtype).store(assign_src).end(*rngs))
|
||||
for op, marg in reversed(assign.arg or ()): ret = ret._mop(op, marg)
|
||||
return ret
|
||||
|
||||
# lower outerworld reduce here
|
||||
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
|
||||
assert sdtype.addrspace == AddrSpace.GLOBAL
|
||||
outer_range = x.src[0].src[1]
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
# NOTE: this has the same number as the outer range, we need string ranges!
|
||||
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
|
||||
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
|
||||
bufi = buf.index(idx, dtype=sdtype)
|
||||
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
|
||||
return buf.after(do_store)
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
|
||||
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
||||
return buf.after(do_store)
|
||||
|
||||
if allow_locals:
|
||||
|
|
@ -387,16 +339,16 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
||||
return buf.after(do_store.barrier())
|
||||
|
||||
# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape
|
||||
# collapse any BUFFERIZE to single input BUFFERIZE
|
||||
def flatten_bufferize(x:UOp):
|
||||
if x.tag is None and len(x.src) == 2: return None
|
||||
ret = x.replace(tag=None, src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
|
||||
if len(x.src) == 2: return None
|
||||
ret = x.replace(src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
|
||||
rngs = x.src[1:]
|
||||
ret = ret.forced_reshape(x.shape)
|
||||
ret = ret.reshape(x.shape)
|
||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||
return ret.rtag(x.tag)
|
||||
return ret
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
|
||||
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
|
@ -404,7 +356,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src])).reshape(m.shape)),
|
||||
|
||||
# remove any RESHAPEs on KERNEL
|
||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
|
|
@ -423,7 +375,6 @@ class LocalAddBufferContext:
|
|||
map:dict = field(default_factory=dict)
|
||||
vars:dict = field(default_factory=dict)
|
||||
range:int = 0
|
||||
parent_tags:list = field(default_factory=list)
|
||||
opts:tuple|None = None
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, buf:UOp):
|
||||
|
|
@ -447,9 +398,6 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
|||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
if r.tag != (): return None
|
||||
if r.arg[-1] == AxisType.OUTER:
|
||||
# for outer range, we replace with a bound variable
|
||||
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
|
||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
||||
ctx.range += 1
|
||||
return ret
|
||||
|
|
@ -488,12 +436,6 @@ rangeify_codegen = PatternMatcher([
|
|||
# TODO: this can be moved into codegen?
|
||||
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
|
||||
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
|
||||
# .f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
# lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()),
|
||||
|
||||
# fix broadcast dtype
|
||||
(UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))),
|
||||
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
|
|
@ -505,32 +447,17 @@ rangeify_codegen = PatternMatcher([
|
|||
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
|
||||
])
|
||||
|
||||
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
|
||||
if x.tag is None or x.tag == (): return None
|
||||
if isinstance(x.tag, tuple): ctx.parent_tags += list(x.tag)
|
||||
return x.replace(tag=None)
|
||||
|
||||
pm_remove_tags = PatternMatcher([
|
||||
(UPat(GroupOp.All, name="x"), remove_metadata_tags),
|
||||
])
|
||||
|
||||
pm_add_range_tags = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())),
|
||||
])
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
# if we have any non-outer ranges open here, we don't split
|
||||
if any(r.arg[-1] != AxisType.OUTER for r in x.ranges): return None
|
||||
|
||||
# ends of outer range don't go in kernels
|
||||
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
|
||||
def split_store(x:UOp) -> UOp|None:
|
||||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# gather the metadata
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
|
||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
|
|
@ -539,8 +466,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||
return kernel
|
||||
|
|
@ -549,42 +475,9 @@ split_kernels = PatternMatcher([
|
|||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None or x in ctx[1]: return None
|
||||
if x.tag is None and x.op is Ops.CALL:
|
||||
# don't tag anything in a CALL
|
||||
for u in x.src[0].toposort(): ctx[1].add(u)
|
||||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = pm_gate_kernel_sink+PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.PARAM for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
x = src
|
||||
while x is not src.base:
|
||||
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
|
||||
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
|
||||
else: return None
|
||||
x = x.src[0]
|
||||
ctx[src.base] = contig
|
||||
replace_contiguous = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
def get_rangeify(sink:UOp) -> UOp:
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites")
|
||||
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
|
||||
|
|
@ -592,19 +485,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
|
||||
# if it's not tagged by here, it's out
|
||||
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.PARAM, Ops.AFTER} and \
|
||||
x.tag is not None and len(x.tag)])
|
||||
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Rangeify")
|
||||
|
||||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
|
||||
name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
|
||||
|
||||
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
|
||||
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
|
||||
|
|
@ -619,17 +505,5 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
|
||||
# TODO: we can probably get this earlier
|
||||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for tag, s in zip(sink_tags, tsink.src):
|
||||
assert tag is not None
|
||||
for a in tag:
|
||||
if a is None: continue
|
||||
becomes_map[uop_list[int(a)]] = s
|
||||
return becomes_map
|
||||
return tsink
|
||||
|
|
@ -336,7 +336,7 @@ class Tensor(OpMixin):
|
|||
raise JitError("cannot access tensor data during JIT capture, the value will be baked in")
|
||||
x = self.cast(self.dtype.base).contiguous()
|
||||
if isinstance(self.device, tuple): x = x.to("CPU")
|
||||
return cast(Buffer, x.realize().uop.base.buffer).ensure_allocated()
|
||||
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
|
||||
def _data(self) -> memoryview: return self._buffer().as_memoryview()
|
||||
|
||||
def data(self) -> memoryview:
|
||||
|
|
@ -404,7 +404,7 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
Creates a clone of this tensor allocating a separate buffer for the data.
|
||||
"""
|
||||
ret = Tensor.empty(self.shape, device=self.device, dtype=self.dtype)
|
||||
ret = self.empty_like()
|
||||
if self.grad is not None: ret.grad = self.grad.clone()
|
||||
return ret.assign(self)
|
||||
|
||||
|
|
@ -537,12 +537,15 @@ class Tensor(OpMixin):
|
|||
device = canonicalize_device(device)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
|
||||
def empty_like(self, **kwargs) -> Tensor:
|
||||
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the same shape as `self`.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
"""
|
||||
return Tensor.empty(self.shape, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
dtype, device = self.dtype if dtype is None else dtype, self.device if device is None else device
|
||||
if isinstance(device, tuple) and (axis := self.uop.axis) is not None:
|
||||
return Tensor(Tensor.empty(self.uop.max_shard_shape, dtype=dtype, device=device, **kwargs).uop.multi(axis), device=device)
|
||||
return Tensor.empty(self.shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
||||
|
|
@ -628,7 +631,7 @@ class Tensor(OpMixin):
|
|||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
|
||||
# increment rng counter for devices
|
||||
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
||||
|
||||
|
|
@ -1075,6 +1078,34 @@ class Tensor(OpMixin):
|
|||
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
|
||||
def _pad_constant(self, pX:tuple[tuple[sint, sint], ...], value:float) -> Tensor:
|
||||
# shrink first for negative pads, then pad with only non-negative values
|
||||
has_neg = not all(resolve(p >= 0) for p in flatten(pX))
|
||||
X = self.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))) if has_neg else self
|
||||
pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if has_neg else pX
|
||||
if value == 0: return X._apply_uop(UOp.pad, arg=pads)
|
||||
return X._apply_uop(UOp.pad, arg=pads) + Tensor.ones_like(X)._apply_uop(UOp.pad, arg=pads).where(0, value)
|
||||
|
||||
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Tensor:
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
||||
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
|
||||
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
|
||||
|
||||
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Tensor:
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
for d,(pB,pA) in enumerate(pads):
|
||||
if mode == "reflect":
|
||||
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
||||
slcB, slcA = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
||||
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
||||
else:
|
||||
shrB, shrA = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
|
||||
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
|
||||
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
||||
# shrink after for negative pads (reflection/replication must see full data first)
|
||||
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
||||
|
||||
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with padding applied based on the input `padding`.
|
||||
|
|
@ -1107,36 +1138,18 @@ class Tensor(OpMixin):
|
|||
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
|
||||
```
|
||||
"""
|
||||
if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
|
||||
# flat padding
|
||||
# normalize to grouped format
|
||||
if all(isinstance(p, (int,UOp)) for p in padding):
|
||||
if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
||||
pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
|
||||
# group padding
|
||||
else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
|
||||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
def _constant(x:Tensor,px,v) -> Tensor:
|
||||
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
||||
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
||||
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
||||
# dispatch
|
||||
if mode == "constant": return self._pad_constant(pX, value)
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
if mode == "circular":
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
||||
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
|
||||
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
|
||||
for d,(pB,pA) in enumerate(pads):
|
||||
if mode == "reflect":
|
||||
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
||||
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
||||
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
||||
if mode == "replicate":
|
||||
shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
|
||||
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
|
||||
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
||||
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
||||
if mode == "circular": return self._pad_circular(pX)
|
||||
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
||||
raise NotImplementedError(f"{mode=} is not supported")
|
||||
|
||||
# convenience
|
||||
def pad_to(self, shape, *args):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDT
|
|||
from tinygrad.dtype import storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.renderer import Estimates
|
||||
|
|
@ -16,16 +16,15 @@ if TYPE_CHECKING:
|
|||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
THREAD = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
|
||||
AxisType.OUTER: "green"}
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1}
|
||||
|
||||
|
|
@ -304,7 +303,24 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
|
||||
@property
|
||||
def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
|
||||
def max_shape(self) -> tuple[int, ...]:
|
||||
return tuple([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
|
||||
|
||||
@property
|
||||
def shard_shape(self) -> tuple[sint, ...]:
|
||||
if not isinstance(self.device, tuple) or self.axis is None: return self.shape
|
||||
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
|
||||
|
||||
@property
|
||||
def max_shard_shape(self) -> tuple[int, ...]:
|
||||
if not isinstance(self.device, tuple) or self.axis is None: return self.max_shape
|
||||
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.max_shape))
|
||||
|
||||
@property
|
||||
def size(self) -> int: return prod(self.max_shape)
|
||||
|
||||
@property
|
||||
def shard_size(self) -> int: return prod(self.max_shard_shape)
|
||||
|
||||
@functools.cached_property
|
||||
def ended_ranges(self):
|
||||
|
|
@ -352,7 +368,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return vmin
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
def __float__(self): return float(self._eval(dtypes.floats, float))
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
|
||||
dvars = {k:v for k,v in dvars.items() if k is not v}
|
||||
if len(dvars) == 0: return self
|
||||
|
|
@ -491,18 +507,23 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
@functools.cached_property
|
||||
def axis(self) -> int|None:
|
||||
# COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK
|
||||
if self.op is Ops.COPY: return None
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
if len(self.src) == 0: return None
|
||||
src_axis = self.src[0].axis
|
||||
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (0, self.src[0].shape[src_axis]):
|
||||
return None # SHRINK will remove the sharding if it's on axis
|
||||
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
||||
if self.op is Ops.RESHAPE:
|
||||
if src_axis is None: return None
|
||||
arg_acc:list[sint] = list(itertools.accumulate(self.marg, operator.mul, initial=1))
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
|
||||
if self.shape[new_axis] % len(self.device) != 0: raise RuntimeError(f"reshape {self.src[0].shape} -> {self.shape} moved items between shards")
|
||||
return new_axis
|
||||
if self.op is Ops.PERMUTE: return self.marg.index(src_axis) if src_axis is not None else None
|
||||
return src_axis
|
||||
|
||||
|
|
@ -537,6 +558,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
|
||||
return self
|
||||
|
||||
@property
|
||||
def multibase(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base
|
||||
return self
|
||||
|
||||
# like gep, but might return an integer
|
||||
def sgep(self, i:int) -> sint:
|
||||
match self.op:
|
||||
|
|
@ -554,6 +581,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
||||
|
||||
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
|
||||
# early NOOP
|
||||
if op in {Ops.SHRINK, Ops.PAD, Ops.EXPAND} and len(arg) == 0:
|
||||
assert len(self.shape) == 0, "0 len arg only valid on zero length shape"
|
||||
return self
|
||||
match op:
|
||||
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
|
||||
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
|
||||
|
|
@ -567,7 +598,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
|
||||
# in these four, if the shape doesn't change we can return self
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
|
|
@ -622,9 +652,26 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@property
|
||||
def buffer(self) -> Buffer|MultiBuffer:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE}: return self.src[0].buffer
|
||||
# this buffer can process disk tensors and simple movement ops
|
||||
if self is not self.base:
|
||||
assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}"
|
||||
return self.src[0].buffer
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops+symbolic)
|
||||
buf = out.src[0].buffer
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for movement ops"
|
||||
assert out.op is Ops.INDEX, "couldn't collapse to a single INDEX"
|
||||
if out.src[1].op is Ops.CONST:
|
||||
return buf.view(1, out.dtype, out.src[1].arg*out.dtype.itemsize)
|
||||
if out.src[1].op is Ops.RANGE:
|
||||
return buf.view(self.size, out.dtype, 0)
|
||||
if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST:
|
||||
return buf.view(self.size, out.dtype, out.src[1].src[1].arg*out.dtype.itemsize)
|
||||
raise RuntimeError(f"cannot collapse INDEX {out.pyrender()} to a single size/offset")
|
||||
if self.op is Ops.BITCAST:
|
||||
buf = self.src[0].buffer
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"
|
||||
return buf.view(self.size, self.dtype, 0)
|
||||
if self.op is Ops.MSELECT:
|
||||
ret = self.src[0].buffer
|
||||
assert isinstance(ret, MultiBuffer)
|
||||
|
|
@ -676,12 +723,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return graph_rewrite(self, pm_unbind, ctx=ret), ret
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def vars(self) -> set[UOp]:
|
||||
topo = self.toposort()
|
||||
bound = {x.src[0]: x for x in topo if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR}
|
||||
return {bound.get(x, x) for x in topo if x.op is Ops.DEFINE_VAR}
|
||||
def variables(self) -> list[Variable]:
|
||||
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
|
|
@ -775,7 +818,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def _sym_fxn(self):
|
||||
sself = self.simplify()
|
||||
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
||||
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
|
||||
|
||||
|
|
@ -794,11 +837,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@property
|
||||
def shard_shape(self):
|
||||
if self.axis is None: return self.shape
|
||||
return tuple(x//len(self.device) if i == self.axis else x for i,x in enumerate(self.shape))
|
||||
|
||||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.PARAM, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
|
|
@ -807,7 +845,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
return UOp.placeholder(self.shard_shape, self.dtype, slot)
|
||||
return UOp.placeholder(self.max_shard_shape, self.dtype, slot)
|
||||
|
||||
# set is store+end+after
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||
|
|
@ -1192,12 +1230,13 @@ if TRACK_MATCH_STATS or PROFILE:
|
|||
SENTINEL: Final[UOp] = cast(UOp, object())
|
||||
class BottomUpGate(Exception): pass
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
def __init__(self, pm, bpm, ctx=None, rewrite_into_calls=False):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
self.bpm: PatternMatcher|None = bpm
|
||||
self.bpm_cache: dict[UOp, UOp|None] = {}
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.rewrite_into_calls = rewrite_into_calls
|
||||
|
||||
# no cache needed: pm_rewrite is called at most once per UOp due to the replace dict check in unified_rewrite
|
||||
def pm_rewrite(self, x:UOp) -> UOp|None: return unwrap(self.pm).rewrite(x, self.ctx)
|
||||
|
|
@ -1232,6 +1271,10 @@ class RewriteContext:
|
|||
if n in waitlist: stack.extend(waitlist.pop(n))
|
||||
continue
|
||||
stack.append((n, 1, new_n))
|
||||
# NOTE: CALL is handled as a special case.
|
||||
# The function that is called is not included in the graph_rewrite.
|
||||
# If you want to graph_rewrite a call, you can
|
||||
if new_n.op is Ops.CALL and not self.rewrite_into_calls: self.replace[new_n.src[0]] = new_n.src[0]
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
|
|
@ -1270,22 +1313,10 @@ class RewriteContext:
|
|||
return self.replace[root]
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, rewrite_into_calls=False) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx, rewrite_into_calls=rewrite_into_calls)
|
||||
return rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
|
||||
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
new_map: dict[UOp, UOp] = {}
|
||||
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
|
||||
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
|
||||
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
|
||||
|
|
@ -1318,7 +1349,6 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get
|
|||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
|
||||
pm_gate_kernel_sink = PatternMatcher([(UPat(Ops.SINK, name="sink"), lambda sink: None if gate_kernel_sink(sink) else panic(BottomUpGate))])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
|
|
@ -1347,7 +1377,7 @@ def bitcast(x, in_dtype:DType, out_dtype:DType):
|
|||
return ret[0] if out_count == 1 else ret
|
||||
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]),
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
|
||||
|
|
@ -1409,8 +1439,6 @@ pm_pyrender_extra = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
|
||||
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
|
||||
# TODO: fix forced_reshape
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
|
||||
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
|
||||
# NOTE: CMPNE doesn't work cause there's no __rne__
|
||||
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
|
||||
|
|
|
|||
|
|
@ -123,16 +123,9 @@ _tensor_spec = PatternMatcher([
|
|||
# REDUCE_AXIS is the reduce in the tensor graph
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
|
||||
# REDUCE with an outerworld range
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# Tensor range bind / store
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
|
||||
|
||||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
|
||||
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
|
||||
|
|
@ -58,6 +60,10 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
|
||||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@
|
|||
pointer-events: none;
|
||||
}
|
||||
label {
|
||||
display: inline-flex;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
line-height: 1;
|
||||
|
|
|
|||
|
|
@ -192,7 +192,9 @@ const waveColor = (op) => {
|
|||
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
|
||||
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
|
||||
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
|
||||
WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
|
||||
GPC:new Map([["NONE","#1a7a2e"],["MEMORY_DEPENDENCY","#8b1a00"],["EXEC_DEPENDENCY","#006b6b"],["INST_FETCH","#7a7a00"],["SYNC","#6b006b"],
|
||||
["PIPE_BUSY","#7a4a00"],["MEMORY_THROTTLE","#5c0000"],["CONSTANT_MEMORY","#1a3d7a"],["NOT_SELECTED","#2e2e3a"],["OTHER","#4a4a55"],
|
||||
["SLEEPING","#1a1a2a"],["DEFAULT","#3a3a45"]]), WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
|
||||
const cycleColors = (lst, i) => lst[i%lst.length];
|
||||
|
||||
const rescaleTrack = (source, tid, k) => {
|
||||
|
|
@ -235,7 +237,11 @@ function selectShape(key) {
|
|||
|
||||
const Modes = {0:'read', 1:'write', 2:'write+read'};
|
||||
|
||||
function getMetadata(key) {
|
||||
function setFocus(key) {
|
||||
if (key !== focusedShape) {
|
||||
saveToHistory({ shape:focusedShape });
|
||||
focusedShape = key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
|
||||
}
|
||||
const { eventType, e } = selectShape(key);
|
||||
const html = d3.create("div").classed("info", true);
|
||||
if (eventType === EventTypes.EXEC) {
|
||||
|
|
@ -247,14 +253,14 @@ function getMetadata(key) {
|
|||
for (const b of e.arg.bufs.sort((a, b) => a.num - b.num)) {
|
||||
group.append("p").text(`${Modes[b.mode]}@data${b.num} ${formatUnit(b.nbytes, 'B')}`).style("cursor", "pointer").on("click", () => {
|
||||
const row = document.getElementById(b.k); if (!isExpanded(row)) { row.click(); }
|
||||
focusShape(b.key);
|
||||
setFocus(b.key);
|
||||
});
|
||||
}
|
||||
if (e.arg.ctx != null) {
|
||||
const i = e.arg.ctx; s = e.arg.step;
|
||||
html.append("a").text(ctxs[i+1].steps[s].name).on("click", () => switchCtx(i, s));
|
||||
const prgSrc = ctxs[i+1].steps.findIndex(s => s.name === "View Program");
|
||||
if (prgSrc !== -1) html.append("a").text("View program").on("click", () => switchCtx(i, prgSrc));
|
||||
const prgSrc = ctxs[i+1].steps.findIndex(s => s.name === "View Source");
|
||||
if (prgSrc !== -1) html.append("a").text("View Source").on("click", () => switchCtx(i, prgSrc));
|
||||
}
|
||||
}
|
||||
if (eventType === EventTypes.BUF) {
|
||||
|
|
@ -268,16 +274,10 @@ function getMetadata(key) {
|
|||
const p = kernels.append("p").append(() => colored(`[${u}] ${repr} ${Modes[mode]}@data${num}`));
|
||||
const shapeInfo = selectShape(shape).e?.arg?.tooltipText?.split("\n");
|
||||
if (shapeInfo?.length > 5) p.append("span").text(" "+shapeInfo[5]);
|
||||
if (shape != null) p.style("cursor", "pointer").on("click", () => focusShape(shape));
|
||||
if (shape != null) p.style("cursor", "pointer").on("click", () => setFocus(shape));
|
||||
}
|
||||
}
|
||||
return html.node();
|
||||
}
|
||||
|
||||
function focusShape(shape) {
|
||||
saveToHistory({ shape:focusedShape });
|
||||
focusedShape = shape; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
|
||||
return metadata.replaceChildren(getMetadata(focusedShape));
|
||||
return metadata.replaceChildren(html.node());
|
||||
}
|
||||
|
||||
const EventTypes = { EXEC:0, BUF:1 };
|
||||
|
|
@ -287,7 +287,7 @@ async function renderProfiler(path, unit, opts) {
|
|||
// support non realtime x axis units
|
||||
formatTime = unit === "realtime" ? formatMicroseconds : formatCycles;
|
||||
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
|
||||
metadata.replaceChildren(getMetadata(focusedShape));
|
||||
setFocus(focusedShape);
|
||||
// layout once!
|
||||
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
|
||||
const profiler = d3.select("#profiler").html("");
|
||||
|
|
@ -375,7 +375,7 @@ async function renderProfiler(path, unit, opts) {
|
|||
// tiny device events go straight to the rewrite rule
|
||||
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
|
||||
const labelHTML = label.map(l=>`<span style="color:${l.color}">${l.st}</span>`).join("");
|
||||
const arg = { tooltipText:labelHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key,
|
||||
const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key,
|
||||
ctx:shapeRef?.ctx, step:shapeRef?.step };
|
||||
if (e.key != null) shapeMap.set(e.key, key);
|
||||
// offset y by depth
|
||||
|
|
@ -606,7 +606,7 @@ async function renderProfiler(path, unit, opts) {
|
|||
e.preventDefault();
|
||||
const foundRect = findRectAtPosition(e.clientX, e.clientY);
|
||||
if (foundRect?.step != null && (foundRect?.key == null || e.type == "dblclick")) { return switchCtx(foundRect.ctx, foundRect.step); }
|
||||
if (foundRect?.key != focusedShape) { focusShape(foundRect?.key); }
|
||||
if (foundRect?.key != focusedShape) { setFocus(foundRect?.key); }
|
||||
}
|
||||
canvas.addEventListener("click", clickShape);
|
||||
canvas.addEventListener("dblclick", clickShape);
|
||||
|
|
@ -739,12 +739,12 @@ function saveToHistory(ns) {
|
|||
const switchCtx = (newCtx, step) => setState({ expandSteps:true, currentCtx:newCtx+1, currentStep:step ?? 0, currentRewrite:0 });
|
||||
|
||||
window.addEventListener("popstate", (e) => {
|
||||
if (e.state?.shape != null) return focusShape(e.state?.shape);
|
||||
if (e.state?.shape != null) return setFocus(e.state?.shape);
|
||||
if (e.state != null) setState(e.state);
|
||||
});
|
||||
|
||||
const createToggle = (id, text) => {
|
||||
const label = d3.create("label").style("display", "block").text(text).node();
|
||||
const label = d3.create("label").text(text).node();
|
||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
||||
label.prepend(toggle);
|
||||
return { toggle, label };
|
||||
|
|
@ -826,7 +826,7 @@ async function main() {
|
|||
}
|
||||
// timeline with cycles on the x axis
|
||||
if (ret instanceof ArrayBuffer) {
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), colorByName:step.name.includes("PKTS")};
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")};
|
||||
return renderProfiler(ckey, "clk", opts);
|
||||
}
|
||||
metadata.innerHTML = "";
|
||||
|
|
@ -872,7 +872,7 @@ async function main() {
|
|||
}
|
||||
if (ret.ref != null) {
|
||||
const disasmIdx = ctxs[ret.ref+1].steps.findIndex(s => s.name === "View Disassembly")
|
||||
metadata.appendChild(d3.create("a").text("View Program Graph").on("click", () => switchCtx(ret.ref, disasmIdx)).node());
|
||||
metadata.appendChild(d3.create("a").text("View Disassembly").on("click", () => switchCtx(ret.ref, disasmIdx)).node());
|
||||
}
|
||||
if (ret.cols != null) renderTable(root, ret);
|
||||
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
|||
const disconnected = new Set();
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
|
||||
if (node.label.startsWith("CALL\n")) {
|
||||
for (const pred of (g.predecessors(n) || [])) {
|
||||
const edge = g.edge(pred, n);
|
||||
if (edge?.label?.text === 0) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from http.server import BaseHTTPRequestHandler
|
|||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, Context
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
from tinygrad.renderer.amd import detect_format
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
class TCPServerWithReuse(socketserver.TCPServer):
|
||||
|
|
@ -93,6 +95,8 @@ def pystr(u:UOp) -> str:
|
|||
try: return pyrender(u)
|
||||
except Exception: return str(u)
|
||||
|
||||
# all the trace points, initialized after the trace loads
|
||||
ctxs:list[dict] = []
|
||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: dict[int, dict] = {}
|
||||
|
|
@ -133,7 +137,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
||||
# limit SOURCE labels line count
|
||||
|
|
@ -184,8 +188,6 @@ def rel_ts(ts:int|Decimal, start_ts:int) -> int:
|
|||
device_ts_diffs:dict[str, Decimal] = {}
|
||||
def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
|
||||
|
||||
amdgpu_targets:dict[str, str] = {}
|
||||
|
||||
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
||||
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
||||
for e in profile:
|
||||
|
|
@ -205,7 +207,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
|
|||
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
|
||||
if dur == 0: continue
|
||||
name, fmt, key = e.name, [], None
|
||||
if (ref:=ref_map.get(name)) is not None:
|
||||
if (ref:=ref_map.get(name)) is not None and ctxs:
|
||||
name = ctxs[ref]["name"]
|
||||
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
|
||||
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
|
||||
|
|
@ -299,16 +301,18 @@ def unpack_pmc(e) -> dict:
|
|||
|
||||
# ** on startup, list all the performance counter traces
|
||||
|
||||
def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
def load_amd_counters(profile:list[ProfileEvent]) -> None:
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
counter_events:dict[tuple[int, int], dict] = {}
|
||||
durations:dict[str, list[float]] = {}
|
||||
prg_events:dict[int, ProfileProgramEvent] = {}
|
||||
arch = ""
|
||||
for e in profile:
|
||||
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
|
||||
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
|
||||
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
|
||||
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
|
||||
if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}"
|
||||
if len(counter_events) == 0: return None
|
||||
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
|
||||
run_number = {n:0 for n,_ in counter_events}
|
||||
|
|
@ -323,13 +327,13 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
|||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)),
|
||||
data=(e.blob, prg_events[k].lib, amdgpu_targets[e.device])))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
|
||||
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
|
||||
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4
|
||||
ret:list[ProfileEvent] = []
|
||||
rows:dict[str, None] = {}
|
||||
trace:dict[str, set[int]] = {}
|
||||
|
|
@ -340,12 +344,12 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
|||
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
if isinstance(p, (INST, INST_RDNA4)):
|
||||
op_name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
|
||||
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
|
||||
add(name, p, width=width, idx=int("OTHER" in name), info=info)
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info) # type: ignore[union-attr]
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)):
|
||||
name = str(p.src).split('.')[1]
|
||||
if name == "VALU_SALU":
|
||||
|
|
@ -359,12 +363,13 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
|||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
|
||||
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent,
|
||||
target:str) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
|
||||
# * init decoder
|
||||
from extra.sqtt.roc import decode
|
||||
base = unwrap(p.base)
|
||||
addr_table = amd_decode(unwrap(p.lib), amdgpu_targets[p.device])
|
||||
disasm:dict[int, tuple[str, int]] = {addr+base:(str(inst), inst.size()) for addr, inst in addr_table.items()}
|
||||
addr_table = amd_decode(unwrap(p.lib), target)
|
||||
disasm:dict[int, Inst] = {addr+base:inst for addr, inst in addr_table.items()}
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# * INST waves
|
||||
|
|
@ -401,9 +406,8 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|||
for ev in profile:
|
||||
if isinstance(ev, ProfileDeviceEvent):
|
||||
device_ts_diffs[ev.device] = ev.tdiff
|
||||
if (d:=ev.device.split(":")[0]) == "AMD":
|
||||
device_decoders[d] = load_counters
|
||||
amdgpu_targets[d] = f"gfx{unwrap(ev.props)['gfx_target_version']//1000}"
|
||||
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters
|
||||
if d == "NV": device_decoders[d] = load_nv_counters
|
||||
# load device specific counters
|
||||
for fxn in device_decoders.values(): fxn(profile)
|
||||
# map events per device
|
||||
|
|
@ -431,6 +435,35 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|||
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":rel_ts(e.ts, start_ts), **e.arg} for e in markers]}).encode()
|
||||
return struct.pack("<IQII", rel_ts(unwrap(end_ts), start_ts), max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
|
||||
|
||||
# ** PMA counters
|
||||
|
||||
def load_nv_counters(profile:list) -> None:
|
||||
steps:list[dict] = []
|
||||
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
|
||||
run_number:dict[str, int] = {}
|
||||
for e in profile:
|
||||
if type(e).__name__ == "ProfilePMAEvent":
|
||||
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
|
||||
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
|
||||
data=(e.blob, sm_version[e.device])))
|
||||
if steps: ctxs.append({"name":"All Counters", "steps":steps})
|
||||
|
||||
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
|
||||
from extra.nv_pma.decode import decode, decode_tpc_id
|
||||
ret:list[ProfileEvent] = []
|
||||
rows:dict[str, None] = {}
|
||||
tpc_count:dict[int, int] = {}
|
||||
# assume every sample is 32 cycles
|
||||
cycles_per_sample = 32
|
||||
for s, tpc_id in decode(blob, sm_version):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
gpc, tpc, sm = decode_tpc_id(tpc_id)
|
||||
tpc_count[tpc_id] = (n:=tpc_count.get(tpc_id,0)) + 1
|
||||
rows.setdefault(row:=f"GPC:{gpc} TPC:{tpc} SM:{sm} WAVE:{s.wave_id}")
|
||||
ret.append(ProfileRangeEvent(row, TracingKey(s.stall_reason.name, ret=f"pc=0x{s.pc_offset:06x} active={s.active}"),
|
||||
Decimal(n*cycles_per_sample), Decimal((n+1)*cycles_per_sample)))
|
||||
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
|
||||
|
||||
# ** Assembly static analyzers
|
||||
|
||||
def get_stdout(f: Callable) -> str:
|
||||
|
|
@ -440,23 +473,12 @@ def get_stdout(f: Callable) -> str:
|
|||
except Exception: traceback.print_exc(file=buf)
|
||||
return buf.getvalue()
|
||||
|
||||
def amd_readelf(lib:bytes) -> list[dict]:
|
||||
from tinygrad.runtime.autogen import amdgpu_kd
|
||||
def get_elf_section(lib:bytes, name:str):
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
image, sections, __ = elf_loader(lib)
|
||||
rodata = next((s for s in sections if s.name == ".rodata")).content
|
||||
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(rodata))
|
||||
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
|
||||
return [{"label":f"{resource} Alloc", "value":val} for resource,val in [("VGPR", (vgpr_gran+1)*8-7), ("LDS",kd.group_segment_fixed_size),
|
||||
("Scratch", kd.private_segment_fixed_size)] if val > 0]
|
||||
return next((sh for sh in elf_loader(lib)[1] if sh.name == name))
|
||||
|
||||
def amd_decode(lib:bytes, target:str) -> dict[int, Any]: # Any is the Inst class from tinygrad.renderer.amd.dsl
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.renderer.amd import detect_format
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found in ELF"
|
||||
def amd_decode(lib:bytes, target:str) -> dict[int, Inst]:
|
||||
text = get_elf_section(lib, ".text")
|
||||
off, buf = text.header.sh_addr, text.content
|
||||
arch = "rdna3" if target.startswith("gfx11") else "rdna4" if target.startswith("gfx12") else "cdna"
|
||||
addr_table:dict[int, Inst] = {}
|
||||
|
|
@ -511,7 +533,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
|
|||
if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1})
|
||||
elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0})
|
||||
elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
|
||||
return {"data":{"blocks":blocks, "paths":paths, "pc_tokens":pc_tokens}, "src":"\n".join(lines)}
|
||||
from tinygrad.runtime.autogen import amdgpu_kd
|
||||
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(get_elf_section(lib, ".rodata").content))
|
||||
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
|
||||
return {"data":{"blocks":blocks, "paths":paths, "pc_tokens":pc_tokens}, "src":"\n".join(lines),
|
||||
"metadata":[[{"label":f"{r} Alloc", "value":v} for r,v in [("VGPR", (vgpr_gran+1)*8-7), ("LDS", kd.group_segment_fixed_size),
|
||||
("Scratch", kd.private_segment_fixed_size)] if v>0]]}
|
||||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
|
|
@ -523,11 +550,10 @@ def get_render(query:str) -> dict:
|
|||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data)), "lang":"txt"}
|
||||
if fmt == "code": return {"src":data, "lang":"cpp"}
|
||||
if fmt == "asm":
|
||||
ret:dict = {"metadata":[]}
|
||||
ret:dict = {}
|
||||
renderer, lib = data
|
||||
if renderer.device.startswith("AMD"):
|
||||
with soft_err(lambda err: ret.update(err)): ret.update(amdgpu_cfg(lib, renderer.arch))
|
||||
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
|
||||
else: ret["src"] = get_stdout(lambda: renderer.compiler.disassemble(lib))
|
||||
return ret
|
||||
if fmt == "all-pmc":
|
||||
|
|
@ -572,9 +598,9 @@ def get_render(query:str) -> dict:
|
|||
pc_to_inst = data["disasm"]
|
||||
start_pc = None
|
||||
rows:dict[int, dict] = {}
|
||||
for pc, (inst,_) in pc_to_inst.items():
|
||||
for pc, inst in pc_to_inst.items():
|
||||
if start_pc is None: start_pc = pc
|
||||
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
|
||||
rows[pc] = {"pc":pc-start_pc, "inst":str(inst), "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
|
||||
for e in w.unpack_insts():
|
||||
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
|
||||
inst["hit_count"] += 1
|
||||
|
|
@ -585,6 +611,12 @@ def get_render(query:str) -> dict:
|
|||
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
||||
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
||||
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
|
||||
if fmt == "prg-pma-pkts":
|
||||
ret = {}
|
||||
with soft_err(lambda err:ret.update(err)):
|
||||
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
||||
else: ret = {"src":"No PMA samples found."}
|
||||
return ret
|
||||
return data
|
||||
|
||||
# ** HTTP server
|
||||
|
|
@ -647,7 +679,7 @@ if __name__ == "__main__":
|
|||
st = time.perf_counter()
|
||||
print("*** viz is starting")
|
||||
|
||||
ctxs:list[dict] = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
|
||||
ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
|
||||
profile_ret = get_profile(load_pickle(args.profile, default=[]))
|
||||
|
||||
server = TCPServerWithReuse(('', PORT), Handler)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue