mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
39 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d319b5f614 |
||
|
|
0decf136fe |
||
|
|
a96db15419 | ||
|
|
b0f7e2a5e1 | ||
|
|
566fe4c7dc |
||
|
|
c2f73b102e |
||
|
|
c6f195f410 | ||
|
|
f26378264b |
||
|
|
1168ed9730 |
||
|
|
017edbbbb5 | ||
|
|
daa72812b0 | ||
|
|
fd325d662c |
||
|
|
8d36539656 | ||
|
|
db2c71536b | ||
|
|
4112b34a32 | ||
|
|
1ad72dff08 | ||
|
|
6f1eaa8d46 | ||
|
|
35d2882991 | ||
|
|
a31732d819 |
||
|
|
43d62c4211 | ||
|
|
4d0429090c | ||
|
|
2c7a1450e7 | ||
|
|
6ffb55cc74 |
||
|
|
1a280829ca |
||
|
|
3b426b1072 | ||
|
|
ce2cdc3708 |
||
|
|
333f062eee | ||
|
|
0d5bf3ca6d | ||
|
|
56bad940df | ||
|
|
f98deb9250 | ||
|
|
bdfcb1cb98 | ||
|
|
a6fdb53a1e |
||
|
|
49deb9714b | ||
|
|
afab220947 |
||
|
|
a7523b2596 | ||
|
|
21806848df | ||
|
|
6fda6c704d |
||
|
|
3f7ec187df | ||
|
|
af9284e9b1 |
5 changed files with 200 additions and 12 deletions
|
|
@ -1,14 +1,16 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
import itertools
|
import itertools
|
||||||
|
import functools
|
||||||
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
|
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
|
||||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
||||||
|
from tinygrad.uop.ops import AxisType, _align_left, _broadcast_shape, identity_element
|
||||||
from tinygrad.uop.render import pyrender
|
from tinygrad.uop.render import pyrender
|
||||||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||||
from tinygrad.renderer import Renderer, Estimates
|
from tinygrad.renderer import Renderer, Estimates
|
||||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||||
|
|
||||||
# import all pattern matchers here
|
# import all pattern matchers here
|
||||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||||
|
|
@ -43,6 +45,98 @@ pm_remove_vec_dtypes = PatternMatcher([
|
||||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||||
])+pm_clean_up_group_sink
|
])+pm_clean_up_group_sink
|
||||||
|
|
||||||
|
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
|
||||||
|
pm_move_regs = PatternMatcher([
|
||||||
|
# BITCAST?
|
||||||
|
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
|
||||||
|
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
|
||||||
|
])
|
||||||
|
|
||||||
|
pm_lower_weakints = PatternMatcher([
|
||||||
|
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def build_range_map(ctx, sink:UOp):
|
||||||
|
for x in sink.toposort():
|
||||||
|
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
|
||||||
|
ctx[x.arg[0]] = len(ctx)
|
||||||
|
|
||||||
|
def fix_reduce(ctx, r:UOp):
|
||||||
|
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
|
||||||
|
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
|
||||||
|
|
||||||
|
expander2 = PatternMatcher([
|
||||||
|
(UPat(Ops.SINK, name="sink"), build_range_map),
|
||||||
|
(UPat(Ops.REDUCE, name="r"), fix_reduce),
|
||||||
|
(UPat(Ops.RANGE, name="r"),
|
||||||
|
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
|
||||||
|
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
|
||||||
|
])+pm_flatten_range
|
||||||
|
|
||||||
|
def broadcast_binary(x:UOp):
|
||||||
|
shapes = [u.shape for u in x.src]
|
||||||
|
if all_same(shapes): return None
|
||||||
|
shaped_aligned = _align_left(*shapes)
|
||||||
|
broadcasted = _broadcast_shape(*shapes)
|
||||||
|
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
|
||||||
|
return x.replace(src=tuple(src_reshaped))
|
||||||
|
|
||||||
|
unbroadcast = PatternMatcher([
|
||||||
|
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
|
||||||
|
])
|
||||||
|
|
||||||
|
def do_devectorize(b:UOp):
|
||||||
|
if b.shape == (): return None
|
||||||
|
# broadcasting needs to be already unpacked
|
||||||
|
if not all_same([x.shape for x in b.src]): return None
|
||||||
|
src = []
|
||||||
|
for idx in itertools.product(*[range(x) for x in b.shape]):
|
||||||
|
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
|
||||||
|
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
|
||||||
|
return UOp.vectorize(*src).reshape(b.shape)
|
||||||
|
|
||||||
|
devectorizer2 = pm_mops+PatternMatcher([
|
||||||
|
# unpack broadcasting
|
||||||
|
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
|
||||||
|
# INDEX into STACK is src
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
|
||||||
|
# stacked INDEX is many INDEX
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
|
||||||
|
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
|
||||||
|
# INDEX into RESHAPE moves the RESHAPE
|
||||||
|
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
|
||||||
|
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
|
||||||
|
# RESHAPE a void is removed (hack for AFTER)
|
||||||
|
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
|
||||||
|
# reshape of a single element shaped value to scalar is an index
|
||||||
|
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
|
||||||
|
# INDEX without src is nothing
|
||||||
|
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
|
||||||
|
])
|
||||||
|
|
||||||
|
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||||
|
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
|
||||||
|
ctx.acc_num += 1
|
||||||
|
topo = r.src[0].toposort()
|
||||||
|
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
|
||||||
|
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
|
||||||
|
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
|
||||||
|
acc_initted = acc.after(acc_init, *r.src[1:])
|
||||||
|
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
|
||||||
|
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
|
||||||
|
return acc.after(acc_out)
|
||||||
|
|
||||||
|
def expand_horizontal_reduce(r:UOp):
|
||||||
|
axes = r.arg[1]
|
||||||
|
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
|
||||||
|
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
|
||||||
|
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
|
||||||
|
|
||||||
|
pm_reduce_local = PatternMatcher([
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
|
||||||
|
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
|
||||||
|
])+pm_clean_up_group_sink
|
||||||
|
|
||||||
def do_number_param(ctx:list[int], x:UOp):
|
def do_number_param(ctx:list[int], x:UOp):
|
||||||
if x.arg.slot != -1: return None
|
if x.arg.slot != -1: return None
|
||||||
ctx[0] += 1
|
ctx[0] += 1
|
||||||
|
|
@ -56,6 +150,90 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||||
if DEBUG >= 5: print(pyrender(ast))
|
if DEBUG >= 5: print(pyrender(ast))
|
||||||
if SPEC: type_verify(ast, spec_tensor)
|
if SPEC: type_verify(ast, spec_tensor)
|
||||||
|
sink = ast
|
||||||
|
|
||||||
|
# preprocess. we need to simplify these
|
||||||
|
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
||||||
|
|
||||||
|
# this is new style
|
||||||
|
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||||
|
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
|
||||||
|
|
||||||
|
# first we optimize
|
||||||
|
if optimize:
|
||||||
|
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||||
|
sink = apply_opts(sink, ren, beam=ast.arg.beam)
|
||||||
|
|
||||||
|
# do expander
|
||||||
|
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
|
||||||
|
|
||||||
|
# add locals (STAGE -> BUFFER)
|
||||||
|
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
|
||||||
|
|
||||||
|
# rewrite reduce after optimizations
|
||||||
|
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
|
||||||
|
|
||||||
|
# add gpu dims
|
||||||
|
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||||
|
|
||||||
|
# add loads
|
||||||
|
sink = graph_rewrite(sink, pm_move_regs, name="move to registers", walk=True)
|
||||||
|
|
||||||
|
# symbolic (note: this does POW decomp)
|
||||||
|
sink = graph_rewrite(sink, sym, name="post index symbolic")
|
||||||
|
|
||||||
|
# ***** make it rendererable (within spec, tighten) *****
|
||||||
|
|
||||||
|
# decompositions
|
||||||
|
supported_ops = tuple(ren.code_for_op.keys())
|
||||||
|
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
|
||||||
|
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||||
|
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="*** decompositions")
|
||||||
|
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
|
||||||
|
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
|
||||||
|
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions more")
|
||||||
|
|
||||||
|
# split ends
|
||||||
|
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
|
||||||
|
|
||||||
|
# this was the linearizer
|
||||||
|
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||||
|
|
||||||
|
# ***** this is where it gets large *****
|
||||||
|
|
||||||
|
# unbroadcast
|
||||||
|
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||||
|
|
||||||
|
# devectorizer
|
||||||
|
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
|
||||||
|
|
||||||
|
# ***** make it rendererable (outside spec, transform) *****
|
||||||
|
|
||||||
|
# final symbolic
|
||||||
|
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
|
||||||
|
|
||||||
|
# move gates from unrenderable INVALID where
|
||||||
|
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||||
|
|
||||||
|
# put registers in slots
|
||||||
|
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
|
||||||
|
name_to_slot = {x:x.replace(arg=replace(x.arg, slot=num_params+i))
|
||||||
|
for i,x in enumerate(sorted([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot == -1]))}
|
||||||
|
sink = sink.substitute(name_to_slot, name="put variables in slots")
|
||||||
|
|
||||||
|
# remove all weakints
|
||||||
|
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
|
||||||
|
|
||||||
|
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
|
||||||
|
if SPEC: type_verify(sink, spec_program)
|
||||||
|
|
||||||
|
# return the rewritten sink
|
||||||
|
return sink
|
||||||
|
|
||||||
|
def old_full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||||
|
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||||
|
if DEBUG >= 5: print(pyrender(ast))
|
||||||
|
if SPEC: type_verify(ast, spec_tensor)
|
||||||
|
|
||||||
# preprocess
|
# preprocess
|
||||||
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
||||||
|
|
|
||||||
|
|
@ -274,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
|
||||||
# allow use of atomics for embedding backward
|
# allow use of atomics for embedding backward
|
||||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||||
# don't allow broadcast
|
# don't allow broadcast
|
||||||
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1)
|
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Metadata:
|
class Metadata:
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,9 @@ def fold_divmod_general(d: UOp) -> UOp|None:
|
||||||
# divide_by_gcd: x//y -> (x//gcd)//(y//gcd)
|
# divide_by_gcd: x//y -> (x//gcd)//(y//gcd)
|
||||||
gcd = UOp.gcd(*all_uops, y).simplify()
|
gcd = UOp.gcd(*all_uops, y).simplify()
|
||||||
if not (gcd.op is Ops.CONST and gcd.arg==1):
|
if not (gcd.op is Ops.CONST and gcd.arg==1):
|
||||||
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd)))
|
x_div, y_div = x.divide_exact(gcd), y.divide_exact(gcd)
|
||||||
|
if x_div is None or y_div is None: return None
|
||||||
|
ret = x_div.alu(d.op, y_div)
|
||||||
return ret*gcd if d.op is Ops.FLOORMOD else ret
|
return ret*gcd if d.op is Ops.FLOORMOD else ret
|
||||||
|
|
||||||
# factor_remainder: (d*x+y)//d -> x+y//d
|
# factor_remainder: (d*x+y)//d -> x+y//d
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
||||||
|
|
||||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||||
if len(arg) == 0: return UOp(Ops.STACK)
|
if len(arg) == 0: return UOp(Ops.STACK)
|
||||||
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg)
|
elif all_int(arg): return UOp.const(dtypes.weakint, arg)
|
||||||
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
|
||||||
|
|
||||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||||
ret: dict[UOp, dict[UOp, None]] = {}
|
ret: dict[UOp, dict[UOp, None]] = {}
|
||||||
|
|
@ -306,9 +306,10 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
# REDUCE with empty axis is passthrough (lowered form)
|
# REDUCE with empty axis is passthrough (lowered form)
|
||||||
case Ops.REDUCE if len(self.arg[1]) == 0:
|
# no longer true
|
||||||
|
#case Ops.REDUCE if len(self.arg[1]) == 0:
|
||||||
# these can mismatch if there's a horizonal reduce
|
# these can mismatch if there's a horizonal reduce
|
||||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
#return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||||
|
|
||||||
# TODO: disallow shape changing bitcast
|
# TODO: disallow shape changing bitcast
|
||||||
case Ops.BITCAST:
|
case Ops.BITCAST:
|
||||||
|
|
@ -473,12 +474,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||||
def vectorize(self, *srcs):
|
def vectorize(self, *srcs):
|
||||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs)
|
return UOp(Ops.STACK, self.dtype, (self,)+srcs)
|
||||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path
|
# pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path
|
||||||
if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
|
#if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
|
||||||
idx = self._normalize_indices(list(argfix(idx)))
|
idx = self._normalize_indices(list(argfix(idx)))
|
||||||
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
|
||||||
# apply SHRINK for slices that aren't the full range
|
# apply SHRINK for slices that aren't the full range
|
||||||
|
|
@ -1682,7 +1683,7 @@ pm_lower_index_dtype = PatternMatcher([
|
||||||
UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
|
||||||
lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)),
|
lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)),
|
||||||
gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)),
|
gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)),
|
||||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
(UPat((Ops.SINK, Ops.NOOP, Ops.END, Ops.AFTER, Ops.BUFFER), name="n"),
|
||||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||||
])
|
])
|
||||||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,14 @@ def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
|
||||||
|
|
||||||
def fold_const_alu(a:UOp) -> UOp|None:
|
def fold_const_alu(a:UOp) -> UOp|None:
|
||||||
vals = [const_arg(s) for s in a.src]
|
vals = [const_arg(s) for s in a.src]
|
||||||
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False))
|
if any(v is None for v in vals): return None
|
||||||
|
if any(isinstance(v, tuple) for v in vals):
|
||||||
|
out_len = prod(a.shape)
|
||||||
|
if not all(not isinstance(v, tuple) or len(v) in {1, out_len} for v in vals): return None
|
||||||
|
return a.const_like(tuple(exec_alu(a.op, a.dtype.scalar(),
|
||||||
|
[v[0] if isinstance(v, tuple) and len(v) == 1 else v[i] if isinstance(v, tuple) else v for v in vals], False)
|
||||||
|
for i in range(out_len)))
|
||||||
|
return a.const_like(exec_alu(a.op, a.dtype, vals, False))
|
||||||
|
|
||||||
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
|
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
|
||||||
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue