move that

This commit is contained in:
George Hotz 2025-08-05 16:23:33 -07:00
commit b2017dec02
2 changed files with 10 additions and 4 deletions

View file

@ -3,6 +3,7 @@ import unittest
from dataclasses import replace
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
from tinygrad.codegen import rewrites_for_views, apply_rewrites
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
from tinygrad.device import Device, Buffer, is_dtype_supported
@ -22,7 +23,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
# now all input buffers in s[-1] should be realized
# create fresh buffers for the outputs
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return s[-1].ast, bufs
return apply_rewrites(s[-1].ast, rewrites_for_views), bufs
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)

View file

@ -30,6 +30,12 @@ class RewriteStep:
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
rewrites_for_views = [
RewriteStep(view_left, name="view left"),
RewriteStep(view_right, name="view right"),
RewriteStep(cleanup_pm, name="cleanup view"),
]
rewrites_for_linearizer = [
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
@ -45,9 +51,8 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
ret.append(RewriteStep(view_left, name="view left"))
ret.append(RewriteStep(view_right, name="view right"))
ret.append(RewriteStep(cleanup_pm, name="cleanup view"))
# this used to be in schedule
ret.extend(rewrites_for_views)
# this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))