mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix FUSE_ARANGE=1 for bert (#10255)
This commit is contained in:
parent
7c4b381fbf
commit
95c6a736a9
2 changed files with 6 additions and 3 deletions
|
|
@ -167,7 +167,6 @@ class TestRealWorld(unittest.TestCase):
|
|||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
|
||||
|
||||
@unittest.expectedFailure # TODO: fix FUSE_ARANGE
|
||||
def test_bert_fuse_arange(self):
|
||||
with Context(FUSE_ARANGE=1):
|
||||
self.test_bert()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections import defaultdict, deque
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.ops import can_pad, sint, track_rewrites, _substitute
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
|
||||
|
|
@ -325,7 +325,11 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
|||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if resolve(s != u))).view(ShapeTracker.from_shape(r.shape))
|
||||
if (contraction:=get_contraction(v.shape, src.shape)) is None: return None
|
||||
new_axis: list[int] = []
|
||||
for i,pairs in enumerate(contraction):
|
||||
if any(x in r.axis_arg for x in pairs): new_axis.append(i)
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue