mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
kernel 5
This commit is contained in:
parent
536556434b
commit
afd61730b4
3 changed files with 21 additions and 7 deletions
|
|
@ -745,8 +745,15 @@ def get_onnx_ops():
|
|||
# 3x3 depthwise (C,1,3,3)
|
||||
# "width multiple of 4 depth multiple of 32 aligned to 128bytes"
|
||||
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
|
||||
order = (2,0,1,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
if x.shape[0]%32 == 0 and False:
|
||||
print("HERE")
|
||||
x = x.reshape(-1, 32, 1, 3, 4)
|
||||
order = (0,3,1,2,4)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x.reshape(-1, 1, 3, 4)
|
||||
else:
|
||||
order = (2,0,1,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x[:, :, :, :3]
|
||||
# we increase the filts to 4-aligned for speed (75% util)
|
||||
WEIGHT_SHIFT = 4
|
||||
|
|
|
|||
|
|
@ -119,7 +119,9 @@ if __name__ == "__main__":
|
|||
k.apply_opt(Opt(OptOps.UPCAST, 1, 96))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif knum == 5:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 96))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
# this breaks something
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
elif knum in [8, 12]:
|
||||
|
|
|
|||
|
|
@ -33,10 +33,15 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for rootsrc, offsets in offsets_rootsrc.items():
|
||||
if len(offsets) == 96 and 3 not in offsets and 0 in offsets and 1 in offsets and 2 in offsets:
|
||||
for i in range(3, 128, 4):
|
||||
assert i not in offsets
|
||||
offsets[i] = {}
|
||||
if 0 in offsets:
|
||||
match = True
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
if i in offsets and i+1 in offsets and i+2 in offsets and i+3 not in offsets: pass
|
||||
else: match = False
|
||||
if match:
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
assert i+3 not in offsets
|
||||
offsets[i+3] = {}
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue