Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
3cfd4c915f warp_num 2025-10-29 19:08:10 +08:00
George Hotz
abb01ce6a0 last 2025-10-29 18:57:38 +08:00
George Hotz
6853da2a2f works 2025-10-29 18:45:16 +08:00
George Hotz
154ddd98fd it works, it's just slow... 2025-10-29 18:35:44 +08:00
George Hotz
e4ef94cf10 both is broken 2025-10-29 18:15:38 +08:00
George Hotz
0c274151ad tensor core works 2025-10-29 18:06:05 +08:00
George Hotz
c4e32d4f63 demote op 2025-10-29 17:47:12 +08:00
George Hotz
a9d91ffcfc DEMOTE op for putting globals in locals 2025-10-29 17:22:59 +08:00
9 changed files with 112 additions and 20 deletions

View file

@ -3,10 +3,40 @@ from tinygrad import Tensor, nn, Device
from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG, DEBUG
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
from tinygrad.codegen.opt import OptOps, Opt
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
@unittest.skipUnless(Device.DEFAULT == "METAL" and not CI, "only for METAL TC")
class TestBigDoubleMatmul(unittest.TestCase):
def setUp(self):
N = 1024
with Context(DEBUG=0):
self.a, self.b, self.c = [Tensor.randn(N, N).contiguous().realize() for _ in range(3)]
with Context(DEBUG=2):
self.ref = (self.a @ self.b @ self.c).realize()
def _test(self, opts):
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
with Context(DEBUG=0):
err = (out-self.ref).square()
self.assertLess(err.max().item(), 1e-4)
self.assertLess(err.mean().item(), 1e-6)
def test_demote_tc_both(self):
outs = ()
outs += (Opt(OptOps.DEMOTE, 2, 8),)
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 1)),)
outs += (Opt(OptOps.TC, 0, (0, 0, 1, 0)),)
outs += (Opt(OptOps.UPCAST, 0, 4),)
outs += (Opt(OptOps.UPCAST, 1, 4),)
#outs += (Opt(OptOps.UNROLL, 0, 4),)
#outs += (Opt(OptOps.UNROLL, 1, 4),)
self._test(outs)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
class TestDoubleMatmul(unittest.TestCase):
def setUp(self):
with Context(DEBUG=0):
@ -51,6 +81,20 @@ class TestDoubleMatmul(unittest.TestCase):
def test_upcast_12_unroll_01(self):
self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
def test_demote(self): self._test((Opt(OptOps.DEMOTE, 2, 8),))
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
def test_demote_tc_top(self):
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
def test_demote_tc_bottom(self):
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1))))
@unittest.skipUnless(Device.DEFAULT == "METAL", "only for METAL TC")
def test_demote_tc_both(self):
self._test((Opt(OptOps.DEMOTE, 2, 8), Opt(OptOps.TC, 0, (0, 0, 1, 1)), Opt(OptOps.TC, 0, (0, 0, 1, 0))))
class TestRangeifyAssign(unittest.TestCase):
def test_assign_permuted(self):
A = Tensor.empty(4, 4, dtype='int')
@ -114,7 +158,7 @@ def fa_bw():
Tensor.realize(*ret)
return ret
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer, CUDARenderer)), "broken in LVP and PTX")
class TestPcontig(unittest.TestCase):
def test_flash_attention_bw(self):
with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2):

View file

@ -15,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_flatten_bufferize
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
@ -46,6 +46,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
# flatten bufferize for expander
sink = graph_rewrite(sink, pm_flatten_bufferize, name="flatten bufferize")
# expand
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")

View file

@ -16,10 +16,10 @@ def linearize(u:UOp) -> list[UOp]:
in_degree[u] = len(u.src)
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
priority = [0] + [priorities[x] for x in consumers[u]]
if u.op is Ops.LOAD: priority.append(-1000)
if u.op is Ops.LOAD: priority.append(-5000)
if u.op is Ops.BARRIER: priority.append(-1500)
# ranges are scheduled as late as possible so anything that can be outside is
# if u.op is Ops.RANGE: priority = [2000]
if u.op is Ops.RANGE: priority = [2000]
if u.op is Ops.END: priority = [-1000]
# move defines and consts to the top
if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}: priority.append(-2000)

View file

@ -6,6 +6,7 @@ from dataclasses import dataclass
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); THREAD = auto() # noqa: E702
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
DEMOTE = auto()
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)

View file

@ -16,6 +16,15 @@ remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
def do_demote(ctx, x:UOp, last=False):
if x.tag is not None: return None
mr = ctx[0]
nr = mr.replace(arg=ctx[0].arg[0:-2]+(mr.arg[-2]+1, mr.arg[-1]))
ctx[0] = nr
if last: buf = x.replace(src=x.src+(mr,), tag=1).substitute({mr:nr})
else: buf = x.replace(src=(x.src[0], mr)+x.src[1:], tag=1).substitute({mr:nr})
return UOp(Ops.APPENDINDEX, dtypes.void, (buf,mr))
class Scheduler:
def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren
@ -174,14 +183,35 @@ class Scheduler:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.DEMOTE:
_, rr = self.shift_to(rng, cast(int, opt.arg), AxisType.LOOP)
# do the demotion
LAST = True
if LAST:
pm_demote = PatternMatcher([
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
(UPat(Ops.BUFFERIZE, name="x"), lambda ctx, x: do_demote(ctx, x, True)),
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
lambda x,y: y.replace(src=(x.src[0],)+y.src[1:]+x.src[1:])),
])
else:
pm_demote = PatternMatcher([
(UPat(Ops.END, src=(UPat(Ops.END, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
(UPat(Ops.BUFFERIZE, name="x"), do_demote),
(UPat(Ops.INDEX, src=(UPat(Ops.APPENDINDEX, name="x"),), name="y", allow_any_len=True),
lambda x,y: y.replace(src=(x.src[0],)+x.src[1:]+y.src[1:])),
])
self.ast = graph_rewrite(self.ast.src[0].end(rr).sink(), pm_demote, ctx=[rr], bottom_up=True, name="demote")
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
#check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) >= 3, "tensor core opts must have valid arg")
assert isinstance(opt.arg, tuple)
check(-1 <= (tc_select:=opt.arg[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=opt.arg[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=opt.arg[2]) <= 2, "use_tensor_cores value is not valid")
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt, opt.arg[3] if len(opt.arg) > 3 else 0)
except ValueError as e: raise KernelOptError(str(e))
check(ret is not None, "no tensor core available")
elif opt.op is OptOps.PADTO:
@ -216,10 +246,10 @@ class Scheduler:
if append_opt: self.applied_opts.append(opt)
return ret
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int, reduce_choice:int) -> None|list[UOp]:
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
reduceop = reduceops[0]
reduceop = reduceops[reduce_choice]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
if mul.op is not Ops.MUL: return None
@ -235,8 +265,8 @@ class Scheduler:
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
print(f"TC({axis}, {reduce_choice}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
# pick ranges
@ -259,21 +289,27 @@ class Scheduler:
axes[i] = self.rngs[idx]
except KernelOptError: continue
upcast_ranges = []
reduce_ranges = []
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
warp_num = 0
ne: list[UOp] = []
for opt in tc.opts:
if opt[0] == "l":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
warp //= 2
warp = UOp.range(2, -1, warp_num, AxisType.WARP)
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp)
warp_num += 1
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
upcast_ranges.append(new_range)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
ne.append(new_range)
for _, amt in tc.get_reduce_axes():
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
reduce_ranges.append(new_range)
if use_tensor_cores != 2:
# fix the srcs
@ -291,6 +327,12 @@ class Scheduler:
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
print(tc_reduce_axes, tc_upcast_axes)
# DIRECT: get range number from ranges
tc_upcast_axes = (((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),), ((upcast_ranges[0].arg[0], 2),))
tc_reduce_axes = tuple([x.arg[0] for x in reduce_ranges])
#print(tc_reduce_axes, tc_upcast_axes)
# construct the op
# TODO: remove tc_upcast_axes from the arg

View file

@ -345,7 +345,7 @@ def flatten_bufferize(x:UOp):
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret.rtag(x.tag)
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_flatten_bufferize = pm_mops+PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(x, idx, allow_locals=False)),

View file

@ -59,6 +59,7 @@ class Ops(FastEnum):
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
INDEX = auto()
APPENDINDEX = auto()
# BinaryOps
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702

View file

@ -188,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.INDEX | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT | Ops.APPENDINDEX:
return None
# some ops init the shape

View file

@ -233,6 +233,7 @@ full_spec = PatternMatcher([
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),
(UPat(Ops.APPENDINDEX), lambda: True),
# all loads/stores
(UPat((Ops.LOAD, Ops.STORE)), lambda: True),