Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
50e5c77814 sink PTRCAT 2025-07-25 19:52:25 -07:00
George Hotz
be0c608b0a store is dtypes.void, no ptr pass through 2025-07-25 19:43:14 -07:00
4 changed files with 7 additions and 16 deletions

View file

@ -5,7 +5,7 @@ from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.helpers import getenv, flatten, AMX, prod, partition, all_same from tinygrad.helpers import getenv, flatten, AMX, prod, partition
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
# ***** image load valid simplification ***** # ***** image load valid simplification *****
@ -111,11 +111,7 @@ def cat_after_store(cat:UOp, data:UOp, sto:UOp):
for s in cat.src: for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:])) ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
offset += s.dtype.count offset += s.dtype.count
# dtype CAT return ret[0].sink(*ret[1:])
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
assert len(dtypes) == len(ret) and all_same([(x.size, x.addrspace) for x in dtypes])
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
def gep_on_store(gep:UOp, st:UOp, sto:UOp): def gep_on_store(gep:UOp, st:UOp, sto:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep # NOTE: we need to invert the gep here, but it may be an expanding gep
@ -287,10 +283,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range]) input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
identity = red.const_like(identity_element(red.arg, red.dtype.scalar())) identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0)) acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element lst = [acc.load(acc.store(identity, UOp(Ops.NOOP, src=input_ranges)), *reduce_range)] + lst # put acc as the first element
ctx.acc_num += 1 ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
def no_vectorized_reduce(inp:UOp, red:UOp): def no_vectorized_reduce(inp:UOp, red:UOp):
if inp.dtype != red.dtype: if inp.dtype != red.dtype:

View file

@ -167,10 +167,8 @@ class CStyleLanguage(Renderer):
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l r[u] = l
else: else:
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: pass
if u.op is Ops.STORE: r[u] = r[u.src[0]] else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
else:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l) kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {Ops.IF, Ops.RANGE}: depth += 1 if u.op in {Ops.IF, Ops.RANGE}: depth += 1

View file

@ -188,9 +188,6 @@ class LLVMRenderer(Renderer):
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None: if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
kernel.append(cast(str, l)) kernel.append(cast(str, l))
# stores pass the first arg through
if u.op is Ops.STORE: r[u] = r[u.src[0]]
return tuple(local_args), self._render_fn(name, args, kernel, prefix) return tuple(local_args), self._render_fn(name, args, kernel, prefix)
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n' barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'

View file

@ -648,7 +648,7 @@ class UPat(MathTrait):
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,)) def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs) def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs) def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs) def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def fuse(self): return self.alu(Ops.FUSE) def fuse(self): return self.alu(Ops.FUSE)