mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove DEFINE_VAR from codebase (gpt) (#16666)
* remove DEFINE_VAR from codebase * junk * remove junk
This commit is contained in:
parent
eda0a402d1
commit
4a4b6956df
20 changed files with 52 additions and 61 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.dtype import DType, dtypes, AddrSpace
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
|
@ -39,7 +39,7 @@ def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], Li
|
|||
prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
|
||||
info = prg.arg
|
||||
functions[info.function_name] = prg.src[3].arg
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
|
||||
cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + list(info.vars)
|
||||
statements.append((info.function_name, cargs, info.global_size, info.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
|
||||
|
|
@ -253,17 +253,18 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = "model"
|
|||
symbolic_vars = OrderedDict()
|
||||
for i, (_, args, global_size, _) in enumerate(statements):
|
||||
for j, var in enumerate(args):
|
||||
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
|
||||
if getattr(var, "op", None) is Ops.PARAM and var.addrspace is AddrSpace.ALU and var.arg.name is not None:
|
||||
if var not in symbolic_vars:
|
||||
symbolic_vars[var] = var.arg[0]
|
||||
symbolic_vars[var] = var.expr
|
||||
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
|
||||
statements[i][1][j] = symbolic_vars[var]
|
||||
|
||||
if global_size:
|
||||
for j, dim in enumerate(global_size):
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and \
|
||||
any(s.op is Ops.PARAM and s.addrspace is AddrSpace.ALU for s in dim.src) and any(s.op is Ops.CONST for s in dim.src):
|
||||
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
|
||||
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
|
||||
global_size[j] = f"_{name.expr}[0] + {val.arg}"
|
||||
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
|
|
|
|||
|
|
@ -737,7 +737,7 @@ class Parser:
|
|||
return _u32(0)
|
||||
|
||||
def _find_var_name(self, base: UOp) -> str | None:
|
||||
if base.op == Ops.DEFINE_VAR and base.arg: return base.arg[0]
|
||||
if base.op == Ops.PARAM and base.arg.name is not None: return base.arg.name
|
||||
for name, v in self.vars.items():
|
||||
if isinstance(v, UOp) and v is base: return name
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def apply_rewrite_values(expr):
|
|||
def evaluate_uop(uop, variables):
|
||||
if uop.op == Ops.CONST:
|
||||
return uop.arg
|
||||
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU):
|
||||
elif uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU:
|
||||
return variables[uop.expr]
|
||||
elif uop.op in GroupOp.ALU:
|
||||
src_values = [evaluate_uop(src, variables) for src in uop.src]
|
||||
|
|
@ -301,13 +301,13 @@ class TestRecurse(unittest.TestCase):
|
|||
@given(matchers)
|
||||
def test_no_inf_loop(self, PatternMatcher):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
|
||||
pm = PatternMatcher([(UPat(Ops.PARAM, name="x"), lambda x: x)])
|
||||
graph_rewrite(a, pm)
|
||||
|
||||
@given(matchers)
|
||||
def test_no_inf_loop_bottom_up(self, PatternMatcher):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
|
||||
pm = PatternMatcher([(UPat(Ops.PARAM, name="x"), lambda x: x)])
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
|
||||
def test_inf_loop(self):
|
||||
|
|
|
|||
|
|
@ -113,11 +113,11 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
# NOTE: this shows why we can't have a UOp in arg
|
||||
@unittest.expectedFailure
|
||||
def test_no_dedup_args(self):
|
||||
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a1 = UOp.variable("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11), dtypes.int)
|
||||
a2 = UOp.variable("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11), dtypes.int)
|
||||
sink = a1.sink(a2)
|
||||
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort() if x.op is Ops.DEFINE_VAR]
|
||||
self.assertEqual(len(define_vars), 1)
|
||||
variables = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort() if x.op is Ops.PARAM and x.addrspace is AddrSpace.ALU]
|
||||
self.assertEqual(len(variables), 1)
|
||||
|
||||
def test_simple(self):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
|
|
@ -344,19 +344,19 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [4, 8]:
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
tuple(UOp.variable(f'tmp{j}', 0, 1, dtypes.half) for j in range(i//2)))
|
||||
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
tuple(UOp.variable(f'tmp{j}', 0, 1, dtypes.half) for j in range(i//2)))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
|
@ -364,17 +364,17 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp.variable(f'tmp{i}', 0, 1, dtypes.half.vec(i))
|
||||
vec = UOp(Ops.STACK, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
|
|||
from tinygrad.uop.validate import uops_to_z3
|
||||
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
|
||||
s_eval = eval(s, sym_vars)
|
||||
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
|
||||
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class TestVminVmaxProperties(unittest.TestCase):
|
|||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_variable_inside_special(self):
|
||||
uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10)),))
|
||||
uop = UOp(Ops.SPECIAL, dtypes.int, arg='gidx0', src=(UOp.variable('i', 1, 10, dtypes.int),))
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 9)
|
||||
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ class TestViz(unittest.TestCase):
|
|||
def test_const_reshape_expand_folded(self):
|
||||
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
|
||||
c = UOp.const(dtypes.float, 1.0, shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
|
||||
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
|
||||
a = UOp.variable("a", 0.0, 10.0, dtypes.float)
|
||||
alu = a + c
|
||||
with save_viz() as viz:
|
||||
graph_rewrite(alu, PatternMatcher([]))
|
||||
|
|
|
|||
|
|
@ -205,8 +205,8 @@ def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
|||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
# here we build buffer map
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
|
||||
ctx = AllocCtx(bases=set([x.multibase for x in big_sink.src if x.base.op not in dont_realize]))
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.AFTER}
|
||||
ctx = AllocCtx(bases=set([x.multibase for x in big_sink.src if x.base.op not in dont_realize and x.base.addrspace is not AddrSpace.ALU]))
|
||||
|
||||
# this rewrite is "read-only", it adds simple things to buffer_map and may sink things on big_sink, bottom_up
|
||||
# this is the only one where we have to be careful to not break the tensor graph
|
||||
|
|
|
|||
|
|
@ -45,9 +45,6 @@ pm_remove_vec_dtypes = PatternMatcher([
|
|||
# replace DEFINE_LOCAL/DEFINE_REG with BUFFER
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="x"), lambda x:
|
||||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
# replace DEFINE_VAR with PARAM
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=AddrSpace.ALU))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def do_number_param(ctx:list[int], x:UOp):
|
||||
|
|
@ -146,7 +143,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
# put unnumbered DEFINE_VAR in slots
|
||||
# put unnumbered variable PARAMs in slots
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM and x.arg.slot != -1])
|
||||
sink = graph_rewrite(sink, pm_number_params, ctx=[num_params], name="number params with -1", walk=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
extra = None
|
||||
match u.op:
|
||||
# the order and placement of these defines is important
|
||||
case Ops.PARAM if u.arg.addrspace is None: priority, extra = -19, u.expr # var params sort after global params
|
||||
case Ops.PARAM: priority, extra = -20, u.arg.slot
|
||||
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
|
||||
case Ops.BUFFER: priority = -18
|
||||
case Ops.DEFINE_REG: priority = -18
|
||||
case Ops.DEFINE_LOCAL: priority = -17
|
||||
|
|
|
|||
|
|
@ -132,6 +132,6 @@ def regalloc_rewrite(ctx:LinearScanRegallocContext, x:UOp):
|
|||
return nx, before + [nx] + after
|
||||
|
||||
pm_regalloc_rewrite = PatternMatcher([
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
|
||||
(UPat({Ops.INS, Ops.RANGE, Ops.END, Ops.DEFINE_REG, Ops.DEFINE_LOCAL, Ops.PARAM, Ops.SPECIAL} | PSEUDO_OPS, name="x"),
|
||||
regalloc_rewrite),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
|
|||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL}: continue
|
||||
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
|
||||
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
|
||||
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# minimal amdgpu elf packer
|
||||
import ctypes
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import ceildiv, round_up
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
|
||||
|
|
@ -36,9 +37,8 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
|
|||
# ** scan sink for metadata
|
||||
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.PARAM and u.addrspace is not None: n_bufs += 1
|
||||
elif u.op is Ops.PARAM and u.addrspace is None: n_vars += 1
|
||||
elif u.op is Ops.DEFINE_VAR: n_vars += 1
|
||||
if u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: n_vars += 1
|
||||
elif u.op is Ops.PARAM: n_bufs += 1
|
||||
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
|
||||
code_bytes = b"".join(inst.to_bytes() for inst in insts)
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
|
|||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
# TODO: add cache to fix n^2
|
||||
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK}): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
|
||||
|
|
@ -534,7 +534,7 @@ to_define_global = PatternMatcher([
|
|||
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
|
||||
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
|
||||
|
||||
# this was DEFINE_VAR, clean this up and make it universal
|
||||
# ALU params are scalar symbolic values, not buffers.
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="v"),)), lambda v: v if v.addrspace == AddrSpace.ALU else None),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ class FastEnum(IntEnum):
|
|||
class Ops(FastEnum):
|
||||
# ** 1 -- defines/special **
|
||||
|
||||
# define GLOBAL/VAR are ptrs to outside the Kernel
|
||||
DEFINE_VAR = auto(); BIND = auto()
|
||||
# BIND pairs a symbolic PARAM with a concrete value
|
||||
BIND = auto()
|
||||
|
||||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
|
@ -127,7 +127,7 @@ class GroupOp:
|
|||
|
||||
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
|
||||
Irreducible = {Ops.CONST, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
else:
|
||||
return (len(self.src),) + self.src[0].shape
|
||||
# TODO: contract and unroll should be deleted
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.CONTRACT | Ops.UNROLL | Ops.VCAT:
|
||||
case Ops.CONST | Ops.CONTRACT | Ops.UNROLL | Ops.VCAT:
|
||||
return (self.dtype.count,) if self.dtype.count > 1 else ()
|
||||
|
||||
# some ops init the shape
|
||||
|
|
@ -796,7 +796,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op in {Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
|
||||
if self.op in {Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
|
||||
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
return self.src[0].addrspace
|
||||
|
|
@ -924,7 +924,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def variables(self) -> list[Variable]:
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU)},
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU},
|
||||
key=lambda v: v.expr)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
|
@ -1018,7 +1018,6 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return self.arg.vmin_vmax
|
||||
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
|
||||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
|
|
@ -1033,7 +1032,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def _sym_fxn(self):
|
||||
from tinygrad.uop.render import _render_with_splits, renderer_infer
|
||||
sself = self.simplify()
|
||||
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU)))
|
||||
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU))
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
|
||||
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
|
||||
|
|
@ -1149,7 +1148,7 @@ class ProgramInfo:
|
|||
global_size: list[int] = [1, 1, 1]
|
||||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU): _vars.append(u)
|
||||
if u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU: _vars.append(u)
|
||||
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
|
|
@ -1158,7 +1157,7 @@ class ProgramInfo:
|
|||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
|
||||
if u.op is Ops.PARAM and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
|
||||
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(dedup(_vars), key=lambda v: v.arg.slot)),
|
||||
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
|
|||
precedence.get(x.src[1].op,99)<precedence[x.op] else right)
|
||||
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: x.arg.name if x.arg.name is not None else f"p{x.arg.slot}"),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
|
|
@ -81,8 +80,6 @@ pm_pyrender_extra = PatternMatcher([
|
|||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
|
||||
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
|
||||
|
|
|
|||
|
|
@ -52,9 +52,8 @@ spec_shared = PatternMatcher([
|
|||
# NOOP. TODO: remove this
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
|
||||
# CONST/DEFINE_VAR are everywhere
|
||||
# CONST is everywhere
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
|
||||
|
||||
# STACK is everywhere too
|
||||
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
|
||||
|
|
@ -197,7 +196,7 @@ spec_tensor = PatternMatcher([
|
|||
# these ops can exist in programs but not the tensor spec. example: LOAD
|
||||
spec_program = PatternMatcher([
|
||||
# no more of these in programs
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.GEP)), lambda: False),
|
||||
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
|
||||
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU/variable min==max -> CONST
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def _get_clause(self:UPat, base:UOp, depth=0) -> UOp:
|
|||
else: and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=self.arg)), arg="{0}.arg == {1}"))
|
||||
if self.strict_length or self.required_len > 0:
|
||||
and_clause.append(UOp(Ops.CUSTOM, src=(base,), arg=("len({0}.src)"+(" == " if self.strict_length else " >= ")+str(self.required_len))))
|
||||
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.DEFINE_VAR, arg=self.name), base)))
|
||||
if self.name is not None: and_clause.append(UOp(Ops.STORE, src=(UOp(Ops.CUSTOMI, arg=self.name), base)))
|
||||
if self.match_dtype is not None:
|
||||
if len(self.match_dtype) > 1:
|
||||
and_clause.append(UOp(Ops.CUSTOM, src=(base, UOp(Ops.BIND, arg=tuple(self.match_dtype))), arg="({0}.dtype in {1} or {0}.dtype._scalar in {1})"))
|
||||
|
|
@ -126,7 +126,7 @@ def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
|
|||
assert len(or_pieces) == 0 and len(s.src) >= 1
|
||||
for ss in s.src: or_pieces.extend(_final_render(ss, has_ctx, depth+1))
|
||||
elif s.op is Ops.STORE:
|
||||
assert s.src[0].op is Ops.DEFINE_VAR and s.src[1].op is Ops.NOOP
|
||||
assert s.src[0].op is Ops.CUSTOMI and s.src[1].op is Ops.NOOP
|
||||
store_pieces.append(f"{s.src[0].arg}={s.src[1].arg}")
|
||||
elif s.op is Ops.NOOP: and_pieces.append(s.arg)
|
||||
else: raise UPatCompileError(f"can't compile this {s}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue