Compare commits

...

39 commits

Author SHA1 Message Date
George Hotz
d319b5f614
Merge branch 'master' into codegen2 2026-06-21 19:18:24 -07:00
George Hotz
0decf136fe
Merge branch 'master' into codegen2 2026-06-19 18:29:13 -07:00
George Hotz
a96db15419 fixes 2026-06-19 17:57:59 -07:00
George Hotz
b0f7e2a5e1 fix imports 2026-06-19 17:44:39 -07:00
George Hotz
566fe4c7dc
Merge branch 'master' into codegen2 2026-06-19 16:57:09 -07:00
George Hotz
c2f73b102e
Merge branch 'master' into codegen2 2026-06-19 12:39:00 -07:00
George Hotz
c6f195f410 fixes 2026-06-18 18:08:14 -07:00
George Hotz
f26378264b
Merge branch 'master' into codegen2 2026-06-18 18:03:15 -07:00
George Hotz
1168ed9730
Merge branch 'master' into codegen2 2026-06-17 00:37:09 -07:00
George Hotz
017edbbbb5 param -1 2026-06-16 21:52:07 -07:00
George Hotz
daa72812b0 add gpu dims 2026-06-16 21:37:59 -07:00
George Hotz
fd325d662c
Merge branch 'master' into codegen2 2026-06-16 21:29:09 -07:00
George Hotz
8d36539656 test tiny passes 2026-06-16 14:57:12 -07:00
George Hotz
db2c71536b almost passing 2026-06-16 13:29:38 -07:00
George Hotz
4112b34a32 closer 2026-06-16 13:23:30 -07:00
George Hotz
1ad72dff08 more passing 2026-06-16 12:54:54 -07:00
George Hotz
6f1eaa8d46 fixes 2026-06-16 12:38:17 -07:00
George Hotz
35d2882991 no vec 2026-06-16 10:47:19 -07:00
George Hotz
a31732d819
Merge branch 'master' into codegen2 2026-06-16 10:33:34 -07:00
George Hotz
43d62c4211 hreduce 2026-06-16 09:36:47 -07:00
George Hotz
4d0429090c split reduce types 2026-06-16 09:27:09 -07:00
George Hotz
2c7a1450e7 fix reduce 2026-06-16 08:40:00 -07:00
George Hotz
6ffb55cc74
Merge branch 'master' into codegen2 2026-06-15 17:19:25 -07:00
George Hotz
1a280829ca
Merge branch 'master' into codegen2 2026-06-15 12:48:46 -07:00
George Hotz
3b426b1072 devec 2026-06-15 08:57:52 -07:00
George Hotz
ce2cdc3708
Merge branch 'master' into codegen2 2026-06-14 16:43:48 -07:00
George Hotz
333f062eee new expander 2026-06-14 13:54:13 -07:00
George Hotz
0d5bf3ca6d revert that 2026-06-14 13:28:28 -07:00
George Hotz
56bad940df disable that 2026-06-14 13:28:02 -07:00
George Hotz
f98deb9250 preprocess 2026-06-14 13:24:19 -07:00
George Hotz
bdfcb1cb98 test ops passes 2026-06-14 12:58:18 -07:00
George Hotz
a6fdb53a1e
Merge branch 'master' into codegen2 2026-06-14 10:09:00 -07:00
George Hotz
49deb9714b test_tiny passes 2026-06-14 09:36:51 -07:00
George Hotz
afab220947
Merge branch 'master' into codegen2 2026-06-14 08:52:36 -07:00
George Hotz
a7523b2596 simpler 2026-06-13 10:40:52 -07:00
George Hotz
21806848df improve new codegen 2026-06-12 20:08:20 -07:00
George Hotz
6fda6c704d
Merge branch 'master' into codegen2 2026-06-12 20:01:43 -07:00
George Hotz
3f7ec187df work 2026-06-12 19:24:56 -07:00
George Hotz
af9284e9b1 try for a full rewrite of codegen 2026-06-12 19:11:54 -07:00
5 changed files with 200 additions and 12 deletions

View file

@ -1,14 +1,16 @@
from typing import cast from typing import cast
from dataclasses import replace from dataclasses import replace
import itertools import itertools
import functools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import AxisType, _align_left, _broadcast_shape, identity_element
from tinygrad.uop.render import pyrender from tinygrad.uop.render import pyrender
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer import Renderer, Estimates
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.codegen.gpudims import pm_add_gpudims
@ -43,6 +45,98 @@ pm_remove_vec_dtypes = PatternMatcher([
lambda x: x.replace(dtype=x.dtype.base.scalar().base)), lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
])+pm_clean_up_group_sink ])+pm_clean_up_group_sink
def maybe_load(u:UOp): return u.load() if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL, AddrSpace.REG) else u
pm_move_regs = PatternMatcher([
# BITCAST?
(UPat(GroupOp.Elementwise, name="x"), lambda x: x.replace(src=tuple([maybe_load(u) for u in x.src]))),
(UPat(Ops.STORE, name="x"), lambda x: x.replace(src=(x.src[0], maybe_load(x.src[1]))+x.src[2:])),
])
pm_lower_weakints = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.weakint, name="x"), lambda x: x.replace(dtype=dtypes.int)),
])
def build_range_map(ctx, sink:UOp):
for x in sink.toposort():
if x.op is Ops.RANGE and x.arg[1] in {AxisType.UNROLL, AxisType.UPCAST}:
ctx[x.arg[0]] = len(ctx)
def fix_reduce(ctx, r:UOp):
range_to_axis = {u:ctx[u.arg[0]] for u in r.ended_ranges if u.arg[0] in ctx if u.arg[1] == AxisType.UNROLL}
return r.replace(src=tuple([u for u in r.src if u not in range_to_axis]), arg=(r.arg[0], r.arg[1]+tuple(range_to_axis.values())))
expander2 = PatternMatcher([
(UPat(Ops.SINK, name="sink"), build_range_map),
(UPat(Ops.REDUCE, name="r"), fix_reduce),
(UPat(Ops.RANGE, name="r"),
lambda ctx, r: UOp.const(r.dtype, tuple(range(r.vmax+1))) \
.reshape(tuple([r.vmax+1 if i == ctx[r.arg[0]] else 1 for i in range(len(ctx))])) if r.arg[0] in ctx else None),
])+pm_flatten_range
def broadcast_binary(x:UOp):
shapes = [u.shape for u in x.src]
if all_same(shapes): return None
shaped_aligned = _align_left(*shapes)
broadcasted = _broadcast_shape(*shapes)
src_reshaped = [u.reshape(shp).expand(broadcasted) for u,shp in zip(x.src, shaped_aligned)]
return x.replace(src=tuple(src_reshaped))
unbroadcast = PatternMatcher([
(UPat(GroupOp.Binary|GroupOp.Ternary|{Ops.STORE}, name="x"), broadcast_binary),
])
def do_devectorize(b:UOp):
if b.shape == (): return None
# broadcasting needs to be already unpacked
if not all_same([x.shape for x in b.src]): return None
src = []
for idx in itertools.product(*[range(x) for x in b.shape]):
idx_c = [UOp.const(dtypes.weakint, i) for i in idx]
src.append(b.replace(src=tuple([x.index(*idx_c) for x in b.src])))
return UOp.vectorize(*src).reshape(b.shape)
devectorizer2 = pm_mops+PatternMatcher([
# unpack broadcasting
(UPat(GroupOp.Elementwise|{Ops.LOAD, Ops.STORE}, name="b"), do_devectorize),
# INDEX into STACK is src
(UPat(Ops.INDEX, src=(UPat(Ops.STACK, name="a"), UPat.cvar("i"))), lambda a,i: a.src[i.arg]),
# stacked INDEX is many INDEX
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.STACK, name="s"))),
lambda b,s: UOp.vectorize(*[b.index(u) for u in s.src])),
# INDEX into RESHAPE moves the RESHAPE
(UPat(Ops.INDEX, src=(UPat((Ops.PARAM, Ops.BUFFER), name="b"), UPat(Ops.RESHAPE, name="s"))),
lambda b,s: b.index(s.src[0]).reshape(s.shape)),
# RESHAPE a void is removed (hack for AFTER)
(UPat(Ops.RESHAPE, dtype=dtypes.void, name="x"), lambda x: x.src[0]),
# reshape of a single element shaped value to scalar is an index
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0].index(UOp.const(dtypes.weakint, 0)) if x.marg == () and x.src[0].shape == (1,) else None),
# INDEX without src is nothing
(UPat(Ops.INDEX, src=(UPat.var('x'),)), lambda x: x),
])
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
vals = [r.src[0].shrink(tuple((idx[axes.index(i)], idx[axes.index(i)]+1) if i in axes else None for i in range(r.src[0].ndim)))
for idx in itertools.product(*[range(r.src[0].max_shape[a]) for a in axes])]
return functools.reduce(lambda x,y: x.alu(r.arg[0], y), vals)
pm_reduce_local = PatternMatcher([
(UPat(Ops.REDUCE, src=(UPat(), UPat()), allow_any_len=True, name="r"), reduce_ranges_to_acc),
(UPat(Ops.REDUCE, src=(UPat(),), name="r"), expand_horizontal_reduce),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp): def do_number_param(ctx:list[int], x:UOp):
if x.arg.slot != -1: return None if x.arg.slot != -1: return None
ctx[0] += 1 ctx[0] += 1
@ -56,6 +150,90 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast)) if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor) if SPEC: type_verify(ast, spec_tensor)
sink = ast
# preprocess. we need to simplify these
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)
# this is new style
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="transform to new style")
# first we optimize
if optimize:
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# do expander
sink = graph_rewrite(sink, expander2, ctx={}, name="expander", bottom_up=True)
# add locals (STAGE -> BUFFER)
sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers")
# rewrite reduce after optimizations
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
# add loads
sink = graph_rewrite(sink, pm_move_regs, name="move to registers", walk=True)
# symbolic (note: this does POW decomp)
sink = graph_rewrite(sink, sym, name="post index symbolic")
# ***** make it rendererable (within spec, tighten) *****
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV))
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="*** decompositions")
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions more")
# split ends
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# ***** this is where it gets large *****
# unbroadcast
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
# devectorizer
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
# ***** make it rendererable (outside spec, transform) *****
# final symbolic
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# put registers in slots
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
name_to_slot = {x:x.replace(arg=replace(x.arg, slot=num_params+i))
for i,x in enumerate(sorted([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot == -1]))}
sink = sink.substitute(name_to_slot, name="put variables in slots")
# remove all weakints
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
if SPEC: type_verify(sink, spec_program)
# return the rewritten sink
return sink
def old_full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
if SPEC: type_verify(ast, spec_tensor)
# preprocess # preprocess
sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True) sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True)

View file

@ -274,7 +274,7 @@ SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward # allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0) USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# don't allow broadcast # don't allow broadcast
DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 1) DISALLOW_BROADCAST = ContextVar("DISALLOW_BROADCAST", 0)
@dataclass(frozen=True) @dataclass(frozen=True)
class Metadata: class Metadata:

View file

@ -75,7 +75,9 @@ def fold_divmod_general(d: UOp) -> UOp|None:
# divide_by_gcd: x//y -> (x//gcd)//(y//gcd) # divide_by_gcd: x//y -> (x//gcd)//(y//gcd)
gcd = UOp.gcd(*all_uops, y).simplify() gcd = UOp.gcd(*all_uops, y).simplify()
if not (gcd.op is Ops.CONST and gcd.arg==1): if not (gcd.op is Ops.CONST and gcd.arg==1):
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd))) x_div, y_div = x.divide_exact(gcd), y.divide_exact(gcd)
if x_div is None or y_div is None: return None
ret = x_div.alu(d.op, y_div)
return ret*gcd if d.op is Ops.FLOORMOD else ret return ret*gcd if d.op is Ops.FLOORMOD else ret
# factor_remainder: (d*x+y)//d -> x+y//d # factor_remainder: (d*x+y)//d -> x+y//d

View file

@ -84,8 +84,8 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp: def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.STACK) if len(arg) == 0: return UOp(Ops.STACK)
elif all_int(arg): return UOp.const(dtypes.weakint.vec(len(arg)), arg) elif all_int(arg): return UOp.const(dtypes.weakint, arg)
else: return UOp(Ops.STACK, dtypes.weakint.vec(len(arg)), tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg)) else: return UOp(Ops.STACK, dtypes.weakint, tuple(UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]): def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {} ret: dict[UOp, dict[UOp, None]] = {}
@ -306,9 +306,10 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END: Ops.COPY | Ops.ALLREDUCE | Ops.STORE | Ops.END:
return self.src[0]._shape return self.src[0]._shape
# REDUCE with empty axis is passthrough (lowered form) # REDUCE with empty axis is passthrough (lowered form)
case Ops.REDUCE if len(self.arg[1]) == 0: # no longer true
#case Ops.REDUCE if len(self.arg[1]) == 0:
# these can mismatch if there's a horizonal reduce # these can mismatch if there's a horizonal reduce
return (self.dtype.count,) if self.dtype.count > 1 else () #return (self.dtype.count,) if self.dtype.count > 1 else ()
# TODO: disallow shape changing bitcast # TODO: disallow shape changing bitcast
case Ops.BITCAST: case Ops.BITCAST:
@ -473,12 +474,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs): def vectorize(self, *srcs):
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs) return UOp(Ops.STACK, self.dtype, (self,)+srcs)
def index(self, *srcs:UOp|None, ptr=False, **kwargs): def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): def __getitem__(self, idx):
# pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path # pointers index into INDEX UOps (scalar lookup); everything else uses the shared mixin view path
if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx) #if not isinstance(self.dtype, PtrDType): return super(UOp, self).__getitem__(idx)
idx = self._normalize_indices(list(argfix(idx))) idx = self._normalize_indices(list(argfix(idx)))
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]): if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
# apply SHRINK for slices that aren't the full range # apply SHRINK for slices that aren't the full range
@ -1682,7 +1683,7 @@ pm_lower_index_dtype = PatternMatcher([
UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))), UPat.var("gate").where(UPat.var("idx_x", dtypes.ints).cast(), UPat(Ops.CONST, arg=Invalid)))),
lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)), lambda buf,idx_x,idx_y,gate: buf.index(gate.where(idx_y, idx_y.const_like(Invalid)),
gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)), gate.where(idx_x, idx_x.const_like(Invalid)), ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), (UPat((Ops.SINK, Ops.NOOP, Ops.END, Ops.AFTER, Ops.BUFFER), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
]) ])
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View file

@ -29,7 +29,14 @@ def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
def fold_const_alu(a:UOp) -> UOp|None: def fold_const_alu(a:UOp) -> UOp|None:
vals = [const_arg(s) for s in a.src] vals = [const_arg(s) for s in a.src]
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False)) if any(v is None for v in vals): return None
if any(isinstance(v, tuple) for v in vals):
out_len = prod(a.shape)
if not all(not isinstance(v, tuple) or len(v) in {1, out_len} for v in vals): return None
return a.const_like(tuple(exec_alu(a.op, a.dtype.scalar(),
[v[0] if isinstance(v, tuple) and len(v) == 1 else v[i] if isinstance(v, tuple) else v for v in vals], False)
for i in range(out_len)))
return a.const_like(exec_alu(a.op, a.dtype, vals, False))
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i") invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat) invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)