mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
20 commits
master
...
codegen_tr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30cc426916 | ||
|
|
c7fd74d523 | ||
|
|
4e398e3f1f | ||
|
|
5238f304c7 | ||
|
|
87e12fad81 | ||
|
|
5eaf02c719 | ||
|
|
4424af7bd8 | ||
|
|
c01d75a651 | ||
|
|
63ec8ad21d |
||
|
|
cccd9c2c03 | ||
|
|
303b6ba14c |
||
|
|
2bf3c48c1b | ||
|
|
ba75b68c12 | ||
|
|
958088cc13 | ||
|
|
257ed03f57 | ||
|
|
15476572cd |
||
|
|
f8949a0de1 | ||
|
|
62c6c75657 | ||
|
|
530aed739d | ||
|
|
da402953d9 |
5 changed files with 194 additions and 20 deletions
|
|
@ -12,11 +12,9 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||||
|
|
||||||
# import all pattern matchers here
|
# import all pattern matchers here
|
||||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
|
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
|
||||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
from tinygrad.codegen.late.devectorizer import load_store_indexing, ReduceContext, pm_render, pm_make_images
|
||||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
|
||||||
ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images
|
|
||||||
from tinygrad.codegen.opt.postrange import apply_opts
|
from tinygrad.codegen.opt.postrange import apply_opts
|
||||||
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
||||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||||
|
|
@ -24,6 +22,8 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
||||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||||
|
|
||||||
|
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local, pm_horizontal_reduce, memory_coalesing
|
||||||
|
|
||||||
pm_index_is_shrink = PatternMatcher([
|
pm_index_is_shrink = PatternMatcher([
|
||||||
# rewrite non-image INDEX to SHRINK
|
# rewrite non-image INDEX to SHRINK
|
||||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
|
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
|
||||||
|
|
@ -81,14 +81,16 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
||||||
|
|
||||||
# expand
|
# expand
|
||||||
sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||||
|
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
|
||||||
|
|
||||||
# add locals
|
# add locals
|
||||||
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||||
|
|
||||||
# ** devectorizer (full_graph_rewrite) **
|
# ** devectorizer (full_graph_rewrite) **
|
||||||
# remove reduce
|
# remove reduce
|
||||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||||
|
sink = graph_rewrite(sink, pm_reduce_local+pm_horizontal_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||||
|
|
||||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||||
|
|
@ -96,15 +98,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# **** optimizations are done, now we lower to actual code ****
|
# **** optimizations are done, now we lower to actual code ****
|
||||||
|
|
||||||
# add loads and remove invalids
|
# add loads and remove invalids
|
||||||
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
#sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
|
||||||
|
sink = graph_rewrite(sink, pm_move_regs, name="** add loads")
|
||||||
|
|
||||||
# create image buffers
|
# create image buffers
|
||||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:
|
||||||
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
||||||
|
|
||||||
|
# hreduce
|
||||||
|
#sink = graph_rewrite(sink, pm_mops+pm_horizontal_reduce, name="hreduce")
|
||||||
|
|
||||||
# devectorize
|
# devectorize
|
||||||
sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
||||||
ctx=ren, name="devectorize")
|
# ctx=ren, name="devectorize")
|
||||||
|
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||||
|
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, ctx=ren, name="devectorize2")
|
||||||
|
|
||||||
# lower the index dtype to a concrete int
|
# lower the index dtype to a concrete int
|
||||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||||
|
|
@ -113,12 +121,21 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
# optional pre matcher
|
# optional pre matcher
|
||||||
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher")
|
||||||
|
|
||||||
|
# dtypes
|
||||||
|
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||||
|
|
||||||
|
# memory coalesing
|
||||||
|
sink = memory_coalesing(sink)
|
||||||
|
|
||||||
|
# again
|
||||||
|
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||||
|
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||||
|
|
||||||
# decompositions
|
# decompositions
|
||||||
supported_ops = tuple(ren.code_for_op.keys())
|
supported_ops = tuple(ren.code_for_op.keys())
|
||||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions")
|
||||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
|
||||||
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||||
|
|
||||||
# GEP/STACK stuff
|
# GEP/STACK stuff
|
||||||
|
|
|
||||||
161
tinygrad/codegen/codegen2.py
Normal file
161
tinygrad/codegen/codegen2.py
Normal file
|
|
@ -0,0 +1,161 @@
|
||||||
|
from typing import Any
|
||||||
|
import itertools, functools
|
||||||
|
from tinygrad.schedule.rangeify import pm_mops
|
||||||
|
from tinygrad.codegen.simplify import pm_flatten_range
|
||||||
|
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, resolve, graph_rewrite
|
||||||
|
from tinygrad.dtype import dtypes, AddrSpace, ImageDType, Invalid
|
||||||
|
from tinygrad.helpers import all_same, flatten, getenv
|
||||||
|
from tinygrad.uop.ops import _align_left, _broadcast_shape, identity_element
|
||||||
|
from tinygrad.codegen.late.devectorizer import ReduceContext
|
||||||
|
from tinygrad.uop.symbolic import pm_clean_up_group_sink
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||||
|
pm_move_regs = PatternMatcher([
|
||||||
|
# BITCAST?
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.REDUCE}, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
|
||||||
|
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||||
|
])
|
||||||
|
|
||||||
|
pm_lower_weakints = PatternMatcher([
|
||||||
|
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def build_range_map(ctx, sink:UOp):
|
||||||
|
for x in sink.toposort():
|
||||||
|
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
|
||||||
|
ctx[x.arg[0]] = len(ctx)
|
||||||
|
|
||||||
|
def fix_reduce(ctx, r:UOp):
|
||||||
|
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
|
||||||
|
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
|
||||||
|
|
||||||
|
expander2 = PatternMatcher([
|
||||||
|
(UPat(Ops.SINK, name="sink"), build_range_map),
|
||||||
|
(UPat(Ops.REDUCE, name="r"), fix_reduce),
|
||||||
|
(UPat(Ops.RANGE, name="r"),
|
||||||
|
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
|
||||||
|
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||||
|
])+pm_flatten_range
|
||||||
|
|
||||||
|
def broadcast_binary(x:UOp):
|
||||||
|
shapes = [u.shape for u in x.src]
|
||||||
|
if all_same(shapes): return None
|
||||||
|
shaped_aligned = _align_left(*shapes)
|
||||||
|
broadcasted = _broadcast_shape(*shapes)
|
||||||
|
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||||
|
return x.replace(src=tuple(src_reshaped))
|
||||||
|
|
||||||
|
unbroadcast = PatternMatcher([
|
||||||
|
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||||
|
])
|
||||||
|
|
||||||
|
def do_devectorize(b:UOp):
|
||||||
|
if b.shape == (): return None
|
||||||
|
# broadcasting needs to be already unpacked
|
||||||
|
if not all_same([x.shape for x in b.src]): return None
|
||||||
|
src = []
|
||||||
|
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||||
|
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||||
|
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||||
|
return UOp.vectorize(*src).reshape(b.shape) if b.op is not Ops.STORE else UOp.group(*src)
|
||||||
|
|
||||||
|
devectorizer2 = pm_mops+PatternMatcher([
|
||||||
|
# unpack broadcasting
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.LOAD,Ops.STORE}, name="b"), do_devectorize),
|
||||||
|
# const INDEX into STACK is src
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||||
|
# stacked INDEX is many INDEX
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
|
||||||
|
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
|
||||||
|
# INDEX into RESHAPE moves the RESHAPE
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
|
||||||
|
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
|
||||||
|
# RESHAPE a void is removed (hack for AFTER)
|
||||||
|
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
|
||||||
|
# reshape of a single element shaped value to scalar is an index
|
||||||
|
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
|
||||||
|
# INDEX without src is nothing
|
||||||
|
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
|
||||||
|
# RESHAPE+EXPAND -> STACK
|
||||||
|
(UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(UPat.var("x"), UPat())), UPat()), name="out"),
|
||||||
|
lambda x,out: UOp.vectorize(*([x]*out.max_numel())) if out.shape == (out.max_numel(),) else None),
|
||||||
|
])
|
||||||
|
|
||||||
|
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||||
|
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
|
||||||
|
ctx.acc_num += 1
|
||||||
|
topo = r.src[0].toposort()
|
||||||
|
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
|
||||||
|
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
|
||||||
|
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
|
||||||
|
acc_initted = acc.after(acc_init, *r.src[1:])
|
||||||
|
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
|
||||||
|
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
|
||||||
|
return acc.after(acc_out)
|
||||||
|
|
||||||
|
def expand_horizontal_reduce(r:UOp):
|
||||||
|
axes = r.arg[1]
|
||||||
|
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
|
||||||
|
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
|
||||||
|
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
|
||||||
|
|
||||||
|
pm_reduce_local = PatternMatcher([
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
|
||||||
|
])+pm_clean_up_group_sink
|
||||||
|
|
||||||
|
pm_horizontal_reduce = PatternMatcher([
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
|
||||||
|
])
|
||||||
|
|
||||||
|
# *** memory coalesing ***
|
||||||
|
|
||||||
|
def memory_coalesing(sink:UOp):
|
||||||
|
if getenv("DMC"): return sink
|
||||||
|
|
||||||
|
# collect
|
||||||
|
memory: defaultdict[tuple[UOp, UOp, UOp], dict[int, list[UOp]]] = defaultdict(dict)
|
||||||
|
for u in sink.toposort():
|
||||||
|
if u.op in {Ops.LOAD, Ops.STORE} and u.src[0].addrspace != AddrSpace.REG:
|
||||||
|
assert u.src[0].op is Ops.INDEX
|
||||||
|
buf,idx_u = u.src[0].src
|
||||||
|
idx: Any = idx_u.src[1] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else idx_u
|
||||||
|
valid: Any = idx_u.src[0] if idx_u.op is Ops.WHERE and idx_u.src[2].arg is Invalid else None
|
||||||
|
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||||
|
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||||
|
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||||
|
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||||
|
else: root_src, arg = idx, 0
|
||||||
|
memory[(u.op, buf, root_src, valid)].setdefault(arg, []).append(u)
|
||||||
|
|
||||||
|
# allowed lengths
|
||||||
|
lengths = [8,4,2,1]
|
||||||
|
|
||||||
|
# build replacements
|
||||||
|
replacements = {}
|
||||||
|
for (op,buf,base,valid),offsets in memory.items():
|
||||||
|
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||||
|
for full_grp in grouped_offsets:
|
||||||
|
while len(full_grp):
|
||||||
|
offset = (base+full_grp[0]) if isinstance(base, UOp) else UOp.const(dtypes.weakint, full_grp[0])
|
||||||
|
length = [l for l in lengths if l <= len(full_grp) and offset.divides(l) is not None][0]
|
||||||
|
grp = full_grp[:length]
|
||||||
|
idx = buf._mop(Ops.SHRINK, arg=[(offset, len(grp))]) if len(grp) > 1 else buf.index(offset)
|
||||||
|
if op is Ops.STORE:
|
||||||
|
datas = []
|
||||||
|
for i,g in enumerate(grp):
|
||||||
|
assert len(offsets[g]) == 1
|
||||||
|
datas.append(offsets[g][0].src[1])
|
||||||
|
data = UOp.vectorize(*datas) if len(datas) > 1 else datas[0]
|
||||||
|
store = idx.store(data, valid) if valid is not None else idx.store(data)
|
||||||
|
for i,g in enumerate(grp): replacements[offsets[g][0]] = store
|
||||||
|
else:
|
||||||
|
ld = idx.load(idx.vconst_like(0), valid) if valid is not None else idx.load()
|
||||||
|
for i,g in enumerate(grp):
|
||||||
|
for oo in offsets[g]:
|
||||||
|
replacements[oo] = ld.index(UOp.const(dtypes.int, i)) if len(grp) > 1 else ld
|
||||||
|
full_grp = full_grp[length:]
|
||||||
|
|
||||||
|
# apply
|
||||||
|
return sink.substitute(replacements, name="memory coalesing")
|
||||||
|
|
||||||
|
|
@ -275,7 +275,7 @@ SCACHE = ContextVar("SCACHE", 1)
|
||||||
# allow use of atomics for embedding backward
|
# allow use of atomics for embedding backward
|
||||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||||
# don't allow broadcast
|
# don't allow broadcast
|
||||||
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
|
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Metadata:
|
class Metadata:
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
|
|
||||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||||
if len(arg) == 0: return UOp(Ops.STACK)
|
if len(arg) == 0: return UOp(Ops.STACK)
|
||||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
elif len(arg) == 1: return UOp.const(dtypes.weakint, arg[0])
|
||||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||||
|
|
||||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||||
ret: dict[UOp, dict[UOp, None]] = {}
|
ret: dict[UOp, dict[UOp, None]] = {}
|
||||||
|
|
@ -305,10 +305,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
case Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.LOAD | \
|
||||||
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
# REDUCE with empty axis is passthrough (lowered form)
|
|
||||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
|
||||||
# these can mismatch if there's a horizonal reduce
|
|
||||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
|
||||||
|
|
||||||
# TODO: disallow shape changing bitcast
|
# TODO: disallow shape changing bitcast
|
||||||
case Ops.BITCAST:
|
case Ops.BITCAST:
|
||||||
|
|
@ -473,7 +469,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||||
def vectorize(self, *srcs):
|
def vectorize(self, *srcs):
|
||||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
|
||||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
|
|
||||||
|
|
@ -192,8 +192,8 @@ spec_program = PatternMatcher([
|
||||||
# no more of these in programs
|
# no more of these in programs
|
||||||
(UPat(Ops.GEP), lambda: False),
|
(UPat(Ops.GEP), lambda: False),
|
||||||
|
|
||||||
# weakint is not allowed in programs
|
# weakint is not allowed in programs, except on CONST and STACK
|
||||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
(UPat(GroupOp.All-{Ops.CONST,Ops.STACK}, dtypes.weakint), lambda: False),
|
||||||
|
|
||||||
# allow special SHRINK
|
# allow special SHRINK
|
||||||
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue