Merge remote-tracking branch 'upstream/master' into new_x86_backend

This commit is contained in:
ttomsa 2026-01-31 19:42:22 +00:00
commit f1327ebff6
57 changed files with 13278 additions and 2216 deletions

View file

@ -26,7 +26,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: llvm-speed
deps: testing_minimal
deps: testing_unit
llvm: 'true'
- name: Speed Test
run: CPU=1 CPU_LLVM=1 python3 test/speed/external_test_speed_v_torch.py
@ -98,7 +98,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: torch-backend-pillow-torchvision-et-pt
deps: testing_minimal
deps: testing_unit
pydeps: "pillow torchvision expecttest"
llvm: 'true'
- name: Install ninja
@ -134,7 +134,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: torch-backend-pillow-torchvision-et-pt
deps: testing_minimal
deps: testing_unit
llvm: 'true'
- name: Install ninja
run: |
@ -156,7 +156,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: be-minimal
deps: testing_minimal
deps: testing_unit
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
- name: Test ops with Python emulator
@ -348,7 +348,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: gpu-image
deps: testing_minimal
deps: testing_unit
opencl: 'true'
- name: Test CL IMAGE=2 ops
run: |
@ -424,7 +424,7 @@ jobs:
with:
key: onnxoptc
deps: testing
python-version: '3.11'
python-version: '3.12'
llvm: 'true'
- name: Test ONNX (CPU)
run: CPU=1 CPU_LLVM=0 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
@ -452,7 +452,7 @@ jobs:
key: onnxoptl
deps: testing
pydeps: "tensorflow==2.19"
python-version: '3.11'
python-version: '3.12'
opencl: 'true'
- name: Test ONNX (CL)
run: CL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
@ -526,7 +526,7 @@ jobs:
with:
key: metal
deps: testing
python-version: '3.11'
python-version: '3.12'
- name: Test models (Metal)
run: METAL=1 python -m pytest -n=auto test/models --durations=20
- name: Test LLaMA compile speed
@ -545,7 +545,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: devectorize-minimal
deps: testing_minimal
deps: testing_unit
pydeps: "pillow"
llvm: "true"
- name: Test LLVM=1 DEVECTORIZE=0
@ -566,7 +566,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: dsp-minimal
deps: testing_minimal
deps: testing_unit
pydeps: "onnx==1.18.0 onnxruntime"
llvm: "true"
- name: Set up Docker Buildx
@ -600,8 +600,8 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: webgpu-minimal
deps: testing_minimal
python-version: '3.11'
deps: testing_unit
python-version: '3.12'
webgpu: 'true'
- name: Check Device.DEFAULT (WEBGPU) and print some source
run: |
@ -634,7 +634,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: ${{ matrix.backend }}-minimal
deps: testing_minimal
deps: testing_unit
amd: 'true'
llvm: ${{ matrix.backend == 'amdllvm' && 'true' }}
- name: Check Device.DEFAULT and print some source
@ -676,9 +676,9 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-emu
deps: testing_minimal
deps: testing_unit
amd: 'true'
python-version: '3.13'
python-version: '3.14'
- name: Verify AMD autogen is up to date
run: |
python -m extra.assembly.amd.generate
@ -724,7 +724,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: ${{ matrix.backend }}-minimal
deps: testing_minimal
deps: testing_unit
cuda: 'true'
ocelot: 'true'
- name: Set env
@ -757,7 +757,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: ${{ matrix.backend }}-minimal
deps: testing_minimal
deps: testing_unit
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
@ -771,9 +771,6 @@ jobs:
run: python -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
- name: Run TRANSCENDENTAL math
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
- name: Test dtype with emulated long
if: matrix.backend != 'lvp' && matrix.backend != 'llvm'
run: EMULATED_DTYPES=long python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
- name: Run process replay tests
uses: ./.github/actions/process-replay
@ -791,7 +788,7 @@ jobs:
with:
key: metal
deps: testing
python-version: '3.11'
python-version: '3.12'
amd: 'true'
cuda: 'true'
ocelot: 'true'
@ -889,7 +886,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: macos-${{ matrix.backend }}-minimal
deps: testing_minimal
deps: testing_unit
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
- name: Set env
@ -956,7 +953,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: compile-${{ matrix.backend }}
deps: testing_minimal
deps: testing_unit
mesa: ${{ (matrix.backend == 'ir3' || matrix.backend == 'nak') && 'true' }}
python-version: '3.14'
- name: Set env

View file

@ -9,9 +9,11 @@ export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
export DEBUG=${DEBUG:-0}
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
export GBS=$((BS * GRADIENT_ACC_STEPS))
export MODEL="llama3"

View file

@ -0,0 +1,9 @@
#!/bin/bash
export BENCHMARK=5
export EVAL_BS=0
export FAKEDATA=1
export HIP_VISIBLE_DEVICES=""
export DEV=NULL
export JITBEAM=0
export LLAMA_LAYERS=${LLAMA_LAYERS:-"2"}
time examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh

View file

@ -427,7 +427,8 @@ class _Ctx:
pcode = get_pcode(op)
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg))
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane,
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
_, assigns = parse_pcode(pcode, srcs)
raw_stores: list = []
@ -796,10 +797,13 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
return bits
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)
srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new}
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0),
'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1),
'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)}
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
@ -1004,7 +1008,7 @@ def _get_runner(inst_bytes: bytes):
canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}"
sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1)
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0):
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
runner = get_runner('CPU', sink)
_canonical_runner_cache.append((base, mask, size, runner))
return runner, True

View file

@ -94,13 +94,19 @@ def _trig_reduce(x, phase=0.0):
return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),))
def _signext(val: UOp) -> UOp:
for bits, mask, ext in [(8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8:
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
sb = (v32 >> _u32(bits - 1)) & _u32(1)
return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int)
return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val
def _signext_4bit(val: UOp) -> UOp:
"""Sign extend a 4-bit value to 32-bit signed integer."""
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3
return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int)
def _abs(val: UOp) -> UOp:
if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val
_, _, _, _, shift = _float_info(val)
@ -227,11 +233,44 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant,
'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop,
's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64),
# Normalization conversions: map [-1,1] or [0,1] to integer range
# Use floor(x + 0.5) for round-to-nearest
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
# Integer truncation conversions
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16),
'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)),
'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)),
'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)),
# Signed extraction with sign extension for dot products
'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)),
'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)),
'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)),
# Float to int16 conversions
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
}
for is_max, name in [(False, 'min'), (True, 'max')]:
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
_FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
_FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
# f16 min/max/min3/max3/med3
for is_max, name in [(False, 'min'), (True, 'max')]:
_FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
# ═══════════════════════════════════════════════════════════════════════════════
# TOKENIZER/PARSER
@ -239,7 +278,7 @@ for is_max, name in [(False, 'min'), (True, 'max')]:
DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64,
'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16,
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u1': dtypes.uint32}
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32}
_BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64}
_NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f')
def _strip_suffix(num: str) -> tuple[str, str]:
@ -432,14 +471,6 @@ class Parser:
return elem
if self.at('LBRACKET') and name not in self.vars:
self.eat('LBRACKET')
if self.at('NUM'):
idx_num = int(self.peek().val)
if f'{name}{idx_num}' in self.vars:
self.eat('NUM')
self.eat('RBRACKET')
elem = self.vars[f'{name}{idx_num}']
if self.try_eat('DOT'): return _cast_to(elem, DTYPES.get(self.eat('IDENT').val, dtypes.uint32))
return elem
first = self.parse()
return self._handle_bracket_rest(first, _u32(0), name)
if name in self.vars:
@ -467,6 +498,7 @@ class Parser:
if dt == base.dtype: return base
if dt.itemsize == 2 and base.dtype.itemsize == 4:
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
if field == 'i4': return _signext_4bit(base)
return _cast_to(base, dt)
def _handle_bracket(self, base, var_name: str | None = None) -> UOp:
@ -509,8 +541,8 @@ class Parser:
var_name = self._find_var_name(base)
if first.op == Ops.CONST:
idx = int(first.arg)
if var_name and f'{var_name}{idx}' in self.vars:
v = self.vars[f'{var_name}{idx}']
if var_name and f'{var_name}@{idx}' in self.vars:
v = self.vars[f'{var_name}@{idx}']
return _cast_to(v, dt_suffix) if dt_suffix else v
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
base_cast = base.cast(dt) if base.dtype != dt else base
@ -518,7 +550,7 @@ class Parser:
return _cast_to(result, dt_suffix) if dt_suffix else result
if var_name:
idx_u32 = _to_u32(first)
elems = [(i, self.vars[f'{var_name}{i}']) for i in range(256) if f'{var_name}{i}' in self.vars]
elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars]
if elems:
result = elems[-1][1]
for ei, ev in reversed(elems[:-1]):
@ -537,7 +569,7 @@ class Parser:
self.eat('RBRACE')
var_name = self._find_var_name(base)
if var_name:
elem = self.vars.get(f'{var_name}{idx}', _u32(0))
elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
@ -787,7 +819,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
for loop_i in range(start_val, end_val + 1):
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
_, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns)
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
if has_break:
assert found_var is not None
found = block_assigns.get(found_var, vars.get(found_var))
@ -944,7 +976,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
else:
block_assigns[f'{var}{idx}'] = vars[f'{var}{idx}'] = val
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
i += 1; continue
# Compound assignment: var += or var -=

View file

@ -13,7 +13,7 @@ def _i32(f: float) -> int: return struct.unpack('<I', struct.pack('<f', f))[0]
def _f32(i: int) -> float: return struct.unpack('<f', struct.pack('<I', i & 0xFFFFFFFF))[0]
# f16 conversion helpers
def _f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
def f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
def f32_to_f16(f: float) -> int:
f = float(f)
if math.isnan(f): return 0x7e00

View file

@ -255,7 +255,6 @@ class TestF16Conversions(unittest.TestCase):
def test_v_cvt_f16_f32_small(self):
"""V_CVT_F16_F32 converts small f32 value."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
instructions = [
v_mov_b32_e32(v[0], 0.5),
v_cvt_f16_f32_e32(v[1], v[0]),
@ -293,7 +292,6 @@ class TestF16Conversions(unittest.TestCase):
def test_v_cvt_f16_f32_reads_full_32bit_source(self):
"""V_CVT_F16_F32 must read full 32-bit f32 source."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x3fc00000), # f32 1.5
v_mov_b32_e32(v[0], s[0]),
@ -302,7 +300,7 @@ class TestF16Conversions(unittest.TestCase):
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo_bits = result & 0xffff
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({_f16(lo_bits)})")
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({f16(lo_bits)})")
def test_v_cvt_i16_f16_zero(self):
"""V_CVT_I16_F16 converts f16 zero to i16 zero."""
@ -696,7 +694,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_abs_negative(self):
"""V_CVT_F32_F16 with |abs| on negative value."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
instructions = [
s_mov_b32(s[0], f16_neg1),
@ -709,7 +706,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_abs_positive(self):
"""V_CVT_F32_F16 with |abs| on positive value (should stay positive)."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_2 = f32_to_f16(2.0) # 0x4000
instructions = [
s_mov_b32(s[0], f16_2),
@ -722,7 +718,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_neg_positive(self):
"""V_CVT_F32_F16 with neg on positive value."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_2 = f32_to_f16(2.0) # 0x4000
instructions = [
s_mov_b32(s[0], f16_2),
@ -735,7 +730,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_neg_negative(self):
"""V_CVT_F32_F16 with neg on negative value (double negative)."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_neg2 = f32_to_f16(-2.0) # 0xc000
instructions = [
s_mov_b32(s[0], f16_neg2),
@ -748,7 +742,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f16_f32_then_pack_for_wmma(self):
"""CVT F32->F16 followed by pack (common WMMA pattern)."""
from extra.assembly.amd.test.hw.helpers import _f16
f32_val = 3.5
instructions = [
s_mov_b32(s[0], f2i(f32_val)),
@ -757,8 +750,8 @@ class TestCvtF16Modifiers(unittest.TestCase):
v_pack_b32_f16(v[2], v[1], v[1]), # Pack same value
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][2] & 0xffff)
hi = _f16((st.vgpr[0][2] >> 16) & 0xffff)
lo = f16(st.vgpr[0][2] & 0xffff)
hi = f16((st.vgpr[0][2] >> 16) & 0xffff)
self.assertAlmostEqual(lo, f32_val, places=1)
self.assertAlmostEqual(hi, f32_val, places=1)
@ -804,7 +797,6 @@ class TestConversionRounding(unittest.TestCase):
def test_f16_to_f32_precision(self):
"""F16 to F32 conversion precision."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_val = f32_to_f16(1.5)
instructions = [
s_mov_b32(s[0], f16_val),
@ -816,7 +808,6 @@ class TestConversionRounding(unittest.TestCase):
def test_f16_denormal_to_f32(self):
"""F16 denormal converts to small positive f32."""
from extra.assembly.amd.test.hw.helpers import _f16
f16_denorm = 0x0001 # Smallest positive f16 denormal
instructions = [
v_mov_b32_e32(v[0], f16_denorm),
@ -1512,5 +1503,63 @@ class TestReciprocalF16(unittest.TestCase):
self.assertAlmostEqual(result, 0.25, places=2, msg="1/4.0 should be 0.25")
class TestCvtNormF16(unittest.TestCase):
"""Tests for V_CVT_NORM_I16_F16 and V_CVT_NORM_U16_F16."""
def test_cvt_norm_i16_f16_positive(self):
"""V_CVT_NORM_I16_F16: f16 1.0 -> i16 max (32767)."""
instructions = [
s_mov_b32(s[0], f32_to_f16(1.0)),
v_mov_b32_e32(v[0], s[0]),
v_cvt_norm_i16_f16_e32(v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 32767)
def test_cvt_norm_i16_f16_negative(self):
"""V_CVT_NORM_I16_F16: f16 -1.0 -> i16 -32767 (0x8001)."""
instructions = [
s_mov_b32(s[0], f32_to_f16(-1.0)),
v_mov_b32_e32(v[0], s[0]),
v_cvt_norm_i16_f16_e32(v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 0x8001) # -32767, hardware uses symmetric range
def test_cvt_norm_i16_f16_zero(self):
"""V_CVT_NORM_I16_F16: f16 0.0 -> i16 0."""
instructions = [
v_mov_b32_e32(v[0], 0),
v_cvt_norm_i16_f16_e32(v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 0)
def test_cvt_norm_u16_f16_one(self):
"""V_CVT_NORM_U16_F16: f16 1.0 -> u16 max (65535)."""
instructions = [
s_mov_b32(s[0], f32_to_f16(1.0)),
v_mov_b32_e32(v[0], s[0]),
v_cvt_norm_u16_f16_e32(v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 65535)
def test_cvt_norm_u16_f16_half(self):
"""V_CVT_NORM_U16_F16: f16 0.5 -> u16 ~32768."""
instructions = [
s_mov_b32(s[0], f32_to_f16(0.5)),
v_mov_b32_e32(v[0], s[0]),
v_cvt_norm_u16_f16_e32(v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertAlmostEqual(result, 32768, delta=1)
if __name__ == '__main__':
unittest.main()

View file

@ -857,7 +857,6 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_fma_f16_inline_const_1_0(self):
"""V_FMA_F16: a*b + 1.0 should use f16 inline constant."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
f16_a = f32_to_f16(0.325928) # ~0x3537
f16_b = f32_to_f16(-0.486572) # ~0xb7c9
instructions = [
@ -868,13 +867,12 @@ class TestF16Modifiers(unittest.TestCase):
v_fma_f16(v[4], v[4], v[6], 1.0), # 1.0 is inline constant
]
st = run_program(instructions, n_lanes=1)
result = _f16(st.vgpr[0][4] & 0xffff)
result = f16(st.vgpr[0][4] & 0xffff)
expected = 0.325928 * (-0.486572) + 1.0
self.assertAlmostEqual(result, expected, delta=0.01)
def test_v_fma_f16_inline_const_0_5(self):
"""V_FMA_F16: a*b + 0.5 should use f16 inline constant."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
f16_a = f32_to_f16(2.0)
f16_b = f32_to_f16(3.0)
instructions = [
@ -885,13 +883,12 @@ class TestF16Modifiers(unittest.TestCase):
v_fma_f16(v[2], v[0], v[1], 0.5), # 0.5 is inline constant
]
st = run_program(instructions, n_lanes=1)
result = _f16(st.vgpr[0][2] & 0xffff)
result = f16(st.vgpr[0][2] & 0xffff)
expected = 2.0 * 3.0 + 0.5
self.assertAlmostEqual(result, expected, delta=0.01)
def test_v_fma_f16_inline_const_neg_1_0(self):
"""V_FMA_F16: a*b + (-1.0) should use f16 inline constant."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
f16_a = f32_to_f16(2.0)
f16_b = f32_to_f16(3.0)
instructions = [
@ -902,13 +899,12 @@ class TestF16Modifiers(unittest.TestCase):
v_fma_f16(v[2], v[0], v[1], -1.0), # -1.0 is inline constant
]
st = run_program(instructions, n_lanes=1)
result = _f16(st.vgpr[0][2] & 0xffff)
result = f16(st.vgpr[0][2] & 0xffff)
expected = 2.0 * 3.0 + (-1.0)
self.assertAlmostEqual(result, expected, delta=0.01)
def test_v_add_f16_abs_both(self):
"""V_ADD_F16 with abs on both operands."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
f16_neg2 = f32_to_f16(-2.0)
f16_neg3 = f32_to_f16(-3.0)
instructions = [
@ -919,12 +915,11 @@ class TestF16Modifiers(unittest.TestCase):
v_add_f16_e64(v[2], abs(v[0]), abs(v[1])), # |-2| + |-3| = 5
]
st = run_program(instructions, n_lanes=1)
result = _f16(st.vgpr[0][2] & 0xffff)
result = f16(st.vgpr[0][2] & 0xffff)
self.assertAlmostEqual(result, 5.0, delta=0.01)
def test_v_mul_f16_neg_abs(self):
"""V_MUL_F16 with neg on one operand and abs on another."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
f16_2 = f32_to_f16(2.0)
f16_neg3 = f32_to_f16(-3.0)
instructions = [
@ -935,7 +930,7 @@ class TestF16Modifiers(unittest.TestCase):
v_mul_f16_e64(v[2], -v[0], abs(v[1])), # -(2) * |-3| = -6
]
st = run_program(instructions, n_lanes=1)
result = _f16(st.vgpr[0][2] & 0xffff)
result = f16(st.vgpr[0][2] & 0xffff)
self.assertAlmostEqual(result, -6.0, delta=0.01)
def test_v_fmac_f16_hi_dest(self):
@ -943,7 +938,6 @@ class TestF16Modifiers(unittest.TestCase):
This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h.
"""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0}
v_mov_b32_e32(v[0], s[0]),
@ -954,8 +948,8 @@ class TestF16Modifiers(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
v0 = st.vgpr[0][0]
result_hi = _f16((v0 >> 16) & 0xffff)
result_lo = _f16(v0 & 0xffff)
result_hi = f16((v0 >> 16) & 0xffff)
result_lo = f16(v0 & 0xffff)
self.assertAlmostEqual(result_hi, 0.5, delta=0.01, msg=f"Expected hi=0.5, got {result_hi}")
self.assertAlmostEqual(result_lo, 1.0, delta=0.01, msg=f"Expected lo=1.0, got {result_lo}")
@ -2955,5 +2949,318 @@ class TestVOP3Clamp(unittest.TestCase):
self.assertAlmostEqual(i2f(st.vgpr[3][1]), 1.0, places=5, msg="lane 3: 2.5 should clamp to 1.0")
class TestCvtPkF16(unittest.TestCase):
"""Tests for V_CVT_PK_RTZ_F16_F32 - pack two f32 to f16 with round toward zero."""
def test_cvt_pk_rtz_f16_f32_basic(self):
"""V_CVT_PK_RTZ_F16_F32: basic pack of two f32 values."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_cvt_pk_rtz_f16_f32_e64(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo_f16 = f16(result & 0xffff)
hi_f16 = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo_f16, 1.0, delta=0.01)
self.assertAlmostEqual(hi_f16, 2.0, delta=0.01)
class TestCvtPkNorm(unittest.TestCase):
"""Tests for V_CVT_PK_NORM_I16_F32 and V_CVT_PK_NORM_U16_F32."""
def test_cvt_pk_norm_i16_f32_basic(self):
"""V_CVT_PK_NORM_I16_F32: pack two f32 to normalized i16."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], -1.0),
v_cvt_pk_norm_i16_f32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
self.assertEqual(lo, 32767)
self.assertEqual(hi, 0x8001) # -32767, hardware uses symmetric range
def test_cvt_pk_norm_u16_f32_basic(self):
"""V_CVT_PK_NORM_U16_F32: pack two f32 to normalized u16."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 0.5),
v_cvt_pk_norm_u16_f32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
self.assertEqual(lo, 65535)
self.assertAlmostEqual(hi, 32768, delta=1)
class TestCvtPkInt(unittest.TestCase):
"""Tests for V_CVT_PK_I16_I32, V_CVT_PK_U16_U32, V_CVT_PK_I16_F32, V_CVT_PK_U16_F32."""
def test_cvt_pk_i16_i32_basic(self):
"""V_CVT_PK_I16_I32: pack two i32 to i16."""
instructions = [
s_mov_b32(s[0], 100),
s_mov_b32(s[1], -100 & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cvt_pk_i16_i32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
lo_signed = lo if lo < 32768 else lo - 65536
hi_signed = hi if hi < 32768 else hi - 65536
self.assertEqual(lo_signed, 100)
self.assertEqual(hi_signed, -100)
def test_cvt_pk_u16_u32_basic(self):
"""V_CVT_PK_U16_U32: pack two u32 to u16."""
instructions = [
s_mov_b32(s[0], 1000),
s_mov_b32(s[1], 2000),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cvt_pk_u16_u32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
self.assertEqual(lo, 1000)
self.assertEqual(hi, 2000)
def test_cvt_pk_i16_f32_basic(self):
"""V_CVT_PK_I16_F32: convert two f32 to packed i16."""
instructions = [
v_mov_b32_e32(v[0], 100.5),
v_mov_b32_e32(v[1], -50.7),
v_cvt_pk_i16_f32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
lo_signed = lo if lo < 32768 else lo - 65536
hi_signed = hi if hi < 32768 else hi - 65536
self.assertEqual(lo_signed, 100)
self.assertEqual(hi_signed, -50)
def test_cvt_pk_u16_f32_basic(self):
"""V_CVT_PK_U16_F32: convert two f32 to packed u16."""
instructions = [
v_mov_b32_e32(v[0], 100.9),
v_mov_b32_e32(v[1], 200.1),
v_cvt_pk_u16_f32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
self.assertEqual(lo, 100)
self.assertEqual(hi, 200)
def test_cvt_pk_u8_f32_basic(self):
"""V_CVT_PK_U8_F32: convert f32 to u8 and pack at byte position."""
instructions = [
v_mov_b32_e32(v[0], 128.5),
v_mov_b32_e32(v[1], 0),
v_mov_b32_e32(v[2], 0),
v_cvt_pk_u8_f32(v[2], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
byte0 = result & 0xff
self.assertEqual(byte0, 128)
class TestDotProduct(unittest.TestCase):
"""Tests for dot product instructions V_DOT4_U32_U8, V_DOT8_U32_U4."""
def test_v_dot4_u32_u8_basic(self):
"""V_DOT4_U32_U8: 4-element dot product of u8 vectors."""
src0 = 0x04030201 # {4, 3, 2, 1}
src1 = 0x01010101 # {1, 1, 1, 1}
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 10)
def test_v_dot4_u32_u8_with_accumulator(self):
"""V_DOT4_U32_U8 with non-zero accumulator."""
src0 = 0x02020202 # {2, 2, 2, 2}
src1 = 0x03030303 # {3, 3, 3, 3}
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 100),
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 124)
def test_v_dot8_u32_u4_basic(self):
"""V_DOT8_U32_U4: 8-element dot product of u4 vectors."""
# src0 = 8 nibbles: {1,2,3,4,5,6,7,8} packed as 0x87654321
# src1 = 8 nibbles: {1,1,1,1,1,1,1,1} packed as 0x11111111
# result = 1+2+3+4+5+6+7+8 = 36
src0 = 0x87654321
src1 = 0x11111111
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot8_u32_u4(v[2], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 36)
class TestMinMaxF16Vop3(unittest.TestCase):
"""Tests for V_MIN3_F16, V_MAX3_F16, V_MED3_F16, V_MINMAX_F16, V_MAXMIN_F16."""
def test_v_min3_f16_basic(self):
"""V_MIN3_F16: minimum of three f16 values."""
instructions = [
s_mov_b32(s[0], f32_to_f16(3.0)),
s_mov_b32(s[1], f32_to_f16(1.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_min3_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 1.0, delta=0.01)
def test_v_max3_f16_basic(self):
"""V_MAX3_F16: maximum of three f16 values."""
instructions = [
s_mov_b32(s[0], f32_to_f16(1.0)),
s_mov_b32(s[1], f32_to_f16(3.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_max3_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 3.0, delta=0.01)
def test_v_med3_f16_basic(self):
"""V_MED3_F16: median of three f16 values."""
instructions = [
s_mov_b32(s[0], f32_to_f16(3.0)),
s_mov_b32(s[1], f32_to_f16(1.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_med3_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 2.0, delta=0.01)
def test_v_minmax_f16_basic(self):
"""V_MINMAX_F16: clamp(src0, min=src1, max=src2)."""
instructions = [
s_mov_b32(s[0], f32_to_f16(2.5)),
s_mov_b32(s[1], f32_to_f16(1.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_minmax_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 2.0, delta=0.01)
def test_v_maxmin_f16_basic(self):
"""V_MAXMIN_F16: clamp(src0, min=src2, max=src1)."""
instructions = [
s_mov_b32(s[0], f32_to_f16(0.5)),
s_mov_b32(s[1], f32_to_f16(2.0)),
s_mov_b32(s[2], f32_to_f16(1.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_maxmin_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 1.0, delta=0.01)
def test_v_min3_f16_with_neg(self):
"""V_MIN3_F16 with neg modifier: min(-3, 1, 2) = -3."""
instructions = [
s_mov_b32(s[0], f32_to_f16(3.0)),
s_mov_b32(s[1], f32_to_f16(1.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_min3_f16(v[3], -v[0], v[1], v[2]), # neg on first operand
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, -3.0, delta=0.01)
def test_v_max3_f16_with_abs(self):
"""V_MAX3_F16 with abs modifier: max(|-3|, 1, 2) = 3."""
instructions = [
s_mov_b32(s[0], f32_to_f16(-3.0)),
s_mov_b32(s[1], f32_to_f16(1.0)),
s_mov_b32(s[2], f32_to_f16(2.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_max3_f16(v[3], abs(v[0]), v[1], v[2]), # abs on first operand
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 3.0, delta=0.01)
def test_v_med3_f16_opsel_hi(self):
"""V_MED3_F16 with opsel reading from hi half."""
# Pack two f16 values: hi=5.0, lo=1.0
packed = (f32_to_f16(5.0) << 16) | f32_to_f16(1.0)
instructions = [
s_mov_b32(s[0], packed),
s_mov_b32(s[1], f32_to_f16(3.0)),
s_mov_b32(s[2], f32_to_f16(4.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
# Read hi half of v[0] (5.0), med3(5, 3, 4) = 4
v_med3_f16(v[3], v[0].h, v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 4.0, delta=0.01)
if __name__ == '__main__':
unittest.main()

View file

@ -149,7 +149,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_src2_f16_lo(self):
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_2 = f32_to_f16(2.0)
instructions = [
s_mov_b32(s[0], f2i(1.0)),
@ -166,7 +165,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_src2_f16_hi(self):
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_2 = f32_to_f16(2.0)
val = (f16_2 << 16) | 0
instructions = [
@ -199,7 +197,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_with_abs_f16_src2_lo(self):
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (lo half). Regression test for sin(1.0) bug."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
instructions = [
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
@ -217,7 +214,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_with_neg_f16_src2_lo(self):
"""V_FMA_MIX_F32 with neg modifier on f16 src2 (lo half)."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_1 = f32_to_f16(1.0) # 0x3c00
instructions = [
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
@ -235,7 +231,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_with_abs_f16_src2_hi(self):
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (hi half)."""
from extra.assembly.amd.test.hw.helpers import f32_to_f16
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
val = (f16_neg1 << 16) | 0 # -1.0 in hi, 0 in lo
instructions = [
@ -254,7 +249,6 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mixlo_f16(self):
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], f2i(2.0)),
v_mov_b32_e32(v[0], s[0]),
@ -267,14 +261,13 @@ class TestFmaMix(unittest.TestCase):
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
lo = f16(st.vgpr[0][3] & 0xffff)
hi = (st.vgpr[0][3] >> 16) & 0xffff
self.assertAlmostEqual(lo, 7.0, places=1)
self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}")
def test_v_fma_mixlo_f16_all_f32_sources(self):
"""V_FMA_MIXLO_F16 with all f32 sources."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], f2i(1.0)),
v_mov_b32_e32(v[0], s[0]),
@ -286,13 +279,12 @@ class TestFmaMix(unittest.TestCase):
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
lo = f16(st.vgpr[0][3] & 0xffff)
# 1*2+3 = 5
self.assertAlmostEqual(lo, 5.0, places=1)
def test_v_fma_mixlo_f16_sin_case(self):
"""V_FMA_MIXLO_F16 case from sin kernel."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x3f800000), # f32 1.0
v_mov_b32_e32(v[3], s[0]),
@ -305,7 +297,7 @@ class TestFmaMix(unittest.TestCase):
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
lo = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(lo, -3.14159, delta=0.01)
@ -314,7 +306,6 @@ class TestVOP3P(unittest.TestCase):
def test_v_pk_add_f16_basic(self):
"""V_PK_ADD_F16 adds two packed f16 values."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
@ -324,14 +315,13 @@ class TestVOP3P(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 4.0, places=2)
self.assertAlmostEqual(hi, 6.0, places=2)
def test_v_pk_mul_f16_basic(self):
"""V_PK_MUL_F16 multiplies two packed f16 values."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
@ -341,14 +331,13 @@ class TestVOP3P(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 8.0, places=1)
self.assertAlmostEqual(hi, 15.0, places=1)
def test_v_pk_fma_f16_basic(self):
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
@ -360,8 +349,8 @@ class TestVOP3P(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 9.0, places=1) # 2*4+1
self.assertAlmostEqual(hi, 16.0, places=0) # 3*5+1
@ -370,7 +359,6 @@ class TestVOP3P(unittest.TestCase):
Inline constants for VOP3P are f16 values in the low 16 bits only.
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
"""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
@ -378,8 +366,8 @@ class TestVOP3P(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
# lo = 1.0 + 1.0 = 2.0, hi = 1.0 + 0.0 = 1.0 (inline const hi half is 0)
self.assertAlmostEqual(lo, 2.0, places=2)
self.assertAlmostEqual(hi, 1.0, places=2)
@ -388,7 +376,6 @@ class TestVOP3P(unittest.TestCase):
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
Inline constant has value only in low 16 bits, hi is 0.
"""
from extra.assembly.amd.test.hw.helpers import _f16
# v0 = packed (3.0, 4.0), multiply by POS_TWO
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
instructions = [
@ -398,8 +385,8 @@ class TestVOP3P(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 6.0, places=1)
self.assertAlmostEqual(hi, 0.0, places=1)
@ -413,7 +400,6 @@ class TestWMMAF16(unittest.TestCase):
def test_v_wmma_f16_16x16x16_f16_all_ones(self):
"""V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
# Initialize A matrix in v[16:23] (8 regs)
@ -432,13 +418,12 @@ class TestWMMAF16(unittest.TestCase):
for lane in range(32):
for reg in range(8):
result = st.vgpr[lane][reg]
lo = _f16(result & 0xffff)
lo = f16(result & 0xffff)
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
def test_v_wmma_f16_16x16x16_f16_with_accumulator(self):
"""V_WMMA_F16_16X16X16_F16 with non-zero accumulator."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only
@ -458,7 +443,7 @@ class TestWMMAF16(unittest.TestCase):
for lane in range(32):
for reg in range(8):
result = st.vgpr[lane][reg]
lo = _f16(result & 0xffff)
lo = f16(result & 0xffff)
self.assertAlmostEqual(lo, 21.0, places=0, msg=f"v[{reg}] lane {lane}: expected 21.0, got {lo}")
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
@ -468,7 +453,6 @@ class TestWMMAF16(unittest.TestCase):
Regression test: WMMA was using static register indices instead of dynamic.
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
"""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
# Initialize A matrix in v[64:71] (8 regs)
@ -490,7 +474,7 @@ class TestWMMAF16(unittest.TestCase):
for lane in range(32):
for reg in range(8):
result = st.vgpr[lane][reg]
lo = _f16(result & 0xffff)
lo = f16(result & 0xffff)
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
@ -713,7 +697,6 @@ class TestPackedMixedSigns(unittest.TestCase):
def test_pk_add_f16_mixed_signs(self):
"""V_PK_ADD_F16 with mixed positive/negative values."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
@ -723,14 +706,13 @@ class TestPackedMixedSigns(unittest.TestCase):
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 2.0, places=2) # 1.0 + 1.0
self.assertAlmostEqual(hi, -1.0, places=2) # -2.0 + 1.0
def test_pk_mul_f16_zero(self):
"""V_PK_MUL_F16 with zero."""
from extra.assembly.amd.test.hw.helpers import _f16
instructions = [
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0
@ -743,5 +725,276 @@ class TestPackedMixedSigns(unittest.TestCase):
self.assertEqual(result, 0x00000000, "2.0 * 0.0 should be 0.0")
class TestDot2F32F16(unittest.TestCase):
"""Tests for V_DOT2_F32_F16 - dot product of f16 pairs producing f32."""
def test_v_dot2_f32_f16_basic(self):
"""V_DOT2_F32_F16: dot product of two packed f16 pairs -> f32."""
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=4.0, lo=3.0}
# result = 1.0*3.0 + 2.0*4.0 + 0 = 3 + 8 = 11.0
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
src1 = (f32_to_f16(4.0) << 16) | f32_to_f16(3.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 11.0, places=2)
def test_v_dot2_f32_f16_with_accumulator(self):
"""V_DOT2_F32_F16 with non-zero f32 accumulator."""
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 5.0
# result = 1.0*1.0 + 1.0*1.0 + 5.0 = 7.0
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], f2i(5.0)),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]), # same as src0
v_mov_b32_e32(v[2], s[1]),
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 7.0, places=2)
def test_v_dot2_f32_f16_negative_values(self):
"""V_DOT2_F32_F16 with negative f16 values."""
# src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0}
# result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0
src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0)
src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 4.0, places=2)
class TestDot2F16F16(unittest.TestCase):
"""Tests for V_DOT2_F16_F16 - dot product of f16 pairs producing f16."""
def test_v_dot2_f16_f16_basic(self):
"""V_DOT2_F16_F16: dot product of two packed f16 pairs -> f16."""
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=3.0, lo=2.0}
# result = 1.0*2.0 + 2.0*3.0 + 0 = 2 + 6 = 8.0 (f16)
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
src1 = (f32_to_f16(3.0) << 16) | f32_to_f16(2.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 8.0, places=1)
def test_v_dot2_f16_f16_with_accumulator(self):
"""V_DOT2_F16_F16 with non-zero f16 accumulator."""
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 3.0 (f16)
# result = 1.0*1.0 + 1.0*1.0 + 3.0 = 5.0 (f16)
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
acc = f32_to_f16(3.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[2], acc),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]), # same as src0
v_mov_b32_e32(v[2], s[2]),
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(result, 5.0, places=1)
class TestSignedDotProducts(unittest.TestCase):
"""Tests for V_DOT4_I32_IU8 and V_DOT8_I32_IU4 with signed inputs."""
def test_v_dot4_i32_iu8_signed_both(self):
"""V_DOT4_I32_IU8 with both inputs signed (neg=0b011)."""
# src0 = {-1, -2, 3, 4} as i8 = {0xff, 0xfe, 0x03, 0x04}
# src1 = {1, 1, 1, 1} as i8
# result = (-1)*1 + (-2)*1 + 3*1 + 4*1 = -1 - 2 + 3 + 4 = 4
src0 = (0xff << 24) | (0xfe << 16) | (0x03 << 8) | 0x04 # -1, -2, 3, 4
src1 = 0x01010101 # 1, 1, 1, 1
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b011), # both signed
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
# Result is i32, interpret as signed
if result >= 0x80000000:
result = result - 0x100000000
self.assertEqual(result, 4)
def test_v_dot4_i32_iu8_src0_signed(self):
"""V_DOT4_I32_IU8 with only src0 signed (neg=0b001)."""
# src0 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
# src1 = {2, 2, 2, 2} as u8
# result = (-1)*2 + (-1)*2 + (-1)*2 + (-1)*2 = -8
src0 = 0xffffffff # -1, -1, -1, -1 (as i8)
src1 = 0x02020202 # 2, 2, 2, 2 (as u8)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b001), # src0 signed
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
if result >= 0x80000000:
result = result - 0x100000000
self.assertEqual(result, -8)
def test_v_dot4_i32_iu8_src1_signed(self):
"""V_DOT4_I32_IU8 with only src1 signed (neg=0b010)."""
# src0 = {2, 2, 2, 2} as u8
# src1 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
# result = 2*(-1) + 2*(-1) + 2*(-1) + 2*(-1) = -8
src0 = 0x02020202 # 2, 2, 2, 2 (as u8)
src1 = 0xffffffff # -1, -1, -1, -1 (as i8)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b010), # src1 signed
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
if result >= 0x80000000:
result = result - 0x100000000
self.assertEqual(result, -8)
def test_v_dot4_i32_iu8_unsigned_as_reference(self):
"""V_DOT4_I32_IU8 with both unsigned (neg=0) - same as V_DOT4_U32_U8."""
# src0 = {0xff, 0xff, 0xff, 0xff} = 255 each as u8
# src1 = {1, 1, 1, 1}
# result = 255*1 + 255*1 + 255*1 + 255*1 = 1020
src0 = 0xffffffff
src1 = 0x01010101
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0), # both unsigned
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][3], 1020)
def test_v_dot8_i32_iu4_signed_both(self):
"""V_DOT8_I32_IU4 with both inputs signed (neg=0b011)."""
# src0 = 8 nibbles: {-1, -2, 3, 4, -1, -2, 3, 4} as i4
# i4 -1 = 0xf, -2 = 0xe, 3 = 0x3, 4 = 0x4
# src0 = 0xfe34fe34
# src1 = {1, 1, 1, 1, 1, 1, 1, 1} as i4 = 0x11111111
# result = 2 * ((-1)*1 + (-2)*1 + 3*1 + 4*1) = 2 * 4 = 8
src0 = 0xfe34fe34
src1 = 0x11111111
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
if result >= 0x80000000:
result = result - 0x100000000
self.assertEqual(result, 8)
def test_v_dot8_i32_iu4_all_negative(self):
"""V_DOT8_I32_IU4 with all negative signed values."""
# src0 = 8 nibbles all -1 (0xf) = 0xffffffff
# src1 = 8 nibbles all 1 = 0x11111111
# result = 8 * ((-1)*1) = -8
src0 = 0xffffffff # all -1 as i4
src1 = 0x11111111 # all 1
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
if result >= 0x80000000:
result = result - 0x100000000
self.assertEqual(result, -8)
class TestPkMinMaxF16(unittest.TestCase):
"""Tests for V_PK_MIN_F16 and V_PK_MAX_F16."""
def test_v_pk_min_f16_basic(self):
"""V_PK_MIN_F16: packed min of two f16 pairs."""
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
# result = {min(3,2)=2, min(1,4)=1}
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_min_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 1.0, delta=0.01)
self.assertAlmostEqual(hi, 2.0, delta=0.01)
def test_v_pk_max_f16_basic(self):
"""V_PK_MAX_F16: packed max of two f16 pairs."""
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
# result = {max(3,2)=3, max(1,4)=4}
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
instructions = [
s_mov_b32(s[0], src0),
s_mov_b32(s[1], src1),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_max_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = f16(result & 0xffff)
hi = f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 4.0, delta=0.01)
self.assertAlmostEqual(hi, 3.0, delta=0.01)
if __name__ == '__main__':
unittest.main()

View file

@ -294,7 +294,7 @@ class TestAllPcode(unittest.TestCase):
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'WAVE_STATUS': u32(),
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}

View file

@ -471,7 +471,7 @@ THREADS = 128
def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.arch}")
print(f"Device arch: {dev.renderer.arch}")
if getenv("STOCK", 0):
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
@ -479,7 +479,7 @@ def test_matmul():
asm = stock_path.read_text()
print(f"Loaded stock kernel from {stock_path}")
else:
asm = build_kernel(dev.arch)
asm = build_kernel(dev.renderer.arch)
binary = dev.compiler.compile(asm)
print(f"Compiled! Binary size: {len(binary)} bytes")

11517
extra/gemm/asm/cdna/asm.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,95 @@
import atexit, functools
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, dedup
from extra.gemm.asm.cdna.asm import build_kernel, GEMM_ARGS
# ** CDNA4 assembly gemm
WORKGROUP_SIZE = 256
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
batch, M, K = A.shape
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(wg, "gidx0")
k = build_kernel(batch, M, N, K, A.dtype.base)
sink = UOp.sink(C.base, A.base, B.base, lidx, gidx,
arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
binary = HIPCompiler(arch).compile(k.to_asm())
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary)))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used'))
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}")
# only sharding on the batch is tested, others might work too
if isinstance(a.device, tuple) and not (a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None):
return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple): batch //= len(a.device)
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}")
return True
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4
# note: this can be removed after we have GEMM on mixins
def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
M, K = A.shape[0]*A.shape[1], A.shape[2]
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
m = UOp.range(M, 1, AxisType.LOOP)
n = UOp.range(N, 2, AxisType.LOOP)
k = UOp.range(K, 0, AxisType.REDUCE)
mul = (A.index((m*UOp.const(dtypes.index, K)+k))*B.index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32)
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
store = C.index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n)
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop
a_T = a_t.transpose(-2, -1)
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
grad_b = (a_T * g_r).sum((-1, 0)).uop
return (None, grad_a, grad_b)
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
squeeze = a.ndim == 2
if squeeze: a = a.unsqueeze(0)
batch, M, K = a.shape
N = b.shape[1]
is_multi = isinstance(a.device, tuple)
if is_multi:
out = Tensor(Tensor.empty(batch//len(a.device), M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
else:
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
dname = a.device[0] if is_multi else a.device
arch = getattr(Device[dname].renderer, "arch", None)
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
numWG = GEMM_ARGS[(M, N, K)][0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
return out.squeeze(0) if squeeze else out

File diff suppressed because it is too large Load diff

View file

@ -1,78 +0,0 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
INSTRUCTIONS
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# basic memory requirements
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 28
# register usage (RSRC1)
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
# workgroup / workitem IDs (RSRC2)
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# user SGPRs, we only specify the kernel args ptr in s[0:1]
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_kernarg_preload_length 0
.amdhsa_user_sgpr_kernarg_preload_offset 0
# gfx90a / gfx940 specifics (RSRC3)
.amdhsa_accum_offset 248
.amdhsa_uses_dynamic_stack 0
.amdhsa_tg_split 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .name: gemm
.symbol: gemm.kd
.args:
- .name: C
.address_space: global
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: B
.address_space: global
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: A
.address_space: global
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 28
.max_flat_workgroup_size: 256
.sgpr_count: 88
.sgpr_spill_count: 0
.vgpr_count: 248
.vgpr_spill_count: 0
.wavefront_size: 64
amdhsa.version:
- 1
- 0
...
.end_amdgpu_metadata

View file

@ -1,73 +0,0 @@
# Run assembly on the AMD runtime and check correctness
# VIZ=2 to profile
import pathlib
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.engine.realize import Estimates
from tinygrad.helpers import getenv
fp = pathlib.Path(__file__).parent/"gemm.s"
N = getenv("N", 8192)
THREADS_PER_WG = 256
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
# ** generate inputs on CPU
scale = 10.0
import torch
torch.manual_seed(0)
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
Bt = B.t().contiguous() # transpose B for the asm gemm
C_torch = A@B
# ** copy buffers to AMD
# input creation and validation run on the copy engine for simpler tracing
def from_torch(t:torch.Tensor) -> Tensor:
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
C_tiny = from_torch(A) @ from_torch(B)
C_asm = Tensor.empty_like(C_tiny)
# ** assembly custom kernel
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
lidx = UOp.special(THREADS_PER_WG, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192)
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
# ** run gemms
sched = Tensor.schedule(C_tiny, C_asm)
eis = [si.lower() for si in sched]
with Context(DEBUG=2):
for ei in eis:
et = ei.run({"SZ":N}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness
import ctypes
def torch_bf16(t:Tensor) -> torch.tensor:
asm_out = t.to("cpu").realize().uop.buffer._buf
buf = (ctypes.c_uint16*C_asm.uop.size).from_address(asm_out.va_addr)
return torch.frombuffer(buf, dtype=torch.bfloat16, count=C_asm.uop.size).reshape(C_asm.shape)
assert torch.allclose(torch_bf16(C_asm), C_torch, rtol=1e-2, atol=1e-3)
assert torch.allclose(torch_bf16(C_tiny), C_torch, rtol=1e-2, atol=1e-3)

View file

@ -0,0 +1,46 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv
from extra.gemm.asm.cdna.gemm import asm_gemm
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi=False) -> None:
Tensor.manual_seed(0)
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype)
with Context(DEBUG=0):
Tensor.realize(a_rand, b_rand)
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) if multi else None
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
tst = asm_gemm(a, b)
tst.sum().backward()
Tensor.realize(tst, a.grad, b.grad)
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
ref.sum().backward()
Tensor.realize(ref, a_ref.grad, b_ref.grad)
with Context(DEBUG=0):
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, multi=True)
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, multi=True)
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, multi=True)
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, multi=True)
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, multi=True)
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, multi=True)
def test_gemm_unsupported(self):
with self.assertRaisesRegex(AssertionError, "shape not supported"):
verify_asm_gemm(8, 8192, 1024, 4096, multi=True)
if __name__ == "__main__":
unittest.main()

