mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move that
This commit is contained in:
parent
35530b40fb
commit
b2017dec02
2 changed files with 10 additions and 4 deletions
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
from dataclasses import replace
|
||||
|
||||
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
|
||||
from tinygrad.codegen import rewrites_for_views, apply_rewrites
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
|
|
@ -22,7 +23,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
|
|||
# now all input buffers in s[-1] should be realized
|
||||
# create fresh buffers for the outputs
|
||||
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||
return s[-1].ast, bufs
|
||||
return apply_rewrites(s[-1].ast, rewrites_for_views), bufs
|
||||
|
||||
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
|
||||
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ class RewriteStep:
|
|||
|
||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||
|
||||
rewrites_for_views = [
|
||||
RewriteStep(view_left, name="view left"),
|
||||
RewriteStep(view_right, name="view right"),
|
||||
RewriteStep(cleanup_pm, name="cleanup view"),
|
||||
]
|
||||
|
||||
rewrites_for_linearizer = [
|
||||
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
||||
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
||||
|
|
@ -45,9 +51,8 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
ret.append(RewriteStep(view_left, name="view left"))
|
||||
ret.append(RewriteStep(view_right, name="view right"))
|
||||
ret.append(RewriteStep(cleanup_pm, name="cleanup view"))
|
||||
# this used to be in schedule
|
||||
ret.extend(rewrites_for_views)
|
||||
|
||||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue