mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
recursive swizzle with just graph_rewrite [pr] (#7626)
This commit is contained in:
parent
7275cfb9d8
commit
a8da84cce0
2 changed files with 11 additions and 6 deletions
|
|
@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
|||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from test.helpers import ast_const, timeit
|
||||
|
|
@ -1626,13 +1626,14 @@ class TestIndexing(unittest.TestCase):
|
|||
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_recursive_st_fixup(self):
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
|
||||
swizzle = ast.src[0].src[2].reshape((4, 1))
|
||||
new_uop = swizzle_rewrite(swizzle)
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertLess(et, 1e3)
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
def test_strongly_connected_DAG(self):
|
||||
val = 1.0
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from dataclasses import dataclass, field
|
|||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
|
|
@ -311,7 +311,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
|
|
@ -347,6 +346,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
def view(self, st:ShapeTracker): return self if self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue