mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove DEFINE_LOCAL and DEFINE_REG (gpt) (#16673)
* remove define_local and define_reg (gpt) * fix precommit * cleanups * regalloc fix * cleanups 2
This commit is contained in:
parent
b05bea81ce
commit
649971f02a
22 changed files with 53 additions and 61 deletions
|
|
@ -458,7 +458,8 @@ def test_matmul():
|
|||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def amd_flash_attention(o:UOp, q:UOp, k:UOp, v:UOp) -> UOp:
|
|||
P_lds = QP_lds[:, :BLOCK_N]
|
||||
P_write = P_lds.reshape(WAVES_M, TM // WMMA_ACC, WMMA_ACC, LANES_PER_WAVE_M, WAVES_N, TN, LANES_PER_WAVE_N)
|
||||
P_write = P_write.permute((0, 4, 3, 6, 1, 2, 5)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) — shaped store fails due to RESHAPE(DEFINE_LOCAL) surviving linearization
|
||||
# TODO: P_write[tid].store(S_reg.cast(dtypes.half)) -- shaped store fails due to RESHAPE(local BUFFER) surviving linearization
|
||||
rw1 = UOp.range(TM, 296, AxisType.LOOP)
|
||||
rw2 = UOp.range(TN, 297, AxisType.LOOP)
|
||||
P_store = P_write[tid, rw1, rw2].store(S_reg[rw1, rw2].cast(dtypes.half)).end(rw1, rw2)
|
||||
|
|
|
|||
|
|
@ -2619,7 +2619,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
|||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
insts = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds = UOp.placeholder((133_120,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
|
||||
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
|
||||
|
|
|
|||
|
|
@ -219,7 +219,8 @@ def test_matmul():
|
|||
def asm_kernel(A, B, C):
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(THREADS, "lidx0")]
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC",2))
|
||||
lds = UOp.placeholder((lds_size,), dtypes.uint8, 0, AddrSpace.LOCAL)
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs,
|
||||
arg=KernelInfo(name=colored("kernel","cyan"), estimates=Estimates(ops=N*N*N*2, mem=N*N*2*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
|
|||
num_threads = A.shape[0]
|
||||
threads = UOp.special(num_threads, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
|
||||
isa = r4 if arch == "rdna4" else r3
|
||||
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
|
||||
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
|
||||
|
|
@ -103,7 +103,7 @@ def custom_handwritten(A:UOp) -> UOp:
|
|||
A = A.flatten()
|
||||
threads = UOp.special(128, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
lds = UOp.placeholder((512,), dtypes.uint8, 0, AddrSpace.LOCAL) # 128 * 4 bytes
|
||||
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
|
||||
k = Kernel()
|
||||
# wrap in loop to filter out icache misses
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
|
|
@ -161,7 +161,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0],
|
||||
[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
|
||||
accs = [u for u in uops if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
assert len(stores) == 1
|
||||
|
|
@ -210,14 +210,14 @@ class TestLinearizer(unittest.TestCase):
|
|||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule_linear().src[-1].src[0]
|
||||
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
|
||||
assert local[0].dtype.base == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
realized_ast = c.schedule_linear().src[-1].src[0]
|
||||
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
|
||||
local = [uop for uop in tuple(program.src[2].src) if uop.op is Ops.BUFFER and uop.addrspace in (AddrSpace.LOCAL, AddrSpace.REG)]
|
||||
self.assertEqual(local[0].dtype.base, expected_dtype)
|
||||
|
||||
tests = (
|
||||
|
|
@ -243,7 +243,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = (x@y).relu()
|
||||
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
|
||||
ast = helper_linearizer_opt(r, [opt])
|
||||
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
# the uops graph is reg BUFFER -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
|
||||
|
|
@ -361,7 +361,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
ast = helper_linearizer_opt(out, opts=[opt])
|
||||
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
|
||||
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
|
||||
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
|
||||
local_stores = [u for u in uops if u.op is Ops.STORE and any(
|
||||
x.op is Ops.BUFFER and x.addrspace is AddrSpace.LOCAL for x in get_recursive(u.src[0]))]
|
||||
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.PARAM for x in get_recursive(u.src[0]))]
|
||||
barrier = [u for u in uops if u.op is Ops.BARRIER]
|
||||
assert len(barrier) == 1
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import tempfile, unittest
|
||||
from tinygrad import Tensor, Context, Device, dtypes, UOp
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import run_linear
|
||||
from tinygrad.codegen import to_program
|
||||
|
|
@ -80,7 +81,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
|
|||
with Context(QUANTIZE=1):
|
||||
linear = run_onnx({"input":inp})["output"].schedule_linear()
|
||||
prg = to_program(linear.src[-2].src[0], renderer=Device[Device.DEFAULT].renderer)
|
||||
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.DEFINE_REG]
|
||||
daccs = [u for u in tuple(prg.src[2].src) if u.op is Ops.BUFFER and u.addrspace is AddrSpace.REG]
|
||||
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
|
|
|
|||
|
|
@ -1425,7 +1425,7 @@ def _compile_mfma(inst: irc.VOP3P, ctx: _Ctx) -> UOp:
|
|||
mant = h & UOp.const(dtypes.uint32, 0x3FF)
|
||||
# Use bf16 path: shift left by 16 to create bf16 bits, then shift mantissa and adjust exponent in float domain
|
||||
# bf16 bits = (sign << 15) | (exp_bf16 << 7) | mant_bf16 -- but f16 and bf16 have different formats
|
||||
# Instead: construct f32 bits properly, use a DEFINE_LOCAL uint32 array to force materialization
|
||||
# Instead: construct f32 bits properly, use a local uint32 array to force materialization
|
||||
f32_bits = (sign << UOp.const(dtypes.uint32, 31)) | \
|
||||
((exp + UOp.const(dtypes.uint32, 112)) << UOp.const(dtypes.uint32, 23)) | \
|
||||
(mant << UOp.const(dtypes.uint32, 13))
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ pm_index_is_shrink = PatternMatcher([
|
|||
|
||||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
# rewrite PARAM to non pointer
|
||||
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
|
||||
(UPat((Ops.PARAM, Ops.BUFFER), name="buf"), lambda buf:
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
|
||||
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
|
||||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.BUFFER}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
|
|
|
|||
|
|
@ -268,12 +268,10 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
|||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.weakint.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index),
|
||||
(UPat(Ops.BUFFER, name="buf"), no_vectorized_buf),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat(Ops.BUFFER).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
devectorize_alu = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import heapq
|
|||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
|
|
@ -23,7 +24,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
match u.op:
|
||||
# the order and placement of these defines is important
|
||||
case Ops.PARAM: priority, extra = -20, u.arg.slot
|
||||
case Ops.BUFFER: priority = -18
|
||||
case Ops.BUFFER: priority = -17 if u.addrspace == AddrSpace.LOCAL else -18
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.STORE: priority = 1 # place stores late
|
||||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
|
|
|
|||
|
|
@ -83,10 +83,10 @@ class LinearScanRegallocContext:
|
|||
live[v] = alloc(cons, i+1 if u.op is not Ops.RANGE else i)
|
||||
self.reals.setdefault(i, {})[v] = live[v]
|
||||
|
||||
# allocate stack array, BUFFER size is in src[0]
|
||||
# allocate stack array
|
||||
if u.op is Ops.BUFFER:
|
||||
self.locals[u] = UOp.const(dtypes.int32, self.stack_size)
|
||||
self.stack_size += u.src[0].arg * u.dtype.itemsize
|
||||
self.stack_size += u.max_numel() * u.dtype.itemsize
|
||||
|
||||
# loop prologue, avoid loading inside the loop
|
||||
if u.op is Ops.RANGE:
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
|
|||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.BUFFER}: continue
|
||||
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
|
||||
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
|
||||
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
|
|||
for u in sink.toposort():
|
||||
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
|
||||
elif u.op is Ops.PARAM: n_bufs += 1
|
||||
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.BUFFER and u.addrspace is AddrSpace.LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
|
||||
code_bytes = b"".join(inst.to_bytes() for inst in insts)
|
||||
arch = next(v for k, v in _arch_map.items() if arch.startswith(k))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys, struct, functools
|
|||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, DType, truncate, AddrSpace
|
||||
from tinygrad.uop import FastEnum, auto, Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, ParamArg
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, Register, PreRegAllocContext
|
||||
from tinygrad.helpers import getenv, CPU_COUNT, unwrap, Target
|
||||
|
||||
|
|
@ -174,7 +174,7 @@ extra_matcher = PatternMatcher([
|
|||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def scratch_buffer(elem_dt:DType, count:int, slot:int) -> UOp:
|
||||
return UOp(Ops.BUFFER, elem_dt, src=(UOp.const(dtypes.int, count),), arg=ParamArg(slot, addrspace=AddrSpace.LOCAL))
|
||||
return UOp.placeholder((count,), elem_dt, slot, AddrSpace.LOCAL)
|
||||
|
||||
def gated_load(ctx, addr:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = scratch_buffer(addr.src[0].dtype.scalar(), x.dtype.count, next(ctx))
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ dsp_pm_late = PatternMatcher([
|
|||
(UPat.var("x")+UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")*UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")//UPat(Ops.STACK,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat(Ops.DEFINE_REG, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.STACK, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]) if d.addrspace is AddrSpace.REG else None),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_claus
|
|||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -427,7 +427,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
ended_stores.append(store_target.replace(dtype=sdtype).store(store.src[1]).end(*end_rngs))
|
||||
return buf.after(*ended_stores)
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
# NOTE: the local BUFFER needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
|
||||
if x.src[0].op is Ops.SLICE:
|
||||
|
|
@ -561,11 +561,11 @@ rangeify_codegen = PatternMatcher([
|
|||
# TODO: this can be moved into codegen?
|
||||
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0] if len(x.src) else None),
|
||||
|
||||
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
|
||||
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
|
||||
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
|
||||
(UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if isinstance(idx.dtype, PtrDType) else
|
||||
(UPat(Ops.BUFFER).f(Ops.AFTER, allow_any_len=True).gep(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True),
|
||||
lambda dg,idx: None if dg.addrspace is not AddrSpace.LOCAL or isinstance(idx.dtype, PtrDType) else
|
||||
idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,7 @@ class Ops(FastEnum):
|
|||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
||||
# define LOCAL/REG allocate things
|
||||
DEFINE_LOCAL = auto(); DEFINE_REG = auto()
|
||||
|
||||
# BUFFER is the new LOCAL/REG
|
||||
# BUFFER allocates global/local/register storage depending on its addrspace
|
||||
BUFFER = auto()
|
||||
|
||||
# ** 2 -- non op uops **
|
||||
|
|
@ -125,7 +122,7 @@ class GroupOp:
|
|||
# TODO: is BITCAST always Elementwise if it's shape changing?
|
||||
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||
|
||||
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
Defines = {Ops.PARAM, Ops.BUFFER}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
|
|
|||
|
|
@ -279,7 +279,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
case Ops.GETADDR: return ()
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
|
||||
case Ops.BUFFER:
|
||||
if isinstance(self.arg, ParamArg):
|
||||
if len(self.src): return self.src[0].as_shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
|
||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||
return (self.arg,)
|
||||
case Ops.SLICE:
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
|
|
@ -288,13 +293,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
case Ops.STAGE:
|
||||
# STAGE adds the existing shape to the front, opposite of INDEX
|
||||
return tuple([int(r.vmax+1) for r in self.src[1:]])+self.src[0].shape
|
||||
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
|
||||
if len(self.src) >= 1:
|
||||
# NOTE: this is the same as PARAM
|
||||
return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
if isinstance(self.dtype, PtrDType):
|
||||
return (self.ptrdtype.size, self.dtype.count) if self.dtype.count > 1 else (self.ptrdtype.size,)
|
||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, ImageDType): return self.dtype.shape
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
|
|
@ -794,8 +792,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def addrspace(self) -> AddrSpace|None:
|
||||
if self.op is Ops.PARAM: return self.arg.addrspace
|
||||
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
|
||||
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
|
|
|
|||
|
|
@ -86,12 +86,8 @@ spec_shared = PatternMatcher([
|
|||
# TODO: remove UNROLL here, it's for SPEC=2
|
||||
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
|
||||
|
||||
# TOOD: these should be buffer with different addrspace everywhere.
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
|
||||
|
||||
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.AFTER, Ops.MULTI,
|
||||
Ops.BITCAST, Ops.INS})),),
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.PARAM, Ops.BUFFER, Ops.CONTIGUOUS, Ops.AFTER, Ops.MULTI, Ops.BITCAST, Ops.INS})),),
|
||||
allow_any_len=True), lambda: True),
|
||||
|
||||
# CUSTOM (inline and non inline)
|
||||
|
|
@ -178,7 +174,7 @@ spec_tensor = PatternMatcher([
|
|||
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
|
||||
lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# TODO: this should not be here. STAGE is transformed to DEFINE_LOCAL later
|
||||
# TODO: this should not be here. STAGE is transformed to BUFFER later
|
||||
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
|
|
@ -198,7 +194,7 @@ spec_tensor = PatternMatcher([
|
|||
# these ops can exist in programs but not the tensor spec. example: LOAD
|
||||
spec_program = PatternMatcher([
|
||||
# no more of these in programs
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
|
||||
(UPat(Ops.GEP), lambda: False),
|
||||
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE
|
|||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#D8F9E4", Ops.STACK: "#D8F9E4",
|
||||
Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue