Compare commits

...

10 commits

Author SHA1 Message Date
George Hotz
b2017dec02 move that 2025-08-05 16:23:41 -07:00
George Hotz
35530b40fb
Merge branch 'master' into swizzle_in_kernel 2025-08-05 15:59:06 -07:00
George Hotz
259f207af6
Merge branch 'master' into swizzle_in_kernel 2025-08-05 15:53:07 -07:00
George Hotz
c632329943
Merge branch 'master' into swizzle_in_kernel 2025-08-05 15:34:21 -07:00
George Hotz
9c5e2b0113 move to codegen 2025-08-05 15:04:03 -07:00
George Hotz
8a6f0f49ef bugfixes 2025-08-05 14:48:51 -07:00
George Hotz
2b90cdb5d6 fix variables 2025-08-05 14:44:02 -07:00
George Hotz
23eca7fc54 move that 2025-08-05 14:05:05 -07:00
George Hotz
3f71dd5fd0 cleanup after 2025-08-05 13:59:23 -07:00
George Hotz
7116948f34 move swizzle to kernel 2025-08-05 13:19:24 -07:00
6 changed files with 43 additions and 27 deletions

View file

@ -3,6 +3,7 @@ import unittest
from dataclasses import replace from dataclasses import replace
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType from tinygrad.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
from tinygrad.codegen import rewrites_for_views, apply_rewrites
from tinygrad.codegen.gpudims import get_grouped_dims from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.device import Device, Buffer, is_dtype_supported
@ -22,7 +23,7 @@ def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
# now all input buffers in s[-1] should be realized # now all input buffers in s[-1] should be realized
# create fresh buffers for the outputs # create fresh buffers for the outputs
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)] bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return s[-1].ast, bufs return apply_rewrites(s[-1].ast, rewrites_for_views), bufs
def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1): def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, use_tensor_cores:int=1):
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in) a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)

View file

@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
ReduceContext, correct_load_store, pm_render ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.optional import get_late_rewrite_patterns from tinygrad.codegen.optional import get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.opt.swizzler import view_left, view_right, cleanup_pm
from tinygrad.opt import pm_optimize from tinygrad.opt import pm_optimize
@dataclass @dataclass
@ -29,6 +30,12 @@ class RewriteStep:
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink) def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
rewrites_for_views = [
RewriteStep(view_left, name="view left"),
RewriteStep(view_right, name="view right"),
RewriteStep(cleanup_pm, name="cleanup view"),
]
rewrites_for_linearizer = [ rewrites_for_linearizer = [
RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True), RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True),
RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"), RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"),
@ -44,6 +51,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
# ** lowerer (rewrite_shapetracker_with_index) ** # ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = [] ret: list[RewriteStep] = []
# this used to be in schedule
ret.extend(rewrites_for_views)
# this is kernel.py # this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast")) ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))

View file

@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.schedule.kernelize import view_left from tinygrad.opt.swizzler import view_left
class OptOps(Enum): class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
@ -73,7 +73,7 @@ class Kernel:
self.sts.append(unwrap(x.src[0].st)) self.sts.append(unwrap(x.src[0].st))
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge # add a shapetracker to the end to track the full shape, with 0 strides so it can merge
full_shape = ast.full_shape full_shape = self.ast.full_shape
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape))) self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
# parameters for optimization # parameters for optimization

View file

@ -1,5 +1,6 @@
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
from tinygrad.helpers import all_same, prod, unwrap from tinygrad.helpers import all_same, prod, unwrap, colored
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
@ -100,3 +101,25 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None), lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
]) ])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
cleanup_pm = PatternMatcher([
# add VIEW to any DEFINE_GLOBAL that somehow lost its view
(UPat(Ops.STORE, src=(UPat(Ops.DEFINE_GLOBAL, name="d"),), name="x", allow_any_len=True), lambda d,x: x.replace(src=(d.view(d.st),)+x.src[1:])),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# VIEW on SINK is SINK
(UPat(Ops.VIEW, name="v").sink(), lambda v: v.src[0].sink()),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])

View file

@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK} Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
# **** Grouper decides which of the UOps realize # **** Grouper decides which of the UOps realize

View file

@ -3,12 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
from tinygrad.uop.ops import track_rewrites, _substitute from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType, dtypes from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop from tinygrad.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
# creation can recurse a lot # creation can recurse a lot
import sys import sys
@ -159,29 +158,14 @@ add_buffer_ops = PatternMatcher([
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])), UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN # passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]), (UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
]) ])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([ fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE from kernel AST # remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x), (UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# no ImageDType after index # no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
]) ])
replace_globals = PatternMatcher([ replace_globals = PatternMatcher([
@ -195,8 +179,6 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace global memory ops with the BUFFER they write to # replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals") ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last # replace buffer with define_global + add load/store last
bufs = [] bufs = []
for s in k.src: for s in k.src:
@ -204,7 +186,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only # traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0] while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s) bufs.append(s)
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer") ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]): if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
return k.replace(arg=Kernel(ast, k.arg.metadata)) return k.replace(arg=Kernel(ast, k.arg.metadata))