mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
simplest migration of indexing [pr] (#7402)
* simplest migration of indexing [pr] * fix locals/barrier
This commit is contained in:
parent
ee9ef93617
commit
f3bd5cbf78
2 changed files with 34 additions and 14 deletions
|
|
@ -868,7 +868,7 @@ class TestOps(unittest.TestCase):
|
|||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"], "not supported on these in CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE")
|
||||
def test_gemm_fp16(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
|
||||
def test_gemm(self):
|
||||
|
|
|
|||
|
|
@ -489,21 +489,37 @@ reducer = PatternMatcher([
|
|||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
idx = x.src[0].index(x.src[1])
|
||||
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
|
||||
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
|
||||
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
|
||||
return UOp(x.op, x.dtype, (idx,)+x.src[2:], x.arg)
|
||||
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is UOps.LOAD else x.src[3:])
|
||||
if x.op is UOps.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
|
||||
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
|
||||
|
||||
indexing = PatternMatcher([
|
||||
migrate_indexing = PatternMatcher([
|
||||
# use indexing for LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
])
|
||||
|
||||
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
|
||||
# this moves the mask from the indexing to the load/store op for rendering
|
||||
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
||||
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
||||
|
||||
masked_index = UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask")))
|
||||
move_masks = PatternMatcher([
|
||||
# NOTE: this shouldn't be here
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
|
||||
# fix up loads/stores
|
||||
# TODO: this should be an IF instead of a masked STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index, masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
linearize_cnt = 0
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
global linearize_cnt, acc_number
|
||||
global acc_number
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
# do graph rewrite
|
||||
|
|
@ -511,14 +527,18 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
sink = graph_rewrite(sink, sym)
|
||||
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
if getenv("DO_REDUCE", 1):
|
||||
sink = graph_rewrite(sink, sym+just_reduce)
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
sink = graph_rewrite(sink, sym+indexing+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN
|
||||
sink = graph_rewrite(sink, sym+just_reduce)
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
|
||||
# temp for indexing migration
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, sym+move_masks+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
|
||||
return sink
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue