mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
define var (#3548)
* define var * remove vars from there * fix python symbolic ops * fix llvm * pypath
This commit is contained in:
parent
83cdc85790
commit
2c19ab6561
8 changed files with 26 additions and 22 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -49,6 +49,8 @@ jobs:
|
|||
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
|
||||
- name: Test ops with Python emulator
|
||||
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)"
|
||||
- name: Test symbolic with Python emulator
|
||||
run: PYTHONPATH=. PYTHON=1 python3 test/test_symbolic_ops.py
|
||||
|
||||
linter:
|
||||
name: Linters
|
||||
|
|
|
|||
|
|
@ -190,20 +190,15 @@ class Linearizer(Kernel):
|
|||
self.loop_uops: Dict[str, UOp] = {}
|
||||
|
||||
# add global buffers
|
||||
buf_count = 0
|
||||
buf_index = {}
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
if buf.idx not in buf_index:
|
||||
buf_index[buf.idx] = buf_count
|
||||
buf_count += 1
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL,
|
||||
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
(buf_index[buf.idx], f"data{buf.idx}"))
|
||||
(buf.idx, f"data{buf.idx}"))
|
||||
# add var vals
|
||||
for i,var in enumerate(self.ast.vars()):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (len(buf_index)+i, var.expr))
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
from __future__ import annotations
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict
|
||||
from typing import List, Set, Optional, Tuple, Any
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum):
|
||||
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
|
|||
deps.add(u)
|
||||
return deps
|
||||
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR}
|
||||
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
|
||||
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
||||
while 1:
|
||||
|
|
@ -83,17 +83,17 @@ def uops_type_verify(uops:List[UOp]):
|
|||
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
|
||||
|
||||
def uops_alu_resolve(u:UOp, vars:Dict[str, Variable]) -> sint:
|
||||
def uops_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop == UOps.CONST: return u.arg
|
||||
elif u.uop == UOps.DEFINE_GLOBAL: return vars[u.arg[1]]
|
||||
elif u.uop == UOps.DEFINE_VAR: return u.arg
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL:
|
||||
return uops_alu_resolve(u.vin[0], vars) * uops_alu_resolve(u.vin[1], vars)
|
||||
return uops_alu_resolve(u.vin[0]) * uops_alu_resolve(u.vin[1])
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD:
|
||||
return uops_alu_resolve(u.vin[0], vars) + uops_alu_resolve(u.vin[1], vars)
|
||||
return uops_alu_resolve(u.vin[0]) + uops_alu_resolve(u.vin[1])
|
||||
else:
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
|
||||
def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint]:
|
||||
def uops_flops_mem(uops:List[UOp]) -> Tuple[sint, sint]:
|
||||
flops: sint = 0
|
||||
mem: sint = 0
|
||||
mults: sint = 1
|
||||
|
|
@ -101,7 +101,7 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint
|
|||
for u in uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
mult_stack.append(mults)
|
||||
mults *= uops_alu_resolve(u.vin[1], vars)
|
||||
mults *= uops_alu_resolve(u.vin[1])
|
||||
if u.uop is UOps.ENDLOOP:
|
||||
mults = mult_stack.pop(-1)
|
||||
if u.uop is UOps.ALU:
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class Compiled:
|
|||
ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
|
||||
from tinygrad.codegen.uops import uops_flops_mem
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
ops, mem = uops_flops_mem(k.uops, {x.expr:x for x in ret.vars})
|
||||
ops, mem = uops_flops_mem(k.uops)
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret.op_estimate = min(ret.op_estimate, ops * run_count)
|
||||
ret.mem_estimate = min(ret.mem_estimate, mem * run_count)
|
||||
|
|
|
|||
|
|
@ -141,7 +141,8 @@ class LazyBuffer:
|
|||
# *** movement ops ***
|
||||
|
||||
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
||||
if self.st.size == 0: return self.const(0, new_st.shape)
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and all((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
||||
return self.const(0, new_st.shape)
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
||||
|
||||
|
|
|
|||
|
|
@ -161,6 +161,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
|||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = args.expr
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
|
||||
bufs.append((args[1], dtype))
|
||||
|
|
|
|||
|
|
@ -70,8 +70,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
|
|||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# extract global buffers
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL}
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
|
|
@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
|
|||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
elif uop is UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class PythonProgram:
|
|||
ul: Dict[int, Any] = {}
|
||||
dl: Dict[int, DType] = {}
|
||||
pbufs: List[memoryview] = list(bufs)
|
||||
pvals: List[int] = list(vals)
|
||||
i = 0
|
||||
loop_ends: Dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
|
|
@ -97,6 +98,8 @@ class PythonProgram:
|
|||
assert dtype.fmt is not None
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
||||
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is UOps.SPECIAL:
|
||||
if arg[1][0] == 'g':
|
||||
ul[i] = [idxs[2-arg[0]]] * warp_size
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue