Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
ede77facbb
Revert "fix mismatch reduce (#11547)"
This reverts commit 49d21a9055.
2025-08-06 22:42:34 -07:00
5 changed files with 27 additions and 8 deletions

View file

@ -2066,6 +2066,7 @@ class TestSwizzle(unittest.TestCase):
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
# kernels can only have 1 or n in each dim
@unittest.expectedFailure
def test_dont_parallelize_different_n(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()

View file

@ -111,6 +111,7 @@ class TestFuse(unittest.TestCase):
with Context(NOOPT=1):
self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v, atol=1e-5)
@unittest.expectedFailure
def test_mismatch_reduce(self):
a = Tensor.ones(16, 10).contiguous().realize()
b = Tensor.ones(16, 20).contiguous().realize()

View file

@ -34,6 +34,17 @@ class TestUOpSpec(unittest.TestCase):
store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b))
helper_test_verify_ast(store)
def test_exactly_one_full_shape(self):
dtype = dtypes.int
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
a = UOp(Ops.LOAD, dtype, (bufs[2].view(ShapeTracker.from_shape((32, 1))),))
b = UOp(Ops.LOAD, dtype, (bufs[3].view(ShapeTracker.from_shape((32, 1))),))
st0 = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), a+b)
a = UOp(Ops.LOAD, dtype, (bufs[4].view(ShapeTracker.from_shape((32, 32))),))
b = UOp(Ops.LOAD, dtype, (bufs[5].view(ShapeTracker.from_shape((32, 32))),))
st1 = UOp.store(bufs[1].view(ShapeTracker.from_shape((32, 32))), a+b)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
def test_no_implicit_broadcasting(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),))

View file

@ -179,10 +179,8 @@ class Kernel:
self.axis_types.insert(insert_at, new_type)
move_axis = axis if top else axis+1
if move_axis < insert_at: insert_at += 1
def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:]
new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis]
def new_shape_fxn(x:tuple[sint, ...]):
amt = amount if amount != 0 else x[axis]
return x[0:axis] + (((amt,x[axis]//amt) if top else (x[axis]//amt,amt)) if resolve(x[axis] > 1) else (1,1)) + x[axis+1:]
self.reshape(new_shape_fxn)
self.permute(new_axes)
@ -271,11 +269,9 @@ class Kernel:
if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
elif opt.arg is not None:
check(isinstance(opt.arg, int), "arg should be int")
amt = cast(int, opt.arg)
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
if opt.op is not OptOps.PADTO:
for st in self.sts: check(st.shape[axis] == 1 or amt == 0 or st.shape[axis] % amt == 0,
f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
else: amt = -1
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \

View file

@ -1,5 +1,5 @@
from typing import cast, Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, resolve
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
from tinygrad.shape.shapetracker import ShapeTracker
@ -206,7 +206,17 @@ spec = PatternMatcher([
])
# *** this is the UOp AST spec ***
def verify_sink_dims(sink:UOp):
if not all_same([s.shape for s in sink.src]): return False
for dims in zip(*[x.shape for x in sink.toposort() if x.op is Ops.VIEW]):
if len(n_dims:={s for s in dims if resolve(s!=1)}) > 1:
print(f"# INVALID KERNEL DIMS: can only have 1 or n in each dimension: {n_dims}")
return False
ast_spec = PatternMatcher([
# shapes must have either 1 or n in each dimension
(UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims),
# VIEW can only exist in the edges
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),