mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
9 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0b9416432 |
||
|
|
b1db99cefe | ||
|
|
de28eaa610 | ||
|
|
f3da3c9be5 | ||
|
|
c43f22b143 | ||
|
|
34e631eb26 | ||
|
|
16983e9c95 | ||
|
|
3eb00a421f | ||
|
|
b54493b003 |
5 changed files with 132 additions and 299 deletions
|
|
@ -1,145 +1,51 @@
|
|||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.uop.ops import UOp, KernelInfo
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import getenv, colored, prod, unwrap
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.opt.kernel import axis_colors, Opt, OptOps
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, view_left
|
||||
|
||||
def to_colored(full_shape, axis_types): return '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
N = 4096
|
||||
run_count = 5
|
||||
|
||||
# block for locals
|
||||
BN = 128
|
||||
BM = 128
|
||||
BK = 8
|
||||
|
||||
# t for registers
|
||||
TN = 4
|
||||
TM = 4
|
||||
|
||||
# NOTE: this is from testgrad
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
# src->r->view --> src->view->r
|
||||
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
|
||||
if r.tag is not None: return None
|
||||
# confirm the input is in order
|
||||
# TODO: replace this with a UOp that allows for nothing else then remove this
|
||||
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
|
||||
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
|
||||
|
||||
# append the reduce shape to each of the views
|
||||
prshape = prod(rshape:=src.shape[-len(r.axis_arg):])
|
||||
rstrides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+rstrides, v.offset*prshape,
|
||||
v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
def hand_spec_kernel3(kernel5=getenv("K5", 0)):
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
# ---------------------------
|
||||
|
||||
# no reshape required with shrinking REDUCE_AXIS
|
||||
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
|
||||
(r.arg[0], tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))))
|
||||
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
])
|
||||
|
||||
def rangeify_kernel3():
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
c = a@b
|
||||
#c = c.reshape((32,2,16,4,32,2,16,4)).contiguous()
|
||||
sink = c.schedule()[-1].ast
|
||||
#print(sink)
|
||||
|
||||
opts = [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UPCAST, 0, 2)]
|
||||
opts += [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.LOCAL, 1, 16), Opt(OptOps.UPCAST, 1, 2)]
|
||||
opts += [Opt(OptOps.UNROLL, 0, 8)]
|
||||
|
||||
return sink.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
|
||||
def top_spec_kernel3():
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
c = a@b
|
||||
sink = c.schedule()[-1].ast
|
||||
L = 16
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)})
|
||||
sink = graph_rewrite(sink, view_left+pm)
|
||||
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
||||
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
||||
|
||||
def hl_spec_kernel3():
|
||||
nbIterWaveM = 2
|
||||
nbIterWaveN = 2
|
||||
|
||||
# define buffers
|
||||
# TODO: remove these views once the defines have a shape
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK, BM))).permute((1,0))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK, BN))).permute((1,0))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,)))
|
||||
|
||||
# shape buffers. TODO: permutes
|
||||
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
|
||||
a = a.reshape((N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, 1, 1, 1, 1, N//BK, BK)).expand(full_shape)
|
||||
b = b.reshape((1, 1, 1, 1, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)).expand(full_shape)
|
||||
c = c.reshape((N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, 1))
|
||||
As = As.reshape((1, nbIterWaveM, BM//(nbIterWaveM * TM), TM, 1, 1, 1, 1, 1, BK)).expand(full_shape)
|
||||
Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape)
|
||||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
|
||||
# U1 L2 L3 L4 L5 U6 U7 U9 L10 L11 L12 L13 U14 U15 U17 U18 U19
|
||||
expanded_shape = (32, 2, 2, 2, 2, 2, 2, 2, 32, 2, 2, 2, 2, 2, 2, 2, 512, 2, 2, 2)
|
||||
assert len(expanded_shape) == 20
|
||||
permute_a = list(range(len(expanded_shape)))
|
||||
permute_b = permute_a[:]
|
||||
|
||||
# this makes all the global loads match
|
||||
# this can also be more simply done by rebinding the RANGEs
|
||||
# but sadly, rebinding the RANGEs doesn't work to change the order of the local axes
|
||||
permute_a[17:20] = [11,12,13]
|
||||
permute_a[11:14] = [17,18,19]
|
||||
permute_a[7], permute_a[10] = permute_a[10], permute_a[7]
|
||||
permute_a[2:7] = [3,4,5,6,2]
|
||||
|
||||
permute_b[2:16] = [19,9,10,11,17,18,8,2,12,13,14,15,3,4]
|
||||
permute_b[17:20] = [5,6,7]
|
||||
|
||||
a_permute = a.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape)
|
||||
As_permute = As.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape)
|
||||
|
||||
b_permute = b.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape)
|
||||
Bs_permute = Bs.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape)
|
||||
|
||||
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
|
||||
out = (As.load(As_permute.store(a_permute.load())) * Bs.load(Bs_permute.store(b_permute.load()))).r(Ops.ADD, (8, 9))
|
||||
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
|
||||
axis_types = (
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.REDUCE, AxisType.REDUCE)
|
||||
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
||||
def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
BLOCK_SIZE = 128 if kernel5 else 256
|
||||
|
||||
nbWaves = BLOCK_SIZE // 32
|
||||
WN = 128 if kernel5 else 64
|
||||
WM = BN * BM // nbWaves // WN
|
||||
|
||||
# Sanity checks (fail fast if shapes/tiles misalign)
|
||||
assert BN % WN == 0, "BN must be a multiple of WN"
|
||||
assert BM % WM == 0, "BM must be a multiple of WM"
|
||||
nbWaveX = BN // WN
|
||||
nbWaveY = BM // WM
|
||||
|
||||
threadIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("lidx0", BLOCK_SIZE))
|
||||
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
|
||||
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
|
||||
|
||||
assert (BN * BK) % BLOCK_SIZE == 0
|
||||
assert (BM * BK) % BLOCK_SIZE == 0
|
||||
|
||||
# ---------------------------
|
||||
# per-thread read mapping
|
||||
# ---------------------------
|
||||
# A: read BK x BN tiles; B: read BN x BK tiles
|
||||
|
||||
threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0")
|
||||
waveIndex = threadIdx_x // 32
|
||||
waveIdx = waveIndex % nbWaveX
|
||||
waveIdy = waveIndex // nbWaveX
|
||||
|
|
@ -157,197 +63,122 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
|||
SUBWN = WN // nbIterWaveN
|
||||
SUBWM = WM // nbIterWaveM
|
||||
|
||||
# Thread mapping to read BKxBN block from A
|
||||
rAIdx = threadIdx_x % BK
|
||||
rAIdy = threadIdx_x // BK
|
||||
# Thread mapping to read BNxBK block from B
|
||||
rBIdx = threadIdx_x % BN
|
||||
rBIdy = threadIdx_x // BN
|
||||
# ---------------------------
|
||||
# block indices & placeholders
|
||||
# ---------------------------
|
||||
blockIdx_x = UOp.special(N // BN, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BM, "gidx1")
|
||||
|
||||
strideReadB = BLOCK_SIZE // BN
|
||||
strideReadA = BLOCK_SIZE // BK
|
||||
nbReadsB = BN * BK // BLOCK_SIZE
|
||||
nbReadsA = BM * BK // BLOCK_SIZE
|
||||
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
|
||||
blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN))
|
||||
blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM))
|
||||
BM_As_stride = (BM + 4) if kernel5 else BM
|
||||
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
|
||||
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
|
||||
BM_As_stride = (BM+4) if kernel5 else BM
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM_As_stride, AddrSpace.LOCAL), arg=0)
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||
A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (nbIterWaveM, TM, nbIterWaveN, TN), slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.dtype.size, 16)
|
||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
|
||||
if kernel4:
|
||||
regA = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsA, AddrSpace.REG), arg=3)
|
||||
regB = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsB, AddrSpace.REG), arg=4)
|
||||
kId_range = UOp.range(N // BK, 0)
|
||||
kId = kId_range * BK
|
||||
|
||||
# initial load from globals into locals (0)
|
||||
kId = 0
|
||||
# ---------------------------
|
||||
# GLOBAL -> LOCAL (As, Bs)
|
||||
# ---------------------------
|
||||
nbReadsB = BN * BK // BLOCK_SIZE
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
rBIdx = threadIdx_x % BN
|
||||
rBIdy = threadIdx_x // BN
|
||||
strideReadB = BLOCK_SIZE // BN
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i)
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(nbReadsB, 0)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
nbReadsA = BM * BK // BLOCK_SIZE
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
rAIdx = threadIdx_x % BK
|
||||
rAIdy = threadIdx_x // BK
|
||||
strideReadA = BLOCK_SIZE // BK
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i)
|
||||
|
||||
i = UOp.range(nbReadsA, 1)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
# TODO: can we automate barrier?
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
Bs = Bs.after(barrier)
|
||||
As = As.after(barrier)
|
||||
|
||||
# iterate over the middle chunk
|
||||
kId_range = UOp.range(N//BK-1, 2)
|
||||
kId = kId_range*BK
|
||||
# open inner k range
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row = B_row[iterWave, i].set(Bs[k, index], end=(iterWave, i))
|
||||
|
||||
# load from globals into registers (next round)
|
||||
i = UOp.range(nbReadsB, 3)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col = A_col[iterWave, i].set(As[k, index], end=(iterWave, i))
|
||||
|
||||
i = UOp.range(nbReadsA, 4)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
||||
# ---------------------------
|
||||
# FMA: c_regs += A_col * B_row
|
||||
# ---------------------------
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
c_idx = c_regs.after(k, kId_range)[iterWaveM, yt, iterWaveN, xt]
|
||||
sink = c_idx.store(c_idx + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
def inner_loop(first_range, inp_dep=()):
|
||||
# inner unroll
|
||||
k = UOp.range(BK, first_range+0)
|
||||
# Close k, sync, and close K tiles
|
||||
sink = sink.end(k).barrier().end(kId_range)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(nbIterWaveN, first_range+1)
|
||||
i = UOp.range(TN, first_range+2)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(nbIterWaveM, first_range+3)
|
||||
i = UOp.range(TM, first_range+4)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(TM, first_range+6)
|
||||
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(TN, first_range+8)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
# sketchy, this should end the kId_range but it doesn't
|
||||
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||
iterWaveM, iterWaveN, yt, xt, k)
|
||||
return sink
|
||||
|
||||
# TODO: kId_range should endrange after a barrier
|
||||
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||
|
||||
# load from registers into locals
|
||||
i = UOp.range(nbReadsB, 14)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||
|
||||
i = UOp.range(nbReadsA, 15)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||
|
||||
# final iteration without the copy
|
||||
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||
else:
|
||||
kId_range = UOp.range(N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||
|
||||
# store c_regs into c
|
||||
# ---------------------------
|
||||
# REG -> GLOBAL (epilogue)
|
||||
# ---------------------------
|
||||
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 1002)
|
||||
xt = UOp.range(TN, 1003)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
indexC = N * (yOut + yt) + xOut + xt
|
||||
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink),
|
||||
iterWaveM, iterWaveN, yt, xt)
|
||||
sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
|
||||
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=()))
|
||||
|
||||
return sink.sink(arg=KernelInfo(name="tinygemm"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
HL = getenv("HL")
|
||||
if HL == 3: hprg = rangeify_kernel3()
|
||||
elif HL == 2: hprg = top_spec_kernel3()
|
||||
elif HL == 1: hprg = hl_spec_kernel3()
|
||||
else: hprg = hand_spec_kernel3()
|
||||
if HL == 3:
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
else:
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
print(prg.src)
|
||||
if getenv("SRC"): exit(0)
|
||||
hrunner = CompiledRunner(prg)
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
hc = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
a = Tensor.randn(N, N).realize()
|
||||
b = Tensor.randn(N, N).realize()
|
||||
hc = Tensor.zeros(N, N).contiguous().realize()
|
||||
sink = hand_spec_kernel3()
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
|
||||
|
||||
GlobalCounters.reset()
|
||||
ets = []
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count):
|
||||
ets.append(ei.run(wait=True))
|
||||
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
|
||||
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): tc = (a@b).realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
buffers = [hc.uop.buffer, a.uop.buffer, b.uop.buffer]
|
||||
ei = ExecItem(hrunner, buffers)
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
err = (hc-tc).square().mean().item()
|
||||
print(f"hrunner {err}")
|
||||
if err > 1e-06: raise RuntimeError("matmul is wrong!")
|
||||
tc = (a @ b).realize()
|
||||
with Context(DEBUG=0):
|
||||
err = (hc - tc).square().mean().item()
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-06:
|
||||
raise RuntimeError("matmul is wrong!")
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ pm_add_control_flow = PatternMatcher([
|
|||
|
||||
def do_split_ends(e:UOp):
|
||||
ret = e.src[0]
|
||||
for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r)
|
||||
for r in sorted(UOp.sink(*e.src[1:]).ranges, key=lambda x: x.arg, reverse=True): ret = ret.end(r)
|
||||
return ret
|
||||
|
||||
pm_split_ends = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class Estimates:
|
|||
if ignore_indexing:
|
||||
def range_gate(x): return x.op is not Ops.RANGE
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
# if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
|
||||
dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
|
||||
# TODO: is this correct? this all needs to be cleaned up
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
from tinygrad.helpers import strip_parens, colored
|
||||
|
|
@ -756,8 +756,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@staticmethod
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
|
|
|
|||
|
|
@ -146,8 +146,8 @@ shared_codegen_spec = PatternMatcher([
|
|||
# SPECIAL
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)),
|
||||
|
||||
# BARRIER
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True),
|
||||
# BARRIER (on any length)
|
||||
(UPat(Ops.BARRIER, dtypes.void), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
|
@ -156,6 +156,7 @@ kernel_spec = PatternMatcher([
|
|||
# RESHAPE (but only RESHAPE) is allowed here
|
||||
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.VCONST, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue