mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
9 commits
master
...
new_expand
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4c17ce251 |
||
|
|
c64c598d43 |
||
|
|
90efe50594 |
||
|
|
32fe500814 | ||
|
|
3412382cae | ||
|
|
7c585be215 |
||
|
|
babce51ce6 |
||
|
|
b6d0e2edd9 | ||
|
|
2815f7cb90 |
6 changed files with 103 additions and 11 deletions
|
|
@ -12,6 +12,7 @@ from tinygrad.codegen.gpudims import pm_add_gpudims
|
|||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.expander2 import expander2, devectorizer2, unbroadcast
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
|
|
@ -50,7 +51,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
|
|||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
sink = graph_rewrite(sink, expander2, name="expander2", ctx={}, bottom_up=True)
|
||||
#sink = graph_rewrite(sink, expander_broadcast, name="fix broadcast")
|
||||
|
||||
# add locals
|
||||
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||
|
|
@ -62,6 +65,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
|
|||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||
|
||||
sink = graph_rewrite(sink, unbroadcast, name="unbroadcast2")
|
||||
|
||||
sink = graph_rewrite(sink, devectorizer2, name="devectorize2")
|
||||
|
||||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads
|
||||
|
|
@ -71,10 +78,12 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, b
|
|||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
"""
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
"""
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
|
|
|||
62
tinygrad/codegen/late/expander2.py
Normal file
62
tinygrad/codegen/late/expander2.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import itertools
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, AxisType, UOp, GroupOp, _align_left, _broadcast_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.schedule.rangeify import pm_index_mops
|
||||
|
||||
def build_range_map(ctx, sink:UOp):
|
||||
for x in sink.toposort():
|
||||
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
|
||||
ctx[x.arg[0]] = len(ctx)
|
||||
|
||||
expander2 = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), build_range_map),
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
lambda ctx, r: UOp(Ops.VCONST, r.dtype, arg=tuple(range(r.vmax+1))) \
|
||||
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||
])+pm_flatten_range
|
||||
|
||||
def broadcast_binary(x:UOp):
|
||||
shapes = [u.shape for u in x.src]
|
||||
print(x.op, shapes)
|
||||
if all_same(shapes): return None
|
||||
shaped_aligned = _align_left(*shapes)
|
||||
broadcasted = _broadcast_shape(*shapes)
|
||||
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||
return x.replace(src=tuple(src_reshaped))
|
||||
|
||||
unbroadcast = PatternMatcher([
|
||||
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||
])
|
||||
|
||||
def do_devectorize(b:UOp):
|
||||
if b.shape == (): return None
|
||||
# broadcasting needs to be already unpacked
|
||||
if not all_same([x.shape for x in b.src]): return None
|
||||
src = []
|
||||
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||
return UOp.cat(*src)
|
||||
|
||||
devectorizer2 = pm_index_mops+PatternMatcher([
|
||||
# INDEX with one src is a noop
|
||||
(UPat(Ops.INDEX, src=(UPat.var("x"),)), lambda x: x),
|
||||
# INDEX into VCONST is CONST
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VCONST, name="a"), UPat.cvar("i", vec=False))),
|
||||
lambda a,i: UOp.const(a.dtype, a.arg[i.arg])),
|
||||
# INDEX into CAT is src
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CAT, name="a"), UPat.cvar("i", vec=False))),
|
||||
lambda a,i: a.src[i.arg] if a.arg == -1 else None),
|
||||
|
||||
# cat goes through index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat(Ops.CAT, name="c"))),
|
||||
lambda a,c: UOp.cat(*[a.index(x) for x in c.src])),
|
||||
|
||||
# cat on store is group (TODO: do we need group?)
|
||||
(UPat(Ops.CAT, src=UPat(Ops.STORE), name="x"), lambda x: UOp.group(*x.src)),
|
||||
|
||||
# unpack broadcasting
|
||||
(UPat(GroupOp.Elementwise|{Ops.STORE}, name="b"), do_devectorize),
|
||||
])
|
||||
|
|
@ -63,12 +63,15 @@ pm_fold_moved_after = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = PatternMatcher([
|
||||
pm_index_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
|
||||
])
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
# TODO: clean up .src[0]._shape is not None
|
||||
pm_mops = pm_index_mops+PatternMatcher([
|
||||
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ class Ops(FastEnum):
|
|||
# the core 6 movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
CAT = auto() # see CAT in spec
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
|
|
|
|||
|
|
@ -233,17 +233,26 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
shp = []
|
||||
for s in self.src[1:]: shp.extend(list(s.shape))
|
||||
return tuple(shp)
|
||||
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
#if not isinstance(self.dtype, PtrDType): return None
|
||||
# fully indexed doesn't have a shape. TODO: remove this
|
||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
#if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
# pointer index
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
#return self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
case Ops.CAT:
|
||||
if self.arg == -1:
|
||||
assert all_same([x.shape for x in self.src])
|
||||
return (len(self.src),)+self.src[0].shape
|
||||
# TODO: write the non arg=-1 path
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
# TODO: VCONST should have the shape of the arg
|
||||
case Ops.VCONST: return ()
|
||||
case Ops.VCONST: return (len(self.arg),)
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
|
|
@ -311,6 +320,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
raise ValueError(f"invalid type for axis: {axis_arg}")
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# broadcasting here
|
||||
# TODO: STORE can only broadcast a smaller src[1] into a larger src[0]
|
||||
if self.op in GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}:
|
||||
return _broadcast_shape(*[u.shape for u in self.src])
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
|
|
@ -410,6 +424,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
def cat(*srcs:UOp, axis=-1): # pylint: disable=no-self-argument
|
||||
assert len(srcs) >= 1 and all_same([x.dtype for x in srcs])
|
||||
return UOp(Ops.CAT, srcs[0].dtype, src=tuple(srcs), arg=axis)
|
||||
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from tinygrad.dtype import dtypes
|
|||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff", Ops.CAT: "#D8F9E4",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
|
|
@ -116,7 +116,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
#if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue