mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
57 commits
master
...
shrink_in_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e76de41110 | ||
|
|
5768042e3f | ||
|
|
ef9c60238e |
||
|
|
5825d4d833 |
||
|
|
394afe40c5 | ||
|
|
774847a54d | ||
|
|
a35c1c9c38 |
||
|
|
2bcd46d946 |
||
|
|
1ffe08bab9 | ||
|
|
894f1221c2 | ||
|
|
f3ecb4f8e8 | ||
|
|
3dbaa526fa | ||
|
|
31a87addca | ||
|
|
c64a37fa7d |
||
|
|
d6f1aadeb7 | ||
|
|
beb6c3ab3e | ||
|
|
08fe658f74 | ||
|
|
f532e1b2d0 | ||
|
|
1c18e1bae8 | ||
|
|
50ac2872b3 | ||
|
|
8d327d4877 | ||
|
|
bdbee57f34 | ||
|
|
7b00120d92 | ||
|
|
46541d70f4 |
||
|
|
8850ce9380 | ||
|
|
4571b0d98a | ||
|
|
9f78877d14 | ||
|
|
6f506dc55e | ||
|
|
12752b8a44 | ||
|
|
e808f698bc | ||
|
|
27835b5a31 | ||
|
|
9ccee6aae7 |
||
|
|
fdc7d4c0af | ||
|
|
ea70715344 |
||
|
|
8ba3ee138e |
||
|
|
604b35aa67 |
||
|
|
7754025f2a | ||
|
|
5ac25ba991 | ||
|
|
507c68dbc8 | ||
|
|
3e0335a4a0 | ||
|
|
cb755bded6 | ||
|
|
10c2a50e79 | ||
|
|
7b951e691e | ||
|
|
90b2c7e115 | ||
|
|
14394eb97d | ||
|
|
213cc5b6b0 | ||
|
|
1670dbfacd |
||
|
|
927e16fbdb |
||
|
|
f165839386 |
||
|
|
3baba3f23f |
||
|
|
a75ad9fbaa | ||
|
|
4115d330ab |
||
|
|
7da2c151be |
||
|
|
19246184d3 |
||
|
|
f17eb03634 |
||
|
|
1c10882bf0 | ||
|
|
6881b32e84 |
9 changed files with 222 additions and 116 deletions
|
|
@ -3,16 +3,16 @@ from dataclasses import replace
|
|||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
||||
from tinygrad.uop.render import pyrender
|
||||
from tinygrad.uop.spec import type_verify, spec_tensor, spec_program
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
|
||||
|
|
@ -24,6 +24,25 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
|||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
|
||||
# NOTE: this is temporary until we fix the devectorizer
|
||||
pm_index_is_shrink = PatternMatcher([
|
||||
# rewrite non-image INDEX to SHRINK
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).cast(name="x"), lambda buf,idx,x:
|
||||
UOp(Ops.SHRINK, dtype=buf.dtype.base, src=(buf, idx, UOp.const(dtypes.int, x.dtype.count))) if isinstance(buf.dtype, PtrDType) else None),
|
||||
# rewrite GEP to INDEX
|
||||
(UPat(Ops.GEP, name="x"), lambda x: x.replace(op=Ops.INDEX, src=x.src+(UOp.const(dtypes.int, x.arg),), arg=None)),
|
||||
])
|
||||
|
||||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
# rewrite PARAM to non pointer
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
|
||||
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
|
||||
# remove all vec dtypes
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}, name="x"),
|
||||
lambda x: x.replace(dtype=x.dtype.base.scalar().base)),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
|
@ -100,6 +119,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, name="remove vec dtypes")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, cast
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
|
||||
from tinygrad.dtype import AddrSpace, DType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class Estimates:
|
|||
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
|
||||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
|
|
@ -51,10 +51,10 @@ class Estimates:
|
|||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
|
|
|||
|
|
@ -8,9 +8,17 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
|
|||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
||||
def render_index(ctx,buf,idx):
|
||||
base = buf
|
||||
while base.op is Ops.AFTER: base = base.src[0]
|
||||
if base.addrspace == AddrSpace.ANON:
|
||||
assert idx.op is Ops.CONST, f"{idx.op} must be CONST"
|
||||
return f"{ctx[buf]}[{idx.arg}]"
|
||||
else:
|
||||
return f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
|
|
@ -18,14 +26,14 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.RANGE, name="x"),
|
||||
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
|
||||
(UPat(Ops.STACK, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(ctx.render_dtype_with_shape(x)))}" + \
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{ctx[x.src[0]]}];"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
|
||||
# const
|
||||
|
|
@ -43,24 +51,26 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
||||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# SHRINK/INDEX
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), render_index),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), render_index),
|
||||
# new load/store
|
||||
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), name="x"), lambda ctx,x,bidx: ctx.render_access(bidx, ctx.render_dtype_with_shape(x))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,bidx,var,gate: f"({ctx[gate]}?{ctx.render_access(bidx, ctx.render_dtype_with_shape(x))}:{ctx[var]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))),
|
||||
lambda ctx,bidx,var: f"{ctx.render_access(bidx, ctx.render_dtype_with_shape(var))} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
# TODO: look for left-associative
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# devectorize any bools
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.SHRINK), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
||||
# CAST (from bool) can't be vectorized
|
||||
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
|
||||
# WHERE can't be vectorized
|
||||
|
|
@ -95,7 +105,11 @@ pm_manual_bf16_cast = PatternMatcher([
|
|||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
def dtype_with_shape(dtype:DType, shape:tuple) -> DType:
|
||||
return dtype.scalar().vec(prod(shape)) if dtype.count == 1 and len(shape) == 1 and isinstance(shape[0], int) and shape[0] > 1 else dtype
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]:
|
||||
return dedup(dtype_with_shape(u.dtype, u._shape or ()) if u.addrspace not in {AddrSpace.GLOBAL, AddrSpace.LOCAL} else u.dtype
|
||||
for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
|
||||
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
||||
def wmma_args(uops:list[UOp]):
|
||||
|
|
@ -133,10 +147,12 @@ class CStyleLanguage(Renderer):
|
|||
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
tmp = ""
|
||||
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
|
||||
def arg_dtype(u:UOp) -> DType:
|
||||
return u.dtype if isinstance(u.dtype, (ImageDType, PtrDType)) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
if any(isinstance(arg_dtype(u), ImageDType) for _,(u,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
|
||||
buftypes = [(name, self.render_dtype(dt, mutable)+self.buffer_suffix if isinstance(dt, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dt == dtypes.int else None) for name,(u,mutable) in bufs for dt in (arg_dtype(u),)]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
|
||||
|
|
@ -145,6 +161,9 @@ class CStyleLanguage(Renderer):
|
|||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
|
||||
def render_dtype_with_shape(self, u:UOp) -> DType: return dtype_with_shape(u.dtype, u.shape)
|
||||
def render_access(self, bidx:UOp, dtype:DType) -> str:
|
||||
return f"(*(({self.render_dtype(dtype.ptr(addrspace=bidx.addrspace))})({self[bidx]})))" if dtype.count > 1 else f"(*{self[bidx]})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
if isinstance(dt, PtrDType):
|
||||
|
|
@ -190,21 +209,24 @@ class CStyleLanguage(Renderer):
|
|||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
Ops.INDEX: "bidx", Ops.SHRINK: "bidx",
|
||||
Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.SHRINK, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op is Ops.SHRINK or u._shape is None: u_dtype = u.src[0].dtype
|
||||
else: u_dtype = self.render_dtype_with_shape(u)
|
||||
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
l = f"{self.render_dtype(u_dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@ from tinygrad.renderer import Renderer
|
|||
from tinygrad.renderer.cstyle import HIPRenderer, create_non_native_float_pats, pm_manual_bf16_cast
|
||||
from tinygrad.uop.decompositions import xexp2, xlog2
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
|
||||
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
|
||||
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate, AddrSpace
|
||||
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
|
||||
|
||||
def ldt(dt:DType):
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
def ldt(dt:DType, count=1, ptr=False):
|
||||
#if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
#if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
if ptr: return ldt(dt, count) + "*"
|
||||
if count > 1: return f"<{count} x {ldt(dt, 1, ptr)}>"
|
||||
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
|
||||
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
|
||||
|
|
@ -54,12 +56,13 @@ def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
|
|||
N,M,K = wmma.arg[1]
|
||||
if cdna:
|
||||
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + \
|
||||
", i32 0, i32 0, i32 0)"
|
||||
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
||||
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype, count=wmma.max_numel())} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype, count=w.max_numel())} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
|
||||
if wmma.dtype.scalar() != dtypes.float else ")")
|
||||
|
||||
# llvm ops, lop[<dtype>][<op>]
|
||||
|
|
@ -75,25 +78,31 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
|
|||
base_rewrite = PatternMatcher([
|
||||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
f" {ctx[x]} = extractelement {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, i32 {x.src[1].arg}" \
|
||||
if x.addrspace == AddrSpace.ANON else None),
|
||||
(UPat((Ops.INDEX, Ops.SHRINK), name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype)}, {ldt(x.dtype, ptr=True)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
|
||||
lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype, idx.max_numel())}, {ldt(idx.dtype, x.max_numel(), True)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat.var('idx').load(name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat.var('idx').store(UPat.var("var")), lambda ctx,idx,var: f" store {ldt(var.dtype)} {ctx[var]}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
|
||||
f" {ctx[x]} = phi {ldt(x.dtype, idx.max_numel())} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat.var('idx').load(name="x"),
|
||||
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(idx.dtype, idx.max_numel())}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
|
||||
(UPat.var('idx').store(UPat.var("var")),
|
||||
lambda ctx,idx,var:
|
||||
f" store {ldt(var.dtype, idx.max_numel())} {ctx[var]}, {ldt(idx.dtype, idx.max_numel(), True)} {ctx[idx]}"),
|
||||
# GEP/VECTORIZE/CAST for float4 support
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
||||
#(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
||||
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
|
||||
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
||||
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.max_numel()} x i32> zeroinitializer"),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(
|
||||
f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
||||
f" = insertelement {ldt(x.dtype, x.max_numel())} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
|
|
@ -137,7 +146,8 @@ class LLVMRenderer(Renderer):
|
|||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
|
||||
def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
# NOTE: CPUAllocator promises 0x20 alignment
|
||||
sargs = ", ".join([f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {name}" for name,u in args])
|
||||
sargs = ", ".join([f"{ldt(u.dtype, ptr=u.addrspace == AddrSpace.GLOBAL)}{' noalias align 32' if u.addrspace == AddrSpace.GLOBAL else ''} {name}"
|
||||
for name,u in args])
|
||||
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
||||
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
|
||||
r: dict[UOp, str] = {}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
|
||||
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
from tinygrad.runtime.autogen import mesa
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
|
||||
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools, os, warnings
|
||||
|
||||
def g(s:str): return getattr(mesa, s)
|
||||
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
|
||||
|
||||
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
|
||||
def glsl_type(t:DType): return {
|
||||
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||
|
||||
|
|
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
|
|||
|
||||
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||
|
||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||
|
|
@ -86,9 +85,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda count:count, bs=lambda dtype:dtype.bitsize, num_components=lambda count:count,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
lambda b, space, addr, dtype, count=1: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
|
||||
|
|
@ -104,16 +103,31 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
|
|||
|
||||
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
|
||||
|
||||
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
||||
def _load_count(x:UOp) -> int: return x.max_numel() if 1 < x.max_numel() <= 4 else 1
|
||||
def _pad_count(b:mesa.nir_builder, dtype:DType, count:int, val):
|
||||
return val if val.num_components == count else nalu(b, f"vec{count}", val, *[nundef(b, dtype) for _ in range(count-1)])
|
||||
|
||||
def nidx(b:mesa.nir_builder, buf, off, dtype, addrspace, gate=None) -> mesa.nir_def:
|
||||
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
|
||||
def reg(b, buf):
|
||||
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
|
||||
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
|
||||
return deref
|
||||
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
|
||||
f = (functools.partial(reg, b, buf) if addrspace == AddrSpace.REG else
|
||||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
def ngated_load_index(ctx, x, buf, off, alt, gate):
|
||||
cnt = _load_count(x)
|
||||
return if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace, ctx.r[gate]), x.dtype, cnt),
|
||||
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
|
||||
|
||||
def ngated_load_shrink(ctx, x, idx, alt, gate):
|
||||
cnt = _load_count(idx)
|
||||
return if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, cnt),
|
||||
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
|
|
@ -137,7 +151,7 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
src=(buf,off.cast(dtypes.long))) if buf.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
# images need index to be int for nir
|
||||
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
|
||||
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
|
||||
|
|
@ -149,18 +163,28 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("val"))),
|
||||
lambda ctx,idx,val: nstore(ctx.b, idx.addrspace, ctx.r[idx], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
ngated_load_index),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
ngated_load_shrink),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), x.dtype, _load_count(x))),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"),), name="x"),
|
||||
lambda ctx,x,idx: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, _load_count(idx))),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("off"), UPat.cvar()), name="x"),
|
||||
lambda ctx,x,buf,off: nidx(ctx.b, ctx.r[buf], ctx.r[off], x.dtype, x.addrspace)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: ctx.r[x.src[0]] if len(x.src) == 1 else
|
||||
nalu(ctx.b, f"vec{len(x.src)}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat.cvar("idx"))),
|
||||
lambda ctx,a,idx: nchannel(ctx.b, ctx.r[a], idx.arg) if a.addrspace == AddrSpace.ANON else None),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
|
||||
mesa.glsl_array_type(glsl_type(x.dtype), x.src[0].arg, 0), f"acc{x.arg}".encode()).contents),
|
||||
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||
|
|
@ -189,8 +213,7 @@ class NIRRenderer(Renderer):
|
|||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
|
||||
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.INDEX and u.src[0].addrspace != AddrSpace.ANON): pass
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
elif u.op == Ops.SINK:
|
||||
|
|
@ -198,7 +221,7 @@ class NIRRenderer(Renderer):
|
|||
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
|
||||
elif u.op == Ops.DEFINE_LOCAL:
|
||||
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
||||
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
|
||||
self.b.shader.contents.info.shared_size += u.src[0].arg * u.dtype.itemsize
|
||||
elif u.op == Ops.RANGE:
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
|
|
@ -217,7 +240,17 @@ class NIRRenderer(Renderer):
|
|||
self.r[u] = cast(mesa.nir_def, d)
|
||||
self.postrender(uops)
|
||||
|
||||
mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
if DEBUG >= 2 and hasattr(os, "fork"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
os._exit(0)
|
||||
_, status = os.waitpid(pid, 0)
|
||||
if os.WIFSIGNALED(status): raise RuntimeError(f"NIR validation failed after render with signal {os.WTERMSIG(status)}")
|
||||
if os.WEXITSTATUS(status) != 0: raise RuntimeError(f"NIR validation failed after render with exit code {os.WEXITSTATUS(status)}")
|
||||
else: mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
|
||||
"__stdoutp" if OSX else "stdout"))
|
||||
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
|
||||
|
|
@ -257,9 +290,10 @@ class LVPRenderer(NIRRenderer):
|
|||
|
||||
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=4,
|
||||
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
|
||||
srcs=lambda b,img,idx_y,idx_x,val:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
srcs=lambda b,img,idx_y,idx_x,val,dtype:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int),
|
||||
val if val.num_components == 4 else nalu(b, "vec4", val, nundef(b, dtype), nundef(b, dtype), nundef(b, dtype)), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
|
||||
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
|
||||
|
|
@ -277,8 +311,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_y, idx_x), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: ctx.nload_img(img, idx_y, idx_x) if len(x.shape) > 0 and x.shape[-1] == 4 else nchannel(ctx.b, ctx.nload_img(img, idx_y, idx_x), 0),
|
||||
lambda: ctx.r[alt] if len(x.shape) == 0 or x.shape[-1] != 4 or ctx.r[alt].num_components == 4 else
|
||||
nalu(ctx.b, "vec4", ctx.r[alt], nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype)))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,19 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
|
|||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
||||
if odt is None: odt = dt
|
||||
# registers aren't packed
|
||||
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half
|
||||
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
|
||||
def is_packed(x:UOp) -> bool:
|
||||
if x.op is Ops.LOAD: dt, addrspace = x.dtype, x.src[0].addrspace
|
||||
elif x.op is Ops.STORE: dt, addrspace = x.src[1].dtype, x.src[0].addrspace
|
||||
else: dt, addrspace = x.dtype.base, x.dtype.addrspace if isinstance(x.dtype, PtrDType) else x.addrspace
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half and addrspace != AddrSpace.REG
|
||||
|
||||
def _packed_size(ctx, x:UOp):
|
||||
size = ctx[x.src[0]]
|
||||
if not is_packed(x): return size
|
||||
elems = 4 // x.dtype.base.itemsize
|
||||
return str((x.src[0].arg + elems - 1) // elems) if x.src[0].op is Ops.CONST else f"(({size}+{elems-1})/{elems})"
|
||||
|
||||
def _buf_map(ctx, x:UOp): return ctx.type_map[x.dtype.base] if x.addrspace == AddrSpace.REG else ctx.buf_map(x.dtype.base)
|
||||
|
||||
def is_nan(a):
|
||||
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
|
||||
|
|
@ -43,12 +50,12 @@ wgsl_matcher = PatternMatcher([
|
|||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"), name="s"),
|
||||
lambda s,bidx,var,gate: packed_store(bidx,var,gate) if is_packed(s) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), name="s"),
|
||||
lambda s,bidx,var: packed_store(bidx,var) if is_packed(s) else None),
|
||||
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
|
|
@ -73,8 +80,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
||||
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{_buf_map(ctx,x)},{_packed_size(ctx,x)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
|
||||
|
|
@ -85,12 +92,12 @@ class WGSLRenderer(CStyleLanguage):
|
|||
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate"), name="l"),
|
||||
lambda ctx,l,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b)}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b"), name="l"), lambda ctx,l,b: ctx.render_load(ctx[b], b)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v"), name="s"), lambda ctx,s,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b) and b.addrspace != AddrSpace.REG \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
|
|
@ -98,9 +105,11 @@ class WGSLRenderer(CStyleLanguage):
|
|||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
||||
def render_load(self, x:str, uop:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(uop) and uop.addrspace != AddrSpace.REG else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt != dtypes.half else self.type_map[dt.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def arg_dtype(u:UOp) -> DType:
|
||||
return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
|
||||
if not local_size: local_size = [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
|
|
@ -110,8 +119,9 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
f"{'var<storage,read_write>' if isinstance(dt, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(dt.base)}>' if isinstance(dt,PtrDType) else self.buf_map(dt)};"
|
||||
for name,(u,_) in bufs for dt in (arg_dtype(u),)])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
|
|
@ -54,18 +54,20 @@ class DSPRenderer(ClangRenderer):
|
|||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
|
||||
if isinstance(b[1][0].dtype, PtrDType)]
|
||||
if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
params = [(f'buf_{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)]
|
||||
msrc += [f"{function_name}({', '.join(params)});"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -275,22 +277,23 @@ class MockDSPRenderer(DSPRenderer):
|
|||
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
|
||||
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
||||
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
||||
msrc = [mockdsp_boilerplate, 'void _start(void) {']
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0].dtype, PtrDType):
|
||||
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
|
||||
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType):
|
||||
sz = dt.size*dt.itemsize
|
||||
# for loop for big reads
|
||||
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
|
||||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
params = [(f'(void*)buf{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
msrc.append(f"{function_name}({', '.join(params)});")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
|
||||
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType): msrc.append(f"write(1, buf{i}, {dt.size*dt.itemsize});")
|
||||
msrc.append('exit(0); }')
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
from typing import Any, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, sys, functools
|
||||
from dataclasses import replace
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, Target, IMAGE
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, truncate, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar, AddrSpace
|
||||
from tinygrad.helpers import all_same, getenv, flatten, Target, IMAGE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp, bitcast
|
||||
|
|
@ -101,19 +101,21 @@ class PythonProgram:
|
|||
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
|
||||
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
|
||||
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
|
||||
elif u.op is Ops.INDEX:
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
elif u.op is Ops.SHRINK or (u.op is Ops.INDEX and len(src_values) == 2):
|
||||
if u.addrspace == AddrSpace.ANON:
|
||||
# old GEP
|
||||
assert all_same(src_values[1]), "all index must be the same"
|
||||
values[u] = src_values[0][src_values[1][0]]
|
||||
else:
|
||||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
# normal index
|
||||
values[u] = [(m,o) for m,o in zip(src_values[0], src_values[1])]
|
||||
elif u.op is Ops.INDEX and len(src_values) == 3:
|
||||
assert isinstance(src_dtypes[0], ImageDType), "3 src index is only for Image"
|
||||
ret:list = []
|
||||
for m,oy,ox in zip(*src_values):
|
||||
if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
|
||||
else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
|
||||
values[u] = ret
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
|
||||
values[u] = src_values[0]
|
||||
elif u.op is Ops.RANGE:
|
||||
if u not in values: values[u] = [0] * warp_size
|
||||
else:
|
||||
|
|
@ -134,7 +136,6 @@ class PythonProgram:
|
|||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(load_sz)]
|
||||
else:
|
||||
values[u] = load(src_values, 0, u.dtype)
|
||||
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
|
||||
elif u.op is Ops.WMMA:
|
||||
first_src_dtype = u.src[0].dtype
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
|
|
|
|||
|
|
@ -114,6 +114,8 @@ class DLL(ctypes.CDLL):
|
|||
|
||||
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
|
||||
self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?"
|
||||
if nm == 'llvm' and (ver:=getenv("LLVM_VERSION", "")):
|
||||
paths = ([f"/opt/homebrew/opt/llvm@{ver}/lib/libLLVM.dylib"] if OSX else [f"LLVM-{ver}"]) + (paths if isinstance(paths, list) else [paths])
|
||||
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
|
||||
if DEBUG >= 3: print(f"loading {nm} from {path}")
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue