mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
new codegen, try 2
This commit is contained in:
parent
30830850a9
commit
da402953d9
4 changed files with 111 additions and 6 deletions
|
|
@ -24,6 +24,8 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
|||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
|
||||
from tinygrad.codegen.codegen2 import expander2, pm_move_regs
|
||||
|
||||
pm_index_is_shrink = PatternMatcher([
|
||||
# rewrite non-image INDEX to SHRINK
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
|
||||
|
|
@ -81,7 +83,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
|
||||
|
||||
# add locals
|
||||
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||
|
|
@ -96,7 +99,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# **** optimizations are done, now we lower to actual code ****
|
||||
|
||||
# add loads and remove invalids
|
||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||
#sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||
|
|
|
|||
101
tinygrad/codegen/codegen2.py
Normal file
101
tinygrad/codegen/codegen2.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import itertools, functools
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.helpers import all_same, flatten
|
||||
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
|
||||
from tinygrad.codegen.late.devectorizer import ReduceContext
|
||||
from tinygrad.uop.symbolic import pm_clean_up_group_sink
|
||||
|
||||
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||
pm_move_regs = PatternMatcher([
|
||||
# BITCAST?
|
||||
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
|
||||
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||
])
|
||||
|
||||
pm_lower_weakints = PatternMatcher([
|
||||
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
|
||||
])
|
||||
|
||||
def build_range_map(ctx, sink:UOp):
|
||||
for x in sink.toposort():
|
||||
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
|
||||
ctx[x.arg[0]] = len(ctx)
|
||||
|
||||
def fix_reduce(ctx, r:UOp):
|
||||
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
|
||||
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
|
||||
|
||||
expander2 = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), build_range_map),
|
||||
(UPat(Ops.REDUCE, name="r"), fix_reduce),
|
||||
(UPat(Ops.RANGE, name="r"),
|
||||
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
|
||||
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||
])+pm_flatten_range
|
||||
|
||||
def broadcast_binary(x:UOp):
|
||||
shapes = [u.shape for u in x.src]
|
||||
if all_same(shapes): return None
|
||||
shaped_aligned = _align_left(*shapes)
|
||||
broadcasted = _broadcast_shape(*shapes)
|
||||
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||
return x.replace(src=tuple(src_reshaped))
|
||||
|
||||
unbroadcast = PatternMatcher([
|
||||
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||
])
|
||||
|
||||
def do_devectorize(b:UOp):
|
||||
if b.shape == (): return None
|
||||
# broadcasting needs to be already unpacked
|
||||
if not all_same([x.shape for x in b.src]): return None
|
||||
src = []
|
||||
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||
return UOp.vectorize(*src).reshape(b.shape)
|
||||
|
||||
devectorizer2 = pm_mops+PatternMatcher([
|
||||
# unpack broadcasting
|
||||
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
|
||||
# INDEX into STACK is src
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||
# stacked INDEX is many INDEX
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
|
||||
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
|
||||
# INDEX into RESHAPE moves the RESHAPE
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
|
||||
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
|
||||
# RESHAPE a void is removed (hack for AFTER)
|
||||
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
|
||||
# reshape of a single element shaped value to scalar is an index
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
|
||||
# INDEX without src is nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
|
||||
])
|
||||
|
||||
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
|
||||
ctx.acc_num += 1
|
||||
topo = r.src[0].toposort()
|
||||
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
|
||||
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
|
||||
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
|
||||
acc_initted = acc.after(acc_init, *r.src[1:])
|
||||
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
|
||||
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
|
||||
return acc.after(acc_out)
|
||||
|
||||
def expand_horizontal_reduce(r:UOp):
|
||||
axes = r.arg[1]
|
||||
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
|
||||
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
|
||||
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
|
||||
|
||||
pm_reduce_local = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
|
||||
])+pm_clean_up_group_sink
|
||||
|
|
@ -274,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
|
|||
# allow use of atomics for embedding backward
|
||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||
# don't allow broadcast
|
||||
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
|
||||
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
|||
|
||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||
if len(arg) == 0: return UOp(Ops.STACK)
|
||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||
elif all_int(arg): return UOp.const(dtypes.weakint, arg)
|
||||
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||
|
||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
|
|
@ -473,7 +473,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def vectorize(self, *srcs):
|
||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
||||
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue