mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
huh, this prevents an extra kernel
This commit is contained in:
parent
487685919b
commit
f6fc2a0d98
1 changed files with 3 additions and 2 deletions
|
|
@ -250,7 +250,6 @@ class LazyBuffer:
|
|||
x = self
|
||||
|
||||
if IMAGE >= 1:
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
added_output_channels = 0
|
||||
|
||||
|
|
@ -258,8 +257,10 @@ class LazyBuffer:
|
|||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))])
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))])
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
|
|
@ -269,7 +270,7 @@ class LazyBuffer:
|
|||
|
||||
# packed
|
||||
assert (C.groups*C.cin) % 4 == 0
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,3,4,1,2))
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,2,3,1))
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
|
||||
|
||||
assert C.cout % 4 == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue