mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
renderers_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8a66639a6 | ||
|
|
0095b91ce6 | ||
|
|
3e13225f62 |
7 changed files with 30 additions and 28 deletions
|
|
@ -55,7 +55,7 @@ class Estimates:
|
|||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.numel()
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ base_rewrite = PatternMatcher([
|
|||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})"
|
||||
if not isinstance(x.dtype, PtrDType) and x._shape is not None and x.numel() > 1 else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
|
|
@ -55,7 +56,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
(f"[{x.arg[0]}]" if x.src[0].numel() > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
||||
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
|
||||
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.numel()} x i32> zeroinitializer"),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
||||
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
|
|
|
|||
|
|
@ -86,9 +86,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda x:x.numel(), bs=lambda x:x.dtype.scalar().bitsize, num_components=lambda x:x.numel(),
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
lambda b, space, addr, x: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
|
||||
|
|
@ -150,10 +150,10 @@ class NIRRenderer(Renderer):
|
|||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.numel()}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||
|
|
@ -200,7 +200,7 @@ class NIRRenderer(Renderer):
|
|||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
|
||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
elif u.op == Ops.END:
|
||||
r = u.src[1]
|
||||
|
|
|
|||
|
|
@ -104,18 +104,18 @@ string_rewrite = PatternMatcher([
|
|||
# store / gated load / load
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
|
||||
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
f"{f'.v{cnt}' if ((cnt:=var.numel())>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.numel() > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
|
||||
lambda ctx, x, loc, alt, gate, buf: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if alt.dtype.count > 1 else [
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if x.numel() > 1 else [
|
||||
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
||||
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),), allow_any_len=True),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.numel() > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
# simple
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
|
||||
|
|
@ -220,14 +220,14 @@ class PTXRenderer(Renderer):
|
|||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.numel())] if u.numel() > 1 else ssa('val', u)
|
||||
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
|
||||
elif u.op is Ops.WMMA:
|
||||
# registers for packing/unpacking input and acc
|
||||
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.numel())]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
|
||||
Ops.PARAM: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
|
|
|
|||
|
|
@ -546,7 +546,7 @@ def split_store(x:UOp) -> UOp|None:
|
|||
# if we have any open ranges here, we don't split
|
||||
if x.ranges: return None
|
||||
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
|
||||
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
|
||||
if x.op is Ops.STORE and x.src[0]._shape is not None and x.src[0].shape != (): return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
|
|
|
|||
|
|
@ -209,10 +209,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
@recursive_property
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
if self.dtype.count > 1: return (self.dtype.count,)
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
|
||||
Ops.STACK | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
|
||||
return None
|
||||
|
||||
|
|
@ -233,12 +234,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
# non pointer index doesn't have a shape
|
||||
if not isinstance(self.dtype, PtrDType): return None
|
||||
# fully indexed doesn't have a shape. TODO: remove this
|
||||
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
|
||||
# pointer index
|
||||
return self.src[0].shape[len(self.src[1:]):]
|
||||
if self.src[0]._shape is None: return None
|
||||
idx_srcs = self.src[1:-1] if len(self.src) > 1 and self.src[-1].dtype is dtypes.bool else self.src[1:]
|
||||
return self.src[0].shape[len(idx_srcs):]
|
||||
|
||||
case Ops.GEP:
|
||||
return None if self.src[0]._shape is None else self.src[0].shape[:-1] + ((len(self.arg),) if len(self.arg) > 1 else ())
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
|
|
@ -492,7 +493,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
||||
arg=dtype.const(b),
|
||||
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
||||
shape:tuple[sint, ...]|None=None, unique=True):
|
||||
|
|
@ -500,7 +501,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
assert not isinstance(fill_value, (UOp, tuple)), "unique const only works on numbers"
|
||||
ret = UOp.const(to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value), fill_value, canonicalize_device(device))
|
||||
ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue