This commit is contained in:
George Hotz 2025-03-24 16:10:08 +08:00
commit afd61730b4
3 changed files with 21 additions and 7 deletions

View file

@ -745,8 +745,15 @@ def get_onnx_ops():
# 3x3 depthwise (C,1,3,3)
# "width multiple of 4 depth multiple of 32 aligned to 128bytes"
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
order = (2,0,1,3)
x = x.permute(*order).contiguous().permute(*argsort(order))
if x.shape[0]%32 == 0 and False:
print("HERE")
x = x.reshape(-1, 32, 1, 3, 4)
order = (0,3,1,2,4)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x.reshape(-1, 1, 3, 4)
else:
order = (2,0,1,3)
x = x.permute(*order).contiguous().permute(*argsort(order))
x = x[:, :, :, :3]
# we increase the filts to 4-aligned for speed (75% util)
WEIGHT_SHIFT = 4

View file

@ -119,7 +119,9 @@ if __name__ == "__main__":
k.apply_opt(Opt(OptOps.UPCAST, 1, 96))
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
elif knum == 5:
k.apply_opt(Opt(OptOps.UPCAST, 2, 96))
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
k.apply_opt(Opt(OptOps.UPCAST, 2, 0))
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
# this breaks something
#k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
elif knum in [8, 12]:

View file

@ -33,10 +33,15 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
idxs: list[int|None] = [None]*vec.dtype.count
global_offset = 0
for rootsrc, offsets in offsets_rootsrc.items():
if len(offsets) == 96 and 3 not in offsets and 0 in offsets and 1 in offsets and 2 in offsets:
for i in range(3, 128, 4):
assert i not in offsets
offsets[i] = {}
if 0 in offsets:
match = True
for i in range(0, max(offsets.keys()), 4):
if i in offsets and i+1 in offsets and i+2 in offsets and i+3 not in offsets: pass
else: match = False
if match:
for i in range(0, max(offsets.keys()), 4):
assert i+3 not in offsets
offsets[i+3] = {}
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same