View file

@ -25,7 +25,7 @@ class StallReason(enum.IntEnum):
OTHER = 11 # misc, dispatch_stall
SLEEPING = 12 # sleeping
STALL_KEY_MAP: dict[int, StallReason] = {
STALL_KEY_MAP_AMPERE: dict[int, StallReason] = {
1: StallReason.MEMORY_THROTTLE, 15: StallReason.MEMORY_THROTTLE,
2: StallReason.CONSTANT_MEMORY,
3: StallReason.SYNC,
@ -37,14 +37,25 @@ STALL_KEY_MAP: dict[int, StallReason] = {
18: StallReason.NONE,
}
STALL_KEY_MAP_BLACKWELL: dict[int, StallReason] = {
0x01: StallReason.MEMORY_THROTTLE, 0x0e: StallReason.MEMORY_THROTTLE,
0x02: StallReason.SYNC,
0x05: StallReason.INST_FETCH, 0x0a: StallReason.INST_FETCH,
0x06: StallReason.EXEC_DEPENDENCY, 0x09: StallReason.EXEC_DEPENDENCY,
0x08: StallReason.MEMORY_DEPENDENCY,
0x0b: StallReason.PIPE_BUSY, 0x0f: StallReason.PIPE_BUSY,
0x10: StallReason.OTHER, 0x13: StallReason.OTHER,
0x11: StallReason.NONE,
}
# Lookup table for extracting sample bytes from 32-byte packet (bytes 0-3, 8-31, skipping header at 4-7)
LOOKUP_28B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
# ═══════════════════════════════════════════════════════════════════════════════
# AMPERE PACKET DEFINITIONS (8-byte aligned)
# PACKET HEADER
# ═══════════════════════════════════════════════════════════════════════════════
# Lookup table for extracting sample bytes from 32-byte packet
LOOKUP_8B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
class PMAHeaderAmpere8B(PacketType):
class PMAHeader(PacketType):
num_bytes = bits[4:0] # number of sample bytes in this packet
tpc_id_lo = bits[15:8] # TPC identifier low 8 bits
tpc_id_hi = bits[27:25] # TPC identifier high 3 bits
@ -52,33 +63,57 @@ class PMAHeaderAmpere8B(PacketType):
@property
def tpc_id(self) -> int: return self.tpc_id_lo | (self.tpc_id_hi << 8)
# ═══════════════════════════════════════════════════════════════════════════════
# 8-BYTE SAMPLE FORMAT (Ampere/Ada/Hopper)
# ═══════════════════════════════════════════════════════════════════════════════
class PMASampleAmpere8B(PacketType):
pc_raw = bits[44:0] # raw PC value (actual PC = pc_raw << 4)
stall_lo = bits[47:45] # stall key low 3 bits
stall_hi = bits[49:48] # stall key high 2 bits
wave_id = bits[55:50] # warp/wave identifier (0-63)
active = bits[62:62] # active flag (warp was executing)
pc_raw = bits[44:0] # raw PC value (pc_offset = pc_raw << 4)
stall_key = bits[49:45] # stall reason key
wave_id = bits[55:50] # warp/wave identifier
active = bits[62:62] # 1 if warp was executing, 0 if scheduled but not issued
@property
def pc_offset(self) -> int: return self.pc_raw << 4
@property
def stall_key(self) -> int: return self.stall_lo | (self.stall_hi << 3)
@property
def stall_reason(self) -> StallReason: return STALL_KEY_MAP.get(self.stall_key, StallReason.OTHER)
def stall_reason(self) -> StallReason: return STALL_KEY_MAP_AMPERE.get(self.stall_key, StallReason.OTHER)
# ═══════════════════════════════════════════════════════════════════════════════
# 9-BYTE SAMPLE FORMAT (Blackwell+)
# ═══════════════════════════════════════════════════════════════════════════════
class PMASampleBlackwell9B(PacketType):
stall_key = bits[5:0] # stall reason key
pc_raw = bits[60:8] # raw PC value (pc_offset = pc_raw << 4)
wave_hi = bits[7:6] # wave_id high 2 bits
wave_lo = bits[71:68] # wave_id low 4 bits
active = bits[67:67] # 1 if warp was executing, 0 if scheduled but not issued
@property
def pc_offset(self) -> int: return self.pc_raw << 4
@property
def stall_reason(self) -> StallReason: return STALL_KEY_MAP_BLACKWELL.get(self.stall_key, StallReason.OTHER)
@property
def wave_id(self) -> int: return (self.wave_hi << 4) | self.wave_lo
PMASample = PMASampleAmpere8B|PMASampleBlackwell9B
def decode(data: bytes, sm_version: int = 0x800) -> Iterator[tuple[PMASample, int]]:
use_9byte = sm_version >= 0xa04
record_size = 9 if use_9byte else 8
sample_cls = PMASampleBlackwell9B if use_9byte else PMASampleAmpere8B
def decode(data: bytes) -> Iterator[tuple[PMASampleAmpere8B, int]]:
tpc_state: dict[int, list[int]] = collections.defaultdict(list)
for pkt_idx in range(len(data) // 32):
pkt = data[pkt_idx * 32:(pkt_idx + 1) * 32]
hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little'))
hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little'))
if hdr.dropped: tpc_state[hdr.tpc_id].clear()
for i in range(hdr.num_bytes):
tpc_state[hdr.tpc_id].append(pkt[LOOKUP_8B[i]])
tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[i]])
while len(tpc_state[hdr.tpc_id]) >= 8:
yield PMASampleAmpere8B.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:8]), 'little')), hdr.tpc_id
del tpc_state[hdr.tpc_id][:8]
while len(tpc_state[hdr.tpc_id]) >= record_size:
yield sample_cls.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:record_size]), 'little')), hdr.tpc_id
del tpc_state[hdr.tpc_id][:record_size]
# ═══════════════════════════════════════════════════════════════════════════════
# CLI
@ -90,11 +125,11 @@ STALL_COLORS = {
StallReason.PIPE_BUSY: "yellow", StallReason.MEMORY_THROTTLE: "RED", StallReason.OTHER: "white",
}
def decode_tpc_id(tpc_id: int) -> tuple[int, int, int]:
def decode_tpc_id(tpc_id:int) -> tuple[int, int, int]:
# NOTE: valid only for ops_nv, cuda encoding is different
return (tpc_id >> 5, (tpc_id >> 1) & 0xf, tpc_id & 1)
def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
def print_samples(samples:list[tuple[PMASample, int]]) -> None:
if not samples: return
base_pc = min(s.pc_offset for s, _ in samples)
for s, tpc_id in samples:
@ -102,40 +137,57 @@ def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
stall_str = colored(f"{s.stall_reason.name:17}", STALL_COLORS.get(s.stall_reason, "white"))
print(f"pc=0x{s.pc_offset - base_pc:06x} {stall_str} ev={s.stall_key:2d} active={s.active} wave={s.wave_id:2d} gpc={gpc} tpc={tpc} sm={sm}")
def print_packets(data: bytes) -> None:
def print_packets(data:bytes, sm_version:int=0x800) -> None:
record_size = 9 if sm_version >= 0x890 else 8
tpc_state: dict[int, list[int]] = collections.defaultdict(list)
for i in range(len(data) // 32):
pkt = data[i * 32:(i + 1) * 32]
hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little'))
print(f"Pkt {i:3d}: tpc={hdr.tpc_id} bytes={hdr.num_bytes} drop={hdr.dropped} | {pkt.hex()}")
hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little'))
if hdr.dropped: tpc_state[hdr.tpc_id].clear()
for j in range(hdr.num_bytes): tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[j]])
# Show complete records extracted from this packet
records = []
while len(tpc_state[hdr.tpc_id]) >= record_size:
records.append(bytes(tpc_state[hdr.tpc_id][:record_size]).hex())
del tpc_state[hdr.tpc_id][:record_size]
leftover = len(tpc_state[hdr.tpc_id])
print(f"Pkt {i:3d}: tpc={hdr.tpc_id:4d} n={hdr.num_bytes:2d} drop={hdr.dropped} left={leftover} | {' '.join(records)}")
def print_aggregated(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
def print_aggregated(samples:list[tuple[PMASample, int]]) -> None:
if not samples: return
base_pc = min(s.pc_offset for s, _ in samples)
counter: collections.Counter[tuple[int, int]] = collections.Counter((s.pc_offset, s.stall_key) for s, _ in samples)
counter: collections.Counter[tuple[int, StallReason]] = collections.Counter((s.pc_offset, s.stall_reason) for s, _ in samples)
print(f"\nAggregated samples (base_pc=0x{base_pc:x}):")
for (pc, key), cnt in sorted(counter.items()):
reason = STALL_KEY_MAP.get(key, StallReason.OTHER)
for (pc, reason), cnt in sorted(counter.items()):
stall_str = colored(f"{reason.name:17}", STALL_COLORS.get(reason, "white"))
print(f" pc=0x{pc - base_pc:06x} {stall_str} ev={key:2d} samples={cnt:4d}")
print(f" pc=0x{pc - base_pc:06x} {stall_str} samples={cnt:4d}")
if __name__ == "__main__":
import sys, pickle
if len(sys.argv) < 2:
print(__doc__)
print("Usage: python decode.py <pkl_file> [--raw] [--sm=0xNNN]")
sys.exit(1)
with open(sys.argv[1], "rb") as f:
data = pickle.load(f)
if isinstance(data, dict): dumps = list(enumerate(data["pma_raw_dumps"]))
else: dumps = [(i, e.blob) for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent")]
if isinstance(data, dict):
sm_version = 0x800 # default to Ampere
for arg in sys.argv:
if arg.startswith("--sm="): sm_version = int(arg[5:], 0)
dumps = [(i, x, sm_version) for i, x in enumerate(data["pma_raw_dumps"])]
else:
devs = {e.device: e for e in data if type(e).__name__ == "ProfileDeviceEvent"}
dumps = []
for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent"):
dumps.append((i, e.blob, devs[e.device].props.get('sm_version', 0x800)))
for dump_idx, raw in dumps:
for dump_idx, raw, sm_ver in dumps:
print(f"\n{'='*60}\nDump {dump_idx} ({len(raw)} bytes, {len(raw)//32} packets)\n{'='*60}")
if "--raw" in sys.argv: print_packets(raw)
if "--raw" in sys.argv: print_packets(raw, sm_ver)
else:
samples = list(decode(raw))
samples = list(decode(raw, sm_ver))
print(f"\nDecoded {len(samples)} samples:")
print_samples(samples)
print_aggregated(samples)

View file

@ -6,13 +6,17 @@ from extra.nv_pma.decode import decode
from tinygrad.helpers import DEBUG
EXAMPLES_DIR = Path(__file__).parent.parent / "examples"
EXAMPLES_5090_DIR = Path(__file__).parent.parent / "examples_5090"
def decode_and_aggregate(raw_dumps: list[bytes]) -> Counter[tuple[int, int]]:
"""Decode all PMA buffers and aggregate by (relative_pc, stall_reason)."""
all_samples = [s for raw in raw_dumps for s, _ in decode(raw)]
if not all_samples: return Counter()
base_pc = min(s.pc_offset for s in all_samples)
return Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in all_samples)
def decode_and_aggregate(raw_dumps: list[bytes], sm_version: int = 0x800) -> Counter[tuple[int, int]]:
"""Decode all PMA buffers and aggregate by (relative_pc, stall_reason). Each dump is normalized separately."""
result: Counter[tuple[int, int]] = Counter()
for raw in raw_dumps:
samples = [s for s, _ in decode(raw, sm_version)]
if not samples: continue
base_pc = min(s.pc_offset for s in samples)
result += Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in samples)
return result
def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]:
"""Convert CUPTI records to Counter[(pcOffset, stallReason)]."""
@ -22,8 +26,8 @@ def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]:
return counter
class TestNVProf(unittest.TestCase):
def _test_example(self, name: str):
pkl_file = EXAMPLES_DIR / f"{name}.pkl"
def _test_example(self, name: str, sm_version: int = 0x800, examples_dir: Path = EXAMPLES_DIR):
pkl_file = examples_dir / f"{name}.pkl"
if not pkl_file.exists():
self.skipTest(f"Example data not found: {pkl_file}. Run collect.py first.")
@ -31,7 +35,7 @@ class TestNVProf(unittest.TestCase):
data = pickle.load(f)
self.assertEqual(data["test_name"], name)
pma_agg = decode_and_aggregate(data["pma_raw_dumps"])
pma_agg = decode_and_aggregate(data["pma_raw_dumps"], sm_version)
cupti_agg = cupti_to_counter(data["cupti_pc_samples"])
if DEBUG >= 2:
@ -45,6 +49,7 @@ class TestNVProf(unittest.TestCase):
self.assertEqual(pma_agg, cupti_agg, f"PMA: {dict(pma_agg)}\nCUPTI: {dict(cupti_agg)}")
# Ampere tests (8-byte format)
def test_decode_test_plus(self): self._test_example("test_plus")
def test_decode_test_reduce_sum(self): self._test_example("test_reduce_sum")
def test_decode_test_broadcast(self): self._test_example("test_broadcast")
@ -54,5 +59,18 @@ class TestNVProf(unittest.TestCase):
def test_decode_test_conv2d(self): self._test_example("test_conv2d")
def test_decode_test_large_matmul(self): self._test_example("test_large_matmul")
# Blackwell/5090 tests (9-byte format)
def test_5090_test_plus(self): self._test_example("test_plus", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_plus_big(self): self._test_example("test_plus_big", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_broadcast(self): self._test_example("test_broadcast", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_matmul(self): self._test_example("test_matmul", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_large_matmul(self): self._test_example("test_large_matmul", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_reduce_sum(self): self._test_example("test_reduce_sum", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_reduce_max(self): self._test_example("test_reduce_max", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_elementwise_chain(self): self._test_example("test_elementwise_chain", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_conv2d(self): self._test_example("test_conv2d", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_exp(self): self._test_example("test_exp", 0xa04, EXAMPLES_5090_DIR)
def test_5090_test_softmax(self): self._test_example("test_softmax", 0xa04, EXAMPLES_5090_DIR)
if __name__ == "__main__":
unittest.main()

View file

@ -9,8 +9,8 @@ from extra.thunder.tiny.tk.kernel import Kernel
from extra.thunder.tiny.tk.tiles import GL, TileLayout
NUM_WORKERS = 1
Q_BLOCK_SIZE = 16
KV_BLOCK_SIZE = 16
Q_BLOCK_SIZE = 32
KV_BLOCK_SIZE = 32
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None) -> Tensor:
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=ref.dtype, device=ref.device)
@ -70,10 +70,10 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
max_vec_last = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
max_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
norm_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
scale_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
max_vec_last = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
max_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
norm_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
scale_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
max_vec = warp.neg_inf(max_vec)
norm_vec = warp.zero(norm_vec)
@ -105,7 +105,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
# softmax
max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec)
max_vec = warp.row_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
max_vec = warp.col_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
scale_vec = warp.map(scale_vec.after(max_vec_last, max_vec), lambda _, idx: max_vec_last[*idx] - max_vec[*idx])
scale_vec = scale_vec.exp2()
@ -116,7 +116,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
att_block -= max_vec
att_block = att_block.exp2()
norm_vec = warp.row_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
norm_vec = warp.col_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
# mma av
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
@ -313,7 +313,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
att_smem = warp.store(att_smem, att_block_transposed)
att_block_row = warp.load(att_block_row, att_smem)
dv_reg_ = warp.mma_AB(dv_reg, att_block_row, do_reg_col)
dv_reg_ = warp.mma_AtB(dv_reg, att_block_row, do_reg_col)
dp_block = warp.zero(dp_block.after(g, q_idx, dv_reg_))
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
@ -325,7 +325,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
att_smem = warp.store(att_smem, att_block_transposed)
att_block_row = warp.load(att_block_row, att_smem)
dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col)
dk_reg = warp.mma_AtB(dk_reg, att_block_row, q_reg_col)
dk_reg = ker.endrange(2)
dv_reg = dv_reg.after(dk_reg)

View file

@ -163,5 +163,30 @@ class TestIndexing(unittest.TestCase):
# at least the arange is being fused
def test_llama_embedding_opt(self): self.test_llama_embedding(0, 1_736_704_000)
# NOTE: call doesn't work with SPEC=2
@unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU")
@Context(USE_ATOMICS=1, SPEC=1)
def test_llama_8b_embedding_backward(self):
from tinygrad.renderer.cstyle import CStyleLanguage
if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer")
vocab_size, embed_size = 1000, 128
bs, seqlen = 4, 256
idx = Tensor.randint(bs, seqlen, high=vocab_size)
emb = nn.Embedding(vocab_size, embed_size)
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
gt = Tensor.zeros(bs, seqlen, embed_size)
Tensor.realize(idx, emb.weight, gt)
GlobalCounters.reset()
loss = (emb(idx)-gt).square().sum()
loss.backward()
emb.weight.grad.realize()
bwd_ops = GlobalCounters.global_ops
print(f"embedding bwd: {GlobalCounters.kernel_count} kernels, {bwd_ops:,} ops")
self.assertLess(bwd_ops, bs*seqlen*embed_size*20, f"backward ops {bwd_ops:,} should be less than 20 per with atomic scatter-add")
# correctness check
expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32)
for i in idx.flatten().numpy(): expected_grad[i] += 2
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
unittest.main()

View file

@ -247,5 +247,39 @@ class TestCustomKernel(unittest.TestCase):
err = (O_custom - O_ref).square().max()
self.assertLess(err.item(), 1e-6)
def test_multi_after_schedule_order(self):
"""Test correct scheduling order when custom_kernel has multiple outputs.
custom_kernel with 4 arguments creates 4 AFTERs from the same kernel.
The custom_kernel depends on both A2 and B2, so it must be scheduled after both.
E only depends on A2, so E can run before custom_kernel finishes waiting for B2.
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
The custom_addmul kernel should be at index 3.
"""
from tinygrad.engine.schedule import create_schedule
from tinygrad.schedule.rangeify import get_rangeify_map
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
A2 = (A + 1).contiguous() # kernel 0: depends on A
B2 = (B * 2).contiguous() # kernel 1: depends on B
C, D = Tensor.empty(4, 4), Tensor.empty(4, 4)
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
big_sink = result.uop.sink()
tensor_map = get_rangeify_map(big_sink)
sched_sink = big_sink.substitute(tensor_map)
schedule, _ = create_schedule(sched_sink)
# Find the custom_addmul kernel position
custom_idx = next((i for i, item in enumerate(schedule)
if hasattr(item.ast, "arg") and hasattr(item.ast.arg, "name")
and "custom_addmul" in item.ast.arg.name), None)
self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule")
self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}")
if __name__ == '__main__':
unittest.main()

View file

@ -1,4 +1,4 @@
import unittest, math
import contextlib, unittest, math
import numpy as np
import torch
from typing import Any, List
@ -7,7 +7,8 @@ from tinygrad.helpers import getenv, DEBUG, CI, EMULATED_DTYPES
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad import Device, Tensor, dtypes
from tinygrad import Context, Device, Tensor, dtypes
from tinygrad.uop import Ops
from hypothesis import given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
@ -336,18 +337,37 @@ class TestUint16DType(TestDType):
class TestInt32DType(TestDType): DTYPE = dtypes.int32
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
class TestInt64DType(TestDType):
DTYPE = dtypes.int64
class TestInt64DType(TestDType): DTYPE = dtypes.int64
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedInt64DType(TestInt64DType):
@classmethod
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestUint64DType(TestDType):
@classmethod
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
DTYPE = dtypes.uint64
def test_uint64_load(self):
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedUInt64DType(TestUint64DType):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16

View file

@ -1,5 +1,5 @@
import unittest, operator, math
from tinygrad import Tensor, dtypes, Device
from tinygrad import Context, Tensor, dtypes, Device
from tinygrad.dtype import DType, truncate
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import _to_np_dtype
@ -7,6 +7,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad.uop import Ops
import numpy as np
import pytest
from hypothesis import assume, given, strategies as strat, settings
@ -165,9 +166,16 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
@Context(EMULATED_DTYPES="long")
def test_emulated_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
@given(ht.int8, ht.int8, strat.sampled_from(integer_binary_operations))
def test_int8(self, a, b, op): universal_test(a, b, dtypes.int8, op)
@ -177,9 +185,16 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
@Context(EMULATED_DTYPES="long")
def test_emulated_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
@given(ht.uint8, strat.sampled_from(integer_unary_operations))
def test_uint8_unary(self, a, op): universal_test_unary(a, dtypes.uint8, op)
@ -191,9 +206,16 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
@Context(EMULATED_DTYPES="long")
def test_emulated_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
@given(ht.int8, strat.sampled_from(integer_unary_operations))
def test_int8_unary(self, a, op): universal_test_unary(a, dtypes.int8, op)
@ -203,9 +225,16 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(integer_unary_operations))
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, strat.sampled_from(integer_unary_operations))
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
@given(ht.int64, strat.sampled_from(integer_unary_operations))
@Context(EMULATED_DTYPES="long")
def test_emulated_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
@given(ht.bool, ht.bool, strat.sampled_from(((operator.add, operator.add), (operator.mul, operator.mul))))
def test_bool(self, a, b, op): universal_test(a, b, dtypes.bool, op)

View file

@ -409,6 +409,28 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_embedding_backward(self, shard_weight_axis=None):
B, T, embed_size, vocab_size = 4, 10, 20, 28
layer = nn.Embedding(vocab_size, embed_size)
layer.weight.requires_grad = True
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer(x)
z.sum().backward()
grad = layer.weight.grad.numpy()
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize()
layer_sharded.weight.requires_grad = True
x_sharded = x.shard(devices_2, axis=None)
z_shard = layer_sharded(x_sharded)
z_shard.sum().backward()
grad_shard = layer_sharded.weight.grad.numpy()
np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6)
def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1)
def test_rmsnorm(self):
B, T, embed_size = 4, 10, 20
@ -1251,19 +1273,20 @@ class TestMultiRamUsage(unittest.TestCase):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
def _test_matmul_half(self, devs):
def _test_matmul_half(self, dev_count:int):
N = 32
total_mem = {}
devs = tuple(f"NULL:{i}" for i in range(dev_count))
for dtype in {dtypes.float, dtypes.half}:
GlobalCounters.reset()
a = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=0)
b = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=None)
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
(a @ b).realize()
total_mem[dtype] = GlobalCounters.global_mem
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
def test_matmul_half(self): self._test_matmul_half(devices_2)
def test_matmul_half_alt(self): self._test_matmul_half(devices_4)
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):

View file

@ -2184,6 +2184,14 @@ class TestBufferUOp(unittest.TestCase):
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.uop.base.realized)
def test_unused_var_not_in_var_vals(self):
# unused variable should not appear in var_vals even when there's other work
a = Tensor(UOp.variable("unused", 0, 10).bind(1))
b = Tensor.empty(3) + 1
_, var_vals = Tensor.schedule_with_vars(a, b)
self.assertEqual(var_vals, {})
self.assertIsNone(a.uop.base.realized)
def test_view_does_not_realize(self):
a = Tensor.randn(1, 4).expand(4, 4)
a.realize()

View file

@ -2,8 +2,7 @@ import numpy as np
import torch
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.tensor import _METADATA
from tinygrad.helpers import Context, getenv, temp, mv_address
from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported
@ -796,92 +795,6 @@ class TestInferenceMode(unittest.TestCase):
assert W.grad is None
f(x, m, W)
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None: _METADATA.set(None)
# NOOPs are not included in kernel metadata
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.uop.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
@unittest.skip("TODO: flaky")
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.uop.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.uop.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@unittest.skip("this no longer works")
def test_assign(self):
x = Tensor.empty(10, 10).realize()
x.assign(Tensor.ones(10, 10).contiguous())
si = x.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "assign")
def test_complex(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.uop.metadata[0].name, "__mul__")
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.uop.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
self.assertTrue(x.grad.uop.metadata[0].backward)
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1]
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
# skip numpy, this is schedule cache
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
#bw = [m for m in si.metadata if m.backward]
#self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid")
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
self.assertIsNone(out.uop.metadata)
self.assertIsNone(out.uop.src[0].metadata)
si = out.schedule()[-1]
self.assertEqual(si.metadata, ())
class TestIdxUpcast(unittest.TestCase):
def _find_op(self, ast: UOp, op: Ops):
if ast.op is op: return ast

View file

@ -6,8 +6,8 @@ import textwrap, functools
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv
from tinygrad.device import Compiler
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.viz.serve import amdgpu_cfg
from extra.assembly.amd.autogen.rdna3.ins import *
@ -73,11 +73,11 @@ def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, compiler:Com
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def run_asm(name:str, insts:list) -> None:
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=Device[Device.DEFAULT].compiler)
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
out.realize()
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
class TestCfg(unittest.TestCase):
def setUp(self):
arch = Device["AMD"].arch
@ -110,7 +110,7 @@ class TestCfg(unittest.TestCase):
s_endpgm(),
s_code_end(),
])
_, lib = assemble("diamond", insts, Device[Device.DEFAULT].compiler)
_, lib = assemble("diamond", insts, HIPCompiler(Device[Device.DEFAULT].arch))
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
self.assertEqual(len(cfg["blocks"]), 5)
edge_count = sum(len(v) for v in cfg["paths"].values())
@ -238,5 +238,19 @@ class TestCfg(unittest.TestCase):
s_code_end(),
])
def test_hit_count(self):
run_asm("test_hit_count", [
"entry:",
s_mov_b32(s[1], 1),
"s_branch alt",
"continue:",
s_mov_b32(s[2], 2),
s_add_u32(s[1], s[1], s[2]),
"alt:",
s_add_u32(s[1], s[1], -1),
s_endpgm(),
s_code_end(),
])
if __name__ == "__main__":
unittest.main()

View file

@ -456,6 +456,24 @@ class TestAssign(unittest.TestCase):
assign.realize()
np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
def test_setitem_list(self):
a = Tensor.zeros(8).contiguous().realize()
a[2:5] = [1, 2, 3]
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
def test_assign_bitcast(self):
# assign to a bitcast view should modify the underlying buffer (only works on DISK currently)
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
np.testing.assert_allclose(a.numpy(), [1.0, 2.0, 3.0, 4.0]) # TODO: should be [4.0, 3.0, 2.0, 1.0]
def test_assign_bitcast_different_size(self):
# assign to a shape-changing bitcast view (only works on DISK currently)
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345)
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)

62
test/unit/test_call.py Normal file
View file

@ -0,0 +1,62 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp
class TestCall(unittest.TestCase):
def test_call_plus(self):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
Tensor.realize(a,b)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn)
np.testing.assert_equal(c.numpy(), (a+b).numpy())
def test_call_plus_backward(self):
a = Tensor.ones(10, 10, requires_grad=True)
b = Tensor.ones(10, 10, requires_grad=True)
(a+b).mean().backward()
gt_a_grad = a.grad.numpy()
gt_b_grad = b.grad.numpy()
a.grad, b.grad = None, None
# this is the gradient for +
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
# we define a plus function
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
c.mean().backward()
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_gemm(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
Tensor.realize(a, b)
# we define a gemm function
x = UOp.param(0, dtypes.float, shape=(M, K))
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()

View file

@ -250,6 +250,24 @@ class TestDiskTensor(unittest.TestCase):
tout = [(x//256, x%256) for x in out]
assert tout == list([(x+1,x) for x in range(32,64,2)])
def test_strided_read(self):
# test non-contiguous (strided) read - should read elements at indices 0, 2, 4
pathlib.Path(temp(fn:="dt_strided_read")).unlink(missing_ok=True)
dt = Tensor([0, 1, 2, 3, 4, 5]).to(f"disk:{temp(fn)}")
result = dt[::2].tolist()
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
# self.assertEqual(result, [0, 2, 4])
self.assertEqual(result, [0, 1, 2]) # wrong!
def test_permuted_read(self):
# test non-contiguous (permuted) read - should read transposed
pathlib.Path(temp(fn:="dt_permuted_read")).unlink(missing_ok=True)
dt = Tensor([[0, 1, 2], [3, 4, 5]]).to(f"disk:{temp(fn)}")
result = dt.T.tolist()
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
def test_write_ones(self):
pathlib.Path(temp("dt_write_ones")).unlink(missing_ok=True)
@ -276,6 +294,15 @@ class TestDiskTensor(unittest.TestCase):
dt[1] = [3]
self.assertEqual(dt.tolist(), [[1], [3]])
def test_strided_setitem(self):
# test non-contiguous (strided) setitem - should set elements at indices 0, 2, 4
pathlib.Path(temp(fn:="dt_strided_setitem")).unlink(missing_ok=True)
dt = Tensor([1, 2, 3, 4, 5, 6]).to(f"disk:{temp(fn)}")
dt[::2] = Tensor([10, 20, 30])
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
def test_assign_const_to_disk(self):
# assign from CONST (Tensor.full) to disk - source has no buffer, needs contiguous first
pathlib.Path(temp(fn:="dt_assign_const")).unlink(missing_ok=True)

View file

@ -1,5 +1,5 @@
import unittest
from tinygrad import dtypes, Device
from tinygrad import dtypes
from tinygrad.device import Buffer
from tinygrad.engine.memory import _internal_memory_planner
@ -7,7 +7,7 @@ global_map = {}
def b(i, base=None, offset=0, pin=False, size=16):
global global_map
if i in global_map: return global_map[i]
global_map[i] = Buffer(Device.DEFAULT, size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
if pin: global_map[i].ref(1)
return global_map[i]

View file

@ -0,0 +1,94 @@
import unittest
from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.helpers import Context
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None:
_METADATA.set(None)
self._ctx = Context(SCACHE=0)
self._ctx.__enter__()
def tearDown(self) -> None:
self._ctx.__exit__(None, None, None)
@unittest.skip("why would this be true?")
def test_exclude_noop_metadata(self):
a = Tensor.rand(4, 4)*1
self.assertEqual(a.uop.metadata[0].name, "__mul__")
k = a.schedule()[-1]
self.assertEqual([m.name for m in k.metadata], ["rand"])
@unittest.skip("metadata not reaching kernel schedule")
def test_exclude_const_metadata(self):
a = Tensor.arange(4)
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
sched = Tensor.schedule(a, b)
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
def test_matmul(self):
x = Tensor.rand(3, requires_grad=True)
W = Tensor.rand(3, 3, requires_grad=True)
out = x.matmul(W)
self.assertEqual(out.uop.metadata[0].name, "matmul")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "matmul")
def test_relu(self):
x = Tensor.rand(3, requires_grad=True)
out = x.relu()
self.assertEqual(out.uop.metadata[0].name, "relu")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
@unittest.skip("assign metadata no longer captured")
def test_assign(self):
x = Tensor.empty(10, 10).realize()
x.assign(Tensor.ones(10, 10).contiguous())
si = x.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "assign")
def test_complex(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = x.relu() * y.sigmoid()
self.assertEqual(out.uop.metadata[0].name, "__mul__")
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
si = out.schedule()[-1]
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.uop.metadata[0].name, "sum")
out.backward()
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
si = Tensor.schedule(out, x.grad, y.grad)[-1]
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
# skip numpy, this is schedule cache
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
#bw = [m for m in si.metadata if m.backward]
#self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid")
def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
self.assertIsNone(out.uop.metadata)
self.assertIsNone(out.uop.src[0].metadata)
si = out.schedule()[-1]
self.assertEqual(si.metadata, ())
if __name__ == '__main__':
unittest.main()

View file

@ -686,6 +686,97 @@ class TestExpander(unittest.TestCase):
sink = expander_rewrite(sink)
print(sink)
class TestReduceCollapse(unittest.TestCase):
def test_multi_range_reduce_add(self):
"""Test that (x + y).reduce(r1, r2) distributes over multiple ranges"""
from tinygrad.codegen.simplify import pm_reduce_collapse
# Create two ranges
r1 = UOp.range(3, 0)
r2 = UOp.range(4, 1)
# Create x + y where x and y depend on different ranges
x = r1.cast(dtypes.float)
y = r2.cast(dtypes.float)
# (x + y).reduce(r1, r2) should be rewritten
red = (x + y).reduce(r1, r2, arg=Ops.ADD)
self.assertEqual(len(red.src), 3) # value + 2 ranges
result = graph_rewrite(red, pm_reduce_collapse, name='test')
# Should become add of two separate reduces
self.assertEqual(result.op, Ops.ADD)
class TestLoadStoreFolding(unittest.TestCase):
def test_gated_load_gep_preserves_alt(self):
"""Test that LOAD(GEP, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.vec(4).ptr(), (), 0)
idx = UOp.const(dtypes.int, 0)
gate = UOp.const(dtypes.bool, True)
gated_index = buf.index(idx, gate)
gep = gated_index.gep(0)
alt = UOp.const(dtypes.float, 42.0)
gated_load = gep.load(alt)
self.assertEqual(len(gated_load.src), 2) # GEP + alt
result = graph_rewrite(gated_load, load_store_folding, name='test')
# After rewrite, should still have alt value preserved
self.assertEqual(result.op, Ops.GEP)
inner_load = result.src[0]
self.assertEqual(inner_load.op, Ops.LOAD)
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
def test_gated_load_ptrcat_preserves_alt(self):
"""Test that LOAD(PTRCAT, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
buf2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
idx1 = buf1.index(idx)
idx2 = buf2.index(idx)
ptrcat = UOp(Ops.PTRCAT, dtypes.float.ptr().vec(2), (idx1, idx2))
alt = UOp.const(dtypes.float.vec(2), 42.0)
gated_load = ptrcat.load(alt)
self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt
result = graph_rewrite(gated_load, load_store_folding, name='test')
# After rewrite, should be CAT of LOADs, each preserving alt
self.assertEqual(result.op, Ops.CAT)
for inner_load in result.src:
self.assertEqual(inner_load.op, Ops.LOAD)
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
self.assertEqual(inner_load.src[1].arg, 42.0) # alt value preserved
class TestConstBufferize(unittest.TestCase):
def test_const_bufferize_with_ranges(self):
"""Test that CONST.BUFFERIZE with ranges is folded correctly.
BUFFERIZE can have ranges as additional sources beyond the value.
The pattern at rangeify.py uses allow_any_len=True because
CONST doesn't depend on ranges (constant is same value everywhere).
"""
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
c = UOp.const(dtypes.float, 42.0)
r1 = UOp.range(3, 0)
bufferize_with_range = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_range.src), 2) # const + 1 range
result = graph_rewrite(bufferize_with_range, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed, result is const broadcast to shape
self.assertNotEqual(result.op, Ops.BUFFERIZE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(42.0, const_vals)
def test_const_bufferize_with_multiple_ranges(self):
"""Test CONST.BUFFERIZE with multiple ranges is also folded."""
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
c = UOp.const(dtypes.float, 3.14)
r1 = UOp.range(3, 0)
r2 = UOp.range(4, 1)
bufferize_with_ranges = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_ranges.src), 3) # const + 2 ranges
result = graph_rewrite(bufferize_with_ranges, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed
self.assertNotEqual(result.op, Ops.BUFFERIZE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(3.14, const_vals)
class TestUOpTags(unittest.TestCase):
def test_inc_by_one(self):
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)

View file

@ -115,8 +115,8 @@ pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
])
# requires lst be toposorted. like graph rewrite, but for lines

View file

@ -123,12 +123,12 @@ load_store_folding = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
])
# *** correct load/store ***
@ -345,6 +345,6 @@ pm_add_loads = PatternMatcher([
(UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else
idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
])

View file

@ -147,8 +147,9 @@ class Buffer:
def deallocate(self):
assert hasattr(self, '_buf'), "buffer must be allocated to deallocate"
if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
if self._base is None and (self.options is None or self.options.external_ptr is None):
if GlobalCounters is not None and not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
if self._base is None:
if GlobalCounters is not None and not self.device.startswith("DISK") and (self.options is None or self.options.external_ptr is None):
GlobalCounters.mem_used -= self.nbytes
if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", self.trace_num))
self.allocator.free(self._buf, self.nbytes, self.options)
elif self._base is not None: self._base.allocated_views -= 1
@ -263,7 +264,7 @@ class LRUAllocator(Allocator, Generic[DeviceType]):
for opaque in opaques: super().free(opaque, sz, options)
opaques.clear()
def free(self, opaque:Any, size:int, options:BufferSpec|None=None):
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
if LRU and (options is None or (not options.nolru and options.external_ptr is None)): self.cache[(size, options)].append(opaque)
else: super().free(opaque, size, options)
# **************** for Compiled Devices ****************

View file

@ -75,7 +75,7 @@ class BufferCopy(Runner):
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
if disk_supports_fast_copyout and hasattr(dest.allocator, 'copy_from_disk') and src.nbytes >= 4096:
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
elif isinstance(src.device, str) and src.device.startswith(("DISK", "TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:

View file

@ -1,21 +1,25 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
# ScheduleItem = tuple[AST, buffer UOps, metadata, fixedvars, bound_ranges]
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], dict[str, int], tuple[UOp, ...]]
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
return s
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
@ -26,81 +30,77 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src[0].src if k.op is Ops.END else k.src:
s = _unwrap_src(s)
if s.op is Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children.setdefault(ss.src[1], []).append(k)
in_degree[k] += 1
elif s.op in {Ops.BUFFER, Ops.BIND}:
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
else:
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
in_degree[k] += 1
case Ops.MSELECT | Ops.MSTACK:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
children.setdefault(ss.src[1], []).append(k)
in_degree[k] += 1
case Ops.BUFFER | Ops.BIND:
pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below)
case _:
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}")
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque()
for k,v in in_degree.items():
if v == 0: queue.append(k)
schedule: list[tuple|UOp] = []
schedule: list[ScheduleItem|UOp] = [] # ScheduleItem for kernels, UOp for RANGE/END
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL:
ast = k.arg.ast
ast = (kernel:=cast(Kernel, k.arg)).ast
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
schedule.append((ast, buf_uops, k.arg.metadata, {}, bound_ranges))
schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges))
if rk.op is Ops.END: schedule.append(rk)
else:
raise RuntimeError(f"can't schedule {k.op}")
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
with cpu_profile(TracingKey("expand ranges")):
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
sched_ptr = 0
in_ranges: dict[UOp, int] = {}
range_ptrs: dict[UOp, int] = {}
while sched_ptr < len(schedule):
si = schedule[sched_ptr]
if isinstance(si, UOp):
if si.op is Ops.RANGE:
in_ranges[si] = 0
range_ptrs[si] = sched_ptr + 1
elif si.op is Ops.END:
if in_ranges[si.src[1]] < si.src[1].vmax:
in_ranges[si.src[1]] += 1
sched_ptr = range_ptrs[si.src[1]]
continue
else:
ast, buf_uops, metadata, fixedvars, bound_ranges = si
fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
buf_uops_list.append(UOp.sink(*buf_uops))
sched_ptr += 1
with cpu_profile(TracingKey("unroll outer ranges")):
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule)
return pre_schedule, UOp.sink(*buf_uops_list)
def unroll_outer_ranges(schedule:list[ScheduleItem|UOp]) -> tuple[list[ExecItem], list[UOp]]:
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
while sched_ptr < len(schedule):
if isinstance(si := schedule[sched_ptr], UOp):
if si.op is Ops.RANGE:
in_ranges[si] = 0
range_ptrs[si] = sched_ptr + 1
elif si.op is Ops.END:
if in_ranges[si.src[1]] < si.src[1].vmax:
in_ranges[si.src[1]] += 1
sched_ptr = range_ptrs[si.src[1]]
continue
else:
ast, buf_uops, metadata, _, bound_ranges = si
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
buf_uops_list.append(UOp.sink(*buf_uops))
sched_ptr += 1
return pre_schedule, buf_uops_list
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
if b.op is Ops.BUFFER:
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1]))
else:
# TODO: flip args in CONST
assert b.op is Ops.CONST
ctx[0][b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx[0]))))
# both BUFFER and CONST have src=(UNIQUE, DEVICE), replace UNIQUE with LUNIQUE
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1]))
return ret
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
@ -110,10 +110,8 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
return ctx[0].setdefault(b, b.replace(src=(b.src[0],)))
pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
# replace UNIQUE with LUNIQUE for cache key normalization
(UPat((Ops.BUFFER, Ops.CONST), src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
])
@ -126,8 +124,8 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
# restore LUNIQUE back to UNIQUE
(UPat((Ops.BUFFER, Ops.CONST), src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
# restore BIND value stripped in pm_pre_sched_cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
])
@ -175,8 +173,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
pre_schedule, combined_sink = sc_ret
# replace all the LUNIQUEs with UNIQUEs (single graph_rewrite for everything)
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite combined")
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
tensor_map_sink, buf_uops_sink = combined.src
tm_src = tensor_map_sink.src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
@ -205,4 +203,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals if schedule else {}
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}

View file

@ -44,6 +44,8 @@ pm_gradient = PatternMatcher([
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# gradient on CALL is a custom function
(UPat(Ops.CALL, name="k"), lambda ctx, k: (None,)+k.arg(ctx, k)),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)),
])

View file

@ -204,6 +204,10 @@ CCACHE = ContextVar("CCACHE", 1)
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
# set to 0 to disable the scheduler cache
SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# allow use of assembly for gemm
ASM_GEMM = ContextVar("ASM_GEMM", 0)
@dataclass(frozen=True)
class Metadata:

View file

@ -3,7 +3,7 @@ import math
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import prod, make_tuple, flatten
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
from tinygrad.nn import optim, state, datasets # noqa: F401
class BatchNorm:
@ -304,6 +304,46 @@ class RMSNorm:
x = self._norm(x.float()).cast(x.dtype)
return x if self.weight is None else x * self.weight
from tinygrad.uop.ops import UOp, KernelInfo, Ops
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
weight, idx = call.src[1:]
# for multi-device: unshard inputs to one device
if isinstance(weight.device, tuple):
assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS"
grad_emb = grad_emb.copy_to_device(weight.device)
idx = idx.copy_to_device(weight.device)
# weight is replicated, grad_weight should match
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
# TODO: how do we remove this dumb kernel and use Tensor.zeros?
def _zero_kernel(out:UOp) -> UOp:
i = UOp.range(out.size, 0)
return out.flatten()[i].store(0).end(i).sink(arg=KernelInfo(name="zero"))
grad_weight_uop = grad_weight_uop.custom_kernel(fxn=_zero_kernel)[0]
# TODO: do we have a universal helper for this?
device = call.device.split(":")[0] if not isinstance(call.device, tuple) else call.device[0].split(":")[0]
# this is the real atomic kernel
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
else: raise NotImplementedError(f"no atomics for device {device}")
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
return (grad_weight_uop.cast(weight.dtype), None)
def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor:
arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device)
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype)
class Embedding:
"""
A simple lookup table that stores embeddings of a fixed dictionary and size.
@ -316,12 +356,12 @@ class Embedding:
```
"""
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device)
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype)
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
return _embedding_fwd(self.weight, idx)
class LSTMCell:
"""

View file

@ -46,10 +46,10 @@ base_rewrite = PatternMatcher([
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](

View file

@ -528,7 +528,7 @@ class AMDCopyQueue(HWQueue):
# USB devices run in single-step mode, so they can't overrun the queue.
total_bytes = (tail_blit_dword * 4 if rem_packet_cnt == 0 else -sdma_queue.put_value % sdma_queue.ring.nbytes) + rem_packet_cnt * 4
assert total_bytes < sdma_queue.ring.nbytes, "SDMA queue overrun"
while not dev.is_usb() and sdma_queue.put_value + total_bytes - sdma_queue.read_ptr > sdma_queue.ring.nbytes: pass
while not dev.is_usb() and sdma_queue.put_value + total_bytes - sdma_queue.read_ptr[0] > sdma_queue.ring.nbytes: pass
start_idx = (sdma_queue.put_value % sdma_queue.ring.nbytes) // 4
sdma_queue.ring[start_idx : start_idx + tail_blit_dword] = array.array('I', cmds[:tail_blit_dword])
@ -640,24 +640,21 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
@dataclass
class AMDQueueDesc:
ring: MMIOInterface
read_ptrs: list[MMIOInterface]
write_ptrs: list[MMIOInterface]
doorbells: list[MMIOInterface]
read_ptr: MMIOInterface
write_ptr: MMIOInterface
doorbell: MMIOInterface
put_value: int = 0
@property
def read_ptr(self): return min(p[0] for p in self.read_ptrs)
def signal_doorbell(self, dev, doorbell_value:int|None=None):
try:
for write_ptr in self.write_ptrs: write_ptr[0] = self.put_value
self.write_ptr[0] = self.put_value
# Ensure all prior writes are visible to the GPU.
System.memory_barrier()
# Flush hdp if queue is in dev mem.
if dev.is_am() and not dev.is_usb(): dev.iface.dev_impl.gmc.flush_hdp()
for doorbell in self.doorbells: doorbell[0] = self.put_value if doorbell_value is None else doorbell_value
self.doorbell[0] = self.put_value if doorbell_value is None else doorbell_value
except Exception as e:
dev.error_state = e
raise
@ -776,9 +773,9 @@ class KFDIface:
self.doorbells_base = queue.doorbell_offset & (~0x1fff) # doorbell is two pages
self.doorbells = cast(FileIOInterface, KFDIface.kfd).mmap(0, 0x2000, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, self.doorbells_base)
return AMDQueueDesc(ring=MMIOInterface(ring.va_addr, ring.size, fmt='I'), read_ptrs=[MMIOInterface(queue.read_pointer_address, 8, fmt='Q')],
write_ptrs=[MMIOInterface(queue.write_pointer_address, 8, fmt='Q')],
doorbells=[MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q')])
return AMDQueueDesc(ring=MMIOInterface(ring.va_addr, ring.size, fmt='I'), read_ptr=MMIOInterface(queue.read_pointer_address, 8, fmt='Q'),
write_ptr=MMIOInterface(queue.write_pointer_address, 8, fmt='Q'),
doorbell=MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q'))
def sleep(self, tm:int) -> bool:
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
@ -857,8 +854,8 @@ class PCIIface(PCIIfaceBase):
wptr_addr=gart.va_addr+wptr, eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size,
idx=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbell=self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q'),
read_ptr=gart.cpu_view().view(offset=rptr, size=8, fmt='Q'), write_ptr=gart.cpu_view().view(offset=wptr, size=8, fmt='Q'), put_value=pv)
def sleep(self, timeout) -> bool:
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):

View file

@ -69,11 +69,11 @@ class CUDAAllocator(LRUAllocator['CUDADevice']):
if options.external_ptr: return cuda.CUdeviceptr_v2(options.external_ptr)
if options.host: return init_c_var(ctypes.c_void_p, lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
return init_c_var(cuda.CUdeviceptr, lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
@suppress_finalizing
def _free(self, opaque, options:BufferSpec):
try:
if options.host: check(cuda.cuMemFreeHost(opaque))
else: check(cuda.cuMemFree_v2(opaque))
except (TypeError, AttributeError): pass
if options.external_ptr: return
if options.host: check(cuda.cuMemFreeHost(opaque))
else: check(cuda.cuMemFree_v2(opaque))
def _copyin(self, dest, src:memoryview):
check(cuda.cuCtxSetCurrent(self.dev.context))
host_mem = self.alloc(len(src), BufferSpec(host=True))

View file

@ -1,5 +1,5 @@
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform, sys
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, PROFILE, ProfileRangeEvent, cpu_profile, unwrap, suppress_finalizing
import tinygrad.runtime.support.objc as objc
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent, CompilerSet, CompilerPair
from tinygrad.renderer.cstyle import MetalRenderer
@ -167,8 +167,9 @@ class MetalAllocator(LRUAllocator[MetalDevice]):
ret.retain = False
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return MetalBuffer(ret, size)
@suppress_finalizing
def _free(self, opaque:MetalBuffer, options):
if not sys.is_finalizing(): opaque.buf.release
if not options.external_ptr: opaque.buf.release
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
dest_dev.synchronize()
src_command_buffer = src_dev.mtl_queue.commandBuffer().retained()

View file

@ -828,3 +828,5 @@ class NVDevice(HCQCompiled[NVSignal]):
self.iface.rm_control(self.profiler, nv_gpu.NVB0CC_CTRL_CMD_PMA_STREAM_UPDATE_GET_PUT,
nv_gpu.struct_NVB0CC_CTRL_PMA_STREAM_UPDATE_GET_PUT_PARAMS(bytesConsumed=params.bytesAvailable))
return pma_data
def device_props(self): return {'arch': self.arch, 'sm_version': self.sm_version}

View file

@ -399,9 +399,11 @@ class QCOMDevice(HCQCompiled):
raise RuntimeError("Failed to map external pointer to GPU memory") from e
def _gpu_free(self, mem:HCQBuffer):
if mem.meta[0] is None: return
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta[0].id)
if mem.meta[1]: FileIOInterface.munmap(mem.va_addr, mem.meta[0].mmapsize)
if mem.meta[0] is None: return # external (gpu) ptr
if not mem.meta[1]: kgsl.IOCTL_KGSL_SHAREDMEM_FREE(self.fd, gpuaddr=mem.meta[0].gpuaddr) # external (cpu) ptr
else:
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta[0].id)
FileIOInterface.munmap(mem.va_addr, mem.meta[0].mmapsize)
def _ensure_stack_size(self, sz):
if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)

View file

@ -88,8 +88,8 @@ def import_asic_regs(prefix:str, version:tuple[int, ...], cls=AMDReg) -> dict[st
return x
def _download_file(ver, suff) -> str:
dir_prefix = {"osssys": "oss"}.get(prefix, prefix)
fetch_name, file_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h", f"{prefix}_{'_'.join(map(str, version))}_{suff}.h"
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=file_name, subdir="asic_regs")
fetch_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h"
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=fetch_name, subdir="asic_regs")
for ver in fixup_ip_version(prefix, version):
try: offs, sh_masks = _extract_regs(_download_file(ver, "offset")), _extract_regs(_download_file(ver, "sh_mask"))

View file

@ -202,7 +202,7 @@ def assign_multi(dest:UOp, src:UOp):
return dest.src[0].assign(src.src[0]).multi(src.axis)
def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis)
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
@ -218,6 +218,7 @@ multi_pm = PatternMatcher([
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# multi supports custom kernels with CUSTOM_KERNEL + AFTER

View file

@ -60,7 +60,7 @@ def split_reduceop(reduce:UOp, x:UOp):
mop_cleanup = PatternMatcher([
# merge adjacent RESHAPES, safe because they are not tagged
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"),
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
@ -68,15 +68,29 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
def resolve_call(c:UOp) -> UOp:
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
for i, (p, a) in enumerate(zip(params, args)):
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict(zip(params, args))).rtag(c.tag)
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@ -233,7 +247,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
# dont bufferize an arange
(UPat.any((r:=UPat(dtype=dtypes.index).cast()).named("src"), r.eq(UPat()).named("src")).f(Ops.BUFFERIZE,
allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
@ -252,24 +266,24 @@ pm_remove_bufferize = PatternMatcher([
])
def late_buffer_view(t:UOp, b:UOp):
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
shape = b.shape
size = prod(shape)
if not (isinstance(b.device, str) and b.device.startswith(("DISK", "TINYFS"))): return b
shape = b.shape
size = prod(shape)
# walk up for the INDEX
x = t
while not any(u.op is Ops.INDEX for u in x.src):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX)
# walk up for the INDEX
x = t
while not any(u.op is Ops.INDEX for u in x.src):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX)
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b
to_bufferview = PatternMatcher([
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
])
@ -502,12 +516,11 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
# gather the metadata
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
# NOTE: the hack for COPY is here
for u in ret.toposort():
# TODO: this can be wrong if there's multiple of these
if u.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}:
ret = u
break
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
if ret.op is Ops.STORE: stored = ret.src[1]
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
else: raise RuntimeError(f"unknown kernel type {ret.op}")
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else:
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
@ -521,11 +534,14 @@ split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
def tag_uop(ctx:list[UOp], x:UOp):
if x.tag is not None: return None
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
if x.tag is not None or x in ctx[1]: return None
if x.tag is None and x.op is Ops.CALL:
# don't tag anything in a CALL
for u in x.src[0].toposort(): ctx[1].add(u)
if x.dtype.scalar() == dtypes.index: return None
ctx.append(x)
return x.replace(tag=(len(ctx)-1,))
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
@ -551,7 +567,7 @@ replace_contiguous = PatternMatcher([
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")

View file

@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@ -232,6 +232,16 @@ class Tensor(OpMixin):
# ***** data handlers ****
def as_param(self, slot:int):
if self.uop.axis is not None:
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
@ -275,13 +285,13 @@ class Tensor(OpMixin):
self.uop = x.uop
return self
def assign(self, x) -> Tensor:
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
self._buffer().copyin(x._data())
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
if self.uop is x.uop: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
# broadcast x
@ -1268,13 +1278,12 @@ class Tensor(OpMixin):
"""
return self._getitem(indices)
def __setitem__(self, indices, v:Tensor|PyConst) -> None:
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self.realize()._getitem(indices).assign(v)
return
# NOTE: check that setitem target is valid first
if isinstance(v, get_args(PyConst)): v = Tensor(v, device=self.device, dtype=self.dtype)
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
self.realize()
if not self.uop.is_contiguous(): raise RuntimeError("setitem target needs to be contiguous")
@ -2422,6 +2431,9 @@ class Tensor(OpMixin):
```
"""
if IMAGE: return self.image_dot(w, dtype)
if ASM_GEMM:
from extra.gemm.asm.cdna.gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(self, w): return asm_gemm(self, w)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")

View file

@ -26,6 +26,7 @@ class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto()
PARAM = auto(); CALL = auto()
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled

View file

@ -1,6 +1,6 @@
from typing import Callable
import math, functools
from tinygrad.dtype import dtypes, DType, promo_lattice
from tinygrad.dtype import dtypes, DType, promo_lattice, truncate
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import flatten, polyN
from tinygrad.uop import GroupOp
@ -353,8 +353,9 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
case Ops.IDIV | Ops.MOD:
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
if dt == dtypes.int:
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dt, a0, a1))[0], a0).bitcast(dtypes.uint), a_neg.where(n[1], a1).bitcast(dtypes.uint)
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dt, b0, b1))[0], b0).bitcast(dtypes.uint), b_neg.where(n[1], b1).bitcast(dtypes.uint)
ua0, ua1, ub0, ub1 = a0.bitcast(dtypes.uint), a1.bitcast(dtypes.uint), b0.bitcast(dtypes.uint), b1.bitcast(dtypes.uint)
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dtypes.uint, ua0, ua1))[0], ua0), a_neg.where(n[1], ua1)
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dtypes.uint, ub0, ub1))[0], ub0), b_neg.where(n[1], ub1)
q, r = (z:=UOp.const(dtypes.uint, 0), z), (z, z)
for i in range(63, -1, -1):
r = l2i(Ops.SHL, dtypes.uint, *r, UOp.const(dtypes.uint, 1), z)
@ -364,8 +365,9 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
if dt == dtypes.int:
nq, nr = l2i(Ops.NEG, dt, q0:=q[0].bitcast(dt), q1:=q[1].bitcast(dt)), l2i(Ops.NEG, dt, r0:=r[0].bitcast(dt), r1:=r[1].bitcast(dt))
return (a_neg.where(nr[0], r0), a_neg.where(nr[1], r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq[0], q0), (a_neg^b_neg).where(nq[1], q1))
(nq0, nq1), (nr0, nr1) = l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *q)), l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *r))
(q0, q1), (r0, r1) = l2i(Ops.BITCAST, dt, *q), l2i(Ops.BITCAST, dt, *r)
return (a_neg.where(nr0, r0), a_neg.where(nr1, r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq0, q0), (a_neg^b_neg).where(nq1, q1))
return (r[0].bitcast(dt), r[1].bitcast(dt)) if op == Ops.MOD else (q[0].bitcast(dt), q[1].bitcast(dt))
case Ops.CMPLT: return (a1 < b1) | ((a1.eq(b1)) & (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)))
case Ops.CMPEQ: return a0.eq(b0) & a1.eq(b1)
@ -447,5 +449,5 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcenden
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
None if x.tag is None else x.replace(dtype=l2i_dt[x.dtype], src=(l2i_idx(idx, x.tag),)))]
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
None if x.tag is None else UOp.const(l2i_dt[x.dtype], (x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))]
None if x.tag is None else UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))]
return PatternMatcher(pat)

View file

@ -58,6 +58,11 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
if pad is not None: ret += " " * (pad-ansilen(ret))
return ret
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg))
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
for u in lst:
@ -205,7 +210,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None
@ -224,9 +229,13 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
case Ops.ENCDEC: return self.arg[0]
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
case Ops.PARAM:
# NOTE: copied from marg
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
return None
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
return self.src[0]._shape
# ops with custom handling
@ -436,7 +445,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
# NOTE: b is ConstType, not ConstLike, so UOps and tuples aren't allowed
assert not isinstance(b, (UOp, tuple)), "unique const only works on numbers"
ret = UOp.const(dtype, b, device)
return ret.replace(src=ret.src + (UOp.unique(None if unique is True else unique),))
return ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@ -558,11 +567,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
case Ops.PERMUTE | Ops.FLIP: src_args = []
case _: raise RuntimeError(f"{op} is not a MovementOp")
usrcs = []
for arg in src_args:
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
usrcs = [shape_to_shape_arg(arg) for arg in src_args]
if len(usrcs) == 0: ret = UOp(op, self.dtype, (self,), arg)
else: ret = UOp(op, self.dtype, (self,)+UOp.sink(*usrcs).simplify().src)
# for all movement ops, we check shape property to validity check the movement op
@ -824,6 +829,13 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
return self.src[0].after(self.store(val).end(*argfix(end)))
# TODO: this should replace placeholder
@staticmethod
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None):
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
@ -1372,8 +1384,8 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
lambda x,d,u: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:

View file

@ -88,14 +88,11 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND, Ops.CONTIGUOUS))), lambda: True),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -113,7 +110,7 @@ _tensor_spec = PatternMatcher([
# device or unique
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat((Ops.LUNIQUE, Ops.UNIQUE)))), lambda: True),
(UPat(Ops.CONST, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE))), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
@ -141,7 +138,11 @@ _tensor_spec = PatternMatcher([
# Tensor range bind / store
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
# allow CALL/PARAM
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
(UPat(Ops.PARAM), lambda: True),
])+movement_ops+shared_spec
tensor_spec = PatternMatcher([

View file

@ -48,6 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
@ -347,7 +348,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
base = unwrap(p.base)
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"], )
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
rctx = decode(data, {p.name:disasm})
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
wave_insts:dict[str, dict[str, dict]] = {}
@ -555,11 +556,11 @@ def get_render(query:str) -> dict:
pc_to_inst = data["disasm"]
start_pc = None
rows:dict[int, dict] = {}
for pc, (inst,_) in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if start_pc is None: start_pc = e.pc
if (inst:=rows.get(e.pc)) is None:
rows[e.pc] = inst = {"pc":e.pc-start_pc, "inst":pc_to_inst[e.pc][0], "hit_count":0, "dur":0, "stall":0, "type":str(e.typ).split("_")[-1],
"hits":{"cols":inst_columns, "rows":[]}}
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1
inst["dur"] += e.dur
inst["stall"] += e.stall