mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Tensor uop spec (#8311)
* Tensor uop spec
* minor
* feedback
* restrict ShapeTracker of VIEW(BUFFER) to contiguous
* in image base mutates, how do we rewrite the view?
* cast post realize
* now ucache errors
* how strict can this be?
* put constraints on EMPTY
* merge
* save lines
* import import
* overloaded assign target
* more strict
* fine don't overload it
* more
* actually, this is better
* and it even exists
* this way it works for BUFFER
* Revert "this way it works for BUFFER"
This reverts commit 71c15f6b14.
* make it like linearize.py
* assign take 4
* minor
* all int, space and that's already base
* target
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
5776ea9386
commit
59f4b8da95
2 changed files with 86 additions and 5 deletions
|
|
@ -2,10 +2,10 @@ import sys, atexit, functools, pickle
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, exec_alu
|
||||
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||
from tinygrad.dtype import ConstType, DType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.device import Buffer
|
||||
|
|
@ -15,6 +15,85 @@ sys.setrecursionlimit(10000)
|
|||
|
||||
BUF_LIMIT = {"METAL":32}
|
||||
|
||||
# **** big graph spec
|
||||
|
||||
tensor_uop_spec = PatternMatcher([
|
||||
# ** stable and well understood specs
|
||||
|
||||
# DEVICE and BUFFER
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE),), name="buf"), lambda buf:
|
||||
# arg: (number, size)
|
||||
isinstance(buf.arg, tuple) and len(buf.arg) == 2 and all_int(buf.arg) and \
|
||||
# dtype
|
||||
isinstance(buf.dtype, (DType, ImageDType))),
|
||||
|
||||
# movement ops
|
||||
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)), lambda mv,x:
|
||||
# naturally correct
|
||||
(isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
||||
# TODO: "make things that can't be images not images" can override the source dtype
|
||||
# is there a clean way to update its _mop children?
|
||||
(isinstance(mv.dtype, ImageDType) and x.dtype == mv.dtype.base and x.is_realized)),
|
||||
|
||||
# tensor variable bindings
|
||||
(UPat(Ops.BIND, dtype=dtypes.int, src=(UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# ** specs with room for refactoring and improving
|
||||
|
||||
# COPY
|
||||
(UPat(Ops.COPY, name="copy", src=(UPat.var("copyin"),)), lambda copy,copyin:
|
||||
# arg (device, clone?)
|
||||
isinstance(copy.arg, tuple) and len(copy.arg) == 2 and isinstance(copy.arg[0], str) and isinstance(copy.arg[1], bool) and \
|
||||
# dtype
|
||||
copy.dtype == copyin.dtype),
|
||||
|
||||
# VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
|
||||
# NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
|
||||
lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),
|
||||
|
||||
# ASSIGN changes the value of an existing buffer
|
||||
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))), lambda assign,target,new_val:
|
||||
# target must be a realized device buffer
|
||||
(target.op is Ops.BUFFER or target.is_realized) and
|
||||
# dtype
|
||||
(assign.dtype == target.dtype == new_val.dtype) and
|
||||
# arg (TODO: replace this ShapeTracker arg with a VIEW on the target BUFFER)
|
||||
# NOTE: this ShapeTracker must not change shape, but it's free to swizzle the STORE st
|
||||
(assign.arg is None or (isinstance(assign.arg, ShapeTracker) and not assign.arg.contiguous and assign.arg.shape == assign.shape))),
|
||||
|
||||
# ** TODO: these UOps need new specs, the current representation relies on hacks
|
||||
|
||||
# BUFFER and VIEW specify shape and device for meta ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"), UPat(GroupOp.Meta, name="uop"))),
|
||||
lambda view,buf,uop: view.dtype == buf.dtype == uop.dtype and view.size == buf.size),
|
||||
|
||||
# Tensor const has a ShapeTracker of shape=() and fake buffer of size 1
|
||||
(UPat(Ops.VIEW, name="view", arg=ShapeTracker.from_shape(()), src=(UPat(Ops.BUFFER, name="fake", arg=(-1, 1)),
|
||||
UPat({Ops.CONST, Ops.BIND}, name="const_uop"))),
|
||||
lambda view,fake,const_uop: view.dtype == fake.dtype == const_uop.dtype),
|
||||
|
||||
# NOTE: EMPTY just ensures the source BUFFER is allocated before children run
|
||||
# TODO: this should be EMPTY(VIEW(BUFFER))
|
||||
(UPat(Ops.EMPTY, src=(), arg=None), lambda: True),
|
||||
|
||||
# TODO: BUFFER_VIEW is overloaded, can we break it into multiple well defined UOps?
|
||||
# BUFFER_VIEW shares the device buffer with its source, it uses a subbuffer of the underlying source buffer
|
||||
|
||||
(UPat(Ops.BUFFER_VIEW, name="root", src=(UPat.var("x"),)), lambda root,x:
|
||||
# BUFFER_VIEW can replace contiguous, keeping dtype the same
|
||||
(root.dtype == x.dtype) or
|
||||
# it can also replace bitcast, this changes the dtype, but the itemsize stays the same
|
||||
(root.dtype != x.dtype and root.dtype.itemsize == x.dtype.itemsize) or
|
||||
# it can also represent shape changing bitcast (only on DISK)
|
||||
(root.dtype != x.dtype and root.dtype.itemsize != x.dtype.itemsize and x.device.startswith("DISK"))),
|
||||
])
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -498,7 +577,8 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER,
|
|||
remove_movement_ops = PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:list[UOp]) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
if not skip_check: type_verify(list(UOp.sink(*outs).toposort), extra_spec=tensor_uop_spec)
|
||||
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
|
|
|
|||
|
|
@ -987,9 +987,10 @@ spec = PatternMatcher([
|
|||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
])
|
||||
|
||||
def type_verify(uops:list[UOp]):
|
||||
def type_verify(uops:list[UOp], extra_spec:Optional[PatternMatcher]=None):
|
||||
spec_pm = spec if extra_spec is None else spec+extra_spec
|
||||
for i,u in enumerate(uops):
|
||||
if not spec.rewrite(u):
|
||||
if not spec_pm.rewrite(u):
|
||||
print_uops(uops)
|
||||
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue