mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
view_in_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fcdd2a480 | ||
|
|
00d33d706e |
5 changed files with 43 additions and 41 deletions
|
|
@ -17,6 +17,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
||||||
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
||||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||||
from tinygrad.opt import pm_optimize
|
from tinygrad.opt import pm_optimize
|
||||||
|
from tinygrad.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RewriteStep:
|
class RewriteStep:
|
||||||
|
|
@ -44,6 +45,11 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||||
ret: list[RewriteStep] = []
|
ret: list[RewriteStep] = []
|
||||||
|
|
||||||
|
# TODO: move these to codegen
|
||||||
|
ret.append(RewriteStep(view_left, name="Main View Left"))
|
||||||
|
ret.append(RewriteStep(view_right, name="Main View Right"))
|
||||||
|
ret.append(RewriteStep(view_left+fix_kernel_ops, bottom_up=True, name="replace buffer"))
|
||||||
|
|
||||||
# this is kernel.py
|
# this is kernel.py
|
||||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
|
||||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||||
from tinygrad.schedule.kernelize import view_left
|
from tinygrad.opt.swizzler import view_left
|
||||||
|
|
||||||
class OptOps(Enum):
|
class OptOps(Enum):
|
||||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||||
from tinygrad.helpers import all_same, prod, unwrap
|
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||||
|
from tinygrad.dtype import ImageDType, dtypes
|
||||||
|
|
||||||
merge_views = PatternMatcher([
|
merge_views = PatternMatcher([
|
||||||
# merge adjacent views
|
# merge adjacent views
|
||||||
|
|
@ -100,3 +101,33 @@ view_right = merge_views+PatternMatcher([
|
||||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def check_load_st(glbl:UOp, view:UOp):
|
||||||
|
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||||
|
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||||
|
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||||
|
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||||
|
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||||
|
# otherwise, it's not fine
|
||||||
|
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||||
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||||
|
|
||||||
|
fix_kernel_ops = PatternMatcher([
|
||||||
|
# add the LOAD
|
||||||
|
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
||||||
|
# STORE (except for meta ops)
|
||||||
|
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||||
|
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||||
|
# passthrough ASSIGN
|
||||||
|
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||||
|
# VALID
|
||||||
|
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||||
|
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||||
|
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||||
|
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||||
|
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||||
|
# no ImageDType after index
|
||||||
|
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||||
|
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||||
|
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||||
|
])
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.schedule.multi import multi_pm
|
from tinygrad.schedule.multi import multi_pm
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
|
||||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
|
from tinygrad.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||||
|
|
||||||
# creation can recurse a lot
|
# creation can recurse a lot
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -157,36 +156,6 @@ early_buffer_ops = PatternMatcher([
|
||||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||||
])
|
])
|
||||||
|
|
||||||
def check_load_st(glbl:UOp, view:UOp):
|
|
||||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
|
||||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
|
||||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
|
||||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
|
||||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
|
||||||
# otherwise, it's not fine
|
|
||||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
|
||||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
|
||||||
|
|
||||||
fix_kernel_ops = PatternMatcher([
|
|
||||||
# add the LOAD
|
|
||||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
|
||||||
# STORE (except for meta ops)
|
|
||||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
|
||||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
|
||||||
# passthrough ASSIGN
|
|
||||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
|
||||||
# VALID
|
|
||||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
|
||||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
|
||||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
|
||||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
|
||||||
# no ImageDType after index
|
|
||||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
|
||||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
|
||||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
|
||||||
])
|
|
||||||
|
|
||||||
replace_globals = PatternMatcher([
|
replace_globals = PatternMatcher([
|
||||||
# replace ASSIGN with the target BUFFER
|
# replace ASSIGN with the target BUFFER
|
||||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||||
|
|
@ -208,10 +177,6 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||||
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
|
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
|
||||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||||
# TODO: move these to codegen
|
|
||||||
ast = graph_rewrite(ast, view_left, name="Main View Left")
|
|
||||||
ast = graph_rewrite(ast, view_right, name="Main View Right")
|
|
||||||
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
|
|
||||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||||
|
|
||||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||||
|
|
|
||||||
|
|
@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||||
def variables(self) -> list[Variable]:
|
def variables(self) -> list[Variable]:
|
||||||
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
|
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
||||||
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
||||||
|
|
||||||
# *** uop symbolic stuff ***
|
# *** uop symbolic stuff ***
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue