int32, and refactor pad/shrink

This commit is contained in:
George Hotz 2023-03-09 12:57:17 -08:00
commit dbbaa0bdd7
2 changed files with 14 additions and 11 deletions

View file

@ -14,7 +14,7 @@ def shapetracker_getitem(st, val):
class CheckingShapeTracker:
def __init__(self, shape):
self.st = ShapeTracker(shape)
self.t = np.arange(prod(shape), dtype=np.int).reshape(shape)
self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape)
@property
def shape(self):

View file

@ -165,22 +165,25 @@ class ShapeTracker:
# *** under this line are not invertible ***
# TODO: take this functionality out of slice
def _resize(self, arg : Tuple[Tuple[int, int], ...]):
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
def pad(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
assert isinstance(arg, tuple)
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
return self.shrink(tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)))
if all(b==0 and e==0 for b,e in arg): return self # ZeroView is expensive if we don't need it
zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg))
zeroview = ZeroView(self.shape, zvarg)
self._resize(zvarg)
# if we add a ZeroView, we add another (stock) view also for modding
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
return self
# TODO: take the pad functionality out of shrink
def shrink(self, arg : Tuple[Tuple[int, int], ...]) -> ShapeTracker:
assert isinstance(arg, tuple)
assert len(arg) == len(self.shape)
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
zeroview = ZeroView(self.shape, arg)
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
if zeroview.expr_node().min == 0: # may be invalid
# if we add a ZeroView, we add another (stock) view also for modding
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
self._resize(arg)
return self
def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
@ -191,7 +194,7 @@ class ShapeTracker:
self.views[-1] = View(new_shape, strides, self.offset)
return self
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
# TODO: combine with flip? this is more generic than we need
def stride(self, mul : Tuple[int, ...]) -> ShapeTracker:
assert isinstance(mul, tuple)
assert all(isinstance(x, int) for x in mul)