recursive swizzle with just graph_rewrite [pr] (#7626)

This commit is contained in:
qazal 2024-11-10 20:14:21 +02:00 committed by GitHub
commit a8da84cce0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View file

@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, timeit
@ -1626,13 +1626,14 @@ class TestIndexing(unittest.TestCase):
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
def test_recursive_st_fixup(self):
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
swizzle = ast.src[0].src[2].reshape((4, 1))
new_uop = swizzle_rewrite(swizzle)
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(et, 1e3)
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_strongly_connected_DAG(self):
val = 1.0

View file

@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@ -311,7 +311,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
@ -347,6 +346,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
# *** uop movement ops ***
def view(self, st:ShapeTracker): return self if self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
# *** uop Variable stuff ***
@staticmethod