mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
user_failu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a3d451ee4 |
1 changed files with 10 additions and 0 deletions
|
|
@ -1,11 +1,21 @@
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
import unittest
|
import unittest
|
||||||
|
from tinygrad import Tensor
|
||||||
from tinygrad.uop.ops import UOp, Ops, AxisType
|
from tinygrad.uop.ops import UOp, Ops, AxisType
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.engine.realize import get_program
|
from tinygrad.engine.realize import get_program
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
|
|
||||||
class TestLinearizerFailures(unittest.TestCase):
|
class TestLinearizerFailures(unittest.TestCase):
|
||||||
|
def test_cumsum_repeat_reshape_multiply(self):
|
||||||
|
# cumsum + repeat + reshape + multiply fails when step > 512
|
||||||
|
step, num_steps = 513, 10
|
||||||
|
t = Tensor.arange(step).float()
|
||||||
|
phase = t.cumsum()
|
||||||
|
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
|
||||||
|
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
|
||||||
|
result = (tiled * pattern).flatten()
|
||||||
|
result.numpy() # should not raise AssertionError in CFGContext
|
||||||
def test_fail_1(self):
|
def test_fail_1(self):
|
||||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0, src=())
|
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=0, src=())
|
||||||
c1 = UOp.range(UOp.const(dtypes.index, 2), 1, AxisType.LOOP)
|
c1 = UOp.range(UOp.const(dtypes.index, 2), 1, AxisType.LOOP)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue