mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
first LazyBuffer optimizations
This commit is contained in:
parent
a1a20891ef
commit
57ebce8d67
3 changed files with 25 additions and 9 deletions
|
|
@ -60,8 +60,8 @@ class GPUBuffer:
|
|||
if self._buf is None: self._buf = CL.malloc(4*prod(self.shape))
|
||||
return self._buf
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
def __repr__(self): return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
def shapeTrackerView(self, st:ShapeTracker): return GPUBuffer(st, hostbuf=self)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, NamedTuple, Union, Any, List, Dict, Type
|
||||
import functools, operator
|
||||
import sys, functools, operator
|
||||
from tinygrad.helpers import ConvArgs
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
|
|
@ -13,8 +13,11 @@ LoadOps = Enum("LoadOps", ["FROMCPU"])
|
|||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[ProcessingOps], Type[LoadOps]]
|
||||
|
||||
# -O1
|
||||
MERGE_MOVEMENT_OPS = True
|
||||
REMOVE_MOVEMENT_NOPS = True
|
||||
|
||||
# lazy can recurse a lot
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
import os
|
||||
|
|
@ -99,8 +102,11 @@ def _realize(self:LazyBuffer) -> DeviceBuffer:
|
|||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
|
||||
elif self.optype == MovementOps:
|
||||
real_src = self.op.src[0].realize(self.device)
|
||||
return real_src.movement_op(self.op.op, self.op.arg), [real_src]
|
||||
real_src = get_lazybuffers(self.op)[0].realize()
|
||||
if getattr(real_src, "shapeTrackerView", None) is not None:
|
||||
return real_src.shapeTrackerView(self.st), [real_src]
|
||||
else:
|
||||
return functools.reduce(lambda x,o: x.movement_op(o.op, o.arg), get_lazyops(self.op)[::-1], real_src), [real_src]
|
||||
elif self.optype == UnaryOps:
|
||||
real_src_x = self.op.src[0].realize(self.device)
|
||||
return real_src_x.unary_op(self.op.op), [real_src_x]
|
||||
|
|
@ -167,7 +173,16 @@ class LazyBuffer:
|
|||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
|
||||
|
||||
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps, LazyOp(op, (x,), arg))
|
||||
# if a MovementOp is applied to a MovementOp, merge them and use one buffer
|
||||
ret = LazyBuffer(x.device, ShapeTracker(x.st).movement_op(op, arg), MovementOps,
|
||||
LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
|
||||
|
||||
if REMOVE_MOVEMENT_NOPS and x.realized is None and ret.st.contiguous:
|
||||
root = get_lazybuffers(ret.op)[0]
|
||||
if ret.st.shape == root.shape:
|
||||
return root
|
||||
|
||||
return ret
|
||||
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
|
|
|||
|
|
@ -29,10 +29,11 @@ class Tensor:
|
|||
raise Exception(f"can't create Tensor from {data}")
|
||||
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad, self.requires_grad = None, requires_grad
|
||||
self.grad : Optional[Tensor] = None
|
||||
self.requires_grad = requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx = None
|
||||
self._ctx : Optional[Function] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata if self.lazydata.realized is None else self.lazydata.realized!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue