mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
no_tuplize
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af80611ec1 |
5 changed files with 7 additions and 10 deletions
|
|
@ -1098,7 +1098,7 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
|||
canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}"
|
||||
sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1)
|
||||
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
|
||||
with Context(NOOPT=1, CHECK_OOB=0, EMULATED_DTYPES=""):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner, True
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import heapq
|
|||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
||||
from tinygrad.helpers import prod, getenv
|
||||
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
|
|
@ -34,7 +34,7 @@ def linearize(sink:UOp) -> list[UOp]:
|
|||
priorities[u] = (run_count, priority, extra)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+((x.op.value, x.arg, x.dtype),)))}
|
||||
|
||||
# then force them to be toposorted in as close to the ideal order as possible
|
||||
heap = [(-nkey[sink], sink)]
|
||||
|
|
|
|||
|
|
@ -199,8 +199,6 @@ SPEC = ContextVar("SPEC", 1)
|
|||
CHECK_OOB = ContextVar("CHECK_OOB", 0)
|
||||
PCONTIG = ContextVar("PCONTIG", 0) # partial contiguous in rangeify
|
||||
DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
|
||||
# set to 1, this uses tuplize in the linearizer sort order
|
||||
TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
|
||||
# set to 0 to disable the compiler cache
|
||||
CCACHE = ContextVar("CCACHE", 1)
|
||||
# allow tf32 to be used on NVIDIA GPUs
|
||||
|
|
|
|||
|
|
@ -191,10 +191,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# returns map of UOps to their consumers in the graph rooted by self
|
||||
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> tuple:
|
||||
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
|
||||
|
||||
@property
|
||||
def ptrdtype(self) -> PtrDType:
|
||||
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
|
||||
|
|
|
|||
|
|
@ -179,7 +179,10 @@ gep_pushing = PatternMatcher([
|
|||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for index) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.index, name='x'),
|
||||
lambda x: x.replace(src=x.src[::-1]) if (x.src[1].op.value, x.src[1].arg, x.src[1].dtype,
|
||||
tuple((s.op.value, s.arg, s.dtype) for s in x.src[1].src)) < (x.src[0].op.value, x.src[0].arg, x.src[0].dtype,
|
||||
tuple((s.op.value, s.arg, s.dtype) for s in x.src[0].src)) else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue