Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
8612385ccb add all codegen stages to spec_tensor 2026-05-12 10:23:03 -07:00
2 changed files with 6 additions and 9 deletions

View file

@ -341,7 +341,6 @@ class TestCustomKernel(unittest.TestCase):
self.assertEqual(y.tolist(), [1, 2, 3, 4])
@Context(DEV="CPU")
@unittest.expectedFailure
def test_simple_from_source(self):
a = Tensor([0., 1., 2.]).realize()

View file

@ -179,9 +179,14 @@ spec_tensor = PatternMatcher([
# TODO: this should not be here. STAGE is transformed to DEFINE_LOCAL later
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
# LINEAR
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# UNROLL/CONTRACT is used here for WMMA
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
@ -211,13 +216,6 @@ spec_full = PatternMatcher([
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),