mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
10 commits
master
...
swizzle_in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2017dec02 | ||
|
|
35530b40fb |
||
|
|
259f207af6 |
||
|
|
c632329943 |
||
|
|
9c5e2b0113 | ||
|
|
8a6f0f49ef | ||
|
|
2b90cdb5d6 | ||
|
|
23eca7fc54 | ||
|
|
3f71dd5fd0 | ||
|
|
7116948f34 |
6 changed files with 43 additions and 27 deletions
|
|
@ -3,6 +3,7 @@ import unittest
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
|
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
|
||||||
|
from tinygrad.codegen import rewrites_for_views, apply_rewrites
|
||||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
|
||||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||||
|
|
@ -22,7 +23,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
|
||||||
# now all input buffers in s[-1] should be realized
|
# now all input buffers in s[-1] should be realized
|
||||||
# create fresh buffers for the outputs
|
# create fresh buffers for the outputs
|
||||||
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||||
return s[-1].ast, bufs
|
return apply_rewrites(s[-1].ast, rewrites_for_views), bufs
|
||||||
|
|
||||||
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
|
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
|
||||||
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
|
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
||||||
ReduceContext, correct_load_store, pm_render
|
ReduceContext, correct_load_store, pm_render
|
||||||
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
||||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||||
|
from tinygrad.opt.swizzler import view_left, view_right, cleanup_pm
|
||||||
from tinygrad.opt import pm_optimize
|
from tinygrad.opt import pm_optimize
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -29,6 +30,12 @@ class RewriteStep:
|
||||||
|
|
||||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||||
|
|
||||||
|
rewrites_for_views = [
|
||||||
|
RewriteStep(view_left, name="view left"),
|
||||||
|
RewriteStep(view_right, name="view right"),
|
||||||
|
RewriteStep(cleanup_pm, name="cleanup view"),
|
||||||
|
]
|
||||||
|
|
||||||
rewrites_for_linearizer = [
|
rewrites_for_linearizer = [
|
||||||
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
|
||||||
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
|
||||||
|
|
@ -44,6 +51,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||||
ret: list[RewriteStep] = []
|
ret: list[RewriteStep] = []
|
||||||
|
|
||||||
|
# this used to be in schedule
|
||||||
|
ret.extend(rewrites_for_views)
|
||||||
|
|
||||||
# this is kernel.py
|
# this is kernel.py
|
||||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
|
||||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||||
from tinygrad.schedule.kernelize import view_left
|
from tinygrad.opt.swizzler import view_left
|
||||||
|
|
||||||
class OptOps(Enum):
|
class OptOps(Enum):
|
||||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||||
|
|
@ -73,7 +73,7 @@ class Kernel:
|
||||||
self.sts.append(unwrap(x.src[0].st))
|
self.sts.append(unwrap(x.src[0].st))
|
||||||
|
|
||||||
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge
|
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge
|
||||||
full_shape = ast.full_shape
|
full_shape = self.ast.full_shape
|
||||||
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
|
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
|
||||||
|
|
||||||
# parameters for optimization
|
# parameters for optimization
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||||
from tinygrad.helpers import all_same, prod, unwrap
|
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||||
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||||
|
|
@ -100,3 +101,25 @@ view_right = merge_views+PatternMatcher([
|
||||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def check_load_st(glbl:UOp, view:UOp):
|
||||||
|
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||||
|
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||||
|
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||||
|
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||||
|
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||||
|
# otherwise, it's not fine
|
||||||
|
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||||
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||||
|
|
||||||
|
cleanup_pm = PatternMatcher([
|
||||||
|
# add VIEW to any DEFINE_GLOBAL that somehow lost its view
|
||||||
|
(UPat(Ops.STORE, src=(UPat(Ops.DEFINE_GLOBAL, name="d"),), name="x", allow_any_len=True), lambda d,x: x.replace(src=(d.view(d.st),)+x.src[1:])),
|
||||||
|
# VALID
|
||||||
|
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||||
|
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||||
|
# VIEW on SINK is SINK
|
||||||
|
(UPat(Ops.VIEW, name="v").sink(), lambda v: v.src[0].sink()),
|
||||||
|
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||||
|
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||||
|
])
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
|
||||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
|
||||||
|
|
||||||
# **** Grouper decides which of the UOps realize
|
# **** Grouper decides which of the UOps realize
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.schedule.multi import multi_pm
|
from tinygrad.schedule.multi import multi_pm
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
|
||||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
|
from tinygrad.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||||
|
|
||||||
# creation can recurse a lot
|
# creation can recurse a lot
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -159,29 +158,14 @@ add_buffer_ops = PatternMatcher([
|
||||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||||
# passthrough ASSIGN
|
# passthrough ASSIGN
|
||||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||||
# VALID
|
|
||||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
|
||||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def check_load_st(glbl:UOp, view:UOp):
|
|
||||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
|
||||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
|
||||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
|
||||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
|
||||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
|
||||||
# otherwise, it's not fine
|
|
||||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
|
||||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
|
||||||
|
|
||||||
fix_kernel_ops = PatternMatcher([
|
fix_kernel_ops = PatternMatcher([
|
||||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||||
# no ImageDType after index
|
# no ImageDType after index
|
||||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
|
||||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
replace_globals = PatternMatcher([
|
replace_globals = PatternMatcher([
|
||||||
|
|
@ -195,8 +179,6 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
||||||
# replace global memory ops with the BUFFER they write to
|
# replace global memory ops with the BUFFER they write to
|
||||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||||
# push views to edges
|
|
||||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
|
||||||
# replace buffer with define_global + add load/store last
|
# replace buffer with define_global + add load/store last
|
||||||
bufs = []
|
bufs = []
|
||||||
for s in k.src:
|
for s in k.src:
|
||||||
|
|
@ -204,7 +186,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||||
bufs.append(s)
|
bufs.append(s)
|
||||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue