mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
revert that
This commit is contained in:
parent
c01d75a651
commit
4424af7bd8
3 changed files with 11 additions and 11 deletions
|
|
@ -91,7 +91,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
|
||||
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||
|
|
@ -107,7 +107,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
||||
|
||||
# hreduce
|
||||
sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
|
||||
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
|
||||
|
||||
# devectorize
|
||||
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Any
|
|||
import itertools, functools
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve
|
||||
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
|
||||
from tinygrad.helpers import all_same, flatten, getenv
|
||||
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
|
||||
|
|
@ -58,7 +58,7 @@ def do_devectorize(b:UOp):
|
|||
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||
return UOp.vectorize(*src).reshape(b.shape)
|
||||
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
|
||||
|
||||
def new_split_load_store(ls:UOp, midx:UOp):
|
||||
# extract all the relevant offsets
|
||||
|
|
@ -81,7 +81,7 @@ def new_split_load_store(ls:UOp, midx:UOp):
|
|||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.src[0]._mop(Ops.SHRINK, arg=[(lidx.src[1], len(grp))])
|
||||
# do load
|
||||
lidx = lidx.load(lidx.vconst_like(0), valid) if not valid else lidx.load()
|
||||
lidx = lidx.load(lidx.vconst_like(0), valid) if not resolve(valid, False) else lidx.load()
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]:
|
||||
|
|
@ -92,12 +92,12 @@ def new_split_load_store(ls:UOp, midx:UOp):
|
|||
#from tinygrad.codegen.late.devectorizer import fold_expanded_index
|
||||
devectorizer2 = pm_mops+PatternMatcher([
|
||||
# LOAD+INDEX -> INDEX+LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf"),)).index(allow_any_len=True),
|
||||
lambda buf: buf.index(UOp.const(dtypes.int, 0)).load() if buf.shape == (1,) else None),
|
||||
#(UPat(Ops.LOAD, src=(UPat.var("buf"),)).index(allow_any_len=True),
|
||||
# lambda buf: buf.index(UOp.const(dtypes.int, 0)).load() if buf.shape == (1,) else None),
|
||||
# TODO: support STORE
|
||||
(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
|
||||
#(UPat((Ops.LOAD,), src=(UPat(Ops.STACK, src=UPat(Ops.INDEX), name="midx"),), name="ls", allow_any_len=True), new_split_load_store),
|
||||
# unpack broadcasting
|
||||
(UPat(GroupOp.Elementwise|{Ops.STORE}, name="b"), do_devectorize),
|
||||
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
|
||||
# const INDEX into STACK is src
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||
# stacked INDEX is many INDEX
|
||||
|
|
|
|||
|
|
@ -1047,11 +1047,11 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
@staticmethod
|
||||
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
if addrspace is AddrSpace.GLOBAL:
|
||||
ret = UOp(Ops.PARAM, dtype, arg=ParamArg(slot, addrspace=addrspace))
|
||||
ret = UOp(Ops.PARAM, dtype.ptr(prod(shape), addrspace), arg=ParamArg(slot, addrspace=addrspace))
|
||||
else:
|
||||
assert addrspace in (AddrSpace.LOCAL, AddrSpace.REG)
|
||||
buf_shape = (prod(shape),) + ((dtype.count,) if dtype.count > 1 else ())
|
||||
ret = UOp(Ops.BUFFER, dtype, src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
||||
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
|
||||
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
|
||||
return ret
|
||||
def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue