mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
more_ren_c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b6ee0c7d | ||
|
|
13f5d39fcf | ||
|
|
431accc9b7 | ||
|
|
df000116ea |
6 changed files with 34 additions and 36 deletions
|
|
@ -56,7 +56,7 @@ class AddrSpace(Enum):
|
|||
@dataclass(frozen=True, eq=False)
|
||||
class DType(metaclass=DTypeMetaClass):
|
||||
priority: int # this determines when things get upcasted
|
||||
bitsize: int
|
||||
bitsize: int # this is the bitsize of the base dtype
|
||||
name: str
|
||||
fmt: FmtStr|None
|
||||
count: int
|
||||
|
|
@ -76,7 +76,7 @@ class DType(metaclass=DTypeMetaClass):
|
|||
def vec(self, sz:int) -> DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
return DType(self.priority, self.bitsize, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
||||
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Callable, cast
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
|
||||
from tinygrad.dtype import AddrSpace, DType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ class Estimates:
|
|||
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
|
||||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
|
|
@ -51,10 +51,10 @@ class Estimates:
|
|||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class ClangRenderer(CStyleLanguage):
|
|||
kernel_typedef = "__attribute__((ms_abi)) void"
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
# round (down) to power of two (this is actually the default clang behavior)
|
||||
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
|
||||
alignment = 2**int(math.log2(dt.itemsize*dt.count)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
|
||||
|
||||
def _render_defines(self, uops) -> list[str]:
|
||||
|
|
@ -433,7 +433,8 @@ class CUDARenderer(CStyleLanguage):
|
|||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
|
||||
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
|
||||
return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
return f"struct __align__({dt.itemsize*dt.count}) {vec} {{ {scal} {elems}; }};" + \
|
||||
f"__device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
||||
|
|
@ -450,7 +451,8 @@ class CUDARenderer(CStyleLanguage):
|
|||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
|
||||
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
||||
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
||||
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
|
||||
# 4 => CUDA reg size in bytes
|
||||
n_operands = [size*dtype.itemsize*dtype.count//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
||||
operands = [f"%{i}" for i in range(sum(n_operands))]
|
||||
|
||||
# mma operands => {c}, {a}, {b}, {c}
|
||||
|
|
|
|||
|
|
@ -633,8 +633,8 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
|
|||
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
|
||||
rm = cast(Register, rm_uop.reg).index
|
||||
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
|
||||
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
|
||||
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
|
||||
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else (rm_uop.dtype.itemsize*rm_uop.dtype.count)
|
||||
reg_sz = ((reg_uop.dtype.itemsize*reg_uop.dtype.count) if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
|
||||
sz = reg_sz or rm_sz
|
||||
|
||||
# encode instruction
|
||||
|
|
|
|||
|
|
@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
|
|||
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
|
||||
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
|
||||
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
|
||||
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
|
||||
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents._def, x, dtype)
|
||||
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize*dtype.count)).contents._def, x, dtype)
|
||||
return instr
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize*dtype.count)
|
||||
|
||||
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
|
||||
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
|
||||
|
|
@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize, num_components=lambda dtype:dtype.count,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,8 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
|
|||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
||||
if odt is None: odt = dt
|
||||
# registers aren't packed
|
||||
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half
|
||||
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
|
||||
def is_packed(u:UOp) -> bool: return u.dtype.itemsize < 4 and u.dtype.base != dtypes.half and u.addrspace != AddrSpace.REG
|
||||
def _packed_size(u:UOp): return u.max_numel() // (4//u.dtype.itemsize) if is_packed(u) else u.max_numel()
|
||||
|
||||
def is_nan(a):
|
||||
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
|
||||
|
|
@ -43,12 +39,12 @@ wgsl_matcher = PatternMatcher([
|
|||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(b) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(b) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(bidx) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(bidx) else None),
|
||||
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
|
|
@ -73,8 +69,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
||||
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
|
||||
|
|
@ -86,11 +82,11 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0])}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0]) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
|
|
@ -98,8 +94,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
||||
def render_load(self, x:str, u:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(u) else x
|
||||
def buf_map(self, u:UOp) -> str: return "atomic<u32>" if is_packed(u) else self.type_map[u.dtype.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
|
||||
if not local_size: local_size = [1]
|
||||
|
|
@ -111,7 +107,7 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
f"{name}:{f'array<{self.buf_map(u)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u)};" for name,(u,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue