mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
8 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cfd4c915f | ||
|
|
abb01ce6a0 | ||
|
|
6853da2a2f | ||
|
|
154ddd98fd | ||
|
|
e4ef94cf10 | ||
|
|
0c274151ad | ||
|
|
c4e32d4f63 | ||
|
|
a9d91ffcfc |
9 changed files with 112 additions and 20 deletions
|
|
@ -3,10 +3,40 @@ from tinygrad import Tensor, nn, Device
|
|||
from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG, DEBUG
|
||||
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
|
||||
from tinygrad.codegen.opt import OptOps, Opt
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL" and not CI, "only for METAL TC")
|
||||
class TestBigDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
N = 1024
|
||||
with Context(DEBUG=0):
|
||||
self.a, self.b, self.c = [Tensor.randn(N, N).contiguous().realize() for _ in range(3)]
|
||||
with Context(DEBUG=2):
|
||||
self.ref = (self.a @ self.b @ self.c).realize()
|
||||
|
||||
def _test(self, opts):
|
||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
err = (out-self.ref).square()
|
||||
self.assertLess(err.max().item(), 1e-4)
|
||||
self.assertLess(err.mean().item(), 1e-6)
|
||||
|
||||
def test_demote_tc_both(self):
|
||||
outs = ()
|
||||
outs += (Opt(OptOps.DEMOTE, 2, 8),)
|
||||
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 1)),)
|
||||
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 0)),)
|
||||
outs += (Opt(OptOps.UPCAST, 0, 4),)
|
||||
outs += (Opt(OptOps.UPCAST, 1, 4),)
|
||||
#outs += (Opt(OptOps.UNROLL, 0, 4),)
|
||||
#outs += (Opt(OptOps.UNROLL, 1, 4),)
|
||||
self._test(outs)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
|
||||
class TestDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with Context(DEBUG=0):
|
||||
|
|
@ -51,6 +81,20 @@ class TestDoubleMatmul(unittest.TestCase):
|
|||
def test_upcast_12_unroll_01(self):
|
||||
self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
|
||||
def test_demote(self): self._test((Opt(OptOps.DEMOTE, 2, 8),))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
|
||||
def test_demote_tc_top(self):
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
|
||||
def test_demote_tc_bottom(self):
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
|
||||
def test_demote_tc_both(self):
|
||||
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1)), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
def test_assign_permuted(self):
|
||||
A = Tensor.empty(4, 4, dtype='int')
|
||||
|
|
@ -114,7 +158,7 @@ def fa_bw():
|
|||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention_bw(self):
|
||||
with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
|||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_flatten_bufferize
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
|
|
@ -46,6 +46,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
# ** expander (expand_rewrite) **
|
||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||
|
||||
# flatten bufferize for expander
|
||||
sink = graph_rewrite(sink, pm_flatten_bufferize, name="flatten bufferize")
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ def linearize(u:UOp) -> list[UOp]:
|
|||
in_degree[u] = len(u.src)
|
||||
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
||||
priority = [0] + [priorities[x] for x in consumers[u]]
|
||||
if u.op is Ops.LOAD: priority.append(-1000)
|
||||
if u.op is Ops.LOAD: priority.append(-5000)
|
||||
if u.op is Ops.BARRIER: priority.append(-1500)
|
||||
# ranges are scheduled as late as possible so anything that can be outside is
|
||||
# if u.op is Ops.RANGE: priority = [2000]
|
||||
if u.op is Ops.RANGE: priority = [2000]
|
||||
if u.op is Ops.END: priority = [-1000]
|
||||
# move defines and consts to the top
|
||||
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
DEMOTE = auto()
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,15 @@ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(
|
|||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
def do_demote(ctx, x:UOp, last=False):
|
||||
if x.tag is not None: return None
|
||||
mr = ctx[0]
|
||||
nr = mr.replace(arg=ctx[0].arg[0:-2]+(mr.arg[-2]+1, mr.arg[-1]))
|
||||
ctx[0] = nr
|
||||
if last: buf = x.replace(src=x.src+(mr,), tag=1).substitute({mr:nr})
|
||||
else: buf = x.replace(src=(x.src[0], mr)+x.src[1:], tag=1).substitute({mr:nr})
|
||||
return UOp(Ops.APPENDINDEX, dtypes.void, (buf,mr))
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, ast:UOp, ren:Renderer):
|
||||
self.ast, self.ren = ast, ren
|
||||
|
|
@ -174,14 +183,35 @@ class Scheduler:
|
|||
check(not self.dont_use_locals, "can't use locals")
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.DEMOTE:
|
||||
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
|
||||
|
||||
# do the demotion
|
||||
LAST = True
|
||||
if LAST:
|
||||
pm_demote = PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
|
||||
(UPat(Ops.BUFFERIZE, name="x"), lambda ctx, x: do_demote(ctx, x, True)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
|
||||
lambda x,y: y.replace(src=(x.src[0],)+y.src[1:]+x.src[1:])),
|
||||
])
|
||||
else:
|
||||
pm_demote = PatternMatcher([
|
||||
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
|
||||
(UPat(Ops.BUFFERIZE, name="x"), do_demote),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
|
||||
lambda x,y: y.replace(src=(x.src[0],)+x.src[1:]+y.src[1:])),
|
||||
])
|
||||
self.ast = graph_rewrite(self.ast.src[0].end(rr).sink(), pm_demote, ctx=[rr], bottom_up=True, name="demote")
|
||||
elif opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
#check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
||||
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) >= 3, "tensor core opts must have valid arg")
|
||||
assert isinstance(opt.arg, tuple)
|
||||
check(-1 <= (tc_select:=opt.arg[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=opt.arg[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=opt.arg[2]) <= 2, "use_tensor_cores value is not valid")
|
||||
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt, opt.arg[3] if len(opt.arg) > 3 else 0)
|
||||
except ValueError as e: raise KernelOptError(str(e))
|
||||
check(ret is not None, "no tensor core available")
|
||||
elif opt.op is OptOps.PADTO:
|
||||
|
|
@ -216,10 +246,10 @@ class Scheduler:
|
|||
if append_opt: self.applied_opts.append(opt)
|
||||
return ret
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int, reduce_choice:int) -> None|list[UOp]:
|
||||
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
|
||||
reduceop = reduceops[0]
|
||||
reduceop = reduceops[reduce_choice]
|
||||
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
|
||||
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
|
||||
if mul.op is not Ops.MUL: return None
|
||||
|
|
@ -235,8 +265,8 @@ class Scheduler:
|
|||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
||||
if DEBUG >= 3:
|
||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
print(f"TC({axis}, {reduce_choice}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
|
||||
|
||||
# pick ranges
|
||||
|
|
@ -259,21 +289,27 @@ class Scheduler:
|
|||
axes[i] = self.rngs[idx]
|
||||
except KernelOptError: continue
|
||||
|
||||
upcast_ranges = []
|
||||
reduce_ranges = []
|
||||
|
||||
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
|
||||
warp = UOp.range(tc.threads, -1, AxisType.WARP)
|
||||
warp_num = 0
|
||||
ne: list[UOp] = []
|
||||
for opt in tc.opts:
|
||||
if opt[0] == "l":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
|
||||
warp //= 2
|
||||
warp = UOp.range(2, -1, warp_num, AxisType.WARP)
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp)
|
||||
warp_num += 1
|
||||
elif opt[0] == "u":
|
||||
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
|
||||
upcast_ranges.append(new_range)
|
||||
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
|
||||
ne.append(new_range)
|
||||
|
||||
for _, amt in tc.get_reduce_axes():
|
||||
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
|
||||
ne.append(new_range)
|
||||
reduce_ranges.append(new_range)
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
|
|
@ -291,6 +327,12 @@ class Scheduler:
|
|||
# axes to range number (was done in lowerer)
|
||||
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
|
||||
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
|
||||
print(tc_reduce_axes, tc_upcast_axes)
|
||||
|
||||
# DIRECT: get range number from ranges
|
||||
tc_upcast_axes = (((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),))
|
||||
tc_reduce_axes = tuple([x.arg[0] for x in reduce_ranges])
|
||||
#print(tc_reduce_axes, tc_upcast_axes)
|
||||
|
||||
# construct the op
|
||||
# TODO: remove tc_upcast_axes from the arg
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ def flatten_bufferize(x:UOp):
|
|||
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||
return ret.rtag(x.tag)
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
pm_flatten_bufferize = pm_mops+PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
|
||||
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(x, idx, allow_locals=False)),
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ class Ops(FastEnum):
|
|||
|
||||
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
||||
INDEX = auto()
|
||||
APPENDINDEX = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT | Ops.APPENDINDEX:
|
||||
return None
|
||||
|
||||
# some ops init the shape
|
||||
|
|
|
|||
|
|
@ -233,6 +233,7 @@ full_spec = PatternMatcher([
|
|||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
(UPat(Ops.APPENDINDEX), lambda: True),
|
||||
|
||||
# all loads/stores
|
||||
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue