mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
20 commits
master
...
codegen_tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30cc426916 | ||
|
|
c7fd74d523 | ||
|
|
4e398e3f1f | ||
|
|
5238f304c7 | ||
|
|
87e12fad81 | ||
|
|
5eaf02c719 | ||
|
|
4424af7bd8 | ||
|
|
c01d75a651 | ||
|
|
63ec8ad21d |
||
|
|
cccd9c2c03 | ||
|
|
303b6ba14c |
||
|
|
2bf3c48c1b | ||
|
|
ba75b68c12 | ||
|
|
958088cc13 | ||
|
|
257ed03f57 | ||
|
|
15476572cd |
||
|
|
f8949a0de1 | ||
|
|
62c6c75657 | ||
|
|
530aed739d | ||
|
|
da402953d9 |
21 changed files with 327 additions and 360 deletions
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
|
|
@ -327,7 +327,7 @@ jobs:
|
||||||
llvm: 'true'
|
llvm: 'true'
|
||||||
- name: Test openpilot model kernel count and gate usage
|
- name: Test openpilot model kernel count and gate usage
|
||||||
run: |
|
run: |
|
||||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1361 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1468 ALLOWED_GATED_READ_IMAGE=10 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||||
- name: Test openpilot CL compile fp32 (test correctness)
|
- name: Test openpilot CL compile fp32 (test correctness)
|
||||||
run: |
|
run: |
|
||||||
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
DEV=CL IMAGE=1 SELFTEST=1 python examples/openpilot/compile3.py https://github.com/haraschax/filedump/raw/refs/heads/master/driving_vision_fp32.onnx
|
||||||
|
|
@ -370,6 +370,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
key: optim
|
key: optim
|
||||||
deps: testing
|
deps: testing
|
||||||
|
pydeps: "tensorflow==2.19"
|
||||||
opencl: 'true'
|
opencl: 'true'
|
||||||
#- name: Test Optimization Helpers
|
#- name: Test Optimization Helpers
|
||||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||||
|
|
@ -378,7 +379,7 @@ jobs:
|
||||||
- name: Test Beam Search
|
- name: Test Beam Search
|
||||||
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
run: DEV=CL IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||||
- name: Test MLPerf stuff
|
- name: Test MLPerf stuff
|
||||||
run: DEV=CL python -m pytest -n=auto test/external/external_test_lr_schedule.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
run: DEV=CL python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||||
- name: DEV=NULL beautiful_mnist_multigpu
|
- name: DEV=NULL beautiful_mnist_multigpu
|
||||||
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
run: DEV=NULL NULL_ALLOW_COPYOUT=1 python examples/beautiful_mnist_multigpu.py
|
||||||
- name: Test Bert training
|
- name: Test Bert training
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||||
|
|
||||||
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
|
||||||
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
x_fp8:Tensor|None=None, x_new_amax:Tensor|None=None,
|
||||||
grad_amax_state:Tensor|None=None, x_prequant_mx:tuple|None=None) -> tuple[Tensor,...]:
|
grad_amax_state:Tensor|None=None) -> tuple[Tensor,...]:
|
||||||
if not fp8:
|
if not fp8:
|
||||||
if ASM_GEMM:
|
if ASM_GEMM:
|
||||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||||
|
|
@ -47,14 +47,12 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca
|
||||||
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
|
||||||
if MXFP8:
|
if MXFP8:
|
||||||
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
from extra.gemm.cdna_asm_gemm import asm_gemm, quantize_mxfp8, mx_pack, can_use_asm_gemm, _mx_block_scale
|
||||||
if x_prequant_mx is not None: x_q, x_e8, x_si = x_prequant_mx # fused producer already quantized (2d)
|
x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
||||||
else: x_q, x_e8, x_si = quantize_mxfp8(x.reshape(-1, x.shape[-1]))
|
|
||||||
l_shape = x.shape[:-1] if x is not None else x_q.shape[:-1]
|
|
||||||
if can_use_asm_gemm(x_q, w.T):
|
if can_use_asm_gemm(x_q, w.T):
|
||||||
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
out = asm_gemm(x_q, w.T, mx=True, mx_scales=(x_si, x_e8, mx_pack(w_inv_scale), w_inv_scale),
|
||||||
mx_w_stored=True).reshape(*l_shape, w.shape[0])
|
mx_w_stored=True).reshape(*x.shape[:-1], w.shape[0])
|
||||||
else:
|
else:
|
||||||
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*l_shape, x_q.shape[-1])
|
x_phys = (x_q.cast(dtypes.bfloat16) * _mx_block_scale(x_e8)).reshape(*x.shape[:-1], x.shape[-1])
|
||||||
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
out = x_phys @ (w.cast(dtypes.bfloat16) * _mx_block_scale(w_inv_scale)).T
|
||||||
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
return out, (amax_x.detach() if amax_x is not None else None), x_q
|
||||||
if x_fp8 is None:
|
if x_fp8 is None:
|
||||||
|
|
@ -216,15 +214,8 @@ class FlatTransformer:
|
||||||
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
x_w3, new_amax, *s = matmul(inp, kwargs["w3"], amax_x=kwargs["amax_x3"], w_inv_scale=kwargs["s_3"], grad_amax_state=kwargs["grad_amax_xw3"])
|
||||||
amaxs.append(new_amax)
|
amaxs.append(new_amax)
|
||||||
saves.extend([*s, x_w3])
|
saves.extend([*s, x_w3])
|
||||||
if FUSED_SILU_W13 and MXFP8:
|
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
||||||
from extra.llama_kernels.fused_silu_mul_quantize_mxfp8 import fused_silu_mul_quantize_mxfp8
|
grad_amax_state=kwargs["grad_amax_xout"])
|
||||||
aq, ae8, asi = fused_silu_mul_quantize_mxfp8(x_w1.reshape(-1, x_w1.shape[-1]), x_w3.reshape(-1, x_w3.shape[-1]))
|
|
||||||
out, new_amax, *s = matmul(None, kwargs["w2"], x_prequant_mx=(aq, ae8, asi), amax_x=kwargs["amax_x2"],
|
|
||||||
w_inv_scale=kwargs["s_2"], grad_amax_state=kwargs["grad_amax_xout"])
|
|
||||||
out = out.reshape(*x_w1.shape[:-1], kwargs["w2"].shape[0])
|
|
||||||
else:
|
|
||||||
out, new_amax, *s = matmul(x_w1.silu() * x_w3, kwargs["w2"], amax_x=kwargs["amax_x2"], w_inv_scale=kwargs["s_2"],
|
|
||||||
grad_amax_state=kwargs["grad_amax_xout"])
|
|
||||||
amaxs.append(new_amax)
|
amaxs.append(new_amax)
|
||||||
saves.extend([*s, out])
|
saves.extend([*s, out])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -143,17 +143,14 @@ def make_getaddr(u, device=None):
|
||||||
def make_ins(op, *srcs):
|
def make_ins(op, *srcs):
|
||||||
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
return UOp(Ops.INS, dtypes.void, tuple(UOp.const(dtypes.uint32, s) if isinstance(s, int) else s.cast(dtypes.uint32) for s in srcs), op)
|
||||||
|
|
||||||
def make_patch(buf:UOp, off:sint, val:UOp, dtype=None) -> UOp:
|
|
||||||
dt = dtype or val.dtype
|
|
||||||
return UOp(Ops.SHRINK, buf.dtype.base, (buf, UOp.const(dtypes.int, off), UOp.const(dtypes.int, dt.itemsize))).bitcast(dt).store(val.cast(dt))
|
|
||||||
|
|
||||||
def make_cmdbuf(lin, devs, tag):
|
def make_cmdbuf(lin, devs, tag):
|
||||||
blob, patches = b'', []
|
blob, patches = b'', []
|
||||||
for s in (s for ins in lin.src for s in ins.src):
|
for s in (s for ins in lin.src for s in ins.src):
|
||||||
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
if s.op is not Ops.CONST: patches.append((len(blob), s))
|
||||||
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
blob += struct.pack(f'<{s.dtype.fmt}', s.arg if s.op is Ops.CONST else 0x0)
|
||||||
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
buf = UOp.new_buffer(devs, len(blob), dtypes.uint8).rtag(tag)
|
||||||
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *[make_patch(buf, off, s) for off, s in patches])
|
stores = [buf.index(UOp.const(dtypes.int, off), dtype=buf.dtype.ptr()).cast(s.dtype.ptr()).store(s) for off, s in patches]
|
||||||
|
return buf.after(buf.store(UOp(Ops.BINARY, dtypes.void, src=(), arg=blob)), *stores)
|
||||||
|
|
||||||
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
def make_mstack(uops): return uops[0] if len(uops) == 1 else UOp(Ops.MSTACK, uops[0].dtype, tuple(uops))
|
||||||
|
|
||||||
|
|
@ -214,11 +211,15 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
||||||
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
||||||
|
|
||||||
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
||||||
(data, info), dev_uop = prg.arg, UOp(Ops.DEVICE, arg=call.src[1].device)
|
data, info = prg.arg
|
||||||
buf = UOp.new_buffer(dev_uop.arg, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
patches = [(i*dtypes.uint64.itemsize, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], UOp(Ops.DEVICE, arg=call.src[1+gi].device))),
|
||||||
patches = [make_patch(buf, i*8, UOp(Ops.GETADDR, dtypes.uint64, src=(call.src[1+gi], dev_uop))) for i,gi in enumerate(info.globals)] \
|
dtypes.uint64) for i,gi in enumerate(info.globals)] \
|
||||||
+ [make_patch(buf, len(info.globals)*8 + i*4, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
+ [(len(info.globals)*dtypes.uint64.itemsize + i*dtypes.uint32.itemsize, v, dtypes.uint32) for i,v in enumerate(info.vars)]
|
||||||
return call.replace(src=(prg.replace(src=prg.src + (buf.after(*patches),), arg=(data, info)),) + call.src[1:])
|
|
||||||
|
buf = UOp.new_buffer(call.src[1].device, data.kernargs_alloc_size, dtypes.uint8).rtag("kernargs")
|
||||||
|
kernargs = buf.after(*tuple(buf.index(UOp.const(dtypes.int, o), dtype=buf.dtype.ptr()).cast(dt.ptr()).store(val.cast(dt)) for o, val, dt in patches))
|
||||||
|
|
||||||
|
return call.replace(src=(prg.replace(src=prg.src + (kernargs,), arg=(data, info)),) + call.src[1:])
|
||||||
|
|
||||||
pm_prep_runtime = PatternMatcher([
|
pm_prep_runtime = PatternMatcher([
|
||||||
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
# bind generic PROGRAM device to the call's actual dev(s), then run device-specific lowering
|
||||||
|
|
@ -531,9 +532,9 @@ pm_resolve_patches = PatternMatcher([
|
||||||
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
(UPat(GroupOp.ALU, src=[UPat(Ops.STACK, name="s"), UPat(Ops.CONST)], name="op"), push_stack),
|
||||||
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
(UPat(Ops.CAST, src=(UPat(Ops.STACK, name="s"),), name="op"), push_stack),
|
||||||
|
|
||||||
# shrink on slice is shrink on base at offset
|
# index on slice is index
|
||||||
(UPat(Ops.SHRINK, src=(UPat(Ops.SLICE, name="bv"), UPat(), UPat()), name="shr"),
|
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat()), name="idx", allow_any_len=True),
|
||||||
lambda shr, bv: shr.replace(src=(bv.src[0], shr.src[1] + bv.src[1].cast(shr.src[1].dtype), shr.src[2]))),
|
lambda idx, bv: idx.replace(src=(bv.src[0], idx.src[1] + bv.src[1].cast(idx.src[1].dtype), *idx.src[2:]))),
|
||||||
|
|
||||||
# getaddr
|
# getaddr
|
||||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), resolve_getaddr_slice), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||||
|
|
@ -541,8 +542,8 @@ pm_resolve_patches = PatternMatcher([
|
||||||
|
|
||||||
# folders
|
# folders
|
||||||
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||||
(UPat(Ops.SHRINK, src=(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf"), UPat.cvar("off"), UPat(Ops.CONST))).bitcast()
|
(UPat({Ops.BUFFER, Ops.SLICE, Ops.MSTACK}, name="buf").index(UPat.cvar("off")).or_casted().store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))),
|
||||||
.store(UPat.any(UPat.cvar("val"), UPat(Ops.STACK, name="val"))), fold_const_store),
|
fold_const_store),
|
||||||
]) + symbolic_simple
|
]) + symbolic_simple
|
||||||
|
|
||||||
# *****************
|
# *****************
|
||||||
|
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
import functools
|
|
||||||
from tinygrad import Tensor, dtypes
|
|
||||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
|
||||||
from extra.llama_kernels import FP8_MAX, THREADS_PER_WG, alloc_like
|
|
||||||
|
|
||||||
BLK = 32
|
|
||||||
PACK = 4
|
|
||||||
LOG2E = 1.4426950408889634
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _custom_silu_mul_quantize_mxfp8(fp8_out:UOp, e8_out:UOp, si_out:UOp, x_w1:UOp, x_w3:UOp) -> UOp:
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
n_elems = rows * K
|
|
||||||
n_super = n_elems // (BLK * PACK)
|
|
||||||
sk4 = scale_K // PACK
|
|
||||||
assert n_super % THREADS_PER_WG == 0, f"{n_super=} must divide over {THREADS_PER_WG=}"
|
|
||||||
nwg = n_super // THREADS_PER_WG
|
|
||||||
|
|
||||||
x_w1, x_w3 = x_w1.reshape(n_elems), x_w3.reshape(n_elems)
|
|
||||||
fp8_out = fp8_out.reshape(n_elems)
|
|
||||||
e8_out = e8_out.reshape(rows * scale_K)
|
|
||||||
si_out = si_out.reshape(sk4 * rows)
|
|
||||||
|
|
||||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
|
||||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
|
||||||
sb = UOp.range(PACK, 2, AxisType.UNROLL)
|
|
||||||
lane = UOp.range(BLK, 3, AxisType.UNROLL)
|
|
||||||
|
|
||||||
super_idx = wg * THREADS_PER_WG + tid
|
|
||||||
idx = super_idx * (BLK * PACK) + sb * BLK + lane
|
|
||||||
|
|
||||||
w1 = x_w1[idx].cast(dtypes.float)
|
|
||||||
w3 = x_w3[idx].cast(dtypes.float)
|
|
||||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
|
||||||
act = w1 * sig * w3
|
|
||||||
abs_a = (act < 0.0).where(-act, act)
|
|
||||||
blk_max = abs_a.reduce(lane, arg=Ops.MAX)
|
|
||||||
e8f = (blk_max.maximum(1e-38).log2().floor() + 127.0).maximum(0.0).minimum(254.0)
|
|
||||||
qscale = (127.0 - e8f).exp2()
|
|
||||||
scaled = (act * qscale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
|
||||||
e8u8 = e8f.cast(dtypes.uint8)
|
|
||||||
|
|
||||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
|
||||||
e8_store = e8_out.after(fp8_store)[super_idx * PACK + sb].store(e8u8)
|
|
||||||
packed = (e8u8.cast(dtypes.uint32) << (sb.cast(dtypes.uint32) * 8)).reduce(sb, arg=Ops.ADD)
|
|
||||||
row, col4 = super_idx // sk4, super_idx % sk4
|
|
||||||
si_store = si_out.after(e8_store.end(sb))[col4 * rows + row].store(packed)
|
|
||||||
return si_store.end(tid, wg).sink(arg=KernelInfo(f"silu_mul_quantize_mxfp8_{n_elems}", opts_to_apply=()))
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _custom_silu_mul_bwd_mxfp8(gx1_out:UOp, gx3_out:UOp, x_w1:UOp, x_w3:UOp, grad_aq:UOp, e8:UOp) -> UOp:
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
n_elems = rows * K
|
|
||||||
VEC = 8
|
|
||||||
assert n_elems % (THREADS_PER_WG * VEC) == 0, f"{n_elems=} must divide {THREADS_PER_WG*VEC=}"
|
|
||||||
nwg = n_elems // (THREADS_PER_WG * VEC)
|
|
||||||
x_w1, x_w3, grad_aq = x_w1.reshape(n_elems), x_w3.reshape(n_elems), grad_aq.reshape(n_elems)
|
|
||||||
gx1_out, gx3_out, e8 = gx1_out.reshape(n_elems), gx3_out.reshape(n_elems), e8.reshape(rows * scale_K)
|
|
||||||
|
|
||||||
wg = UOp.range(nwg, 0, AxisType.GLOBAL)
|
|
||||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
|
||||||
lane = UOp.range(VEC, 2, AxisType.UNROLL)
|
|
||||||
idx = (wg * THREADS_PER_WG + tid) * VEC + lane
|
|
||||||
|
|
||||||
e8v = e8[idx // BLK].cast(dtypes.float)
|
|
||||||
qscale = (127.0 - e8v).exp2()
|
|
||||||
ga = grad_aq[idx].cast(dtypes.float) * qscale
|
|
||||||
w1 = x_w1[idx].cast(dtypes.float)
|
|
||||||
w3 = x_w3[idx].cast(dtypes.float)
|
|
||||||
sig = (1.0 + (w1 * -LOG2E).exp2()).reciprocal()
|
|
||||||
s = w1 * sig
|
|
||||||
sprime = sig * (1.0 + w1 * (1.0 - sig))
|
|
||||||
gx1 = gx1_out[idx].store((ga * sprime * w3).cast(gx1_out.dtype.base))
|
|
||||||
gx3 = gx3_out.after(gx1)[idx].store((ga * s).cast(gx3_out.dtype.base))
|
|
||||||
return gx3.end(lane, tid, wg).sink(arg=KernelInfo(f"silu_mul_bwd_mxfp8_{n_elems}", opts_to_apply=()))
|
|
||||||
|
|
||||||
def _silu_mul_quantize_mxfp8_bwd(gradient:UOp, kernel:UOp):
|
|
||||||
_, e8_out, _, x_w1, x_w3 = kernel.src[1:]
|
|
||||||
device = x_w1.device
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
axis = x_w1.axis if isinstance(device, tuple) else None
|
|
||||||
gx1 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
|
||||||
gx3 = alloc_like((rows, K), dtypes.bfloat16, device, axis)
|
|
||||||
gx1, gx3, *_ = Tensor.custom_kernel(gx1, gx3, Tensor(x_w1, device=device), Tensor(x_w3, device=device),
|
|
||||||
Tensor(gradient, device=device).cast(dtypes.bfloat16), Tensor(e8_out.after(kernel), device=device),
|
|
||||||
fxn=_custom_silu_mul_bwd_mxfp8)
|
|
||||||
return (None, None, None, gx1.uop, gx3.uop)
|
|
||||||
|
|
||||||
def fused_silu_mul_quantize_mxfp8(x_w1:Tensor, x_w3:Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
|
||||||
assert x_w1.shape == x_w3.shape, f"{x_w1.shape} != {x_w3.shape}"
|
|
||||||
assert x_w1.dtype == dtypes.bfloat16 and x_w3.dtype == dtypes.bfloat16
|
|
||||||
assert x_w1.ndim == 2, f"expected 2d, got {x_w1.shape}"
|
|
||||||
from extra.gemm.cdna_asm_gemm import FP8_DTYPE
|
|
||||||
rows, K = x_w1.shape
|
|
||||||
scale_K = K // BLK
|
|
||||||
axis = x_w1.uop.axis if isinstance(x_w1.device, tuple) else None
|
|
||||||
fp8_out = alloc_like((rows, K), FP8_DTYPE, x_w1.device, axis)
|
|
||||||
e8_out = alloc_like((rows, scale_K), dtypes.uint8, x_w1.device, axis)
|
|
||||||
si_out = alloc_like((scale_K // PACK, rows), dtypes.uint32, x_w1.device, None if axis is None else (1 if axis == 0 else 0))
|
|
||||||
fp8_out, e8_out, si_out, *_ = Tensor.custom_kernel(fp8_out, e8_out, si_out, x_w1, x_w3,
|
|
||||||
fxn=_custom_silu_mul_quantize_mxfp8, grad_fxn=_silu_mul_quantize_mxfp8_bwd)
|
|
||||||
return fp8_out, e8_out, si_out
|
|
||||||
|
|
@ -42,8 +42,8 @@ def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_st
|
||||||
step = THREADS_PER_WG // 2
|
step = THREADS_PER_WG // 2
|
||||||
while step:
|
while step:
|
||||||
active = tid < step
|
active = tid < step
|
||||||
other = lds[(tid + step).valid(active)].load()
|
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||||
lds = lds.after(lds[tid.valid(active)].store(lds[tid].maximum(other)).barrier())
|
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||||
step //= 2
|
step //= 2
|
||||||
|
|
||||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||||
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
||||||
assert num_loads <= 4, "more load uops than needed"
|
assert num_loads <= 4, "more load uops than needed"
|
||||||
assert num_loads >= 1, "expected at least one load uop"
|
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||||
|
|
||||||
@unittest.skip("this is handled at higher level now")
|
@unittest.skip("this is handled at higher level now")
|
||||||
def test_upcast_cse(self):
|
def test_upcast_cse(self):
|
||||||
|
|
|
||||||
67
test/external/external_test_lr_schedule.py
vendored
67
test/external/external_test_lr_schedule.py
vendored
|
|
@ -1,67 +0,0 @@
|
||||||
import unittest, math
|
|
||||||
import numpy as np
|
|
||||||
from tinygrad.tensor import Tensor
|
|
||||||
from tinygrad.nn.optim import AdamW
|
|
||||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
|
|
||||||
|
|
||||||
np.random.seed(1337)
|
|
||||||
x_init = np.random.randn(1,4).astype(np.float32)
|
|
||||||
W_init = np.random.randn(4,4).astype(np.float32)
|
|
||||||
m_init = np.random.randn(1,4).astype(np.float32)
|
|
||||||
|
|
||||||
class TinyNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.x = Tensor(x_init.copy())
|
|
||||||
self.W = Tensor(W_init.copy())
|
|
||||||
self.m = Tensor(m_init.copy())
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
out = self.x.matmul(self.W).relu()
|
|
||||||
out = out.log_softmax(1)
|
|
||||||
out = out.mul(self.m).add(self.m).sum()
|
|
||||||
return out
|
|
||||||
|
|
||||||
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
|
|
||||||
# only tests the lr
|
|
||||||
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
|
|
||||||
net = TinyNet()
|
|
||||||
optim = AdamW([net.W], lr=0.0)
|
|
||||||
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
|
|
||||||
lr = []
|
|
||||||
for _ in range(warmup_steps+decay_steps):
|
|
||||||
lr.append(optim.lr.item())
|
|
||||||
tiny_lr.step()
|
|
||||||
# reimplemented in python
|
|
||||||
expected = []
|
|
||||||
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
|
|
||||||
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
|
|
||||||
np.testing.assert_allclose(lr, expected, rtol=1e-5)
|
|
||||||
|
|
||||||
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
|
|
||||||
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
|
|
||||||
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
|
|
||||||
|
|
||||||
class TestLambdaLRLinearWarmup(unittest.TestCase):
|
|
||||||
def test_linear_lr_warmup(self):
|
|
||||||
BS, BASE_LR = 304, 2.5e-7
|
|
||||||
lr = BS * BASE_LR
|
|
||||||
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
|
|
||||||
optimizer = AdamW([Tensor([1.])])
|
|
||||||
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
|
||||||
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
|
|
||||||
lrs = {}
|
|
||||||
|
|
||||||
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
|
|
||||||
for i in range(1200):
|
|
||||||
lr_scheduler.step()
|
|
||||||
if i in {0, 499, 998, 999, 1000, 1199}:
|
|
||||||
lrs[i] = optimizer.lr.item()
|
|
||||||
|
|
||||||
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
|
|
||||||
np.testing.assert_equal(lrs[999], lrs[1000])
|
|
||||||
np.testing.assert_equal(lrs[999], lrs[1199])
|
|
||||||
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
|
|
||||||
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
49
test/external/external_test_optim.py
vendored
49
test/external/external_test_optim.py
vendored
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import unittest
|
import unittest, math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.optimizers import Lamb
|
from tensorflow.keras.optimizers import Lamb
|
||||||
|
|
@ -7,11 +7,11 @@ from tensorflow.python.ops import math_ops
|
||||||
from extra.lr_scheduler import LRSchedulerGroup
|
from extra.lr_scheduler import LRSchedulerGroup
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
|
||||||
|
|
||||||
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
||||||
|
|
||||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
|
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
|
||||||
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
|
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
|
||||||
|
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
|
|
@ -173,5 +173,48 @@ class ExternalTestOptim(unittest.TestCase):
|
||||||
'warmup': steps_per_epoch * warmup_epochs,
|
'warmup': steps_per_epoch * warmup_epochs,
|
||||||
}, 1e-5, 1e-5, do_optim=False)
|
}, 1e-5, 1e-5, do_optim=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
|
||||||
|
# only tests the lr
|
||||||
|
def _test_lr(self, base_lr, end_lr, warmup_steps, decay_steps):
|
||||||
|
net = TinyNet()
|
||||||
|
optim = AdamW([net.W], lr=0.0)
|
||||||
|
tiny_lr = CosineAnnealingLRWithWarmup(optim, base_lr, end_lr, warmup_steps, decay_steps)
|
||||||
|
lr = []
|
||||||
|
for _ in range(warmup_steps+decay_steps):
|
||||||
|
lr.append(optim.lr.item())
|
||||||
|
tiny_lr.step()
|
||||||
|
# reimplemented in python
|
||||||
|
expected = []
|
||||||
|
for i in range(warmup_steps): expected.append((i+1)/warmup_steps*base_lr)
|
||||||
|
for i in range(decay_steps): expected.append(end_lr+(base_lr-end_lr)*(1+math.cos((i+1)/decay_steps*math.pi))/2)
|
||||||
|
np.testing.assert_allclose(lr, expected, rtol=1e-5)
|
||||||
|
|
||||||
|
def test_lr_0(self): self._test_lr(3e-4, 8e-5, 3, 5)
|
||||||
|
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
|
||||||
|
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
|
||||||
|
|
||||||
|
class TestLambdaLRLinearWarmup(unittest.TestCase):
|
||||||
|
def test_linear_lr_warmup(self):
|
||||||
|
BS, BASE_LR = 304, 2.5e-7
|
||||||
|
lr = BS * BASE_LR
|
||||||
|
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
|
||||||
|
optimizer = AdamW([Tensor([1.])])
|
||||||
|
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
||||||
|
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
|
||||||
|
lrs = {}
|
||||||
|
|
||||||
|
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
|
||||||
|
for i in range(1200):
|
||||||
|
lr_scheduler.step()
|
||||||
|
if i in {0, 499, 998, 999, 1000, 1199}:
|
||||||
|
lrs[i] = optimizer.lr.item()
|
||||||
|
|
||||||
|
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
|
||||||
|
np.testing.assert_equal(lrs[999], lrs[1000])
|
||||||
|
np.testing.assert_equal(lrs[999], lrs[1199])
|
||||||
|
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
|
||||||
|
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -16,17 +16,41 @@ def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move
|
||||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||||
return UOp(Ops.LOAD, dtypes.float, (
|
return UOp(Ops.LOAD, dtypes.float, (
|
||||||
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
|
||||||
|
UOp.const(dtypes.float, 0.0)
|
||||||
))
|
))
|
||||||
|
|
||||||
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
|
||||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||||
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
|
||||||
|
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
|
||||||
))
|
))
|
||||||
|
|
||||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
|
||||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||||
def Range(n, nmax): return UOp.range(nmax, n)
|
def Range(n, nmax): return UOp.range(nmax, n)
|
||||||
|
|
||||||
|
class TestHelpers(unittest.TestCase):
|
||||||
|
def test_is_increasing(self):
|
||||||
|
idx1 = Special("idx1", 32)
|
||||||
|
idx2 = Special("idx2", 64)
|
||||||
|
ridx0 = Variable("ridx0", 0, 5)
|
||||||
|
ridx1 = Variable("ridx1", 0, 2)
|
||||||
|
ridx2 = Variable("ridx2", 0, 2)
|
||||||
|
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
|
||||||
|
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
|
||||||
|
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
|
||||||
|
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
|
||||||
|
f3 = (idx2*2)+ridx1+(-1)
|
||||||
|
|
||||||
|
self.assertFalse(f0.is_increasing())
|
||||||
|
self.assertTrue(f1.is_increasing())
|
||||||
|
self.assertTrue(f2.is_increasing())
|
||||||
|
self.assertTrue(f3.is_increasing())
|
||||||
|
|
||||||
|
rng = UOp.range(5, 2)
|
||||||
|
self.assertTrue(rng.is_increasing())
|
||||||
|
self.assertTrue((rng+2).is_increasing())
|
||||||
|
|
||||||
class TestValidIdxSimplification(unittest.TestCase):
|
class TestValidIdxSimplification(unittest.TestCase):
|
||||||
def check(self, load, sidx, svalid, extra=()):
|
def check(self, load, sidx, svalid, extra=()):
|
||||||
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
|
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
|
||||||
|
|
@ -482,16 +506,6 @@ class TestImageSimplification(unittest.TestCase):
|
||||||
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
|
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
|
||||||
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
|
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
|
||||||
|
|
||||||
def test_drop_non_monotonic_window(self):
|
|
||||||
# two-sided window valid (645 <= gidx0 < 653) on a non-monotonic index (lane split via %4 and //4):
|
|
||||||
# gidx0 outside the window pushes idx_x out of the (1, 48) image, so the gate is dropped
|
|
||||||
gidx0 = Special("gidx0", 1064)
|
|
||||||
r12 = Range(12, 3)
|
|
||||||
valid = ((gidx0 < 645).ne(True)) & (gidx0 < 653)
|
|
||||||
idx = (r12*4 + (gidx0+3)%4 + (gidx0+3)//4*24 - 3888, UOp.const(dtypes.weakint, 0))
|
|
||||||
load = get_load_image_uop((1, 48, 4), valid, idx)
|
|
||||||
self.check(load, None, "(r12*4+(gidx0+3)%4+(gidx0+3)//4*24+-3888)", "0")
|
|
||||||
|
|
||||||
class TestDropTrueGate(unittest.TestCase):
|
class TestDropTrueGate(unittest.TestCase):
|
||||||
def test_drop_true_gate_on_index(self):
|
def test_drop_true_gate_on_index(self):
|
||||||
# test that INDEX with a constant True valid gets simplified to drop the valid
|
# test that INDEX with a constant True valid gets simplified to drop the valid
|
||||||
|
|
|
||||||
|
|
@ -12,18 +12,17 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||||
|
|
||||||
# import all pattern matchers here
|
# import all pattern matchers here
|
||||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
|
||||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps, get_simplifying_rewrite_patterns
|
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
from tinygrad.codegen.late.devectorizer import load_store_indexing, ReduceContext, pm_render, pm_make_images
|
||||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
|
||||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
|
||||||
from tinygrad.codegen.opt.postrange import apply_opts
|
from tinygrad.codegen.opt.postrange import apply_opts
|
||||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||||
from tinygrad.codegen.late.coalese import memory_coalesing
|
|
||||||
|
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local, pm_horizontal_reduce, memory_coalesing
|
||||||
|
|
||||||
pm_index_is_shrink = PatternMatcher([
|
pm_index_is_shrink = PatternMatcher([
|
||||||
# rewrite non-image INDEX to SHRINK
|
# rewrite non-image INDEX to SHRINK
|
||||||
|
|
@ -53,10 +52,6 @@ pm_number_params = PatternMatcher([
|
||||||
(UPat(Ops.PARAM, name="x"), do_number_param),
|
(UPat(Ops.PARAM, name="x"), do_number_param),
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_no_weakints = PatternMatcher([
|
|
||||||
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int))
|
|
||||||
])
|
|
||||||
|
|
||||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||||
if DEBUG >= 5: print(pyrender(ast))
|
if DEBUG >= 5: print(pyrender(ast))
|
||||||
|
|
@ -86,14 +81,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
||||||
|
|
||||||
# expand
|
# expand
|
||||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||||
|
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
|
||||||
|
|
||||||
# add locals
|
# add locals
|
||||||
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||||
|
|
||||||
# ** devectorizer (full_graph_rewrite) **
|
# ** devectorizer (full_graph_rewrite) **
|
||||||
# remove reduce
|
# remove reduce
|
||||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||||
|
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||||
|
|
||||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||||
|
|
@ -101,15 +98,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# **** optimizations are done, now we lower to actual code ****
|
# **** optimizations are done, now we lower to actual code ****
|
||||||
|
|
||||||
# add loads and remove invalids
|
# add loads and remove invalids
|
||||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
#sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||||
|
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
|
||||||
|
|
||||||
# create image buffers
|
# create image buffers
|
||||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||||
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
||||||
|
|
||||||
|
# hreduce
|
||||||
|
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
|
||||||
|
|
||||||
# devectorize
|
# devectorize
|
||||||
sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
||||||
ctx=ren, name="devectorize")
|
# ctx=ren, name="devectorize")
|
||||||
|
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||||
|
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
|
||||||
|
|
||||||
# lower the index dtype to a concrete int
|
# lower the index dtype to a concrete int
|
||||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||||
|
|
@ -118,23 +121,27 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# optional pre matcher
|
# optional pre matcher
|
||||||
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
||||||
|
|
||||||
# floordiv+mod / dtype decomp (early)
|
# dtypes
|
||||||
supported_ops = tuple(ren.code_for_op.keys())
|
|
||||||
pm_decomp = symbolic_simple+get_simplifying_rewrite_patterns(supported_ops)
|
|
||||||
sink = graph_rewrite(sink, pm_decomp, name="early decompositions")
|
|
||||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||||
|
|
||||||
# do memory coalesing (late)
|
# memory coalesing
|
||||||
sink = memory_coalesing(sink, ren)
|
sink = memory_coalesing(sink)
|
||||||
|
|
||||||
# instruction selection decompositions
|
# again
|
||||||
pm_decomp = pm_decomp+\
|
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||||
get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))+\
|
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||||
get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
|
||||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="late decompositions")
|
|
||||||
|
|
||||||
# this is new style (TODO: this should all be removed)
|
# decompositions
|
||||||
|
supported_ops = tuple(ren.code_for_op.keys())
|
||||||
|
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||||
|
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||||
|
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
||||||
|
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||||
|
|
||||||
|
# GEP/STACK stuff
|
||||||
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
|
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
|
||||||
|
|
||||||
|
# this is new style
|
||||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||||
|
|
||||||
|
|
@ -143,7 +150,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
|
|
||||||
# final rules for the renderer (without sym)
|
# final rules for the renderer (without sym)
|
||||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||||
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends+pm_no_weakints
|
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends
|
||||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||||
|
|
||||||
# this was the linearizer
|
# this was the linearizer
|
||||||
|
|
|
||||||
161
tinygrad/codegen/codegen2.py
Normal file
161
tinygrad/codegen/codegen2.py
Normal file
|
|
@ -0,0 +1,161 @@
|
||||||
|
from typing import Any
|
||||||
|
import itertools, functools
|
||||||
|
from tinygrad.schedule.rangeify import pm_mops
|
||||||
|
from tinygrad.codegen.simplify import pm_flatten_range
|
||||||
|
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve, graph_rewrite
|
||||||
|
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
|
||||||
|
from tinygrad.helpers import all_same, flatten, getenv
|
||||||
|
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
|
||||||
|
from tinygrad.codegen.late.devectorizer import ReduceContext
|
||||||
|
from tinygrad.uop.symbolic import pm_clean_up_group_sink
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||||
|
pm_move_regs = PatternMatcher([
|
||||||
|
# BITCAST?
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.REDUCE}, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
|
||||||
|
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||||
|
])
|
||||||
|
|
||||||
|
pm_lower_weakints = PatternMatcher([
|
||||||
|
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def build_range_map(ctx, sink:UOp):
|
||||||
|
for x in sink.toposort():
|
||||||
|
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
|
||||||
|
ctx[x.arg[0]] = len(ctx)
|
||||||
|
|
||||||
|
def fix_reduce(ctx, r:UOp):
|
||||||
|
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
|
||||||
|
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
|
||||||
|
|
||||||
|
expander2 = PatternMatcher([
|
||||||
|
(UPat(Ops.SINK, name="sink"), build_range_map),
|
||||||
|
(UPat(Ops.REDUCE, name="r"), fix_reduce),
|
||||||
|
(UPat(Ops.RANGE, name="r"),
|
||||||
|
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
|
||||||
|
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||||
|
])+pm_flatten_range
|
||||||
|
|
||||||
|
def broadcast_binary(x:UOp):
|
||||||
|
shapes = [u.shape for u in x.src]
|
||||||
|
if all_same(shapes): return None
|
||||||
|
shaped_aligned = _align_left(*shapes)
|
||||||
|
broadcasted = _broadcast_shape(*shapes)
|
||||||
|
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||||
|
return x.replace(src=tuple(src_reshaped))
|
||||||
|
|
||||||
|
unbroadcast = PatternMatcher([
|
||||||
|
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||||
|
])
|
||||||
|
|
||||||
|
def do_devectorize(b:UOp):
|
||||||
|
if b.shape == (): return None
|
||||||
|
# broadcasting needs to be already unpacked
|
||||||
|
if not all_same([x.shape for x in b.src]): return None
|
||||||
|
src = []
|
||||||
|
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||||
|
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||||
|
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||||
|
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
|
||||||
|
|
||||||
|
devectorizer2 = pm_mops+PatternMatcher([
|
||||||
|
# unpack broadcasting
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
|
||||||
|
# const INDEX into STACK is src
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||||
|
# stacked INDEX is many INDEX
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
|
||||||
|
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
|
||||||
|
# INDEX into RESHAPE moves the RESHAPE
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
|
||||||
|
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
|
||||||
|
# RESHAPE a void is removed (hack for AFTER)
|
||||||
|
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
|
||||||
|
# reshape of a single element shaped value to scalar is an index
|
||||||
|
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
|
||||||
|
# INDEX without src is nothing
|
||||||
|
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
|
||||||
|
# RESHAPE+EXPAND -> STACK
|
||||||
|
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
|
||||||
|
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
|
||||||
|
])
|
||||||
|
|
||||||
|
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||||
|
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
|
||||||
|
ctx.acc_num += 1
|
||||||
|
topo = r.src[0].toposort()
|
||||||
|
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
|
||||||
|
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
|
||||||
|
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
|
||||||
|
acc_initted = acc.after(acc_init, *r.src[1:])
|
||||||
|
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
|
||||||
|
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
|
||||||
|
return acc.after(acc_out)
|
||||||
|
|
||||||
|
def expand_horizontal_reduce(r:UOp):
|
||||||
|
axes = r.arg[1]
|
||||||
|
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
|
||||||
|
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
|
||||||
|
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
|
||||||
|
|
||||||
|
pm_reduce_local = PatternMatcher([
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
|
||||||
|
])+pm_clean_up_group_sink
|
||||||
|
|
||||||
|
pm_horizontal_reduce = PatternMatcher([
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
|
||||||
|
])
|
||||||
|
|
||||||
|
# *** memory coalesing ***
|
||||||
|
|
||||||
|
def memory_coalesing(sink:UOp):
|
||||||
|
if getenv("DMC"): return sink
|
||||||
|
|
||||||
|
# collect
|
||||||
|
memory: defaultdict[tuple[UOp, UOp, UOp], dict[int, list[UOp]]] = defaultdict(dict)
|
||||||
|
for u in sink.toposort():
|
||||||
|
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
|
||||||
|
assert u.src[0].op is Ops.INDEX
|
||||||
|
buf,idx_u = u.src[0].src
|
||||||
|
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
||||||
|
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
||||||
|
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||||
|
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||||
|
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||||
|
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||||
|
else: root_src, arg = idx, 0
|
||||||
|
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
||||||
|
|
||||||
|
# allowed lengths
|
||||||
|
lengths = [8,4,2,1]
|
||||||
|
|
||||||
|
# build replacements
|
||||||
|
replacements = {}
|
||||||
|
for (op,buf,base,valid),offsets in memory.items():
|
||||||
|
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||||
|
for full_grp in grouped_offsets:
|
||||||
|
while len(full_grp):
|
||||||
|
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
|
||||||
|
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
|
||||||
|
grp = full_grp[:length]
|
||||||
|
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||||
|
if op is Ops.STORE:
|
||||||
|
datas = []
|
||||||
|
for i,g in enumerate(grp):
|
||||||
|
assert len(offsets[g]) == 1
|
||||||
|
datas.append(offsets[g][0].src[1])
|
||||||
|
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
||||||
|
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
||||||
|
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
||||||
|
else:
|
||||||
|
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
||||||
|
for i,g in enumerate(grp):
|
||||||
|
for oo in offsets[g]:
|
||||||
|
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
||||||
|
full_grp = full_grp[length:]
|
||||||
|
|
||||||
|
# apply
|
||||||
|
return sink.substitute(replacements, name="memory coalesing")
|
||||||
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
import itertools
|
|
||||||
from collections import defaultdict
|
|
||||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid, ImageDType
|
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
|
||||||
from tinygrad.helpers import getenv
|
|
||||||
from tinygrad.renderer import Renderer
|
|
||||||
|
|
||||||
def memory_coalesing(sink:UOp, ctx:Renderer) -> UOp:
|
|
||||||
if getenv("DMC"): return sink
|
|
||||||
|
|
||||||
# collect
|
|
||||||
memory: defaultdict[tuple[Ops, UOp, Any, Any], dict[int, list[UOp]]] = defaultdict(dict)
|
|
||||||
for u in sink.toposort():
|
|
||||||
# TODO: this should handle images too, it's just memory coalesing
|
|
||||||
if u.op in {Ops.LOAD, Ops.STORE} and not isinstance(u.src[0].src[0].dtype, ImageDType):
|
|
||||||
assert len(u.src) == (2 if u.op is Ops.STORE else 1), "memory coalesing does not support gated loads/stores"
|
|
||||||
assert u.src[0].op is Ops.INDEX
|
|
||||||
buf, idx_u = u.src[0].src
|
|
||||||
if buf.addrspace == AddrSpace.REG: continue
|
|
||||||
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
|
||||||
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
|
||||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
|
||||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
|
||||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
|
||||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
|
||||||
else: root_src, arg = idx, 0
|
|
||||||
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
|
||||||
|
|
||||||
# build replacements
|
|
||||||
replacements = {}
|
|
||||||
for (op,buf,base,valid),offsets in memory.items():
|
|
||||||
# allowed lengths (copied in)
|
|
||||||
lengths = []
|
|
||||||
must_divide = True
|
|
||||||
if ctx is not None and ctx.target.device == "DSP":
|
|
||||||
lengths = [128,64,32,16,8,4]
|
|
||||||
must_divide = False
|
|
||||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
|
||||||
pass
|
|
||||||
elif buf.addrspace == AddrSpace.REG:
|
|
||||||
pass
|
|
||||||
elif isinstance(buf.dtype, ImageDType):
|
|
||||||
lengths = [4]
|
|
||||||
elif ctx is not None and ctx.supports_float4:
|
|
||||||
# TODO: a better way to get this than ctx
|
|
||||||
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
|
||||||
lengths.append(1) # worst case, it's not folded
|
|
||||||
# do the grouping
|
|
||||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
|
||||||
for full_grp in grouped_offsets:
|
|
||||||
while len(full_grp):
|
|
||||||
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.int, full_grp[0])
|
|
||||||
length = [l for l in lengths if l <= len(full_grp) and (not must_divide or offset.divides(l) is not None)][0]
|
|
||||||
grp = full_grp[:length]
|
|
||||||
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
|
||||||
if op == Ops.STORE:
|
|
||||||
datas = []
|
|
||||||
for i,g in enumerate(grp):
|
|
||||||
assert len(offsets[g]) == 1, f"attempting multiple stores: {len(offsets[g])}"
|
|
||||||
datas.append(offsets[g][0].src[1])
|
|
||||||
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
|
||||||
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
|
||||||
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
|
||||||
else:
|
|
||||||
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
|
||||||
for i,g in enumerate(grp):
|
|
||||||
for oo in offsets[g]:
|
|
||||||
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
|
||||||
full_grp = full_grp[length:]
|
|
||||||
|
|
||||||
# apply
|
|
||||||
return sink.substitute(replacements, name="memory coalesing")
|
|
||||||
|
|
@ -14,7 +14,7 @@ from tinygrad.renderer import Renderer
|
||||||
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
||||||
# can drop valid if idx is out of bound when valid is False
|
# can drop valid if idx is out of bound when valid is False
|
||||||
drop_stmt = []
|
drop_stmt = []
|
||||||
for i,stmt in enumerate(valid.split_uop(Ops.AND)):
|
for stmt in valid.split_uop(Ops.AND):
|
||||||
if (res:=parse_valid(stmt)) is None: continue
|
if (res:=parse_valid(stmt)) is None: continue
|
||||||
X, is_upper_bound, c = res
|
X, is_upper_bound, c = res
|
||||||
|
|
||||||
|
|
@ -25,12 +25,12 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
||||||
drop_stmt.append(stmt)
|
drop_stmt.append(stmt)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if idx is out of bound when X is on the wrong side of the bound: X in [c+1, vmax] or [vmin, c-1]
|
# if X <= c, check if it's out of bound when X = c+1
|
||||||
lo, hi = (c + 1, X.vmax) if is_upper_bound else (X.vmin, c - 1)
|
# if X >= c, check if it's out of bound when X = c-1
|
||||||
if lo <= hi:
|
test_value = c + 1 if is_upper_bound else c - 1
|
||||||
fake = UOp.variable(f"fake{i}", lo, hi, X.dtype)
|
for i,b in zip(idx.src, (width, height)):
|
||||||
for coord,b in zip(idx.src, (width, height)):
|
if i.is_increasing():
|
||||||
rw = coord.substitute({X:fake}).simplify()
|
rw = i.substitute({X:X.const_like(test_value)})
|
||||||
if rw.vmin >= b or rw.vmax < 0:
|
if rw.vmin >= b or rw.vmax < 0:
|
||||||
drop_stmt.append(stmt)
|
drop_stmt.append(stmt)
|
||||||
break
|
break
|
||||||
|
|
@ -162,8 +162,18 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||||
# determine fold lengths
|
# determine fold lengths
|
||||||
lengths = []
|
lengths = []
|
||||||
must_divide = True
|
must_divide = True
|
||||||
# TODO: this belongs in coalese
|
if ctx is not None and ctx.target.device == "DSP":
|
||||||
if isinstance(buf.dtype, ImageDType): lengths = [4]
|
lengths = [128,64,32,16,8,4]
|
||||||
|
must_divide = False
|
||||||
|
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||||
|
pass
|
||||||
|
elif buf.addrspace == AddrSpace.REG:
|
||||||
|
pass
|
||||||
|
elif isinstance(buf.dtype, ImageDType):
|
||||||
|
lengths = [4]
|
||||||
|
elif ctx is not None and ctx.supports_float4:
|
||||||
|
# TODO: a better way to get this than ctx
|
||||||
|
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
|
||||||
lengths.append(1) # worst case, it's not folded
|
lengths.append(1) # worst case, it's not folded
|
||||||
|
|
||||||
# filter fold lengths that don't divide
|
# filter fold lengths that don't divide
|
||||||
|
|
|
||||||
|
|
@ -101,12 +101,6 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
||||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||||
is_masked = k.rngs[axis] in where_gate_rngs
|
is_masked = k.rngs[axis] in where_gate_rngs
|
||||||
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||||
# upcasting a masked global axis moves that range out of the launch grid into each work-item
|
|
||||||
# under IMAGE, skip the upcast unless enough global work-items remain after it to hide memory latency
|
|
||||||
if IMAGE and k.axis_types[axis] is AxisType.GLOBAL:
|
|
||||||
global_upcast = prod(k.full_shape[i] for i in to_upcast if k.axis_types[i] is AxisType.GLOBAL) * k.full_shape[axis]
|
|
||||||
global_items_after = prod(k.full_shape[i] for i in k.axes_of(AxisType.GLOBAL)) // global_upcast
|
|
||||||
if resolve(global_items_after < getenv("OCCUPANCY_FLOOR", 4096), False): continue
|
|
||||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||||
to_upcast.append(axis)
|
to_upcast.append(axis)
|
||||||
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||||
|
|
|
||||||
|
|
@ -275,7 +275,7 @@ SCACHE = ContextVar("SCACHE", 1)
|
||||||
# allow use of atomics for embedding backward
|
# allow use of atomics for embedding backward
|
||||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||||
# don't allow broadcast
|
# don't allow broadcast
|
||||||
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
|
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Metadata:
|
class Metadata:
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ base_rewrite = PatternMatcher([
|
||||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_type(x)})" \
|
||||||
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
|
if x.max_numel() > 1 and x.addrspace is AddrSpace.REG else None),
|
||||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x, ctx[x.src[0]])})"),
|
||||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: ctx[x.src[0]] if x.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else None),
|
|
||||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"__builtin_bit_cast({ctx.render_type(x)}, ({ctx.render_type(x.src[0])})({ctx[x.src[0]]}))"),
|
||||||
|
|
||||||
# GPU stuff
|
# GPU stuff
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import sys
|
||||||
sys.setrecursionlimit(10000)
|
sys.setrecursionlimit(10000)
|
||||||
|
|
||||||
def add_ranges_to_store(ctx, x):
|
def add_ranges_to_store(ctx, x):
|
||||||
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == () or x.src[0].max_numel() == x.src[1].max_numel() == 1: return None
|
if x.src[0]._shape is None or x.src[1]._shape is None or x.src[0].shape == (): return None
|
||||||
assert x.src[0].shape == x.src[1].shape, "bad store shape"
|
assert x.src[0].shape == x.src[1].shape, "bad store shape"
|
||||||
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
|
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
|
||||||
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
|
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
|
||||||
|
|
|
||||||
|
|
@ -454,8 +454,7 @@ def floormod_to_mod(a:UOp, b:UOp) -> UOp:
|
||||||
|
|
||||||
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
|
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
|
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
||||||
# these are rewrites that make things simpler
|
|
||||||
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
|
pat: list[tuple[UPat, Callable]] = [(UPat.var("a")//UPat.var("b"), floordiv_to_idiv)]
|
||||||
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
|
# FLOORMOD by 2**y -> x & (2**y-1) (correct floor mod for any sign in two's complement); fires before floormod_to_mod
|
||||||
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
|
if Ops.AND in ops: pat.append((UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None))
|
||||||
|
|
@ -464,11 +463,6 @@ def get_simplifying_rewrite_patterns(ops:tuple[Ops, ...]) -> PatternMatcher:
|
||||||
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
||||||
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||||
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
||||||
return PatternMatcher(pat)
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
|
|
||||||
pat: list[tuple[UPat, Callable]] = []
|
|
||||||
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
||||||
lambda x,y: (x | y).logical_not())]
|
lambda x,y: (x | y).logical_not())]
|
||||||
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
# rewrite MUL/CDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
|
|
||||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||||
if len(arg) == 0: return UOp(Ops.STACK)
|
if len(arg) == 0: return UOp(Ops.STACK)
|
||||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
elif len(arg) == 1: return UOp.const(dtypes.weakint, arg[0])
|
||||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||||
|
|
||||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||||
ret: dict[UOp, dict[UOp, None]] = {}
|
ret: dict[UOp, dict[UOp, None]] = {}
|
||||||
|
|
@ -305,10 +305,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
||||||
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
# REDUCE with empty axis is passthrough (lowered form)
|
|
||||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
|
||||||
# these can mismatch if there's a horizonal reduce
|
|
||||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
|
||||||
|
|
||||||
# TODO: disallow shape changing bitcast
|
# TODO: disallow shape changing bitcast
|
||||||
case Ops.BITCAST:
|
case Ops.BITCAST:
|
||||||
|
|
@ -473,7 +469,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||||
def vectorize(self, *srcs):
|
def vectorize(self, *srcs):
|
||||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
|
||||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
|
@ -919,6 +915,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
|
|
||||||
# *** uop symbolic stuff ***
|
# *** uop symbolic stuff ***
|
||||||
|
|
||||||
|
def is_increasing(self:UOp) -> bool:
|
||||||
|
# is f a monotonically increasing function regards its input
|
||||||
|
if self.op in GroupOp.Irreducible: return True
|
||||||
|
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
||||||
|
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
||||||
|
return False # False if not sure
|
||||||
def const_factor(self) -> int:
|
def const_factor(self) -> int:
|
||||||
"""largest known int that divides self"""
|
"""largest known int that divides self"""
|
||||||
# TODO: for negatives it's not the largest
|
# TODO: for negatives it's not the largest
|
||||||
|
|
|
||||||
|
|
@ -192,8 +192,8 @@ spec_program = PatternMatcher([
|
||||||
# no more of these in programs
|
# no more of these in programs
|
||||||
(UPat(Ops.GEP), lambda: False),
|
(UPat(Ops.GEP), lambda: False),
|
||||||
|
|
||||||
# weakint is not allowed in programs
|
# weakint is not allowed in programs, except on CONST and STACK
|
||||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
(UPat(GroupOp.All-{Ops.CONST,Ops.STACK}, dtypes.weakint), lambda: False),
|
||||||
|
|
||||||
# allow special SHRINK
|
# allow special SHRINK
|
||||||
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
# TODO: combine this with "# rules for threefry" below
|
# TODO: combine this with "# rules for threefry" below
|
||||||
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
|
((UPat.var("x") & UPat.cvar("mask")) >> UPat.cvar("k"),
|
||||||
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
|
lambda x,mask,k: x >> k.arg if mask.arg | ((1 << k.arg) - 1) == -1 else None),
|
||||||
((UPat.var("x") & UPat.cvar("mask")) // UPat.cvar("c"),
|
|
||||||
lambda x,mask,c: x // c.arg if c.arg > 0 and c.arg & (c.arg-1) == 0 and mask.arg | (c.arg-1) == -1 else None),
|
|
||||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||||
# ** constant folding **
|
# ** constant folding **
|
||||||
|
|
@ -162,7 +160,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||||
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||||
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
|
||||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
||||||
# ** simple where folding **
|
# ** simple where folding **
|
||||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||||
|
|
@ -170,9 +167,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||||
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
(UPat.cvar("gate").where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||||
# STACK on INDEX CONST (TODO: remove all the GEP crap)
|
|
||||||
(UPat(Ops.STACK, src=UPat(Ops.INDEX, src=(UPat.var("src"), UPat(Ops.CONST))), name="stk"),
|
|
||||||
lambda src,stk: src if stk.shape == src.shape and list(range(len(stk.src))) == [x.src[1].arg for x in stk.src] else None),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue