mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
TRAINING ContextVar (#16703)
This commit is contained in:
parent
fe9b19b12d
commit
625d8bbd0d
2 changed files with 13 additions and 6 deletions
|
|
@ -240,6 +240,7 @@ DEV, DEBUG, BEAM, NOOPT = _DEV("DEV", ""), ContextVar("DEBUG", 0), ContextVar("B
|
|||
IMAGE, FLOAT16, OPENPILOT_HACKS = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0), ContextVar("OPENPILOT_HACKS", 0)
|
||||
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
|
||||
WINO, CAPTURING, TRACEMETA, NO_COLOR = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("NO_COLOR", 0)
|
||||
TRAINING = ContextVar("TRAINING", 0)
|
||||
USE_TC, TC_SELECT, TC_OPT = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0)
|
||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@
|
|||
from __future__ import annotations
|
||||
import time, math, itertools, functools, sys, inspect, pathlib, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
from typing import Any, Callable, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING
|
||||
if TYPE_CHECKING: import numpy
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid
|
||||
from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, fully_flatten, ceildiv, fetch, flat_to_grouped
|
||||
from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc, TRAINING
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
|
||||
from tinygrad.mixin.rand import RandMixin
|
||||
from tinygrad.schedule import create_linear_with_vars
|
||||
|
|
@ -59,7 +59,14 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
class Tensor(RandMixin):
|
||||
# TODO: deprecate this, always use TRAINING
|
||||
class TensorMeta(type):
|
||||
@property
|
||||
def training(cls) -> bool: return bool(TRAINING.value)
|
||||
@training.setter
|
||||
def training(cls, mode:bool): TRAINING.value = int(mode)
|
||||
|
||||
class Tensor(RandMixin, metaclass=TensorMeta):
|
||||
"""
|
||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||
|
||||
|
|
@ -71,7 +78,6 @@ class Tensor(RandMixin):
|
|||
```
|
||||
"""
|
||||
__slots__ = "uop", "is_param", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None):
|
||||
|
|
@ -150,8 +156,8 @@ class Tensor(RandMixin):
|
|||
|
||||
class train(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
def __enter__(self): self.prev, TRAINING.value = TRAINING.value, int(self.mode)
|
||||
def __exit__(self, exc_type, exc_value, traceback): TRAINING.value = self.prev
|
||||
|
||||
def __repr__(self):
|
||||
ld = self.uop
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue