mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
revert-115
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ede77facbb |
5 changed files with 27 additions and 8 deletions
|
|
@ -2066,6 +2066,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
|
||||
|
||||
# kernels can only have 1 or n in each dim
|
||||
@unittest.expectedFailure
|
||||
def test_dont_parallelize_different_n(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 2, 2).realize()
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ class TestFuse(unittest.TestCase):
|
|||
with Context(NOOPT=1):
|
||||
self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v, atol=1e-5)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mismatch_reduce(self):
|
||||
a = Tensor.ones(16, 10).contiguous().realize()
|
||||
b = Tensor.ones(16, 20).contiguous().realize()
|
||||
|
|
|
|||
|
|
@ -34,6 +34,17 @@ class TestUOpSpec(unittest.TestCase):
|
|||
store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b))
|
||||
helper_test_verify_ast(store)
|
||||
|
||||
def test_exactly_one_full_shape(self):
|
||||
dtype = dtypes.int
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[2].view(ShapeTracker.from_shape((32, 1))),))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[3].view(ShapeTracker.from_shape((32, 1))),))
|
||||
st0 = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), a+b)
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[4].view(ShapeTracker.from_shape((32, 32))),))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[5].view(ShapeTracker.from_shape((32, 32))),))
|
||||
st1 = UOp.store(bufs[1].view(ShapeTracker.from_shape((32, 32))), a+b)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
||||
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),))
|
||||
|
|
|
|||
|
|
@ -179,10 +179,8 @@ class Kernel:
|
|||
self.axis_types.insert(insert_at, new_type)
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_at: insert_at += 1
|
||||
def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:]
|
||||
new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis]
|
||||
def new_shape_fxn(x:tuple[sint, ...]):
|
||||
amt = amount if amount != 0 else x[axis]
|
||||
return x[0:axis] + (((amt,x[axis]//amt) if top else (x[axis]//amt,amt)) if resolve(x[axis] > 1) else (1,1)) + x[axis+1:]
|
||||
self.reshape(new_shape_fxn)
|
||||
self.permute(new_axes)
|
||||
|
||||
|
|
@ -271,11 +269,9 @@ class Kernel:
|
|||
if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
|
||||
elif opt.arg is not None:
|
||||
check(isinstance(opt.arg, int), "arg should be int")
|
||||
amt = cast(int, opt.arg)
|
||||
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
||||
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
||||
if opt.op is not OptOps.PADTO:
|
||||
for st in self.sts: check(st.shape[axis] == 1 or amt == 0 or st.shape[axis] % amt == 0,
|
||||
f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
||||
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
||||
else: amt = -1
|
||||
|
||||
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import cast, Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, resolve
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
|
@ -206,7 +206,17 @@ spec = PatternMatcher([
|
|||
])
|
||||
|
||||
# *** this is the UOp AST spec ***
|
||||
|
||||
def verify_sink_dims(sink:UOp):
|
||||
if not all_same([s.shape for s in sink.src]): return False
|
||||
for dims in zip(*[x.shape for x in sink.toposort() if x.op is Ops.VIEW]):
|
||||
if len(n_dims:={s for s in dims if resolve(s!=1)}) > 1:
|
||||
print(f"# INVALID KERNEL DIMS: can only have 1 or n in each dimension: {n_dims}")
|
||||
return False
|
||||
|
||||
ast_spec = PatternMatcher([
|
||||
# shapes must have either 1 or n in each dimension
|
||||
(UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims),
|
||||
# VIEW can only exist in the edges
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
|
||||
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue