simplest migration of indexing [pr] (#7402)

* simplest migration of indexing [pr]

* fix locals/barrier
This commit is contained in:
George Hotz 2024-10-30 19:58:18 +07:00 committed by GitHub
commit f3bd5cbf78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 14 deletions

View file

@ -868,7 +868,7 @@ class TestOps(unittest.TestCase):
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"], "not supported on these in CI")
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE")
def test_gemm_fp16(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
def test_gemm(self):

View file

@ -489,21 +489,37 @@ reducer = PatternMatcher([
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1])
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is UOps.LOAD else x.src[3:])
if x.op is UOps.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
indexing = PatternMatcher([
migrate_indexing = PatternMatcher([
# use indexing for LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
])
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
# this moves the mask from the indexing to the load/store op for rendering
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
masked_index = UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask")))
move_masks = PatternMatcher([
# NOTE: this shouldn't be here
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
# fix up loads/stores
# TODO: this should be an IF instead of a masked STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index, masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
])
# *** uop graph ***
linearize_cnt = 0
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
global linearize_cnt, acc_number
global acc_number
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
# do graph rewrite
@ -511,14 +527,18 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
sink = graph_rewrite(sink, sym)
# expand
linearize_cnt += 1
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
sink = graph_rewrite(sink, sym+expander)
if getenv("DO_REDUCE", 1):
sink = graph_rewrite(sink, sym+just_reduce)
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, sym+reducer)
sink = graph_rewrite(sink, sym+indexing+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
sink = graph_rewrite(sink, sym+expander)
# convert REDUCE to DEFINE_ACC + ASSIGN
sink = graph_rewrite(sink, sym+just_reduce)
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, sym+reducer)
# temp for indexing migration
sink = graph_rewrite(sink, sym+migrate_indexing)
# finalize
sink = graph_rewrite(sink, sym+move_masks+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
return sink