Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
27b7680e03 prereqs: move views to codegen 2025-08-05 18:49:17 -07:00
4 changed files with 37 additions and 37 deletions

View file

@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.schedule.kernelize import view_left
from tinygrad.opt.swizzler import view_left
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702

View file

@ -1,8 +1,9 @@
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
from tinygrad.helpers import all_same, prod, unwrap
from tinygrad.helpers import all_same, prod, unwrap, colored
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
from tinygrad.dtype import ImageDType, dtypes
merge_views = PatternMatcher([
# merge adjacent views
@ -100,3 +101,33 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# add the LOAD
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])

View file

@ -3,12 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
from tinygrad.opt.swizzler import merge_views, view_left, view_right, fix_kernel_ops, apply_swizzle, swizzle_reduceop
# creation can recurse a lot
import sys
@ -157,36 +156,6 @@ early_buffer_ops = PatternMatcher([
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# add the LOAD
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])
replace_globals = PatternMatcher([
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),

View file

@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
# *** uop symbolic stuff ***