Tensor uop spec (#8311)

* Tensor uop spec

* minor

* feedback

* restrict ShapeTracker of VIEW(BUFFER) to contiguous

* in image base mutates, how do we rewrite the view?

* cast post realize

* now ucache errors

* how strict can this be?

* put constraints on EMPTY

* merge

* save lines

* import import

* overloaded assign target

* more strict

* fine don't overload it

* more

* actually, this is better

* and it even exists

* this way it works for BUFFER

* Revert "this way it works for BUFFER"

This reverts commit 71c15f6b14.

* make it like linearize.py

* assign take 4

* minor

* all int, space and that's already base

* target

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal 2024-12-20 17:47:40 +02:00 committed by GitHub
commit 59f4b8da95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 5 deletions

View file

@ -2,10 +2,10 @@ import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, exec_alu
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
from tinygrad.dtype import ConstType, ImageDType, dtypes
from tinygrad.dtype import ConstType, DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.device import Buffer
@ -15,6 +15,85 @@ sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
# **** big graph spec
tensor_uop_spec = PatternMatcher([
# ** stable and well understood specs
# DEVICE and BUFFER
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"), lambda buf:
# arg: (number, size)
isinstance(buf.arg, tuple) and len(buf.arg) == 2 and all_int(buf.arg) and \
# dtype
isinstance(buf.dtype, (DType, ImageDType))),
# movement ops
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)), lambda mv,x:
# naturally correct
(isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
# TODO: "make things that can't be images not images" can override the source dtype
# is there a clean way to update its _mop children?
(isinstance(mv.dtype, ImageDType) and x.dtype == mv.dtype.base and x.is_realized)),
# tensor variable bindings
(UPat(Ops.BIND, dtype=dtypes.int, src=(UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
# ** specs with room for refactoring and improving
# COPY
(UPat(Ops.COPY, name="copy", src=(UPat.var("copyin"),)), lambda copy,copyin:
# arg (device, clone?)
isinstance(copy.arg, tuple) and len(copy.arg) == 2 and isinstance(copy.arg[0], str) and isinstance(copy.arg[1], bool) and \
# dtype
copy.dtype == copyin.dtype),
# VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
# NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
# ASSIGN changes the value of an existing buffer
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))), lambda assign,target,new_val:
# target must be a realized device buffer
(target.op is Ops.BUFFER or target.is_realized) and
# dtype
(assign.dtype == target.dtype == new_val.dtype) and
# arg (TODO: replace this ShapeTracker arg with a VIEW on the target BUFFER)
# NOTE: this ShapeTracker must not change shape, but it's free to swizzle the STORE st
(assign.arg is None or (isinstance(assign.arg, ShapeTracker) and not assign.arg.contiguous and assign.arg.shape == assign.shape))),
# ** TODO: these UOps need new specs, the current representation relies on hacks
# BUFFER and VIEW specify shape and device for meta ops
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"), UPat(GroupOp.Meta, name="uop"))),
lambda view,buf,uop: view.dtype == buf.dtype == uop.dtype and view.size == buf.size),
# Tensor const has a ShapeTracker of shape=() and fake buffer of size 1
(UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.BUFFER, name="fake", arg=(-1, 1)),
UPat({Ops.CONST, Ops.BIND}, name="const_uop"))),
lambda view,fake,const_uop: view.dtype == fake.dtype == const_uop.dtype),
# NOTE: EMPTY just ensures the source BUFFER is allocated before children run
# TODO: this should be EMPTY(VIEW(BUFFER))
(UPat(Ops.EMPTY, src=(), arg=None), lambda: True),
# TODO: BUFFER_VIEW is overloaded, can we break it into multiple well defined UOps?
# BUFFER_VIEW shares the device buffer with its source, it uses a subbuffer of the underlying source buffer
(UPat(Ops.BUFFER_VIEW, name="root", src=(UPat.var("x"),)), lambda root,x:
# BUFFER_VIEW can replace contiguous, keeping dtype the same
(root.dtype == x.dtype) or
# it can also replace bitcast, this changes the dtype, but the itemsize stays the same
(root.dtype != x.dtype and root.dtype.itemsize == x.dtype.itemsize) or
# it can also represent shape changing bitcast (only on DISK)
(root.dtype != x.dtype and root.dtype.itemsize != x.dtype.itemsize and x.device.startswith("DISK"))),
])
# **** ScheduleItem return type
@dataclass(frozen=True)
@ -498,7 +577,8 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER,
remove_movement_ops = PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:list[UOp]) -> tuple[list[ScheduleItem], dict[Variable, int]]:
def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int]]:
if not skip_check: type_verify(list(UOp.sink(*outs).toposort), extra_spec=tensor_uop_spec)
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
# create the big graph
ctx = ScheduleContext()

View file

@ -987,9 +987,10 @@ spec = PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
])
def type_verify(uops:list[UOp]):
def type_verify(uops:list[UOp], extra_spec:Optional[PatternMatcher]=None):
spec_pm = spec if extra_spec is None else spec+extra_spec
for i,u in enumerate(uops):
if not spec.rewrite(u):
if not spec_pm.rewrite(u):
print_uops(uops)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")