remove DEFINE_VAR from codebase (gpt) (#16666)

* remove DEFINE_VAR from codebase

* junk

* remove junk
This commit is contained in:
George Hotz 2026-06-18 15:33:50 -07:00 committed by GitHub
commit 4a4b6956df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 52 additions and 61 deletions

View file

@ -1,5 +1,5 @@
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType, dtypes
from tinygrad.dtype import DType, dtypes, AddrSpace
from tinygrad.tensor import Tensor
from tinygrad.device import Device, Buffer
from tinygrad.engine.jit import TinyJit
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
info = prg.arg
functions[info.function_name] = prg.src[3].arg
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
statements.append((info.function_name, cargs, info.global_size, info.local_size))
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
symbolic_vars[var] = var.expr
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
global_size[j] = f"_{name.expr}[0] + {val.arg}"
prg = ""
if target == "clang":

View file

@ -737,7 +737,7 @@ class Parser:
return _u32(0)
def _find_var_name(self, base: UOp) -> str | None:
if base.op == Ops.DEFINE_VAR and base.arg: return base.arg[0]
if base.op == Ops.PARAM and base.arg.name is not None: return base.arg.name
for name, v in self.vars.items():
if isinstance(v, UOp) and v is base: return name
return None

View file

@ -22,7 +22,7 @@ def apply_rewrite_values(expr):
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
return uop.arg
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU):
elif uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU:
return variables[uop.expr]
elif uop.op in GroupOp.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]
@ -301,13 +301,13 @@ class TestRecurse(unittest.TestCase):
@given(matchers)
def test_no_inf_loop(self, PatternMatcher):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
pm = PatternMatcher([(UPat(Ops.PARAM, name="x"), lambda x: x)])
graph_rewrite(a, pm)
@given(matchers)
def test_no_inf_loop_bottom_up(self, PatternMatcher):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
pm = PatternMatcher([(UPat(Ops.PARAM, name="x"), lambda x: x)])
graph_rewrite(a, pm, bottom_up=True)
def test_inf_loop(self):

View file

@ -113,11 +113,11 @@ class TestGraphRewrite(unittest.TestCase):
# NOTE: this shows why we can't have a UOp in arg
@unittest.expectedFailure
def test_no_dedup_args(self):
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a1 = UOp.variable("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11), dtypes.int)
a2 = UOp.variable("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11), dtypes.int)
sink = a1.sink(a2)
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort() if x.op is Ops.DEFINE_VAR]
self.assertEqual(len(define_vars), 1)
variables = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort() if x.op is Ops.PARAM and x.addrspace is AddrSpace.ALU]
self.assertEqual(len(variables), 1)
def test_simple(self):
c1 = UOp.const(dtypes.float, 1.0)
@ -344,19 +344,19 @@ class TestUOpGraph(unittest.TestCase):
for i in [4, 8]:
vec = UOp(Ops.STACK, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
tuple(UOp.variable(f'tmp{j}', 0, 1, dtypes.half) for j in range(i//2)))
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
for i in [4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
vec = UOp(Ops.STACK, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
tuple(UOp.variable(f'tmp{j}', 0, 1, dtypes.half) for j in range(i//2)))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
@ -364,17 +364,17 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(Ops.STACK, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
for i in [2, 4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
vec = UOp(Ops.STACK, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-2], wmma) # -2 to skip SINK

View file

@ -9,7 +9,7 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str):
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
s_eval = eval(s, sym_vars)
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)

View file

@ -75,7 +75,7 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_variable_inside_special(self):
uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10)),))
uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp.variable('i', 1, 10, dtypes.int),))
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 9)

View file

@ -236,7 +236,7 @@ class TestViz(unittest.TestCase):
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
a = UOp.variable("a", 0.0, 10.0, dtypes.float)
alu = a + c
with save_viz() as viz:
graph_rewrite(alu, PatternMatcher([]))

View file

@ -205,8 +205,8 @@ def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
# uop list is a list in the original_sink graph and we can map to the tags later
# here we build buffer map
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
ctx = AllocCtx(bases=set([x.multibase for x in big_sink.src if x.base.op not in dont_realize]))
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.AFTER}
ctx = AllocCtx(bases=set([x.multibase for x in big_sink.src if x.base.op not in dont_realize and x.base.addrspace is not AddrSpace.ALU]))
# this rewrite is "read-only", it adds simple things to buffer_map and may sink things on big_sink, bottom_up
# this is the only one where we have to be careful to not break the tensor graph

View file

@ -45,9 +45,6 @@ pm_remove_vec_dtypes = PatternMatcher([
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
# replace DEFINE_VAR with PARAM
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=AddrSpace.ALU))),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):
@ -146,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# put unnumbered DEFINE_VAR in slots
# put unnumbered variable PARAMs in slots
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
sink = graph_rewrite(sink, pm_number_params, ctx=[num_params], name="number params with -1", walk=True)

View file

@ -22,9 +22,7 @@ def linearize(sink:UOp) -> list[UOp]:
extra = None
match u.op:
# the order and placement of these defines is important
case Ops.PARAM if u.arg.addrspace is None: priority, extra = -19, u.expr # var params sort after global params
case Ops.PARAM: priority, extra = -20, u.arg.slot
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
case Ops.BUFFER: priority = -18
case Ops.DEFINE_REG: priority = -18
case Ops.DEFINE_LOCAL: priority = -17

View file

@ -132,6 +132,6 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
return nx, before + [nx] + after
pm_regalloc_rewrite = PatternMatcher([
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
regalloc_rewrite),
])

View file

@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")

View file

@ -1,5 +1,6 @@
# minimal amdgpu elf packer
import ctypes
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import ceildiv, round_up
from tinygrad.uop.ops import UOp, Ops
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
@ -36,9 +37,8 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
# ** scan sink for metadata
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
for u in sink.toposort():
if u.op is Ops.PARAM and u.addrspace is not None: n_bufs += 1
elif u.op is Ops.PARAM and u.addrspace is None: n_vars += 1
elif u.op is Ops.DEFINE_VAR: n_vars += 1
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
elif u.op is Ops.PARAM: n_bufs += 1
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
code_bytes = b"".join(inst.to_bytes() for inst in insts)

View file

@ -383,7 +383,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
bufs: set[UOp] = set()
def gate_input(u:UOp):
# TODO: add cache to fix n^2
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK}): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
@ -534,7 +534,7 @@ to_define_global = PatternMatcher([
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
# this was DEFINE_VAR, clean this up and make it universal
# ALU params are scalar symbolic values, not buffers.
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="v"),)), lambda v: v if v.addrspace == AddrSpace.ALU else None),
(UPat(Ops.BIND, name="b"), unbind_kernel),

View file

@ -13,8 +13,8 @@ class FastEnum(IntEnum):
class Ops(FastEnum):
# ** 1 -- defines/special **
# define GLOBAL/VAR are ptrs to outside the Kernel
DEFINE_VAR = auto(); BIND = auto()
# BIND pairs a symbolic PARAM with a concrete value
BIND = auto()
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
@ -127,7 +127,7 @@ class GroupOp:
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
# BinaryOps that can be flipped

View file

@ -272,7 +272,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
else:
return (len(self.src),) + self.src[0].shape
# TODO: contract and unroll should be deleted
case Ops.CONST | Ops.DEFINE_VAR | Ops.CONTRACT | Ops.UNROLL | Ops.VCAT:
case Ops.CONST | Ops.CONTRACT | Ops.UNROLL | Ops.VCAT:
return (self.dtype.count,) if self.dtype.count > 1 else ()
# some ops init the shape
@ -796,7 +796,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op in {Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
return self.src[0].addrspace
@ -924,7 +924,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@property
def val(self) -> int: return self.unbind()[1]
def variables(self) -> list[Variable]:
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU)},
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU},
key=lambda v: v.expr)
# *** uop symbolic stuff ***
@ -1018,7 +1018,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return self.arg.vmin_vmax
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
@ -1033,7 +1032,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def _sym_fxn(self):
from tinygrad.uop.render import _render_with_splits, renderer_infer
sself = self.simplify()
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU)))
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU))
# TODO: sanitize varnames, or don't use naked eval while staying fast
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
@ -1149,7 +1148,7 @@ class ProgramInfo:
global_size: list[int] = [1, 1, 1]
local_size: list[int]|None = [1, 1, 1]
for u in sink.toposort():
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU): _vars.append(u)
if u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU: _vars.append(u)
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
@ -1158,7 +1157,7 @@ class ProgramInfo:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
if u.op is Ops.PARAM and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
tuple(local_size) if local_size is not None else None, tuple(sorted(dedup(_vars), key=lambda v: v.arg.slot)),
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)

View file

@ -33,7 +33,6 @@ def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
precedence.get(x.src[1].op,99)<precedence[x.op] else right)
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
(UPat(Ops.PARAM, name="x"), lambda x: x.arg.name if x.arg.name is not None else f"p{x.arg.slot}"),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
@ -81,8 +80,6 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:

View file

@ -52,9 +52,8 @@ spec_shared = PatternMatcher([
# NOOP. TODO: remove this
(UPat(Ops.NOOP), lambda: True),
# CONST/DEFINE_VAR are everywhere
# CONST is everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
# STACK is everywhere too
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
@ -197,7 +196,7 @@ spec_tensor = PatternMatcher([
# these ops can exist in programs but not the tensor spec. example: LOAD
spec_program = PatternMatcher([
# no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),

View file

@ -259,7 +259,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding

View file

@ -21,7 +21,7 @@ def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.arg)), arg="{0}.arg == {1}"))
if self.strict_length or self.required_len > 0:
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.CUSTOMI, arg=self.name), base)))
if self.match_dtype is not None:
if len(self.match_dtype) > 1:
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.match_dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
@ -126,7 +126,7 @@ def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
assert len(or_pieces) == 0 and len(s.src) >= 1
for ss in s.src: or_pieces.extend(_final_render(ss, has_ctx, depth+1))
elif s.op is Ops.STORE:
assert s.src[0].op is Ops.DEFINE_VAR and s.src[1].op is Ops.NOOP
assert s.src[0].op is Ops.CUSTOMI and s.src[1].op is Ops.NOOP
store_pieces.append(f"{s.src[0].arg}={s.src[1].arg}")
elif s.op is Ops.NOOP: and_pieces.append(s.arg)
else: raise UPatCompileError(f"can't compile this {s}")