define var (#3548)

* define var

* remove vars from there

* fix python symbolic ops

* fix llvm

* pypath
This commit is contained in:
George Hotz 2024-02-29 16:43:27 -08:00 committed by GitHub
commit 2c19ab6561
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 26 additions and 22 deletions

View file

@ -49,6 +49,8 @@ jobs:
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
- name: Test ops with Python emulator
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)"
- name: Test symbolic with Python emulator
run: PYTHONPATH=. PYTHON=1 python3 test/test_symbolic_ops.py
linter:
name: Linters

View file

@ -190,20 +190,15 @@ class Linearizer(Kernel):
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
buf_count = 0
buf_index = {}
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
if buf.idx not in buf_index:
buf_index[buf.idx] = buf_count
buf_count += 1
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf_index[buf.idx], f"data{buf.idx}"))
(buf.idx, f"data{buf.idx}"))
# add var vals
for i,var in enumerate(self.ast.vars()):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (len(buf_index)+i, var.expr))
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_VAR, dtypes.int32, (), var)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))

View file

@ -1,16 +1,16 @@
from __future__ import annotations
from typing import List, Set, Optional, Tuple, Any, Dict
from typing import List, Set, Optional, Tuple, Any
from tinygrad.helpers import DEBUG, flatten
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.symbolic import sint
from enum import Enum, auto
from dataclasses import dataclass
# bottom ones are asm only
class UOps(Enum):
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
@ -33,7 +33,7 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
deps.add(u)
return deps
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR}
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
while 1:
@ -83,17 +83,17 @@ def uops_type_verify(uops:List[UOp]):
assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
def uops_alu_resolve(u:UOp, vars:Dict[str, Variable]) -> sint:
def uops_alu_resolve(u:UOp) -> sint:
if u.uop == UOps.CONST: return u.arg
elif u.uop == UOps.DEFINE_GLOBAL: return vars[u.arg[1]]
elif u.uop == UOps.DEFINE_VAR: return u.arg
elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL:
return uops_alu_resolve(u.vin[0], vars) * uops_alu_resolve(u.vin[1], vars)
return uops_alu_resolve(u.vin[0]) * uops_alu_resolve(u.vin[1])
elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD:
return uops_alu_resolve(u.vin[0], vars) + uops_alu_resolve(u.vin[1], vars)
return uops_alu_resolve(u.vin[0]) + uops_alu_resolve(u.vin[1])
else:
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint]:
def uops_flops_mem(uops:List[UOp]) -> Tuple[sint, sint]:
flops: sint = 0
mem: sint = 0
mults: sint = 1
@ -101,7 +101,7 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint
for u in uops:
if u.uop is UOps.LOOP:
mult_stack.append(mults)
mults *= uops_alu_resolve(u.vin[1], vars)
mults *= uops_alu_resolve(u.vin[1])
if u.uop is UOps.ENDLOOP:
mults = mult_stack.pop(-1)
if u.uop is UOps.ALU:

View file

@ -238,7 +238,7 @@ class Compiled:
ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
from tinygrad.codegen.uops import uops_flops_mem
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
ops, mem = uops_flops_mem(k.uops, {x.expr:x for x in ret.vars})
ops, mem = uops_flops_mem(k.uops)
# NOTE: we use min here to ignore the indexing FLOPS
ret.op_estimate = min(ret.op_estimate, ops * run_count)
ret.mem_estimate = min(ret.mem_estimate, mem * run_count)

View file

@ -141,7 +141,8 @@ class LazyBuffer:
# *** movement ops ***
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
if self.st.size == 0: return self.const(0, new_st.shape)
if self.st.size == 0 or (new_st.views[-1].mask is not None and all((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return self.const(0, new_st.shape)
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)

View file

@ -161,6 +161,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
elif uop is UOps.DEFINE_LOCAL:
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = args.expr
elif uop is UOps.DEFINE_GLOBAL:
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], dtype))

View file

@ -70,8 +70,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
# extract global buffers
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL}
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
elif uop is UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]]
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")

View file

@ -56,6 +56,7 @@ class PythonProgram:
ul: Dict[int, Any] = {}
dl: Dict[int, DType] = {}
pbufs: List[memoryview] = list(bufs)
pvals: List[int] = list(vals)
i = 0
loop_ends: Dict[int, int] = {}
while i < len(self.uops):
@ -97,6 +98,8 @@ class PythonProgram:
assert dtype.fmt is not None
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
elif uop is UOps.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is UOps.SPECIAL:
if arg[1][0] == 'g':
ul[i] = [idxs[2-arg[0]]] * warp_size