Compare commits

..

20 commits

Author SHA1 Message Date
George Hotz
30cc426916 order 2026-06-23 12:39:06 -07:00
George Hotz
c7fd74d523 order 2026-06-23 12:30:19 -07:00
George Hotz
4e398e3f1f lengths 2026-06-23 09:18:51 -07:00
George Hotz
5238f304c7 coalese 2026-06-23 09:04:23 -07:00
George Hotz
87e12fad81 cleanup 2026-06-23 08:43:00 -07:00
George Hotz
5eaf02c719 memory coalesing 2026-06-23 08:30:23 -07:00
George Hotz
4424af7bd8 revert that 2026-06-22 18:36:50 -07:00
George Hotz
c01d75a651 fixes 2026-06-22 18:18:19 -07:00
George Hotz
63ec8ad21d
Merge branch 'master' into codegen_try_2 2026-06-22 17:44:37 -07:00
George Hotz
cccd9c2c03 loads are grouped 2026-06-22 16:12:18 -07:00
George Hotz
303b6ba14c
Merge branch 'master' into codegen_try_2 2026-06-22 13:02:37 -07:00
George Hotz
2bf3c48c1b reduce hack remove 2026-06-22 13:01:36 -07:00
George Hotz
ba75b68c12 simple 2026-06-22 11:36:13 -07:00
George Hotz
958088cc13 Merge remote-tracking branch 'origin/master' into codegen_try_2 2026-06-22 11:25:46 -07:00
George Hotz
257ed03f57 fix store 2026-06-22 08:46:56 -07:00
George Hotz
15476572cd
Merge branch 'master' into codegen_try_2 2026-06-22 08:34:49 -07:00
George Hotz
f8949a0de1 fix exec_alu 2026-06-20 17:01:28 -07:00
George Hotz
62c6c75657 test tiny almost passes 2026-06-20 16:46:01 -07:00
George Hotz
530aed739d devec 2026-06-20 08:34:20 -07:00
George Hotz
da402953d9 new codegen, try 2 2026-06-20 08:30:38 -07:00
21 changed files with 327 additions and 360 deletions

View file

@ -327,7 +327,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)
run: | run: |
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
@ -370,6 +370,7 @@ jobs:
with: with:
key: optim key: optim
deps: testing deps: testing
pydeps: "tensorflow==2.19"
opencl: 'true' opencl: 'true'
#- name: Test Optimization Helpers #- name: Test Optimization Helpers
# run: DEBUG=1 python3 extra/optimization/test_helpers.py # run: DEBUG=1 python3 extra/optimization/test_helpers.py
@ -378,7 +379,7 @@ jobs:
- name: Test Beam Search - name: Test Beam Search
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Test MLPerf stuff - name: Test MLPerf stuff
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20 run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
- name: DEV=NULL beautiful_mnist_multigpu - name: DEV=NULL beautiful_mnist_multigpu
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
- name: Test Bert training - name: Test Bert training

View file

@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None, x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]: grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8: if not fp8:
if ASM_GEMM: if ASM_GEMM:
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
@ -47,14 +47,12 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
if MXFP8: if MXFP8:
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d) x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
if can_use_asm_gemm(x_q, w.T): if can_use_asm_gemm(x_q, w.T):
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale), out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
mx_w_stored=True).reshape(*l_shape, w.shape[0]) mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
else: else:
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1]) x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
return out, (amax_x.detach() if amax_x is not None else None), x_q return out, (amax_x.detach() if amax_x is not None else None), x_q
if x_fp8 is None: if x_fp8 is None:
@ -216,13 +214,6 @@ class FlatTransformer:
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"]) x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
amaxs.append(new_amax) amaxs.append(new_amax)
saves.extend([*s, x_w3]) saves.extend([*s, x_w3])
if FUSED_SILU_W13 and MXFP8:
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
else:
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"], out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
grad_amax_state=kwargs["grad_amax_xout"]) grad_amax_state=kwargs["grad_amax_xout"])
amaxs.append(new_amax) amaxs.append(new_amax)

View file

@ -143,17 +143,14 @@ def make_getaddr(u, device=None):
def make_ins(op, *srcs): def make_ins(op, *srcs):
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op) return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
dt = dtype or val.dtype
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
def make_cmdbuf(lin, devs, tag): def make_cmdbuf(lin, devs, tag):
blob, patches = b'', [] blob, patches = b'', []
for s in (s for ins in lin.src for s in ins.src): for s in (s for ins in lin.src for s in ins.src):
if s.op is not Ops.CONST: patches.append((len(blob), s)) if s.op is not Ops.CONST: patches.append((len(blob), s))
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0) blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag) buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches]) stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops)) def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
@ -214,11 +211,15 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call)) return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
def prep_kernargs(call:UOp, prg:UOp) -> UOp: def prep_kernargs(call:UOp, prg:UOp) -> UOp:
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device) data, info = prg.arg
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs") patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \ dtypes.uint64) for i,gi in enumerate(info.globals)] \
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)] + [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
pm_prep_runtime = PatternMatcher([ pm_prep_runtime = PatternMatcher([
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering # bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
@ -531,9 +532,9 @@ pm_resolve_patches = PatternMatcher([
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack), (UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack), (UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
# shrink on slice is shrink on base at offset # index on slice is index
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"), (UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))), lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
# getaddr # getaddr
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x) (UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
@ -541,8 +542,8 @@ pm_resolve_patches = PatternMatcher([
# folders # folders
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store), (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast() (UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store), fold_const_store),
]) + symbolic_simple ]) + symbolic_simple
# ***************** # *****************

View file

@ -1,104 +0,0 @@
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
BLK = 32
PACK = 4
LOG2E = 1.4426950408889634
@functools.cache
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
n_super = n_elems // (BLK * PACK)
sk4 = scale_K // PACK
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
nwg = n_super // THREADS_PER_WG
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
e8_out = e8_out.reshape(rows * scale_K)
si_out = si_out.reshape(sk4 * rows)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
sb = UOp.range(PACK, 2, AxisType.UNROLL)
lane = UOp.range(BLK, 3, AxisType.UNROLL)
super_idx = wg * THREADS_PER_WG + tid
idx = super_idx * (BLK * PACK) + sb * BLK + lane
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
act = w1 * sig * w3
abs_a = (act < 0.0).where(-act, act)
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
qscale = (127.0 - e8f).exp2()
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
e8u8 = e8f.cast(dtypes.uint8)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
row, col4 = super_idx // sk4, super_idx % sk4
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
rows, K = x_w1.shape
scale_K = K // BLK
n_elems = rows * K
VEC = 8
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
nwg = n_elems // (THREADS_PER_WG * VEC)
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
lane = UOp.range(VEC, 2, AxisType.UNROLL)
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
e8v = e8[idx // BLK].cast(dtypes.float)
qscale = (127.0 - e8v).exp2()
ga = grad_aq[idx].cast(dtypes.float) * qscale
w1 = x_w1[idx].cast(dtypes.float)
w3 = x_w3[idx].cast(dtypes.float)
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
s = w1 * sig
sprime = sig * (1.0 + w1 * (1.0 - sig))
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
device = x_w1.device
rows, K = x_w1.shape
axis = x_w1.axis if isinstance(device, tuple) else None
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
fxn=_custom_silu_mul_bwd_mxfp8)
return (None, None, None, gx1.uop, gx3.uop)
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
rows, K = x_w1.shape
scale_K = K // BLK
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
return fp8_out, e8_out, si_out

View file

@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
step = THREADS_PER_WG // 2 step = THREADS_PER_WG // 2
while step: while step:
active = tid < step active = tid < step
other = lds[(tid + step).valid(active)].load() other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier()) lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
step //= 2 step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])

View file

@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src) renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed" assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 1, "expected at least one load uop" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@unittest.skip("this is handled at higher level now") @unittest.skip("this is handled at higher level now")
def test_upcast_cse(self): def test_upcast_cse(self):

View file

@ -1,67 +0,0 @@
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import AdamW
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest import unittest, math
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.optimizers import Lamb from tensorflow.keras.optimizers import Lamb
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup from extra.lr_scheduler import LRSchedulerGroup
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337) np.random.seed(1337)
@ -173,5 +173,48 @@ class ExternalTestOptim(unittest.TestCase):
'warmup': steps_per_epoch * warmup_epochs, 'warmup': steps_per_epoch * warmup_epochs,
}, 1e-5, 1e-5, do_optim=False) }, 1e-5, 1e-5, do_optim=False)
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
# only tests the lr
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
net = TinyNet()
optim = AdamW([net.W], lr=0.0)
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
lr = []
for _ in range(warmup_steps+decay_steps):
lr.append(optim.lr.item())
tiny_lr.step()
# reimplemented in python
expected = []
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
np.testing.assert_allclose(lr, expected, rtol=1e-5)
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -16,17 +16,41 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
def get_gated_load_uop(valid:UOp, idx:UOp): def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, ( return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True), UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
)) ))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]): def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), ( return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True), UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
)) ))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr) def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n) def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase): class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()): def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0] load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
@ -482,16 +506,6 @@ class TestImageSimplification(unittest.TestCase):
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))", self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0") "(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
def test_drop_non_monotonic_window(self):
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
gidx0 = Special("gidx0", 1064)
r12 = Range(12, 3)
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
load = get_load_image_uop((1, 48, 4), valid, idx)
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
class TestDropTrueGate(unittest.TestCase): class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self): def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid # test that INDEX with a constant True valid gets simplified to drop the valid

View file

@ -12,18 +12,17 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_indexing, ReduceContext, pm_render, pm_make_images
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.late.gater import pm_move_gates_from_index from tinygrad.codegen.late.gater import pm_move_gates_from_index
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.late.coalese import memory_coalesing
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local, pm_horizontal_reduce, memory_coalesing
pm_index_is_shrink = PatternMatcher([ pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK # rewrite non-image INDEX to SHRINK
@ -53,10 +52,6 @@ pm_number_params = PatternMatcher([
(UPat(Ops.PARAM, name="x"), do_number_param), (UPat(Ops.PARAM, name="x"), do_number_param),
]) ])
pm_no_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
@ -86,14 +81,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic") sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand # expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") #sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals # add locals
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers") sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
# ** devectorizer (full_graph_rewrite) ** # ** devectorizer (full_graph_rewrite) **
# remove reduce # remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce") #sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims (late). this works after devectorize, but it's faster here # add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims") sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
@ -101,15 +98,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code **** # **** optimizations are done, now we lower to actual code ****
# add loads and remove invalids # add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)") #sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
# create image buffers # create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}: if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch) sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
# hreduce
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
# devectorize # devectorize
sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing, #sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
ctx=ren, name="devectorize") # ctx=ren, name="devectorize")
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
# lower the index dtype to a concrete int # lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
@ -118,23 +121,27 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# optional pre matcher # optional pre matcher
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
# floordiv+mod / dtype decomp (early) # dtypes
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
# do memory coalesing (late) # memory coalesing
sink = memory_coalesing(sink, ren) sink = memory_coalesing(sink)
# instruction selection decompositions # again
pm_decomp = pm_decomp+\ sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\ sink = graph_rewrite(sink, symbolic, name="post index symbolic")
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
# this is new style (TODO: this should all be removed) # decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# GEP/STACK stuff
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack") sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style") sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
@ -143,7 +150,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# final rules for the renderer (without sym) # final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
# this was the linearizer # this was the linearizer

View file

@ -0,0 +1,161 @@
from typing import Any
import itertools, functools
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve, graph_rewrite
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
from tinygrad.helpers import all_same, flatten, getenv
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
from tinygrad.codegen.late.devectorizer import ReduceContext
from tinygrad.uop.symbolic import pm_clean_up_group_sink
from collections import defaultdict
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise|{Ops.REDUCE}, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
# const INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
# RESHAPE+EXPAND -> STACK
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
])+pm_clean_up_group_sink
pm_horizontal_reduce = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])
# *** memory coalesing ***
def memory_coalesing(sink:UOp):
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[UOp, UOp, UOp], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
assert u.src[0].op is Ops.INDEX
buf,idx_u = u.src[0].src
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# allowed lengths
lengths = [8,4,2,1]
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op is Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -1,73 +0,0 @@
from typing import Any
import itertools
from collections import defaultdict
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import getenv
from tinygrad.renderer import Renderer
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
if getenv("DMC"): return sink
# collect
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
for u in sink.toposort():
# TODO: this should handle images too, it's just memory coalesing
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
assert u.src[0].op is Ops.INDEX
buf, idx_u = u.src[0].src
if buf.addrspace == AddrSpace.REG: continue
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
# build replacements
replacements = {}
for (op,buf,base,valid),offsets in memory.items():
# allowed lengths (copied in)
lengths = []
must_divide = True
if ctx is not None and ctx.target.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# do the grouping
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for full_grp in grouped_offsets:
while len(full_grp):
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
grp = full_grp[:length]
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
if op == Ops.STORE:
datas = []
for i,g in enumerate(grp):
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
datas.append(offsets[g][0].src[1])
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
store = idx.store(data, valid) if valid is not None else idx.store(data)
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
else:
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
for i,g in enumerate(grp):
for oo in offsets[g]:
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
full_grp = full_grp[length:]
# apply
return sink.substitute(replacements, name="memory coalesing")

View file

@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
# can drop valid if idx is out of bound when valid is False # can drop valid if idx is out of bound when valid is False
drop_stmt = [] drop_stmt = []
for i,stmt in enumerate(valid.split_uop(Ops.AND)): for stmt in valid.split_uop(Ops.AND):
if (res:=parse_valid(stmt)) is None: continue if (res:=parse_valid(stmt)) is None: continue
X, is_upper_bound, c = res X, is_upper_bound, c = res
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
drop_stmt.append(stmt) drop_stmt.append(stmt)
continue continue
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1] # if X <= c, check if it's out of bound when X = c+1
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1) # if X >= c, check if it's out of bound when X = c-1
if lo <= hi: test_value = c + 1 if is_upper_bound else c - 1
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype) for i,b in zip(idx.src, (width, height)):
for coord,b in zip(idx.src, (width, height)): if i.is_increasing():
rw = coord.substitute({X:fake}).simplify() rw = i.substitute({X:X.const_like(test_value)})
if rw.vmin >= b or rw.vmax < 0: if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt) drop_stmt.append(stmt)
break break
@ -162,8 +162,18 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# determine fold lengths # determine fold lengths
lengths = [] lengths = []
must_divide = True must_divide = True
# TODO: this belongs in coalese if ctx is not None and ctx.target.device == "DSP":
if isinstance(buf.dtype, ImageDType): lengths = [4] lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
pass
elif buf.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide # filter fold lengths that don't divide

View file

@ -101,12 +101,6 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# for Schedule, we check if the range is used in INDEX gates or WHERE gates # for Schedule, we check if the range is used in INDEX gates or WHERE gates
is_masked = k.rngs[axis] in where_gate_rngs is_masked = k.rngs[axis] in where_gate_rngs
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
# upcasting a masked global axis moves that range out of the launch grid into each work-item
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
if DEBUG >= 4: print(f"upcasting masked axis : {axis}") if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis) to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))

View file

@ -275,7 +275,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward # allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0) USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast # don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1) DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True) @dataclass(frozen=True)
class Metadata: class Metadata:

View file

@ -22,7 +22,6 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None), if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
# GPU stuff # GPU stuff

View file

@ -18,7 +18,7 @@ import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
def add_ranges_to_store(ctx, x): def add_ranges_to_store(ctx, x):
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None
assert x.src[0].shape == x.src[1].shape, "bad store shape" assert x.src[0].shape == x.src[1].shape, "bad store shape"
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape] idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs) return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)

View file

@ -454,8 +454,7 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
powers_of_two: dict[int, int] = {2**i:i for i in range(64)} powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache @functools.cache
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher: def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
# these are rewrites that make things simpler
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)] pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod # FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)) if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
@ -464,11 +463,6 @@ def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
return PatternMatcher(pat)
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())] lambda x,y: (x | y).logical_not())]
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) # rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK) if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg) elif len(arg) == 1: return UOp.const(dtypes.weakint, arg[0])
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg)) else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]): def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {} ret: dict[UOp, dict[UOp, None]] = {}
@ -305,10 +305,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \ case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END: Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce
return (self.dtype.count,) if self.dtype.count > 1 else ()
# TODO: disallow shape changing bitcast # TODO: disallow shape changing bitcast
case Ops.BITCAST: case Ops.BITCAST:
@ -473,7 +469,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs): def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs) return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs): def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): def __getitem__(self, idx):
@ -919,6 +915,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop symbolic stuff *** # *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
# TODO: for negatives it's not the largest # TODO: for negatives it's not the largest

View file

@ -192,8 +192,8 @@ spec_program = PatternMatcher([
# no more of these in programs # no more of these in programs
(UPat(Ops.GEP), lambda: False), (UPat(Ops.GEP), lambda: False),
# weakint is not allowed in programs # weakint is not allowed in programs, except on CONST and STACK
(UPat(GroupOp.All, dtypes.weakint), lambda: False), (UPat(GroupOp.All-{Ops.CONST,Ops.STACK}, dtypes.weakint), lambda: False),
# allow special SHRINK # allow special SHRINK
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True), (UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),

View file

@ -121,8 +121,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# TODO: combine this with "# rules for threefry" below # TODO: combine this with "# rules for threefry" below
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"), ((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None), lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints) lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding ** # ** constant folding **
@ -162,7 +160,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
# ** simple where folding ** # ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals # a conditional with the same results either way is a noop, also fold const conditionals
@ -170,9 +167,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), (UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# a.where(b.where(c, d), d) -> (a & b).where(c, d) # a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)), (UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# STACK on INDEX CONST (TODO: remove all the GEP crap)
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
]) ])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********