fix FUSE_ARANGE=1 for bert (#10255)

This commit is contained in:
qazal 2025-05-12 14:44:05 +03:00 committed by GitHub
commit 95c6a736a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View file

@ -167,7 +167,6 @@ class TestRealWorld(unittest.TestCase):
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
@unittest.expectedFailure # TODO: fix FUSE_ARANGE
def test_bert_fuse_arange(self):
with Context(FUSE_ARANGE=1):
self.test_bert()

View file

@ -2,7 +2,7 @@ from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.ops import can_pad, sint, track_rewrites, _substitute
from tinygrad.codegen.lowerer import get_contraction_with_reduce
from tinygrad.codegen.lowerer import get_contraction_with_reduce, get_contraction
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, ContextVar, Context, diskcache_put
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
@ -325,7 +325,11 @@ def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if resolve(s != u))).view(ShapeTracker.from_shape(r.shape))
if (contraction:=get_contraction(v.shape, src.shape)) is None: return None
new_axis: list[int] = []
for i,pairs in enumerate(contraction):
if any(x in r.axis_arg for x in pairs): new_axis.append(i)
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None