revert that

This commit is contained in:
George Hotz 2026-06-22 18:36:50 -07:00
commit 4424af7bd8
3 changed files with 11 additions and 11 deletions

View file

@ -91,7 +91,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# ** devectorizer (full_graph_rewrite) **
# remove reduce
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
@ -107,7 +107,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
# hreduce
sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
# devectorize
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,

View file

@ -2,7 +2,7 @@ from typing import Any
import itertools, functools
from tinygrad.schedule.rangeify import pm_mops
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
from tinygrad.helpers import all_same, flatten, getenv
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
@ -58,7 +58,7 @@ def do_devectorize(b:UOp):
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape)
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
def new_split_load_store(ls:UOp, midx:UOp):
# extract all the relevant offsets
@ -81,7 +81,7 @@ def new_split_load_store(ls:UOp, midx:UOp):
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.src[0]._mop(Ops.SHRINK, arg=[(lidx.src[1], len(grp))])
# do load
lidx = lidx.load(lidx.vconst_like(0), valid) if not valid else lidx.load()
lidx = lidx.load(lidx.vconst_like(0), valid) if not resolve(valid, False) else lidx.load()
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]:
@ -92,12 +92,12 @@ def new_split_load_store(ls:UOp, midx:UOp):
#from tinygrad.codegen.late.devectorizer import fold_expanded_index
devectorizer2 = pm_mops+PatternMatcher([
# LOAD+INDEX -> INDEX+LOAD
(UPat(Ops.LOAD, src=(UPat.var("buf"),)).index(allow_any_len=True),
lambda buf: buf.index(UOp.const(dtypes.int, 0)).load() if buf.shape == (1,) else None),
#(UPat(Ops.LOAD, src=(UPat.var("buf"),)).index(allow_any_len=True),
# lambda buf: buf.index(UOp.const(dtypes.int, 0)).load() if buf.shape == (1,) else None),
# TODO: support STORE
(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
#(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.STORE}, name="b"), do_devectorize),
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
# const INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX

View file

@ -1047,11 +1047,11 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
if addrspace is AddrSpace.GLOBAL:
ret = UOp(Ops.PARAM, dtype, arg=ParamArg(slot, addrspace=addrspace))
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
else:
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
ret = UOp(Ops.BUFFER, dtype, src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret
def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):