mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
int32, and refactor pad/shrink
This commit is contained in:
parent
fb5ee9260f
commit
dbbaa0bdd7
2 changed files with 14 additions and 11 deletions
|
|
@ -14,7 +14,7 @@ def shapetracker_getitem(st, val):
|
|||
class CheckingShapeTracker:
|
||||
def __init__(self, shape):
|
||||
self.st = ShapeTracker(shape)
|
||||
self.t = np.arange(prod(shape), dtype=np.int).reshape(shape)
|
||||
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
|
|
|||
|
|
@ -165,22 +165,25 @@ class ShapeTracker:
|
|||
|
||||
# *** under this line are not invertible ***
|
||||
|
||||
# TODO: take this functionality out of slice
|
||||
def _resize(self, arg : Tuple[Tuple[int, int], ...]):
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
|
||||
|
||||
def pad(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple)
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
return self.shrink(tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)))
|
||||
if all(b==0 and e==0 for b,e in arg): return self # ZeroView is expensive if we don't need it
|
||||
zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg))
|
||||
zeroview = ZeroView(self.shape, zvarg)
|
||||
self._resize(zvarg)
|
||||
# if we add a ZeroView, we add another (stock) view also for modding
|
||||
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
|
||||
return self
|
||||
|
||||
# TODO: take the pad functionality out of shrink
|
||||
def shrink(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple)
|
||||
assert len(arg) == len(self.shape)
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
zeroview = ZeroView(self.shape, arg)
|
||||
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
|
||||
if zeroview.expr_node().min == 0: # may be invalid
|
||||
# if we add a ZeroView, we add another (stock) view also for modding
|
||||
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
|
||||
self._resize(arg)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
|
||||
|
|
@ -191,7 +194,7 @@ class ShapeTracker:
|
|||
self.views[-1] = View(new_shape, strides, self.offset)
|
||||
return self
|
||||
|
||||
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
||||
# TODO: combine with flip? this is more generic than we need
|
||||
def stride(self, mul : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(mul, tuple)
|
||||
assert all(isinstance(x, int) for x in mul)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue