mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
minor_view
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27b7680e03 |
4 changed files with 37 additions and 37 deletions
|
|
@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
|
|||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.schedule.kernelize import view_left
|
||||
from tinygrad.opt.swizzler import view_left
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||
from tinygrad.helpers import all_same, prod, unwrap
|
||||
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
|
|
@ -100,3 +101,33 @@ view_right = merge_views+PatternMatcher([
|
|||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||
# otherwise, it's not fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# add the LOAD
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
|||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, fix_kernel_ops, apply_swizzle, swizzle_reduceop
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
|
|
@ -157,36 +156,6 @@ early_buffer_ops = PatternMatcher([
|
|||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||
# otherwise, it's not fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# add the LOAD
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
||||
replace_globals = PatternMatcher([
|
||||
# replace ASSIGN with the target BUFFER
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||
|
|
|
|||
|
|
@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> list[Variable]:
|
||||
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
|
||||
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
||||
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue