mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
17 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f597ec8d5f |
||
|
|
4524c78e87 | ||
|
|
651031e3da |
||
|
|
79ad8f3ef4 |
||
|
|
c54151235e | ||
|
|
8a25d18135 |
||
|
|
4e3c3e66b0 | ||
|
|
474ec6e44d | ||
|
|
7274239315 | ||
|
|
3378626308 | ||
|
|
9e408e239f | ||
|
|
b7e3b92c54 | ||
|
|
96ad6f05bf |
||
|
|
f3e7efd39e | ||
|
|
a263eca378 |
||
|
|
fede39db53 | ||
|
|
0e90a46ae0 |
9 changed files with 169 additions and 16 deletions
|
|
@ -375,7 +375,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
|
|||
"""Conditional memory store with sub-word support. Returns list of store UOps."""
|
||||
adt = dtypes.uint64 if addr_bits == 64 else dtypes.uint32
|
||||
word_addr = addr >> UOp.const(adt, 2)
|
||||
idx = mem.index(word_addr.cast(dtypes.int).valid(active))
|
||||
idx = mem.index(word_addr.valid(active))
|
||||
if data_bits == 32: return [idx.store(active.where(_to_u32(val), idx))]
|
||||
# Sub-word store: read-modify-write with mask
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
|
|
@ -388,7 +388,7 @@ def _mem_store(mem: UOp, addr: UOp, val: UOp, active: UOp, addr_bits: int = 32,
|
|||
is_cross = byte_pos.eq(_c(3))
|
||||
cross_word0 = (idx & _c(0x00FFFFFF)) | ((val_u32 & _c(0xFF)) << _c(24))
|
||||
store0 = idx.store(active.where(is_cross.where(cross_word0, new_word), idx))
|
||||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).cast(dtypes.int).valid(active & is_cross))
|
||||
next_idx = mem.index((word_addr + UOp.const(adt, 1)).valid(active & is_cross))
|
||||
cross_word1 = (next_idx & _c(0xFFFFFF00)) | ((val_u32 >> _c(8)) & _c(0xFF))
|
||||
return [store0, next_idx.store((active & is_cross).where(cross_word1, next_idx))]
|
||||
|
||||
|
|
@ -398,7 +398,7 @@ def _mem_store_bytes(mem: UOp, addr: UOp, val: UOp, active: UOp, data_bits: int
|
|||
val_u32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
for i in range(data_bits // 8):
|
||||
byte_val = (val_u32 >> UOp.const(dtypes.uint32, i * 8)) & UOp.const(dtypes.uint32, 0xFF)
|
||||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).cast(dtypes.int).valid(active)).store(byte_val.cast(dtypes.uint8)))
|
||||
stores.append(mem.index((addr + UOp.const(dtypes.uint64, i)).valid(active)).store(byte_val.cast(dtypes.uint8)))
|
||||
return stores
|
||||
|
||||
def _collect_data_slices(assigns: list[tuple[str, UOp]], data_prefix: str, pcode_vars: dict | None = None, op_name: str = "") -> dict[int, UOp]:
|
||||
|
|
@ -463,7 +463,7 @@ class _Ctx:
|
|||
"""Read instruction dword from vmem at PC + dword_idx*4."""
|
||||
pc = self.rpc()
|
||||
addr = pc if dword_idx == 0 else pc + UOp.const(dtypes.uint64, dword_idx * 4)
|
||||
return self.vmem.index((addr >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int), ptr=True).load()
|
||||
return self.vmem.index(addr >> UOp.const(dtypes.uint64, 2), ptr=True).load()
|
||||
|
||||
def inst_field(self, field) -> UOp:
|
||||
"""Extract field bits from instruction encoding. Tracks field for canonical key computation."""
|
||||
|
|
@ -516,8 +516,8 @@ class _Ctx:
|
|||
# Dynamic register access (takes UOp index instead of int)
|
||||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read SGPR with dynamic register index."""
|
||||
if valid is not None: return self.sgpr.index(reg.cast(dtypes.int).valid(valid), ptr=True).load()
|
||||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
if valid is not None: return self.sgpr.index(reg.valid(valid), ptr=True).load()
|
||||
return self.sgpr.index(reg, ptr=True).load()
|
||||
|
||||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||||
"""Write SGPR with dynamic register index. On RDNA, index 124 = NULL (writes discarded). On CDNA, index 124 = M0 (read/write)."""
|
||||
|
|
@ -833,7 +833,7 @@ def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
|||
nval = int(part.removeprefix('DWORD').removeprefix('X') or '1') if 'DWORD' in part else int(part[1:]) / 32 * (-1 if part[0] == 'I' else 1)
|
||||
ndwords = max(1, int(abs(nval)))
|
||||
dword_base = addr >> UOp.const(dtypes.uint64, 2)
|
||||
vals = [ctx.vmem.index((dword_base + UOp.const(dtypes.uint64, i)).cast(dtypes.int)) for i in range(ndwords)]
|
||||
vals = [ctx.vmem.index(dword_base + UOp.const(dtypes.uint64, i)) for i in range(ndwords)]
|
||||
if abs(nval) < 1:
|
||||
nbits = int(abs(nval) * 32)
|
||||
byte_off = (addr & UOp.const(dtypes.uint64, 3)).cast(dtypes.uint32) * UOp.const(dtypes.uint32, 8)
|
||||
|
|
@ -1847,7 +1847,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
|||
def wmem(addr: UOp, val: UOp, active: UOp, data_bits: int = 32) -> UOp:
|
||||
if data_bits < 32:
|
||||
# Sub-dword LDS write: read-modify-write within the uint32 slot
|
||||
word_addr = (addr >> addr_shift).cast(dtypes.int)
|
||||
word_addr = addr >> addr_shift
|
||||
idx = mem.index(word_addr.valid(active))
|
||||
byte_pos = addr.cast(dtypes.uint32) & _c(3)
|
||||
byte_shift = byte_pos * _c(8)
|
||||
|
|
@ -1855,7 +1855,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
|||
mask = size_mask << byte_shift
|
||||
new_word = (idx & (mask ^ _c(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & size_mask) << byte_shift)
|
||||
return idx.store(active.where(new_word, idx))
|
||||
idx = mem.index((addr >> addr_shift).cast(dtypes.int))
|
||||
idx = mem.index(addr >> addr_shift)
|
||||
return idx.store(active.where(val, idx.load()))
|
||||
|
||||
def make_srcs(lane: UOp) -> dict:
|
||||
|
|
@ -2005,7 +2005,7 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
|
|||
for i in range(n_dwords):
|
||||
word_addr = (addr + UOp.const(dtypes.uint64, i * 4)) >> UOp.const(dtypes.uint64, 2)
|
||||
val = in_bounds.where(mem.index(word_addr.cast(dtypes.int64), ptr=True).load(), _c(0))
|
||||
lds_idx = ((lds_addr + _c(i * 4)) >> _c(2)).cast(dtypes.int)
|
||||
lds_idx = (lds_addr + _c(i * 4)) >> _c(2)
|
||||
lds_slot = ctx.lds.index(lds_idx.valid(active))
|
||||
stores.append(lds_slot.store(active.where(val, lds_slot)))
|
||||
elif is_store:
|
||||
|
|
|
|||
|
|
@ -845,10 +845,10 @@ class Parser:
|
|||
val = _u32(0)
|
||||
for i in range(4): val = val | (mindex(idx + _const(dtypes.int, i), ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
|
||||
else:
|
||||
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
|
||||
idx = (addr >> _const(addr.dtype, 2))
|
||||
val = mindex(idx)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
|
||||
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2))
|
||||
val = val.cast(dtypes.uint64) | (mindex(idx2).cast(dtypes.uint64) << _u64(32))
|
||||
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
|
||||
elif dt in (dtypes.uint16, dtypes.int16):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import cast
|
|||
from dataclasses import replace
|
||||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, MINIGEN
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo
|
||||
from tinygrad.uop.render import pyrender
|
||||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
|
|
@ -22,6 +22,7 @@ from tinygrad.codegen.late.gater import pm_move_gates_from_index
|
|||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.minigen import minigen_to_sink
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
|
|
@ -29,6 +30,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
if DEBUG >= 5: print(pyrender(ast))
|
||||
if SPEC: type_verify(ast, spec_tensor)
|
||||
|
||||
# mini codegen
|
||||
if MINIGEN: return minigen_to_sink(ast, ren, optimize)
|
||||
|
||||
# preprocess
|
||||
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
|
||||
|
||||
|
|
|
|||
142
tinygrad/codegen/minigen.py
Normal file
142
tinygrad/codegen/minigen.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, GroupOp, UPat, Ops, AxisType
|
||||
from tinygrad.uop.spec import type_verify, spec_program
|
||||
from tinygrad.dtype import dtypes, Invalid, AddrSpace, least_upper_dtype
|
||||
from tinygrad.helpers import SPEC
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import pm_add_loads, reduce_to_acc, merge_reduce_ends, pm_render
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow
|
||||
from tinygrad.schedule.rangeify import pm_index_on_index, pm_mops
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.codegen.late.expander import pm_group_for_reduce
|
||||
|
||||
def lower_weakint(x:UOp):
|
||||
# no sources, it's an int
|
||||
if len(x.src) == 0: return x.replace(dtype=dtypes.int)
|
||||
return x.replace(dtype=least_upper_dtype(dtypes.int, *[u.dtype for u in x.src]))
|
||||
|
||||
pm_lower_weakint = PatternMatcher([
|
||||
(UPat(GroupOp.All, dtypes.weakint, name="x"), lower_weakint),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
local_num: int = 0
|
||||
|
||||
def stage_to_local(ctx:ReduceContext, x:UOp):
|
||||
# TODO: addrspace shouldn't be on dtype
|
||||
ret = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(size=x.numel(), addrspace=AddrSpace.LOCAL), (), arg=ctx.local_num).reshape(*[u.vmax+1 for u in x.src[1:]])
|
||||
ctx.local_num += 1
|
||||
return ret.after(ret.index(*x.src[1:]).store(x.src[0]).end(*x.src[1:]))
|
||||
|
||||
pm_minimal_reduce = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
(UPat(Ops.STAGE, name="x"), stage_to_local),
|
||||
|
||||
# TODO: this should not be here! this is caused by pm_load_collapse
|
||||
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),
|
||||
])
|
||||
|
||||
pm_minimal_move_gates_from_index = PatternMatcher([
|
||||
# here we create the alt value for load to be 0s and remove the where Invalid
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
|
||||
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
|
||||
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class ExpanderState:
|
||||
total_dim: int
|
||||
current_dim: int = 0
|
||||
|
||||
def expand_range(ctx:ExpanderState, r:UOp):
|
||||
if r.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}: return None
|
||||
ret = UOp.const(r.dtype, tuple(range(r.vmax+1)))
|
||||
ret = ret.reshape([-1 if x == ctx.current_dim else 1 for x in range(ctx.total_dim)])
|
||||
ctx.current_dim += 1
|
||||
return ret
|
||||
|
||||
pm_mini_expander = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r"), expand_range),
|
||||
])
|
||||
|
||||
def minigen_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
sink = ast
|
||||
|
||||
if optimize:
|
||||
# collapse loads reduce (indexing by a tensor), must run while REDUCE is still in the graph
|
||||
sink = graph_rewrite(sink, pm_load_collapse, name="load collapse")
|
||||
|
||||
# split ranges
|
||||
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
|
||||
|
||||
# optimize (schedule) the AST
|
||||
sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges")
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren, beam=ast.arg.beam)
|
||||
|
||||
# expanded
|
||||
rcount = sum([1 for u in sink.toposort() if u.op is Ops.RANGE and u.arg[-1] in {AxisType.UPCAST, AxisType.UNROLL}])
|
||||
sink = graph_rewrite(sink, pm_mini_expander+pm_group_for_reduce, ctx=ExpanderState(rcount), name="expander 2.0")
|
||||
|
||||
# REDUCE is not allowed in programs, we need to do that REDUCE somewhere
|
||||
# this creates a register where the reduce happens
|
||||
# STAGE is also not allowed in programs, this is similar to REDUCE
|
||||
sink = graph_rewrite(sink, pm_minimal_reduce, ctx=ReduceContext(), name="remove reduce/stage")
|
||||
|
||||
# if there's any movement ops left, we need to remove them. removing stage might add movement ops
|
||||
# also handle INDEX on INDEX
|
||||
sink = graph_rewrite(sink, pm_index_on_index+pm_mops, name="remove movement ops")
|
||||
|
||||
# do single symbolic (this rewrites POW)
|
||||
sink = graph_rewrite(sink, sym, name="symbolic")
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||
|
||||
# we need to add loads
|
||||
# this is really a store to DEFINE_REG, but load is simpler
|
||||
# LOAD(DATA) is anonymous store -> AFTER(anon_buf, STORE(anon_buf, DATA))
|
||||
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
||||
|
||||
# we need to lower weakint for the program
|
||||
# this will be simpler when we have implicit dtype
|
||||
sink = graph_rewrite(sink, pm_lower_weakint, name="remove weakint")
|
||||
|
||||
# move gates from INDEX to LOAD/STORE (Invalid isn't renderable)
|
||||
sink = graph_rewrite(sink, pm_minimal_move_gates_from_index, name="move gates")
|
||||
|
||||
# split ENDs, renderable ENDs can only end one RANGE
|
||||
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
|
||||
|
||||
# **** enter decanonicalize *****
|
||||
|
||||
# decompose dtypes we don't support into renderable versions
|
||||
# NOTE: this adds Ops.FLOORDIV, so it needs to come before decompose ops
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.target), name="decomp dtypes")
|
||||
|
||||
# decompose ops like SIN/THREEFRY into renderable versions
|
||||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = get_late_rewrite_patterns(supported_ops, disable_fast_idiv=True) + \
|
||||
get_transcendental_patterns(supported_ops, force_transcendental=False)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.target, name="decompose ops to renderable")
|
||||
|
||||
# extra matcher from the renderer
|
||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
sink = graph_rewrite(sink, pm_render+extra_matcher, ctx=ren.target, name="final rewrite")
|
||||
|
||||
# this was the linearizer, add control flow edges where they are needed on RANGEs
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
if SPEC: type_verify(sink, spec_program)
|
||||
return sink
|
||||
|
|
@ -91,7 +91,7 @@ class DType(metaclass=DTypeMetaClass):
|
|||
return float("inf") if dtypes.is_float(self) else True
|
||||
def const(self, val: tuple[ConstType, ...]|ConstType):
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == self.count, f"mismatch {val} {self}"
|
||||
#assert len(val) == self.count, f"mismatch {val} {self}"
|
||||
return tuple(map(self.const, val))
|
||||
if isinstance(val, InvalidType): return val
|
||||
# NOTE: float('nan') != float('nan'), so we canonicalize here
|
||||
|
|
|
|||
|
|
@ -268,6 +268,8 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
|
|||
SCACHE = ContextVar("SCACHE", 1)
|
||||
# allow use of atomics for embedding backward
|
||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||
# enable mini codegen
|
||||
MINIGEN = ContextVar("MINIGEN", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
|
|
|||
|
|
@ -37,10 +37,13 @@ pm_store_ranges = PatternMatcher([
|
|||
(UPat(Ops.STORE, name="x"), add_ranges_to_store),
|
||||
])
|
||||
|
||||
pm_syntactic_sugar = PatternMatcher([
|
||||
pm_index_on_index = PatternMatcher([
|
||||
# INDEX on ptr INDEX concats them
|
||||
(UPat(Ops.INDEX, name="i1").f(Ops.INDEX, name="i2", allow_any_len=True),
|
||||
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
pm_syntactic_sugar = pm_index_on_index+PatternMatcher([
|
||||
# early rangeify
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise | {Ops.CONST}, name="x"),), allow_any_len=True, name="idx"),
|
||||
lambda idx,x: x.replace(src=tuple([s.index(*idx.src[1:]) for s in x.src]))),
|
||||
|
|
|
|||
|
|
@ -247,6 +247,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return tuple(shp) + self.src[0].shape[len(self.src[1:]):]
|
||||
|
||||
# TODO: these should have the shape of the dtype.count
|
||||
case Ops.VCONST: return (len(self.arg),)
|
||||
case Ops.CONST | Ops.DEFINE_VAR: return ()
|
||||
case Ops.GEP | Ops.STACK | Ops.VCAT | Ops.GETADDR: return ()
|
||||
|
||||
|
|
|
|||
|
|
@ -205,7 +205,8 @@ spec_program = PatternMatcher([
|
|||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
|
||||
# STACK/GEP in program. TODO: this should match Tensor
|
||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
#(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: True),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue