mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
store_is_v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50e5c77814 | ||
|
|
be0c608b0a |
4 changed files with 7 additions and 16 deletions
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition, all_same
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
|
@ -111,11 +111,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
|||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
# dtype CAT
|
||||
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
||||
assert len(dtypes) == len(ret) and all_same([(x.size, x.addrspace) for x in dtypes])
|
||||
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
||||
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
||||
return ret[0].sink(*ret[1:])
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
|
|
@ -287,10 +283,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element
|
||||
lst = [acc.load(acc.store(identity, UOp(Ops.NOOP, src=input_ranges)), *reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
|
|
|
|||
|
|
@ -167,10 +167,8 @@ class CStyleLanguage(Renderer):
|
|||
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void:
|
||||
if u.op is Ops.STORE: r[u] = r[u.src[0]]
|
||||
else:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: pass
|
||||
else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
||||
|
|
|
|||
|
|
@ -188,9 +188,6 @@ class LLVMRenderer(Renderer):
|
|||
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
|
||||
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
kernel.append(cast(str, l))
|
||||
|
||||
# stores pass the first arg through
|
||||
if u.op is Ops.STORE: r[u] = r[u.src[0]]
|
||||
return tuple(local_args), self._render_fn(name, args, kernel, prefix)
|
||||
|
||||
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
||||
|
|
|
|||
|
|
@ -648,7 +648,7 @@ class UPat(MathTrait):
|
|||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
|
||||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue