mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge remote-tracking branch 'upstream/master' into new_x86_backend
This commit is contained in:
commit
f1327ebff6
57 changed files with 13278 additions and 2216 deletions
43
.github/workflows/test.yml
vendored
43
.github/workflows/test.yml
vendored
|
|
@ -26,7 +26,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: llvm-speed
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Speed Test
|
||||
run: CPU=1 CPU_LLVM=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
|
|
@ -98,7 +98,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: torch-backend-pillow-torchvision-et-pt
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
pydeps: "pillow torchvision expecttest"
|
||||
llvm: 'true'
|
||||
- name: Install ninja
|
||||
|
|
@ -134,7 +134,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: torch-backend-pillow-torchvision-et-pt
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Install ninja
|
||||
run: |
|
||||
|
|
@ -156,7 +156,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: be-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
|
||||
- name: Test ops with Python emulator
|
||||
|
|
@ -348,7 +348,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: gpu-image
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
opencl: 'true'
|
||||
- name: Test CL IMAGE=2 ops
|
||||
run: |
|
||||
|
|
@ -424,7 +424,7 @@ jobs:
|
|||
with:
|
||||
key: onnxoptc
|
||||
deps: testing
|
||||
python-version: '3.11'
|
||||
python-version: '3.12'
|
||||
llvm: 'true'
|
||||
- name: Test ONNX (CPU)
|
||||
run: CPU=1 CPU_LLVM=0 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
|
|
@ -452,7 +452,7 @@ jobs:
|
|||
key: onnxoptl
|
||||
deps: testing
|
||||
pydeps: "tensorflow==2.19"
|
||||
python-version: '3.11'
|
||||
python-version: '3.12'
|
||||
opencl: 'true'
|
||||
- name: Test ONNX (CL)
|
||||
run: CL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
|
|
@ -526,7 +526,7 @@ jobs:
|
|||
with:
|
||||
key: metal
|
||||
deps: testing
|
||||
python-version: '3.11'
|
||||
python-version: '3.12'
|
||||
- name: Test models (Metal)
|
||||
run: METAL=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test LLaMA compile speed
|
||||
|
|
@ -545,7 +545,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: devectorize-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
pydeps: "pillow"
|
||||
llvm: "true"
|
||||
- name: Test LLVM=1 DEVECTORIZE=0
|
||||
|
|
@ -566,7 +566,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: dsp-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
pydeps: "onnx==1.18.0 onnxruntime"
|
||||
llvm: "true"
|
||||
- name: Set up Docker Buildx
|
||||
|
|
@ -600,8 +600,8 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: webgpu-minimal
|
||||
deps: testing_minimal
|
||||
python-version: '3.11'
|
||||
deps: testing_unit
|
||||
python-version: '3.12'
|
||||
webgpu: 'true'
|
||||
- name: Check Device.DEFAULT (WEBGPU) and print some source
|
||||
run: |
|
||||
|
|
@ -634,7 +634,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: ${{ matrix.backend }}-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
amd: 'true'
|
||||
llvm: ${{ matrix.backend == 'amdllvm' && 'true' }}
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
|
|
@ -676,9 +676,9 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: rdna3-emu
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
amd: 'true'
|
||||
python-version: '3.13'
|
||||
python-version: '3.14'
|
||||
- name: Verify AMD autogen is up to date
|
||||
run: |
|
||||
python -m extra.assembly.amd.generate
|
||||
|
|
@ -724,7 +724,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: ${{ matrix.backend }}-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
cuda: 'true'
|
||||
ocelot: 'true'
|
||||
- name: Set env
|
||||
|
|
@ -757,7 +757,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: ${{ matrix.backend }}-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
|
||||
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
|
||||
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
|
||||
|
|
@ -771,9 +771,6 @@ jobs:
|
|||
run: python -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Test dtype with emulated long
|
||||
if: matrix.backend != 'lvp' && matrix.backend != 'llvm'
|
||||
run: EMULATED_DTYPES=long python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
|
@ -791,7 +788,7 @@ jobs:
|
|||
with:
|
||||
key: metal
|
||||
deps: testing
|
||||
python-version: '3.11'
|
||||
python-version: '3.12'
|
||||
amd: 'true'
|
||||
cuda: 'true'
|
||||
ocelot: 'true'
|
||||
|
|
@ -889,7 +886,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-${{ matrix.backend }}-minimal
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }}
|
||||
mesa: ${{ matrix.backend == 'lvp' && 'true' }}
|
||||
- name: Set env
|
||||
|
|
@ -956,7 +953,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: compile-${{ matrix.backend }}
|
||||
deps: testing_minimal
|
||||
deps: testing_unit
|
||||
mesa: ${{ (matrix.backend == 'ir3' || matrix.backend == 'nak') && 'true' }}
|
||||
python-version: '3.14'
|
||||
- name: Set env
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
|
|||
export DEBUG=${DEBUG:-0}
|
||||
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
export BENCHMARK=5
|
||||
export EVAL_BS=0
|
||||
export FAKEDATA=1
|
||||
export HIP_VISIBLE_DEVICES=""
|
||||
export DEV=NULL
|
||||
export JITBEAM=0
|
||||
export LLAMA_LAYERS=${LLAMA_LAYERS:-"2"}
|
||||
time examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh
|
||||
|
|
@ -427,7 +427,8 @@ class _Ctx:
|
|||
pcode = get_pcode(op)
|
||||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg))
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane,
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
raw_stores: list = []
|
||||
|
|
@ -796,10 +797,13 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
|||
return bits
|
||||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||||
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
|
||||
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
|
||||
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
|
||||
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)
|
||||
srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new}
|
||||
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
|
||||
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
|
||||
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
|
||||
srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0),
|
||||
'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1),
|
||||
'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)}
|
||||
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||||
|
||||
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
|
||||
|
|
@ -1004,7 +1008,7 @@ def _get_runner(inst_bytes: bytes):
|
|||
canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}"
|
||||
sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1)
|
||||
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0):
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES=""):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner, True
|
||||
|
|
|
|||
|
|
@ -94,13 +94,19 @@ def _trig_reduce(x, phase=0.0):
|
|||
return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),))
|
||||
|
||||
def _signext(val: UOp) -> UOp:
|
||||
for bits, mask, ext in [(8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
|
||||
for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
|
||||
if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8:
|
||||
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
sb = (v32 >> _u32(bits - 1)) & _u32(1)
|
||||
return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int)
|
||||
return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val
|
||||
|
||||
def _signext_4bit(val: UOp) -> UOp:
|
||||
"""Sign extend a 4-bit value to 32-bit signed integer."""
|
||||
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3
|
||||
return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int)
|
||||
|
||||
def _abs(val: UOp) -> UOp:
|
||||
if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val
|
||||
_, _, _, _, shift = _float_info(val)
|
||||
|
|
@ -227,11 +233,44 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
|
|||
'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant,
|
||||
'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop,
|
||||
's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64),
|
||||
# Normalization conversions: map [-1,1] or [0,1] to integer range
|
||||
# Use floor(x + 0.5) for round-to-nearest
|
||||
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
|
||||
'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
|
||||
# Integer truncation conversions
|
||||
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
|
||||
'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16),
|
||||
'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)),
|
||||
'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)),
|
||||
'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)),
|
||||
# Signed extraction with sign extension for dot products
|
||||
'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)),
|
||||
'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)),
|
||||
'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)),
|
||||
# Float to int16 conversions
|
||||
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
|
||||
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
|
||||
}
|
||||
for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
|
||||
_FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
|
||||
_FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
|
||||
# f16 min/max/min3/max3/med3
|
||||
for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
_FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TOKENIZER/PARSER
|
||||
|
|
@ -239,7 +278,7 @@ for is_max, name in [(False, 'min'), (True, 'max')]:
|
|||
|
||||
DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64,
|
||||
'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16,
|
||||
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u1': dtypes.uint32}
|
||||
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32}
|
||||
_BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64}
|
||||
_NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f')
|
||||
def _strip_suffix(num: str) -> tuple[str, str]:
|
||||
|
|
@ -432,14 +471,6 @@ class Parser:
|
|||
return elem
|
||||
if self.at('LBRACKET') and name not in self.vars:
|
||||
self.eat('LBRACKET')
|
||||
if self.at('NUM'):
|
||||
idx_num = int(self.peek().val)
|
||||
if f'{name}{idx_num}' in self.vars:
|
||||
self.eat('NUM')
|
||||
self.eat('RBRACKET')
|
||||
elem = self.vars[f'{name}{idx_num}']
|
||||
if self.try_eat('DOT'): return _cast_to(elem, DTYPES.get(self.eat('IDENT').val, dtypes.uint32))
|
||||
return elem
|
||||
first = self.parse()
|
||||
return self._handle_bracket_rest(first, _u32(0), name)
|
||||
if name in self.vars:
|
||||
|
|
@ -467,6 +498,7 @@ class Parser:
|
|||
if dt == base.dtype: return base
|
||||
if dt.itemsize == 2 and base.dtype.itemsize == 4:
|
||||
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
|
||||
if field == 'i4': return _signext_4bit(base)
|
||||
return _cast_to(base, dt)
|
||||
|
||||
def _handle_bracket(self, base, var_name: str | None = None) -> UOp:
|
||||
|
|
@ -509,8 +541,8 @@ class Parser:
|
|||
var_name = self._find_var_name(base)
|
||||
if first.op == Ops.CONST:
|
||||
idx = int(first.arg)
|
||||
if var_name and f'{var_name}{idx}' in self.vars:
|
||||
v = self.vars[f'{var_name}{idx}']
|
||||
if var_name and f'{var_name}@{idx}' in self.vars:
|
||||
v = self.vars[f'{var_name}@{idx}']
|
||||
return _cast_to(v, dt_suffix) if dt_suffix else v
|
||||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
base_cast = base.cast(dt) if base.dtype != dt else base
|
||||
|
|
@ -518,7 +550,7 @@ class Parser:
|
|||
return _cast_to(result, dt_suffix) if dt_suffix else result
|
||||
if var_name:
|
||||
idx_u32 = _to_u32(first)
|
||||
elems = [(i, self.vars[f'{var_name}{i}']) for i in range(256) if f'{var_name}{i}' in self.vars]
|
||||
elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars]
|
||||
if elems:
|
||||
result = elems[-1][1]
|
||||
for ei, ev in reversed(elems[:-1]):
|
||||
|
|
@ -537,7 +569,7 @@ class Parser:
|
|||
self.eat('RBRACE')
|
||||
var_name = self._find_var_name(base)
|
||||
if var_name:
|
||||
elem = self.vars.get(f'{var_name}{idx}', _u32(0))
|
||||
elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4
|
||||
if self.try_eat('DOT'):
|
||||
dt_name = self.eat('IDENT').val
|
||||
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
|
||||
|
|
@ -787,7 +819,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
|
||||
for loop_i in range(start_val, end_val + 1):
|
||||
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns)
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
|
||||
if has_break:
|
||||
assert found_var is not None
|
||||
found = block_assigns.get(found_var, vars.get(found_var))
|
||||
|
|
@ -944,7 +976,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
else:
|
||||
block_assigns[f'{var}{idx}'] = vars[f'{var}{idx}'] = val
|
||||
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
|
||||
i += 1; continue
|
||||
|
||||
# Compound assignment: var += or var -=
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def _i32(f: float) -> int: return struct.unpack('<I', struct.pack('<f', f))[0]
|
|||
def _f32(i: int) -> float: return struct.unpack('<f', struct.pack('<I', i & 0xFFFFFFFF))[0]
|
||||
|
||||
# f16 conversion helpers
|
||||
def _f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
|
||||
def f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
|
||||
def f32_to_f16(f: float) -> int:
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00
|
||||
|
|
|
|||
|
|
@ -255,7 +255,6 @@ class TestF16Conversions(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f16_f32_small(self):
|
||||
"""V_CVT_F16_F32 converts small f32 value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0.5),
|
||||
v_cvt_f16_f32_e32(v[1], v[0]),
|
||||
|
|
@ -293,7 +292,6 @@ class TestF16Conversions(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f16_f32_reads_full_32bit_source(self):
|
||||
"""V_CVT_F16_F32 must read full 32-bit f32 source."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3fc00000), # f32 1.5
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -302,7 +300,7 @@ class TestF16Conversions(unittest.TestCase):
|
|||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo_bits = result & 0xffff
|
||||
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({_f16(lo_bits)})")
|
||||
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({f16(lo_bits)})")
|
||||
|
||||
def test_v_cvt_i16_f16_zero(self):
|
||||
"""V_CVT_I16_F16 converts f16 zero to i16 zero."""
|
||||
|
|
@ -696,7 +694,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f32_f16_abs_negative(self):
|
||||
"""V_CVT_F32_F16 with |abs| on negative value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg1),
|
||||
|
|
@ -709,7 +706,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f32_f16_abs_positive(self):
|
||||
"""V_CVT_F32_F16 with |abs| on positive value (should stay positive)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
|
|
@ -722,7 +718,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f32_f16_neg_positive(self):
|
||||
"""V_CVT_F32_F16 with neg on positive value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
|
|
@ -735,7 +730,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f32_f16_neg_negative(self):
|
||||
"""V_CVT_F32_F16 with neg on negative value (double negative)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg2 = f32_to_f16(-2.0) # 0xc000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg2),
|
||||
|
|
@ -748,7 +742,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_cvt_f16_f32_then_pack_for_wmma(self):
|
||||
"""CVT F32->F16 followed by pack (common WMMA pattern)."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
f32_val = 3.5
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(f32_val)),
|
||||
|
|
@ -757,8 +750,8 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
|||
v_pack_b32_f16(v[2], v[1], v[1]), # Pack same value
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][2] & 0xffff)
|
||||
hi = _f16((st.vgpr[0][2] >> 16) & 0xffff)
|
||||
lo = f16(st.vgpr[0][2] & 0xffff)
|
||||
hi = f16((st.vgpr[0][2] >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, f32_val, places=1)
|
||||
self.assertAlmostEqual(hi, f32_val, places=1)
|
||||
|
||||
|
|
@ -804,7 +797,6 @@ class TestConversionRounding(unittest.TestCase):
|
|||
|
||||
def test_f16_to_f32_precision(self):
|
||||
"""F16 to F32 conversion precision."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_val = f32_to_f16(1.5)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_val),
|
||||
|
|
@ -816,7 +808,6 @@ class TestConversionRounding(unittest.TestCase):
|
|||
|
||||
def test_f16_denormal_to_f32(self):
|
||||
"""F16 denormal converts to small positive f32."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
f16_denorm = 0x0001 # Smallest positive f16 denormal
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], f16_denorm),
|
||||
|
|
@ -1512,5 +1503,63 @@ class TestReciprocalF16(unittest.TestCase):
|
|||
self.assertAlmostEqual(result, 0.25, places=2, msg="1/4.0 should be 0.25")
|
||||
|
||||
|
||||
class TestCvtNormF16(unittest.TestCase):
|
||||
"""Tests for V_CVT_NORM_I16_F16 and V_CVT_NORM_U16_F16."""
|
||||
|
||||
def test_cvt_norm_i16_f16_positive(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 1.0 -> i16 max (32767)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 32767)
|
||||
|
||||
def test_cvt_norm_i16_f16_negative(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 -1.0 -> i16 -32767 (0x8001)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(-1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 0x8001) # -32767, hardware uses symmetric range
|
||||
|
||||
def test_cvt_norm_i16_f16_zero(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 0.0 -> i16 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
def test_cvt_norm_u16_f16_one(self):
|
||||
"""V_CVT_NORM_U16_F16: f16 1.0 -> u16 max (65535)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_u16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 65535)
|
||||
|
||||
def test_cvt_norm_u16_f16_half(self):
|
||||
"""V_CVT_NORM_U16_F16: f16 0.5 -> u16 ~32768."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(0.5)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_u16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertAlmostEqual(result, 32768, delta=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -857,7 +857,6 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
|
||||
def test_v_fma_f16_inline_const_1_0(self):
|
||||
"""V_FMA_F16: a*b + 1.0 should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(0.325928) # ~0x3537
|
||||
f16_b = f32_to_f16(-0.486572) # ~0xb7c9
|
||||
instructions = [
|
||||
|
|
@ -868,13 +867,12 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
v_fma_f16(v[4], v[4], v[6], 1.0), # 1.0 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][4] & 0xffff)
|
||||
result = f16(st.vgpr[0][4] & 0xffff)
|
||||
expected = 0.325928 * (-0.486572) + 1.0
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_fma_f16_inline_const_0_5(self):
|
||||
"""V_FMA_F16: a*b + 0.5 should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
|
|
@ -885,13 +883,12 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
v_fma_f16(v[2], v[0], v[1], 0.5), # 0.5 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
expected = 2.0 * 3.0 + 0.5
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_fma_f16_inline_const_neg_1_0(self):
|
||||
"""V_FMA_F16: a*b + (-1.0) should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
|
|
@ -902,13 +899,12 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
v_fma_f16(v[2], v[0], v[1], -1.0), # -1.0 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
expected = 2.0 * 3.0 + (-1.0)
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_add_f16_abs_both(self):
|
||||
"""V_ADD_F16 with abs on both operands."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_neg2 = f32_to_f16(-2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
|
|
@ -919,12 +915,11 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
v_add_f16_e64(v[2], abs(v[0]), abs(v[1])), # |-2| + |-3| = 5
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
self.assertAlmostEqual(result, 5.0, delta=0.01)
|
||||
|
||||
def test_v_mul_f16_neg_abs(self):
|
||||
"""V_MUL_F16 with neg on one operand and abs on another."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
|
|
@ -935,7 +930,7 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
v_mul_f16_e64(v[2], -v[0], abs(v[1])), # -(2) * |-3| = -6
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
self.assertAlmostEqual(result, -6.0, delta=0.01)
|
||||
|
||||
def test_v_fmac_f16_hi_dest(self):
|
||||
|
|
@ -943,7 +938,6 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
|
||||
This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0}
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -954,8 +948,8 @@ class TestF16Modifiers(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
v0 = st.vgpr[0][0]
|
||||
result_hi = _f16((v0 >> 16) & 0xffff)
|
||||
result_lo = _f16(v0 & 0xffff)
|
||||
result_hi = f16((v0 >> 16) & 0xffff)
|
||||
result_lo = f16(v0 & 0xffff)
|
||||
self.assertAlmostEqual(result_hi, 0.5, delta=0.01, msg=f"Expected hi=0.5, got {result_hi}")
|
||||
self.assertAlmostEqual(result_lo, 1.0, delta=0.01, msg=f"Expected lo=1.0, got {result_lo}")
|
||||
|
||||
|
|
@ -2955,5 +2949,318 @@ class TestVOP3Clamp(unittest.TestCase):
|
|||
self.assertAlmostEqual(i2f(st.vgpr[3][1]), 1.0, places=5, msg="lane 3: 2.5 should clamp to 1.0")
|
||||
|
||||
|
||||
class TestCvtPkF16(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_RTZ_F16_F32 - pack two f32 to f16 with round toward zero."""
|
||||
|
||||
def test_cvt_pk_rtz_f16_f32_basic(self):
|
||||
"""V_CVT_PK_RTZ_F16_F32: basic pack of two f32 values."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_cvt_pk_rtz_f16_f32_e64(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo_f16 = f16(result & 0xffff)
|
||||
hi_f16 = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo_f16, 1.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi_f16, 2.0, delta=0.01)
|
||||
|
||||
|
||||
class TestCvtPkNorm(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_NORM_I16_F32 and V_CVT_PK_NORM_U16_F32."""
|
||||
|
||||
def test_cvt_pk_norm_i16_f32_basic(self):
|
||||
"""V_CVT_PK_NORM_I16_F32: pack two f32 to normalized i16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], -1.0),
|
||||
v_cvt_pk_norm_i16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 32767)
|
||||
self.assertEqual(hi, 0x8001) # -32767, hardware uses symmetric range
|
||||
|
||||
def test_cvt_pk_norm_u16_f32_basic(self):
|
||||
"""V_CVT_PK_NORM_U16_F32: pack two f32 to normalized u16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 0.5),
|
||||
v_cvt_pk_norm_u16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 65535)
|
||||
self.assertAlmostEqual(hi, 32768, delta=1)
|
||||
|
||||
|
||||
class TestCvtPkInt(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_I16_I32, V_CVT_PK_U16_U32, V_CVT_PK_I16_F32, V_CVT_PK_U16_F32."""
|
||||
|
||||
def test_cvt_pk_i16_i32_basic(self):
|
||||
"""V_CVT_PK_I16_I32: pack two i32 to i16."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 100),
|
||||
s_mov_b32(s[1], -100 & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_cvt_pk_i16_i32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
lo_signed = lo if lo < 32768 else lo - 65536
|
||||
hi_signed = hi if hi < 32768 else hi - 65536
|
||||
self.assertEqual(lo_signed, 100)
|
||||
self.assertEqual(hi_signed, -100)
|
||||
|
||||
def test_cvt_pk_u16_u32_basic(self):
|
||||
"""V_CVT_PK_U16_U32: pack two u32 to u16."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 1000),
|
||||
s_mov_b32(s[1], 2000),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_cvt_pk_u16_u32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 1000)
|
||||
self.assertEqual(hi, 2000)
|
||||
|
||||
def test_cvt_pk_i16_f32_basic(self):
|
||||
"""V_CVT_PK_I16_F32: convert two f32 to packed i16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100.5),
|
||||
v_mov_b32_e32(v[1], -50.7),
|
||||
v_cvt_pk_i16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
lo_signed = lo if lo < 32768 else lo - 65536
|
||||
hi_signed = hi if hi < 32768 else hi - 65536
|
||||
self.assertEqual(lo_signed, 100)
|
||||
self.assertEqual(hi_signed, -50)
|
||||
|
||||
def test_cvt_pk_u16_f32_basic(self):
|
||||
"""V_CVT_PK_U16_F32: convert two f32 to packed u16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100.9),
|
||||
v_mov_b32_e32(v[1], 200.1),
|
||||
v_cvt_pk_u16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 100)
|
||||
self.assertEqual(hi, 200)
|
||||
|
||||
def test_cvt_pk_u8_f32_basic(self):
|
||||
"""V_CVT_PK_U8_F32: convert f32 to u8 and pack at byte position."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 128.5),
|
||||
v_mov_b32_e32(v[1], 0),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_cvt_pk_u8_f32(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
byte0 = result & 0xff
|
||||
self.assertEqual(byte0, 128)
|
||||
|
||||
|
||||
class TestDotProduct(unittest.TestCase):
|
||||
"""Tests for dot product instructions V_DOT4_U32_U8, V_DOT8_U32_U4."""
|
||||
|
||||
def test_v_dot4_u32_u8_basic(self):
|
||||
"""V_DOT4_U32_U8: 4-element dot product of u8 vectors."""
|
||||
src0 = 0x04030201 # {4, 3, 2, 1}
|
||||
src1 = 0x01010101 # {1, 1, 1, 1}
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 10)
|
||||
|
||||
def test_v_dot4_u32_u8_with_accumulator(self):
|
||||
"""V_DOT4_U32_U8 with non-zero accumulator."""
|
||||
src0 = 0x02020202 # {2, 2, 2, 2}
|
||||
src1 = 0x03030303 # {3, 3, 3, 3}
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 100),
|
||||
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 124)
|
||||
|
||||
def test_v_dot8_u32_u4_basic(self):
|
||||
"""V_DOT8_U32_U4: 8-element dot product of u4 vectors."""
|
||||
# src0 = 8 nibbles: {1,2,3,4,5,6,7,8} packed as 0x87654321
|
||||
# src1 = 8 nibbles: {1,1,1,1,1,1,1,1} packed as 0x11111111
|
||||
# result = 1+2+3+4+5+6+7+8 = 36
|
||||
src0 = 0x87654321
|
||||
src1 = 0x11111111
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_u32_u4(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 36)
|
||||
|
||||
|
||||
class TestMinMaxF16Vop3(unittest.TestCase):
|
||||
"""Tests for V_MIN3_F16, V_MAX3_F16, V_MED3_F16, V_MINMAX_F16, V_MAXMIN_F16."""
|
||||
|
||||
def test_v_min3_f16_basic(self):
|
||||
"""V_MIN3_F16: minimum of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_min3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 1.0, delta=0.01)
|
||||
|
||||
def test_v_max3_f16_basic(self):
|
||||
"""V_MAX3_F16: maximum of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_max3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 3.0, delta=0.01)
|
||||
|
||||
def test_v_med3_f16_basic(self):
|
||||
"""V_MED3_F16: median of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_med3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 2.0, delta=0.01)
|
||||
|
||||
def test_v_minmax_f16_basic(self):
|
||||
"""V_MINMAX_F16: clamp(src0, min=src1, max=src2)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(2.5)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_minmax_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 2.0, delta=0.01)
|
||||
|
||||
def test_v_maxmin_f16_basic(self):
|
||||
"""V_MAXMIN_F16: clamp(src0, min=src2, max=src1)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(0.5)),
|
||||
s_mov_b32(s[1], f32_to_f16(2.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_maxmin_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 1.0, delta=0.01)
|
||||
|
||||
def test_v_min3_f16_with_neg(self):
|
||||
"""V_MIN3_F16 with neg modifier: min(-3, 1, 2) = -3."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_min3_f16(v[3], -v[0], v[1], v[2]), # neg on first operand
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, -3.0, delta=0.01)
|
||||
|
||||
def test_v_max3_f16_with_abs(self):
|
||||
"""V_MAX3_F16 with abs modifier: max(|-3|, 1, 2) = 3."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(-3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_max3_f16(v[3], abs(v[0]), v[1], v[2]), # abs on first operand
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 3.0, delta=0.01)
|
||||
|
||||
def test_v_med3_f16_opsel_hi(self):
|
||||
"""V_MED3_F16 with opsel reading from hi half."""
|
||||
# Pack two f16 values: hi=5.0, lo=1.0
|
||||
packed = (f32_to_f16(5.0) << 16) | f32_to_f16(1.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], packed),
|
||||
s_mov_b32(s[1], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(4.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
# Read hi half of v[0] (5.0), med3(5, 3, 4) = 4
|
||||
v_med3_f16(v[3], v[0].h, v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 4.0, delta=0.01)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -149,7 +149,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mix_f32_src2_f16_lo(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
|
|
@ -166,7 +165,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mix_f32_src2_f16_hi(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
val = (f16_2 << 16) | 0
|
||||
instructions = [
|
||||
|
|
@ -199,7 +197,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mix_f32_with_abs_f16_src2_lo(self):
|
||||
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (lo half). Regression test for sin(1.0) bug."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
|
||||
|
|
@ -217,7 +214,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mix_f32_with_neg_f16_src2_lo(self):
|
||||
"""V_FMA_MIX_F32 with neg modifier on f16 src2 (lo half)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_1 = f32_to_f16(1.0) # 0x3c00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
|
||||
|
|
@ -235,7 +231,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mix_f32_with_abs_f16_src2_hi(self):
|
||||
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (hi half)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
val = (f16_neg1 << 16) | 0 # -1.0 in hi, 0 in lo
|
||||
instructions = [
|
||||
|
|
@ -254,7 +249,6 @@ class TestFmaMix(unittest.TestCase):
|
|||
|
||||
def test_v_fma_mixlo_f16(self):
|
||||
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -267,14 +261,13 @@ class TestFmaMix(unittest.TestCase):
|
|||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
hi = (st.vgpr[0][3] >> 16) & 0xffff
|
||||
self.assertAlmostEqual(lo, 7.0, places=1)
|
||||
self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}")
|
||||
|
||||
def test_v_fma_mixlo_f16_all_f32_sources(self):
|
||||
"""V_FMA_MIXLO_F16 with all f32 sources."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -286,13 +279,12 @@ class TestFmaMix(unittest.TestCase):
|
|||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
# 1*2+3 = 5
|
||||
self.assertAlmostEqual(lo, 5.0, places=1)
|
||||
|
||||
def test_v_fma_mixlo_f16_sin_case(self):
|
||||
"""V_FMA_MIXLO_F16 case from sin kernel."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3f800000), # f32 1.0
|
||||
v_mov_b32_e32(v[3], s[0]),
|
||||
|
|
@ -305,7 +297,7 @@ class TestFmaMix(unittest.TestCase):
|
|||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(lo, -3.14159, delta=0.01)
|
||||
|
||||
|
||||
|
|
@ -314,7 +306,6 @@ class TestVOP3P(unittest.TestCase):
|
|||
|
||||
def test_v_pk_add_f16_basic(self):
|
||||
"""V_PK_ADD_F16 adds two packed f16 values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
|
||||
|
|
@ -324,14 +315,13 @@ class TestVOP3P(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 4.0, places=2)
|
||||
self.assertAlmostEqual(hi, 6.0, places=2)
|
||||
|
||||
def test_v_pk_mul_f16_basic(self):
|
||||
"""V_PK_MUL_F16 multiplies two packed f16 values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
|
||||
|
|
@ -341,14 +331,13 @@ class TestVOP3P(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 8.0, places=1)
|
||||
self.assertAlmostEqual(hi, 15.0, places=1)
|
||||
|
||||
def test_v_pk_fma_f16_basic(self):
|
||||
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
|
||||
|
|
@ -360,8 +349,8 @@ class TestVOP3P(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 9.0, places=1) # 2*4+1
|
||||
self.assertAlmostEqual(hi, 16.0, places=0) # 3*5+1
|
||||
|
||||
|
|
@ -370,7 +359,6 @@ class TestVOP3P(unittest.TestCase):
|
|||
Inline constants for VOP3P are f16 values in the low 16 bits only.
|
||||
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
|
@ -378,8 +366,8 @@ class TestVOP3P(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
# lo = 1.0 + 1.0 = 2.0, hi = 1.0 + 0.0 = 1.0 (inline const hi half is 0)
|
||||
self.assertAlmostEqual(lo, 2.0, places=2)
|
||||
self.assertAlmostEqual(hi, 1.0, places=2)
|
||||
|
|
@ -388,7 +376,6 @@ class TestVOP3P(unittest.TestCase):
|
|||
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
|
||||
Inline constant has value only in low 16 bits, hi is 0.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
# v0 = packed (3.0, 4.0), multiply by POS_TWO
|
||||
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
|
||||
instructions = [
|
||||
|
|
@ -398,8 +385,8 @@ class TestVOP3P(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 6.0, places=1)
|
||||
self.assertAlmostEqual(hi, 0.0, places=1)
|
||||
|
||||
|
|
@ -413,7 +400,6 @@ class TestWMMAF16(unittest.TestCase):
|
|||
|
||||
def test_v_wmma_f16_16x16x16_f16_all_ones(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[16:23] (8 regs)
|
||||
|
|
@ -432,13 +418,12 @@ class TestWMMAF16(unittest.TestCase):
|
|||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
def test_v_wmma_f16_16x16x16_f16_with_accumulator(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with non-zero accumulator."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only
|
||||
|
|
@ -458,7 +443,7 @@ class TestWMMAF16(unittest.TestCase):
|
|||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 21.0, places=0, msg=f"v[{reg}] lane {lane}: expected 21.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
|
|
@ -468,7 +453,6 @@ class TestWMMAF16(unittest.TestCase):
|
|||
Regression test: WMMA was using static register indices instead of dynamic.
|
||||
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[64:71] (8 regs)
|
||||
|
|
@ -490,7 +474,7 @@ class TestWMMAF16(unittest.TestCase):
|
|||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
|
|
@ -713,7 +697,6 @@ class TestPackedMixedSigns(unittest.TestCase):
|
|||
|
||||
def test_pk_add_f16_mixed_signs(self):
|
||||
"""V_PK_ADD_F16 with mixed positive/negative values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
|
||||
|
|
@ -723,14 +706,13 @@ class TestPackedMixedSigns(unittest.TestCase):
|
|||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 2.0, places=2) # 1.0 + 1.0
|
||||
self.assertAlmostEqual(hi, -1.0, places=2) # -2.0 + 1.0
|
||||
|
||||
def test_pk_mul_f16_zero(self):
|
||||
"""V_PK_MUL_F16 with zero."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
|
||||
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0
|
||||
|
|
@ -743,5 +725,276 @@ class TestPackedMixedSigns(unittest.TestCase):
|
|||
self.assertEqual(result, 0x00000000, "2.0 * 0.0 should be 0.0")
|
||||
|
||||
|
||||
class TestDot2F32F16(unittest.TestCase):
|
||||
"""Tests for V_DOT2_F32_F16 - dot product of f16 pairs producing f32."""
|
||||
|
||||
def test_v_dot2_f32_f16_basic(self):
|
||||
"""V_DOT2_F32_F16: dot product of two packed f16 pairs -> f32."""
|
||||
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=4.0, lo=3.0}
|
||||
# result = 1.0*3.0 + 2.0*4.0 + 0 = 3 + 8 = 11.0
|
||||
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(4.0) << 16) | f32_to_f16(3.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 11.0, places=2)
|
||||
|
||||
def test_v_dot2_f32_f16_with_accumulator(self):
|
||||
"""V_DOT2_F32_F16 with non-zero f32 accumulator."""
|
||||
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 5.0
|
||||
# result = 1.0*1.0 + 1.0*1.0 + 5.0 = 7.0
|
||||
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], f2i(5.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[0]), # same as src0
|
||||
v_mov_b32_e32(v[2], s[1]),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 7.0, places=2)
|
||||
|
||||
def test_v_dot2_f32_f16_negative_values(self):
|
||||
"""V_DOT2_F32_F16 with negative f16 values."""
|
||||
# src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0}
|
||||
# result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0
|
||||
src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0)
|
||||
src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 4.0, places=2)
|
||||
|
||||
|
||||
class TestDot2F16F16(unittest.TestCase):
|
||||
"""Tests for V_DOT2_F16_F16 - dot product of f16 pairs producing f16."""
|
||||
|
||||
def test_v_dot2_f16_f16_basic(self):
|
||||
"""V_DOT2_F16_F16: dot product of two packed f16 pairs -> f16."""
|
||||
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=3.0, lo=2.0}
|
||||
# result = 1.0*2.0 + 2.0*3.0 + 0 = 2 + 6 = 8.0 (f16)
|
||||
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(3.0) << 16) | f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 8.0, places=1)
|
||||
|
||||
def test_v_dot2_f16_f16_with_accumulator(self):
|
||||
"""V_DOT2_F16_F16 with non-zero f16 accumulator."""
|
||||
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 3.0 (f16)
|
||||
# result = 1.0*1.0 + 1.0*1.0 + 3.0 = 5.0 (f16)
|
||||
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
|
||||
acc = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[2], acc),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[0]), # same as src0
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 5.0, places=1)
|
||||
|
||||
|
||||
class TestSignedDotProducts(unittest.TestCase):
|
||||
"""Tests for V_DOT4_I32_IU8 and V_DOT8_I32_IU4 with signed inputs."""
|
||||
|
||||
def test_v_dot4_i32_iu8_signed_both(self):
|
||||
"""V_DOT4_I32_IU8 with both inputs signed (neg=0b011)."""
|
||||
# src0 = {-1, -2, 3, 4} as i8 = {0xff, 0xfe, 0x03, 0x04}
|
||||
# src1 = {1, 1, 1, 1} as i8
|
||||
# result = (-1)*1 + (-2)*1 + 3*1 + 4*1 = -1 - 2 + 3 + 4 = 4
|
||||
src0 = (0xff << 24) | (0xfe << 16) | (0x03 << 8) | 0x04 # -1, -2, 3, 4
|
||||
src1 = 0x01010101 # 1, 1, 1, 1
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
# Result is i32, interpret as signed
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, 4)
|
||||
|
||||
def test_v_dot4_i32_iu8_src0_signed(self):
|
||||
"""V_DOT4_I32_IU8 with only src0 signed (neg=0b001)."""
|
||||
# src0 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
|
||||
# src1 = {2, 2, 2, 2} as u8
|
||||
# result = (-1)*2 + (-1)*2 + (-1)*2 + (-1)*2 = -8
|
||||
src0 = 0xffffffff # -1, -1, -1, -1 (as i8)
|
||||
src1 = 0x02020202 # 2, 2, 2, 2 (as u8)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b001), # src0 signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
def test_v_dot4_i32_iu8_src1_signed(self):
|
||||
"""V_DOT4_I32_IU8 with only src1 signed (neg=0b010)."""
|
||||
# src0 = {2, 2, 2, 2} as u8
|
||||
# src1 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
|
||||
# result = 2*(-1) + 2*(-1) + 2*(-1) + 2*(-1) = -8
|
||||
src0 = 0x02020202 # 2, 2, 2, 2 (as u8)
|
||||
src1 = 0xffffffff # -1, -1, -1, -1 (as i8)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b010), # src1 signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
def test_v_dot4_i32_iu8_unsigned_as_reference(self):
|
||||
"""V_DOT4_I32_IU8 with both unsigned (neg=0) - same as V_DOT4_U32_U8."""
|
||||
# src0 = {0xff, 0xff, 0xff, 0xff} = 255 each as u8
|
||||
# src1 = {1, 1, 1, 1}
|
||||
# result = 255*1 + 255*1 + 255*1 + 255*1 = 1020
|
||||
src0 = 0xffffffff
|
||||
src1 = 0x01010101
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0), # both unsigned
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], 1020)
|
||||
|
||||
def test_v_dot8_i32_iu4_signed_both(self):
|
||||
"""V_DOT8_I32_IU4 with both inputs signed (neg=0b011)."""
|
||||
# src0 = 8 nibbles: {-1, -2, 3, 4, -1, -2, 3, 4} as i4
|
||||
# i4 -1 = 0xf, -2 = 0xe, 3 = 0x3, 4 = 0x4
|
||||
# src0 = 0xfe34fe34
|
||||
# src1 = {1, 1, 1, 1, 1, 1, 1, 1} as i4 = 0x11111111
|
||||
# result = 2 * ((-1)*1 + (-2)*1 + 3*1 + 4*1) = 2 * 4 = 8
|
||||
src0 = 0xfe34fe34
|
||||
src1 = 0x11111111
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, 8)
|
||||
|
||||
def test_v_dot8_i32_iu4_all_negative(self):
|
||||
"""V_DOT8_I32_IU4 with all negative signed values."""
|
||||
# src0 = 8 nibbles all -1 (0xf) = 0xffffffff
|
||||
# src1 = 8 nibbles all 1 = 0x11111111
|
||||
# result = 8 * ((-1)*1) = -8
|
||||
src0 = 0xffffffff # all -1 as i4
|
||||
src1 = 0x11111111 # all 1
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
|
||||
class TestPkMinMaxF16(unittest.TestCase):
|
||||
"""Tests for V_PK_MIN_F16 and V_PK_MAX_F16."""
|
||||
|
||||
def test_v_pk_min_f16_basic(self):
|
||||
"""V_PK_MIN_F16: packed min of two f16 pairs."""
|
||||
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
|
||||
# result = {min(3,2)=2, min(1,4)=1}
|
||||
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_min_f16(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 1.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi, 2.0, delta=0.01)
|
||||
|
||||
def test_v_pk_max_f16_basic(self):
|
||||
"""V_PK_MAX_F16: packed max of two f16 pairs."""
|
||||
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
|
||||
# result = {max(3,2)=3, max(1,4)=4}
|
||||
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_max_f16(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 4.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi, 3.0, delta=0.01)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -294,7 +294,7 @@ class TestAllPcode(unittest.TestCase):
|
|||
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
|
||||
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
|
||||
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'WAVE_STATUS': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
|
||||
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
|
||||
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
|
||||
|
||||
|
|
|
|||
|
|
@ -471,7 +471,7 @@ THREADS = 128
|
|||
|
||||
def test_matmul():
|
||||
dev = Device[Device.DEFAULT]
|
||||
print(f"Device arch: {dev.arch}")
|
||||
print(f"Device arch: {dev.renderer.arch}")
|
||||
|
||||
if getenv("STOCK", 0):
|
||||
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
|
||||
|
|
@ -479,7 +479,7 @@ def test_matmul():
|
|||
asm = stock_path.read_text()
|
||||
print(f"Loaded stock kernel from {stock_path}")
|
||||
else:
|
||||
asm = build_kernel(dev.arch)
|
||||
asm = build_kernel(dev.renderer.arch)
|
||||
|
||||
binary = dev.compiler.compile(asm)
|
||||
print(f"Compiled! Binary size: {len(binary)} bytes")
|
||||
|
|
|
|||
11517
extra/gemm/asm/cdna/asm.py
Normal file
11517
extra/gemm/asm/cdna/asm.py
Normal file
File diff suppressed because it is too large
Load diff
95
extra/gemm/asm/cdna/gemm.py
Normal file
95
extra/gemm/asm/cdna/gemm.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import atexit, functools
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.helpers import getenv, all_same, dedup
|
||||
from extra.gemm.asm.cdna.asm import build_kernel, GEMM_ARGS
|
||||
|
||||
# ** CDNA4 assembly gemm
|
||||
|
||||
WORKGROUP_SIZE = 256
|
||||
|
||||
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
|
||||
batch, M, K = A.shape
|
||||
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
assert K == K2
|
||||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(wg, "gidx0")
|
||||
k = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
sink = UOp.sink(C.base, A.base, B.base, lidx, gidx,
|
||||
arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
|
||||
binary = HIPCompiler(arch).compile(k.to_asm())
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary)))
|
||||
|
||||
counters = {"used":0, "todos":[]}
|
||||
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
|
||||
atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used'))
|
||||
|
||||
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
|
||||
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
|
||||
if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}")
|
||||
# only sharding on the batch is tested, others might work too
|
||||
if isinstance(a.device, tuple) and not (a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None):
|
||||
return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
|
||||
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
|
||||
N = b.shape[1]
|
||||
if isinstance(a.device, tuple): batch //= len(a.device)
|
||||
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
|
||||
if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}")
|
||||
return True
|
||||
|
||||
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4
|
||||
# note: this can be removed after we have GEMM on mixins
|
||||
|
||||
def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
M, K = A.shape[0]*A.shape[1], A.shape[2]
|
||||
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
assert K == K2
|
||||
m = UOp.range(M, 1, AxisType.LOOP)
|
||||
n = UOp.range(N, 2, AxisType.LOOP)
|
||||
k = UOp.range(K, 0, AxisType.REDUCE)
|
||||
mul = (A.index((m*UOp.const(dtypes.index, K)+k))*B.index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32)
|
||||
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
|
||||
store = C.index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n)
|
||||
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
|
||||
|
||||
# ** backward gemm, might use the asm gemm
|
||||
|
||||
def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
out, a, b = kernel.src
|
||||
assert all_same([gradient.device, a.device, b.device, out.device])
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
grad_a = (g_t @ b_t.T).uop
|
||||
a_T = a_t.transpose(-2, -1)
|
||||
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
|
||||
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
|
||||
grad_b = (a_T * g_r).sum((-1, 0)).uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
squeeze = a.ndim == 2
|
||||
if squeeze: a = a.unsqueeze(0)
|
||||
|
||||
batch, M, K = a.shape
|
||||
N = b.shape[1]
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.empty(batch//len(a.device), M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
else:
|
||||
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
|
||||
|
||||
dname = a.device[0] if is_multi else a.device
|
||||
arch = getattr(Device[dname].renderer, "arch", None)
|
||||
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
|
||||
numWG = GEMM_ARGS[(M, N, K)][0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0]
|
||||
else:
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
|
||||
return out.squeeze(0) if squeeze else out
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,78 +0,0 @@
|
|||
.text
|
||||
.section .text.
|
||||
.global gemm
|
||||
.p2align 8
|
||||
.type gemm,@function
|
||||
|
||||
gemm:
|
||||
INSTRUCTIONS
|
||||
|
||||
.section .rodata,"a",@progbits
|
||||
.p2align 6, 0x0
|
||||
.amdhsa_kernel gemm
|
||||
# basic memory requirements
|
||||
.amdhsa_group_segment_fixed_size 133120
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 28
|
||||
# register usage (RSRC1)
|
||||
.amdhsa_next_free_vgpr 504
|
||||
.amdhsa_next_free_sgpr 96
|
||||
# workgroup / workitem IDs (RSRC2)
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
# user SGPRs, we only specify the kernel args ptr in s[0:1]
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_count 2
|
||||
.amdhsa_user_sgpr_kernarg_preload_length 0
|
||||
.amdhsa_user_sgpr_kernarg_preload_offset 0
|
||||
# gfx90a / gfx940 specifics (RSRC3)
|
||||
.amdhsa_accum_offset 248
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.amdhsa_tg_split 0
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.kernels:
|
||||
- .name: gemm
|
||||
.symbol: gemm.kd
|
||||
.args:
|
||||
- .name: C
|
||||
.address_space: global
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: B
|
||||
.address_space: global
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: A
|
||||
.address_space: global
|
||||
.offset: 16
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: sz
|
||||
.offset: 24
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
.group_segment_fixed_size: 133120
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 28
|
||||
.max_flat_workgroup_size: 256
|
||||
.sgpr_count: 88
|
||||
.sgpr_spill_count: 0
|
||||
.vgpr_count: 248
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 64
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
# Run assembly on the AMD runtime and check correctness
|
||||
# VIZ=2 to profile
|
||||
import pathlib
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.engine.realize import Estimates
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||
|
||||
N = getenv("N", 8192)
|
||||
THREADS_PER_WG = 256
|
||||
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
|
||||
|
||||
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
|
||||
|
||||
# ** generate inputs on CPU
|
||||
|
||||
scale = 10.0
|
||||
|
||||
import torch
|
||||
torch.manual_seed(0)
|
||||
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||
Bt = B.t().contiguous() # transpose B for the asm gemm
|
||||
C_torch = A@B
|
||||
|
||||
# ** copy buffers to AMD
|
||||
|
||||
# input creation and validation run on the copy engine for simpler tracing
|
||||
|
||||
def from_torch(t:torch.Tensor) -> Tensor:
|
||||
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
|
||||
|
||||
C_tiny = from_torch(A) @ from_torch(B)
|
||||
C_asm = Tensor.empty_like(C_tiny)
|
||||
|
||||
# ** assembly custom kernel
|
||||
|
||||
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
lidx = UOp.special(THREADS_PER_WG, "lidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
|
||||
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
||||
|
||||
sz = UOp.variable("SZ", 256, 8192)
|
||||
|
||||
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
|
||||
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
||||
|
||||
# ** run gemms
|
||||
|
||||
sched = Tensor.schedule(C_tiny, C_asm)
|
||||
eis = [si.lower() for si in sched]
|
||||
|
||||
with Context(DEBUG=2):
|
||||
for ei in eis:
|
||||
et = ei.run({"SZ":N}, wait=True)
|
||||
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
||||
|
||||
# ** correctness
|
||||
|
||||
import ctypes
|
||||
|
||||
def torch_bf16(t:Tensor) -> torch.tensor:
|
||||
asm_out = t.to("cpu").realize().uop.buffer._buf
|
||||
buf = (ctypes.c_uint16*C_asm.uop.size).from_address(asm_out.va_addr)
|
||||
return torch.frombuffer(buf, dtype=torch.bfloat16, count=C_asm.uop.size).reshape(C_asm.shape)
|
||||
|
||||
assert torch.allclose(torch_bf16(C_asm), C_torch, rtol=1e-2, atol=1e-3)
|
||||
assert torch.allclose(torch_bf16(C_tiny), C_torch, rtol=1e-2, atol=1e-3)
|
||||
46
extra/gemm/asm/cdna/test_asm_gemm.py
Normal file
46
extra/gemm/asm/cdna/test_asm_gemm.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.gemm.asm.cdna.gemm import asm_gemm
|
||||
|
||||
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi=False) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
|
||||
b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype)
|
||||
with Context(DEBUG=0):
|
||||
Tensor.realize(a_rand, b_rand)
|
||||
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) if multi else None
|
||||
|
||||
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
|
||||
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
Tensor.realize(tst, a.grad, b.grad)
|
||||
|
||||
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
|
||||
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
|
||||
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
|
||||
ref.sum().backward()
|
||||
Tensor.realize(ref, a_ref.grad, b_ref.grad)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
|
||||
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
|
||||
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
|
||||
|
||||
class TestGemm(unittest.TestCase):
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
|
||||
|
||||
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, multi=True)
|
||||
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, multi=True)
|
||||
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, multi=True)
|
||||
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, multi=True)
|
||||
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, multi=True)
|
||||
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, multi=True)
|
||||
def test_gemm_unsupported(self):
|
||||
with self.assertRaisesRegex(AssertionError, "shape not supported"):
|
||||
verify_asm_gemm(8, 8192, 1024, 4096, multi=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -25,7 +25,7 @@ class StallReason(enum.IntEnum):
|
|||
OTHER = 11 # misc, dispatch_stall
|
||||
SLEEPING = 12 # sleeping
|
||||
|
||||
STALL_KEY_MAP: dict[int, StallReason] = {
|
||||
STALL_KEY_MAP_AMPERE: dict[int, StallReason] = {
|
||||
1: StallReason.MEMORY_THROTTLE, 15: StallReason.MEMORY_THROTTLE,
|
||||
2: StallReason.CONSTANT_MEMORY,
|
||||
3: StallReason.SYNC,
|
||||
|
|
@ -37,14 +37,25 @@ STALL_KEY_MAP: dict[int, StallReason] = {
|
|||
18: StallReason.NONE,
|
||||
}
|
||||
|
||||
STALL_KEY_MAP_BLACKWELL: dict[int, StallReason] = {
|
||||
0x01: StallReason.MEMORY_THROTTLE, 0x0e: StallReason.MEMORY_THROTTLE,
|
||||
0x02: StallReason.SYNC,
|
||||
0x05: StallReason.INST_FETCH, 0x0a: StallReason.INST_FETCH,
|
||||
0x06: StallReason.EXEC_DEPENDENCY, 0x09: StallReason.EXEC_DEPENDENCY,
|
||||
0x08: StallReason.MEMORY_DEPENDENCY,
|
||||
0x0b: StallReason.PIPE_BUSY, 0x0f: StallReason.PIPE_BUSY,
|
||||
0x10: StallReason.OTHER, 0x13: StallReason.OTHER,
|
||||
0x11: StallReason.NONE,
|
||||
}
|
||||
|
||||
# Lookup table for extracting sample bytes from 32-byte packet (bytes 0-3, 8-31, skipping header at 4-7)
|
||||
LOOKUP_28B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# AMPERE PACKET DEFINITIONS (8-byte aligned)
|
||||
# PACKET HEADER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Lookup table for extracting sample bytes from 32-byte packet
|
||||
LOOKUP_8B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
|
||||
|
||||
class PMAHeaderAmpere8B(PacketType):
|
||||
class PMAHeader(PacketType):
|
||||
num_bytes = bits[4:0] # number of sample bytes in this packet
|
||||
tpc_id_lo = bits[15:8] # TPC identifier low 8 bits
|
||||
tpc_id_hi = bits[27:25] # TPC identifier high 3 bits
|
||||
|
|
@ -52,33 +63,57 @@ class PMAHeaderAmpere8B(PacketType):
|
|||
@property
|
||||
def tpc_id(self) -> int: return self.tpc_id_lo | (self.tpc_id_hi << 8)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 8-BYTE SAMPLE FORMAT (Ampere/Ada/Hopper)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PMASampleAmpere8B(PacketType):
|
||||
pc_raw = bits[44:0] # raw PC value (actual PC = pc_raw << 4)
|
||||
stall_lo = bits[47:45] # stall key low 3 bits
|
||||
stall_hi = bits[49:48] # stall key high 2 bits
|
||||
wave_id = bits[55:50] # warp/wave identifier (0-63)
|
||||
active = bits[62:62] # active flag (warp was executing)
|
||||
pc_raw = bits[44:0] # raw PC value (pc_offset = pc_raw << 4)
|
||||
stall_key = bits[49:45] # stall reason key
|
||||
wave_id = bits[55:50] # warp/wave identifier
|
||||
active = bits[62:62] # 1 if warp was executing, 0 if scheduled but not issued
|
||||
@property
|
||||
def pc_offset(self) -> int: return self.pc_raw << 4
|
||||
@property
|
||||
def stall_key(self) -> int: return self.stall_lo | (self.stall_hi << 3)
|
||||
@property
|
||||
def stall_reason(self) -> StallReason: return STALL_KEY_MAP.get(self.stall_key, StallReason.OTHER)
|
||||
def stall_reason(self) -> StallReason: return STALL_KEY_MAP_AMPERE.get(self.stall_key, StallReason.OTHER)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# 9-BYTE SAMPLE FORMAT (Blackwell+)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PMASampleBlackwell9B(PacketType):
|
||||
stall_key = bits[5:0] # stall reason key
|
||||
pc_raw = bits[60:8] # raw PC value (pc_offset = pc_raw << 4)
|
||||
wave_hi = bits[7:6] # wave_id high 2 bits
|
||||
wave_lo = bits[71:68] # wave_id low 4 bits
|
||||
active = bits[67:67] # 1 if warp was executing, 0 if scheduled but not issued
|
||||
@property
|
||||
def pc_offset(self) -> int: return self.pc_raw << 4
|
||||
@property
|
||||
def stall_reason(self) -> StallReason: return STALL_KEY_MAP_BLACKWELL.get(self.stall_key, StallReason.OTHER)
|
||||
@property
|
||||
def wave_id(self) -> int: return (self.wave_hi << 4) | self.wave_lo
|
||||
|
||||
PMASample = PMASampleAmpere8B|PMASampleBlackwell9B
|
||||
|
||||
def decode(data: bytes, sm_version: int = 0x800) -> Iterator[tuple[PMASample, int]]:
|
||||
use_9byte = sm_version >= 0xa04
|
||||
record_size = 9 if use_9byte else 8
|
||||
sample_cls = PMASampleBlackwell9B if use_9byte else PMASampleAmpere8B
|
||||
|
||||
def decode(data: bytes) -> Iterator[tuple[PMASampleAmpere8B, int]]:
|
||||
tpc_state: dict[int, list[int]] = collections.defaultdict(list)
|
||||
for pkt_idx in range(len(data) // 32):
|
||||
pkt = data[pkt_idx * 32:(pkt_idx + 1) * 32]
|
||||
hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little'))
|
||||
hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little'))
|
||||
|
||||
if hdr.dropped: tpc_state[hdr.tpc_id].clear()
|
||||
|
||||
for i in range(hdr.num_bytes):
|
||||
tpc_state[hdr.tpc_id].append(pkt[LOOKUP_8B[i]])
|
||||
tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[i]])
|
||||
|
||||
while len(tpc_state[hdr.tpc_id]) >= 8:
|
||||
yield PMASampleAmpere8B.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:8]), 'little')), hdr.tpc_id
|
||||
del tpc_state[hdr.tpc_id][:8]
|
||||
while len(tpc_state[hdr.tpc_id]) >= record_size:
|
||||
yield sample_cls.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:record_size]), 'little')), hdr.tpc_id
|
||||
del tpc_state[hdr.tpc_id][:record_size]
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CLI
|
||||
|
|
@ -90,11 +125,11 @@ STALL_COLORS = {
|
|||
StallReason.PIPE_BUSY: "yellow", StallReason.MEMORY_THROTTLE: "RED", StallReason.OTHER: "white",
|
||||
}
|
||||
|
||||
def decode_tpc_id(tpc_id: int) -> tuple[int, int, int]:
|
||||
def decode_tpc_id(tpc_id:int) -> tuple[int, int, int]:
|
||||
# NOTE: valid only for ops_nv, cuda encoding is different
|
||||
return (tpc_id >> 5, (tpc_id >> 1) & 0xf, tpc_id & 1)
|
||||
|
||||
def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
|
||||
def print_samples(samples:list[tuple[PMASample, int]]) -> None:
|
||||
if not samples: return
|
||||
base_pc = min(s.pc_offset for s, _ in samples)
|
||||
for s, tpc_id in samples:
|
||||
|
|
@ -102,40 +137,57 @@ def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
|
|||
stall_str = colored(f"{s.stall_reason.name:17}", STALL_COLORS.get(s.stall_reason, "white"))
|
||||
print(f"pc=0x{s.pc_offset - base_pc:06x} {stall_str} ev={s.stall_key:2d} active={s.active} wave={s.wave_id:2d} gpc={gpc} tpc={tpc} sm={sm}")
|
||||
|
||||
def print_packets(data: bytes) -> None:
|
||||
def print_packets(data:bytes, sm_version:int=0x800) -> None:
|
||||
record_size = 9 if sm_version >= 0x890 else 8
|
||||
tpc_state: dict[int, list[int]] = collections.defaultdict(list)
|
||||
for i in range(len(data) // 32):
|
||||
pkt = data[i * 32:(i + 1) * 32]
|
||||
hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little'))
|
||||
print(f"Pkt {i:3d}: tpc={hdr.tpc_id} bytes={hdr.num_bytes} drop={hdr.dropped} | {pkt.hex()}")
|
||||
hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little'))
|
||||
if hdr.dropped: tpc_state[hdr.tpc_id].clear()
|
||||
for j in range(hdr.num_bytes): tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[j]])
|
||||
# Show complete records extracted from this packet
|
||||
records = []
|
||||
while len(tpc_state[hdr.tpc_id]) >= record_size:
|
||||
records.append(bytes(tpc_state[hdr.tpc_id][:record_size]).hex())
|
||||
del tpc_state[hdr.tpc_id][:record_size]
|
||||
leftover = len(tpc_state[hdr.tpc_id])
|
||||
print(f"Pkt {i:3d}: tpc={hdr.tpc_id:4d} n={hdr.num_bytes:2d} drop={hdr.dropped} left={leftover} | {' '.join(records)}")
|
||||
|
||||
def print_aggregated(samples: list[tuple[PMASampleAmpere8B, int]]) -> None:
|
||||
def print_aggregated(samples:list[tuple[PMASample, int]]) -> None:
|
||||
if not samples: return
|
||||
base_pc = min(s.pc_offset for s, _ in samples)
|
||||
counter: collections.Counter[tuple[int, int]] = collections.Counter((s.pc_offset, s.stall_key) for s, _ in samples)
|
||||
counter: collections.Counter[tuple[int, StallReason]] = collections.Counter((s.pc_offset, s.stall_reason) for s, _ in samples)
|
||||
print(f"\nAggregated samples (base_pc=0x{base_pc:x}):")
|
||||
for (pc, key), cnt in sorted(counter.items()):
|
||||
reason = STALL_KEY_MAP.get(key, StallReason.OTHER)
|
||||
for (pc, reason), cnt in sorted(counter.items()):
|
||||
stall_str = colored(f"{reason.name:17}", STALL_COLORS.get(reason, "white"))
|
||||
print(f" pc=0x{pc - base_pc:06x} {stall_str} ev={key:2d} samples={cnt:4d}")
|
||||
print(f" pc=0x{pc - base_pc:06x} {stall_str} samples={cnt:4d}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys, pickle
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print(__doc__)
|
||||
print("Usage: python decode.py <pkl_file> [--raw] [--sm=0xNNN]")
|
||||
sys.exit(1)
|
||||
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
if isinstance(data, dict): dumps = list(enumerate(data["pma_raw_dumps"]))
|
||||
else: dumps = [(i, e.blob) for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent")]
|
||||
if isinstance(data, dict):
|
||||
sm_version = 0x800 # default to Ampere
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--sm="): sm_version = int(arg[5:], 0)
|
||||
dumps = [(i, x, sm_version) for i, x in enumerate(data["pma_raw_dumps"])]
|
||||
else:
|
||||
devs = {e.device: e for e in data if type(e).__name__ == "ProfileDeviceEvent"}
|
||||
dumps = []
|
||||
for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent"):
|
||||
dumps.append((i, e.blob, devs[e.device].props.get('sm_version', 0x800)))
|
||||
|
||||
for dump_idx, raw in dumps:
|
||||
for dump_idx, raw, sm_ver in dumps:
|
||||
print(f"\n{'='*60}\nDump {dump_idx} ({len(raw)} bytes, {len(raw)//32} packets)\n{'='*60}")
|
||||
if "--raw" in sys.argv: print_packets(raw)
|
||||
if "--raw" in sys.argv: print_packets(raw, sm_ver)
|
||||
else:
|
||||
samples = list(decode(raw))
|
||||
samples = list(decode(raw, sm_ver))
|
||||
print(f"\nDecoded {len(samples)} samples:")
|
||||
print_samples(samples)
|
||||
print_aggregated(samples)
|
||||
|
|
|
|||
|
|
@ -6,13 +6,17 @@ from extra.nv_pma.decode import decode
|
|||
from tinygrad.helpers import DEBUG
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent / "examples"
|
||||
EXAMPLES_5090_DIR = Path(__file__).parent.parent / "examples_5090"
|
||||
|
||||
def decode_and_aggregate(raw_dumps: list[bytes]) -> Counter[tuple[int, int]]:
|
||||
"""Decode all PMA buffers and aggregate by (relative_pc, stall_reason)."""
|
||||
all_samples = [s for raw in raw_dumps for s, _ in decode(raw)]
|
||||
if not all_samples: return Counter()
|
||||
base_pc = min(s.pc_offset for s in all_samples)
|
||||
return Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in all_samples)
|
||||
def decode_and_aggregate(raw_dumps: list[bytes], sm_version: int = 0x800) -> Counter[tuple[int, int]]:
|
||||
"""Decode all PMA buffers and aggregate by (relative_pc, stall_reason). Each dump is normalized separately."""
|
||||
result: Counter[tuple[int, int]] = Counter()
|
||||
for raw in raw_dumps:
|
||||
samples = [s for s, _ in decode(raw, sm_version)]
|
||||
if not samples: continue
|
||||
base_pc = min(s.pc_offset for s in samples)
|
||||
result += Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in samples)
|
||||
return result
|
||||
|
||||
def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]:
|
||||
"""Convert CUPTI records to Counter[(pcOffset, stallReason)]."""
|
||||
|
|
@ -22,8 +26,8 @@ def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]:
|
|||
return counter
|
||||
|
||||
class TestNVProf(unittest.TestCase):
|
||||
def _test_example(self, name: str):
|
||||
pkl_file = EXAMPLES_DIR / f"{name}.pkl"
|
||||
def _test_example(self, name: str, sm_version: int = 0x800, examples_dir: Path = EXAMPLES_DIR):
|
||||
pkl_file = examples_dir / f"{name}.pkl"
|
||||
if not pkl_file.exists():
|
||||
self.skipTest(f"Example data not found: {pkl_file}. Run collect.py first.")
|
||||
|
||||
|
|
@ -31,7 +35,7 @@ class TestNVProf(unittest.TestCase):
|
|||
data = pickle.load(f)
|
||||
|
||||
self.assertEqual(data["test_name"], name)
|
||||
pma_agg = decode_and_aggregate(data["pma_raw_dumps"])
|
||||
pma_agg = decode_and_aggregate(data["pma_raw_dumps"], sm_version)
|
||||
cupti_agg = cupti_to_counter(data["cupti_pc_samples"])
|
||||
|
||||
if DEBUG >= 2:
|
||||
|
|
@ -45,6 +49,7 @@ class TestNVProf(unittest.TestCase):
|
|||
|
||||
self.assertEqual(pma_agg, cupti_agg, f"PMA: {dict(pma_agg)}\nCUPTI: {dict(cupti_agg)}")
|
||||
|
||||
# Ampere tests (8-byte format)
|
||||
def test_decode_test_plus(self): self._test_example("test_plus")
|
||||
def test_decode_test_reduce_sum(self): self._test_example("test_reduce_sum")
|
||||
def test_decode_test_broadcast(self): self._test_example("test_broadcast")
|
||||
|
|
@ -54,5 +59,18 @@ class TestNVProf(unittest.TestCase):
|
|||
def test_decode_test_conv2d(self): self._test_example("test_conv2d")
|
||||
def test_decode_test_large_matmul(self): self._test_example("test_large_matmul")
|
||||
|
||||
# Blackwell/5090 tests (9-byte format)
|
||||
def test_5090_test_plus(self): self._test_example("test_plus", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_plus_big(self): self._test_example("test_plus_big", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_broadcast(self): self._test_example("test_broadcast", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_matmul(self): self._test_example("test_matmul", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_large_matmul(self): self._test_example("test_large_matmul", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_reduce_sum(self): self._test_example("test_reduce_sum", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_reduce_max(self): self._test_example("test_reduce_max", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_elementwise_chain(self): self._test_example("test_elementwise_chain", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_conv2d(self): self._test_example("test_conv2d", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_exp(self): self._test_example("test_exp", 0xa04, EXAMPLES_5090_DIR)
|
||||
def test_5090_test_softmax(self): self._test_example("test_softmax", 0xa04, EXAMPLES_5090_DIR)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from extra.thunder.tiny.tk.kernel import Kernel
|
|||
from extra.thunder.tiny.tk.tiles import GL, TileLayout
|
||||
|
||||
NUM_WORKERS = 1
|
||||
Q_BLOCK_SIZE = 16
|
||||
KV_BLOCK_SIZE = 16
|
||||
Q_BLOCK_SIZE = 32
|
||||
KV_BLOCK_SIZE = 32
|
||||
|
||||
def _sharded_empty(shape:Tensor, ref:Tensor, axis:int|None) -> Tensor:
|
||||
if not isinstance(ref.device, tuple): return Tensor.empty(*shape, dtype=ref.dtype, device=ref.device)
|
||||
|
|
@ -70,10 +70,10 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
|
||||
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
|
||||
max_vec_last = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
max_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
norm_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
scale_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
max_vec_last = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
max_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
norm_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
scale_vec = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
|
||||
max_vec = warp.neg_inf(max_vec)
|
||||
norm_vec = warp.zero(norm_vec)
|
||||
|
|
@ -105,7 +105,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
|
||||
# softmax
|
||||
max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec)
|
||||
max_vec = warp.row_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
max_vec = warp.col_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
|
||||
scale_vec = warp.map(scale_vec.after(max_vec_last, max_vec), lambda _, idx: max_vec_last[*idx] - max_vec[*idx])
|
||||
scale_vec = scale_vec.exp2()
|
||||
|
|
@ -116,7 +116,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
att_block -= max_vec
|
||||
att_block = att_block.exp2()
|
||||
|
||||
norm_vec = warp.row_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
|
||||
norm_vec = warp.col_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
|
||||
|
||||
# mma av
|
||||
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
|
||||
|
|
@ -313,7 +313,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
|
||||
att_smem = warp.store(att_smem, att_block_transposed)
|
||||
att_block_row = warp.load(att_block_row, att_smem)
|
||||
dv_reg_ = warp.mma_AB(dv_reg, att_block_row, do_reg_col)
|
||||
dv_reg_ = warp.mma_AtB(dv_reg, att_block_row, do_reg_col)
|
||||
|
||||
dp_block = warp.zero(dp_block.after(g, q_idx, dv_reg_))
|
||||
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
|
||||
|
|
@ -325,7 +325,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
|||
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
|
||||
att_smem = warp.store(att_smem, att_block_transposed)
|
||||
att_block_row = warp.load(att_block_row, att_smem)
|
||||
dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col)
|
||||
dk_reg = warp.mma_AtB(dk_reg, att_block_row, q_reg_col)
|
||||
dk_reg = ker.endrange(2)
|
||||
dv_reg = dv_reg.after(dk_reg)
|
||||
|
||||
|
|
|
|||
|
|
@ -163,5 +163,30 @@ class TestIndexing(unittest.TestCase):
|
|||
# at least the arange is being fused
|
||||
def test_llama_embedding_opt(self): self.test_llama_embedding(0, 1_736_704_000)
|
||||
|
||||
# NOTE: call doesn't work with SPEC=2
|
||||
@unittest.skipIf(Device.DEFAULT not in ("CPU", "AMD"), "atomics only on AMD/CPU")
|
||||
@Context(USE_ATOMICS=1, SPEC=1)
|
||||
def test_llama_8b_embedding_backward(self):
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
if Device.DEFAULT == "CPU" and not isinstance(Device["CPU"].renderer, CStyleLanguage): self.skipTest("CPU needs Clang renderer")
|
||||
vocab_size, embed_size = 1000, 128
|
||||
bs, seqlen = 4, 256
|
||||
idx = Tensor.randint(bs, seqlen, high=vocab_size)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
emb.weight = Tensor.ones(vocab_size, embed_size, requires_grad=True)
|
||||
gt = Tensor.zeros(bs, seqlen, embed_size)
|
||||
Tensor.realize(idx, emb.weight, gt)
|
||||
GlobalCounters.reset()
|
||||
loss = (emb(idx)-gt).square().sum()
|
||||
loss.backward()
|
||||
emb.weight.grad.realize()
|
||||
bwd_ops = GlobalCounters.global_ops
|
||||
print(f"embedding bwd: {GlobalCounters.kernel_count} kernels, {bwd_ops:,} ops")
|
||||
self.assertLess(bwd_ops, bs*seqlen*embed_size*20, f"backward ops {bwd_ops:,} should be less than 20 per with atomic scatter-add")
|
||||
# correctness check
|
||||
expected_grad = np.zeros((vocab_size, embed_size), dtype=np.float32)
|
||||
for i in idx.flatten().numpy(): expected_grad[i] += 2
|
||||
np.testing.assert_allclose(emb.weight.grad.numpy(), expected_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -247,5 +247,39 @@ class TestCustomKernel(unittest.TestCase):
|
|||
err = (O_custom - O_ref).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_multi_after_schedule_order(self):
|
||||
"""Test correct scheduling order when custom_kernel has multiple outputs.
|
||||
|
||||
custom_kernel with 4 arguments creates 4 AFTERs from the same kernel.
|
||||
The custom_kernel depends on both A2 and B2, so it must be scheduled after both.
|
||||
E only depends on A2, so E can run before custom_kernel finishes waiting for B2.
|
||||
|
||||
Expected schedule order: [A2, B2, E, custom_addmul, final_sum]
|
||||
The custom_addmul kernel should be at index 3.
|
||||
"""
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
|
||||
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
A2 = (A + 1).contiguous() # kernel 0: depends on A
|
||||
B2 = (B * 2).contiguous() # kernel 1: depends on B
|
||||
C, D = Tensor.empty(4, 4), Tensor.empty(4, 4)
|
||||
C, D, _, _ = Tensor.custom_kernel(C, D, A2, B2, fxn=custom_elementwise_addmul_kernel) # depends on A2 AND B2
|
||||
E = (A2 * 3).contiguous() # kernel 2: depends only on A2
|
||||
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
|
||||
|
||||
big_sink = result.uop.sink()
|
||||
tensor_map = get_rangeify_map(big_sink)
|
||||
sched_sink = big_sink.substitute(tensor_map)
|
||||
schedule, _ = create_schedule(sched_sink)
|
||||
|
||||
# Find the custom_addmul kernel position
|
||||
custom_idx = next((i for i, item in enumerate(schedule)
|
||||
if hasattr(item.ast, "arg") and hasattr(item.ast.arg, "name")
|
||||
and "custom_addmul" in item.ast.arg.name), None)
|
||||
|
||||
self.assertIsNotNone(custom_idx, "custom_addmul kernel not found in schedule")
|
||||
self.assertEqual(custom_idx, 3, f"custom_addmul should be at index 3, got {custom_idx}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, math
|
||||
import contextlib, unittest, math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
|
|
@ -7,7 +7,8 @@ from tinygrad.helpers import getenv, DEBUG, CI, EMULATED_DTYPES
|
|||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad import Context, Device, Tensor, dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import rand_for_dtype
|
||||
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
|
||||
|
|
@ -336,18 +337,37 @@ class TestUint16DType(TestDType):
|
|||
class TestInt32DType(TestDType): DTYPE = dtypes.int32
|
||||
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
|
||||
|
||||
class TestInt64DType(TestDType):
|
||||
DTYPE = dtypes.int64
|
||||
class TestInt64DType(TestDType): DTYPE = dtypes.int64
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
class TestEmulatedInt64DType(TestInt64DType):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestUint64DType(TestDType):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
DTYPE = dtypes.uint64
|
||||
def test_uint64_load(self):
|
||||
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
class TestEmulatedUInt64DType(TestUint64DType):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
||||
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import unittest, operator, math
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad import Context, Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, truncate
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
|
|
@ -7,6 +7,7 @@ from tinygrad.device import is_dtype_supported
|
|||
from tinygrad.runtime.ops_python import from_storage_scalar
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from tinygrad.uop import Ops
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import assume, given, strategies as strat, settings
|
||||
|
|
@ -165,9 +166,16 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
def test_emulated_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
|
||||
|
||||
@given(ht.int8, ht.int8, strat.sampled_from(integer_binary_operations))
|
||||
def test_int8(self, a, b, op): universal_test(a, b, dtypes.int8, op)
|
||||
|
||||
|
|
@ -177,9 +185,16 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
|
||||
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
def test_emulated_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
|
||||
|
||||
@given(ht.uint8, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint8_unary(self, a, op): universal_test_unary(a, dtypes.uint8, op)
|
||||
|
||||
|
|
@ -191,9 +206,16 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
def test_emulated_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
|
||||
|
||||
@given(ht.int8, strat.sampled_from(integer_unary_operations))
|
||||
def test_int8_unary(self, a, op): universal_test_unary(a, dtypes.int8, op)
|
||||
|
||||
|
|
@ -203,9 +225,16 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.int32, strat.sampled_from(integer_unary_operations))
|
||||
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
|
||||
|
||||
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
@Context(EMULATED_DTYPES="long")
|
||||
def test_emulated_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
|
||||
|
||||
@given(ht.bool, ht.bool, strat.sampled_from(((operator.add, operator.add), (operator.mul, operator.mul))))
|
||||
def test_bool(self, a, b, op): universal_test(a, b, dtypes.bool, op)
|
||||
|
||||
|
|
|
|||
|
|
@ -409,6 +409,28 @@ class TestMultiTensor(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_embedding_backward(self, shard_weight_axis=None):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
||||
layer = nn.Embedding(vocab_size, embed_size)
|
||||
layer.weight.requires_grad = True
|
||||
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
|
||||
z = layer(x)
|
||||
z.sum().backward()
|
||||
grad = layer.weight.grad.numpy()
|
||||
|
||||
layer_sharded = nn.Embedding(vocab_size, embed_size)
|
||||
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize()
|
||||
layer_sharded.weight.requires_grad = True
|
||||
x_sharded = x.shard(devices_2, axis=None)
|
||||
z_shard = layer_sharded(x_sharded)
|
||||
z_shard.sum().backward()
|
||||
grad_shard = layer_sharded.weight.grad.numpy()
|
||||
|
||||
np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1)
|
||||
|
||||
def test_rmsnorm(self):
|
||||
B, T, embed_size = 4, 10, 20
|
||||
|
||||
|
|
@ -1251,19 +1273,20 @@ class TestMultiRamUsage(unittest.TestCase):
|
|||
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
|
||||
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
|
||||
|
||||
def _test_matmul_half(self, devs):
|
||||
def _test_matmul_half(self, dev_count:int):
|
||||
N = 32
|
||||
total_mem = {}
|
||||
devs = tuple(f"NULL:{i}" for i in range(dev_count))
|
||||
for dtype in {dtypes.float, dtypes.half}:
|
||||
GlobalCounters.reset()
|
||||
a = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=0)
|
||||
b = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=None)
|
||||
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
|
||||
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
|
||||
(a @ b).realize()
|
||||
total_mem[dtype] = GlobalCounters.global_mem
|
||||
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
|
||||
|
||||
def test_matmul_half(self): self._test_matmul_half(devices_2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(devices_4)
|
||||
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
|
||||
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiFromUnrenderable(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -2184,6 +2184,14 @@ class TestBufferUOp(unittest.TestCase):
|
|||
run_schedule(check_schedule(a, 0))
|
||||
self.assertIsNone(a.uop.base.realized)
|
||||
|
||||
def test_unused_var_not_in_var_vals(self):
|
||||
# unused variable should not appear in var_vals even when there's other work
|
||||
a = Tensor(UOp.variable("unused", 0, 10).bind(1))
|
||||
b = Tensor.empty(3) + 1
|
||||
_, var_vals = Tensor.schedule_with_vars(a, b)
|
||||
self.assertEqual(var_vals, {})
|
||||
self.assertIsNone(a.uop.base.realized)
|
||||
|
||||
def test_view_does_not_realize(self):
|
||||
a = Tensor.randn(1, 4).expand(4, 4)
|
||||
a.realize()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.tensor import _METADATA
|
||||
from tinygrad.helpers import Context, getenv, temp, mv_address
|
||||
from tinygrad.helpers import getenv, temp, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
|
@ -796,92 +795,6 @@ class TestInferenceMode(unittest.TestCase):
|
|||
assert W.grad is None
|
||||
f(x, m, W)
|
||||
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None: _METADATA.set(None)
|
||||
|
||||
# NOOPs are not included in kernel metadata
|
||||
@unittest.skip("why would this be true?")
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
# we exclude const from kernel metadata because tensor methods can share the same CONST UOp
|
||||
@unittest.skip("TODO: flaky")
|
||||
def test_exclude_const_metadata(self):
|
||||
a = Tensor.arange(4)
|
||||
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
||||
sched = Tensor.schedule(a, b)
|
||||
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
||||
|
||||
def test_matmul(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.uop.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
@unittest.skip("this no longer works")
|
||||
def test_assign(self):
|
||||
x = Tensor.empty(10, 10).realize()
|
||||
x.assign(Tensor.ones(10, 10).contiguous())
|
||||
si = x.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "assign")
|
||||
|
||||
def test_complex(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.uop.metadata[0].name, "sum")
|
||||
out.backward()
|
||||
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
||||
self.assertTrue(x.grad.uop.metadata[0].backward)
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
# skip numpy, this is schedule cache
|
||||
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
||||
#bw = [m for m in si.metadata if m.backward]
|
||||
#self.assertEqual(len(bw), 1)
|
||||
#self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
|
||||
def test_tracemeta_0(self):
|
||||
with Context(TRACEMETA=0):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertIsNone(out.uop.metadata)
|
||||
self.assertIsNone(out.uop.src[0].metadata)
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(si.metadata, ())
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
def _find_op(self, ast: UOp, op: Ops):
|
||||
if ast.op is op: return ast
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import textwrap, functools
|
|||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.device import Compiler
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.viz.serve import amdgpu_cfg
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
|
|
@ -73,11 +73,11 @@ def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, compiler:Com
|
|||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def run_asm(name:str, insts:list) -> None:
|
||||
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=Device[Device.DEFAULT].compiler)
|
||||
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
out.realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD" and not getenv("AMD_LLVM"), "only on AMD with comgr")
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
|
||||
class TestCfg(unittest.TestCase):
|
||||
def setUp(self):
|
||||
arch = Device["AMD"].arch
|
||||
|
|
@ -110,7 +110,7 @@ class TestCfg(unittest.TestCase):
|
|||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
_, lib = assemble("diamond", insts, Device[Device.DEFAULT].compiler)
|
||||
_, lib = assemble("diamond", insts, HIPCompiler(Device[Device.DEFAULT].arch))
|
||||
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
|
||||
self.assertEqual(len(cfg["blocks"]), 5)
|
||||
edge_count = sum(len(v) for v in cfg["paths"].values())
|
||||
|
|
@ -238,5 +238,19 @@ class TestCfg(unittest.TestCase):
|
|||
s_code_end(),
|
||||
])
|
||||
|
||||
def test_hit_count(self):
|
||||
run_asm("test_hit_count", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 1),
|
||||
"s_branch alt",
|
||||
"continue:",
|
||||
s_mov_b32(s[2], 2),
|
||||
s_add_u32(s[1], s[1], s[2]),
|
||||
"alt:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -456,6 +456,24 @@ class TestAssign(unittest.TestCase):
|
|||
assign.realize()
|
||||
np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
|
||||
|
||||
def test_setitem_list(self):
|
||||
a = Tensor.zeros(8).contiguous().realize()
|
||||
a[2:5] = [1, 2, 3]
|
||||
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
|
||||
|
||||
def test_assign_bitcast(self):
|
||||
# assign to a bitcast view should modify the underlying buffer (only works on DISK currently)
|
||||
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
||||
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
|
||||
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
|
||||
np.testing.assert_allclose(a.numpy(), [1.0, 2.0, 3.0, 4.0]) # TODO: should be [4.0, 3.0, 2.0, 1.0]
|
||||
|
||||
def test_assign_bitcast_different_size(self):
|
||||
# assign to a shape-changing bitcast view (only works on DISK currently)
|
||||
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
|
||||
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
|
||||
np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345)
|
||||
|
||||
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||
def test_cast_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
|
|
|
|||
62
test/unit/test_call.py
Normal file
62
test/unit/test_call.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
class TestCall(unittest.TestCase):
|
||||
def test_call_plus(self):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
Tensor.realize(a,b)
|
||||
|
||||
# we define a plus function
|
||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||
|
||||
c = Tensor.call(a, b, fxn=plus_fxn)
|
||||
np.testing.assert_equal(c.numpy(), (a+b).numpy())
|
||||
|
||||
def test_call_plus_backward(self):
|
||||
a = Tensor.ones(10, 10, requires_grad=True)
|
||||
b = Tensor.ones(10, 10, requires_grad=True)
|
||||
|
||||
(a+b).mean().backward()
|
||||
gt_a_grad = a.grad.numpy()
|
||||
gt_b_grad = b.grad.numpy()
|
||||
a.grad, b.grad = None, None
|
||||
|
||||
# this is the gradient for +
|
||||
def grad_fxn(grad:UOp, call:UOp): return (grad, grad)
|
||||
|
||||
# we define a plus function
|
||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
|
||||
c.mean().backward()
|
||||
|
||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
||||
|
||||
def test_call_gemm(self):
|
||||
M, K, N = 4, 8, 4
|
||||
a = Tensor.randn(M, K)
|
||||
b = Tensor.randn(K, N)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
|
||||
|
||||
@unittest.skip("needs GEMM on mixins")
|
||||
def test_call_gemm_uop(self):
|
||||
M, K, N = 4, 8, 4
|
||||
a = Tensor.randn(M, K)
|
||||
b = Tensor.randn(K, N)
|
||||
Tensor.realize(a, b)
|
||||
|
||||
# we define a gemm function
|
||||
x = UOp.param(0, dtypes.float, shape=(M, K))
|
||||
y = UOp.param(1, dtypes.float, shape=(K, N))
|
||||
c = Tensor.call(a, b, fxn=x@y)
|
||||
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -250,6 +250,24 @@ class TestDiskTensor(unittest.TestCase):
|
|||
tout = [(x//256, x%256) for x in out]
|
||||
assert tout == list([(x+1,x) for x in range(32,64,2)])
|
||||
|
||||
def test_strided_read(self):
|
||||
# test non-contiguous (strided) read - should read elements at indices 0, 2, 4
|
||||
pathlib.Path(temp(fn:="dt_strided_read")).unlink(missing_ok=True)
|
||||
dt = Tensor([0, 1, 2, 3, 4, 5]).to(f"disk:{temp(fn)}")
|
||||
result = dt[::2].tolist()
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [0, 2, 4]
|
||||
# self.assertEqual(result, [0, 2, 4])
|
||||
self.assertEqual(result, [0, 1, 2]) # wrong!
|
||||
|
||||
def test_permuted_read(self):
|
||||
# test non-contiguous (permuted) read - should read transposed
|
||||
pathlib.Path(temp(fn:="dt_permuted_read")).unlink(missing_ok=True)
|
||||
dt = Tensor([[0, 1, 2], [3, 4, 5]]).to(f"disk:{temp(fn)}")
|
||||
result = dt.T.tolist()
|
||||
# TODO: transpose should give [[0, 3], [1, 4], [2, 5]]
|
||||
# self.assertEqual(result, [[0, 3], [1, 4], [2, 5]])
|
||||
self.assertEqual(result, [[0, 1], [2, 3], [4, 5]]) # wrong!
|
||||
|
||||
def test_write_ones(self):
|
||||
pathlib.Path(temp("dt_write_ones")).unlink(missing_ok=True)
|
||||
|
||||
|
|
@ -276,6 +294,15 @@ class TestDiskTensor(unittest.TestCase):
|
|||
dt[1] = [3]
|
||||
self.assertEqual(dt.tolist(), [[1], [3]])
|
||||
|
||||
def test_strided_setitem(self):
|
||||
# test non-contiguous (strided) setitem - should set elements at indices 0, 2, 4
|
||||
pathlib.Path(temp(fn:="dt_strided_setitem")).unlink(missing_ok=True)
|
||||
dt = Tensor([1, 2, 3, 4, 5, 6]).to(f"disk:{temp(fn)}")
|
||||
dt[::2] = Tensor([10, 20, 30])
|
||||
# TODO: dt[::2] selects indices 0, 2, 4, so result should be [10, 2, 20, 4, 30, 6]
|
||||
# self.assertEqual(dt.tolist(), [10, 2, 20, 4, 30, 6])
|
||||
self.assertEqual(dt.tolist(), [10, 20, 30, 4, 5, 6]) # wrong!
|
||||
|
||||
def test_assign_const_to_disk(self):
|
||||
# assign from CONST (Tensor.full) to disk - source has no buffer, needs contiguous first
|
||||
pathlib.Path(temp(fn:="dt_assign_const")).unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ global_map = {}
|
|||
def b(i, base=None, offset=0, pin=False, size=16):
|
||||
global global_map
|
||||
if i in global_map: return global_map[i]
|
||||
global_map[i] = Buffer(Device.DEFAULT, size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
|
||||
global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
|
||||
if pin: global_map[i].ref(1)
|
||||
return global_map[i]
|
||||
|
||||
94
test/unit/test_tensor_metadata.py
Normal file
94
test/unit/test_tensor_metadata.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.tensor import _METADATA
|
||||
from tinygrad.helpers import Context
|
||||
|
||||
class TestTensorMetadata(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
_METADATA.set(None)
|
||||
self._ctx = Context(SCACHE=0)
|
||||
self._ctx.__enter__()
|
||||
def tearDown(self) -> None:
|
||||
self._ctx.__exit__(None, None, None)
|
||||
|
||||
@unittest.skip("why would this be true?")
|
||||
def test_exclude_noop_metadata(self):
|
||||
a = Tensor.rand(4, 4)*1
|
||||
self.assertEqual(a.uop.metadata[0].name, "__mul__")
|
||||
k = a.schedule()[-1]
|
||||
self.assertEqual([m.name for m in k.metadata], ["rand"])
|
||||
|
||||
@unittest.skip("metadata not reaching kernel schedule")
|
||||
def test_exclude_const_metadata(self):
|
||||
a = Tensor.arange(4)
|
||||
b = Tensor.full((4,), -1, dtype=dtypes.int).contiguous()
|
||||
sched = Tensor.schedule(a, b)
|
||||
self.assertEqual([m.name for m in sched[0].metadata], ["arange"])
|
||||
self.assertEqual([m.name for m in sched[1].metadata], ["contiguous"])
|
||||
|
||||
def test_matmul(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.uop.metadata[0].name, "matmul")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
|
||||
def test_relu(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.uop.metadata[0].name, "relu")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
@unittest.skip("assign metadata no longer captured")
|
||||
def test_assign(self):
|
||||
x = Tensor.empty(10, 10).realize()
|
||||
x.assign(Tensor.ones(10, 10).contiguous())
|
||||
si = x.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "assign")
|
||||
|
||||
def test_complex(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu() * y.sigmoid()
|
||||
self.assertEqual(out.uop.metadata[0].name, "__mul__")
|
||||
self.assertEqual(out.uop.src[0].metadata[0].name, "relu")
|
||||
self.assertEqual(out.uop.src[1].metadata[0].name, "sigmoid")
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.uop.metadata[0].name, "sum")
|
||||
out.backward()
|
||||
self.assertEqual(x.grad.uop.metadata[0].name, "relu")
|
||||
#self.assertTrue(x.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
#self.assertTrue(y.grad.uop.metadata[0].backward) # TODO: backward flag is False
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
#self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
# skip numpy, this is schedule cache
|
||||
self.assertSetEqual(set(m.name for m in si.metadata if m.name != "numpy"), {"sigmoid", "relu"})
|
||||
#bw = [m for m in si.metadata if m.backward]
|
||||
#self.assertEqual(len(bw), 1)
|
||||
#self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
def test_tracemeta_0(self):
|
||||
with Context(TRACEMETA=0):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertIsNone(out.uop.metadata)
|
||||
self.assertIsNone(out.uop.src[0].metadata)
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(si.metadata, ())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -686,6 +686,97 @@ class TestExpander(unittest.TestCase):
|
|||
sink = expander_rewrite(sink)
|
||||
print(sink)
|
||||
|
||||
class TestReduceCollapse(unittest.TestCase):
|
||||
def test_multi_range_reduce_add(self):
|
||||
"""Test that (x + y).reduce(r1, r2) distributes over multiple ranges"""
|
||||
from tinygrad.codegen.simplify import pm_reduce_collapse
|
||||
# Create two ranges
|
||||
r1 = UOp.range(3, 0)
|
||||
r2 = UOp.range(4, 1)
|
||||
# Create x + y where x and y depend on different ranges
|
||||
x = r1.cast(dtypes.float)
|
||||
y = r2.cast(dtypes.float)
|
||||
# (x + y).reduce(r1, r2) should be rewritten
|
||||
red = (x + y).reduce(r1, r2, arg=Ops.ADD)
|
||||
self.assertEqual(len(red.src), 3) # value + 2 ranges
|
||||
result = graph_rewrite(red, pm_reduce_collapse, name='test')
|
||||
# Should become add of two separate reduces
|
||||
self.assertEqual(result.op, Ops.ADD)
|
||||
|
||||
class TestLoadStoreFolding(unittest.TestCase):
|
||||
def test_gated_load_gep_preserves_alt(self):
|
||||
"""Test that LOAD(GEP, alt) preserves alt value after rewrite"""
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.vec(4).ptr(), (), 0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
gate = UOp.const(dtypes.bool, True)
|
||||
gated_index = buf.index(idx, gate)
|
||||
gep = gated_index.gep(0)
|
||||
alt = UOp.const(dtypes.float, 42.0)
|
||||
gated_load = gep.load(alt)
|
||||
self.assertEqual(len(gated_load.src), 2) # GEP + alt
|
||||
result = graph_rewrite(gated_load, load_store_folding, name='test')
|
||||
# After rewrite, should still have alt value preserved
|
||||
self.assertEqual(result.op, Ops.GEP)
|
||||
inner_load = result.src[0]
|
||||
self.assertEqual(inner_load.op, Ops.LOAD)
|
||||
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
|
||||
|
||||
def test_gated_load_ptrcat_preserves_alt(self):
|
||||
"""Test that LOAD(PTRCAT, alt) preserves alt value after rewrite"""
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding
|
||||
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
buf2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
idx1 = buf1.index(idx)
|
||||
idx2 = buf2.index(idx)
|
||||
ptrcat = UOp(Ops.PTRCAT, dtypes.float.ptr().vec(2), (idx1, idx2))
|
||||
alt = UOp.const(dtypes.float.vec(2), 42.0)
|
||||
gated_load = ptrcat.load(alt)
|
||||
self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt
|
||||
result = graph_rewrite(gated_load, load_store_folding, name='test')
|
||||
# After rewrite, should be CAT of LOADs, each preserving alt
|
||||
self.assertEqual(result.op, Ops.CAT)
|
||||
for inner_load in result.src:
|
||||
self.assertEqual(inner_load.op, Ops.LOAD)
|
||||
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
|
||||
self.assertEqual(inner_load.src[1].arg, 42.0) # alt value preserved
|
||||
|
||||
class TestConstBufferize(unittest.TestCase):
|
||||
def test_const_bufferize_with_ranges(self):
|
||||
"""Test that CONST.BUFFERIZE with ranges is folded correctly.
|
||||
|
||||
BUFFERIZE can have ranges as additional sources beyond the value.
|
||||
The pattern at rangeify.py uses allow_any_len=True because
|
||||
CONST doesn't depend on ranges (constant is same value everywhere).
|
||||
"""
|
||||
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
|
||||
c = UOp.const(dtypes.float, 42.0)
|
||||
r1 = UOp.range(3, 0)
|
||||
bufferize_with_range = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
|
||||
self.assertEqual(len(bufferize_with_range.src), 2) # const + 1 range
|
||||
|
||||
result = graph_rewrite(bufferize_with_range, pm_const_buffer_folding, name='test')
|
||||
# BUFFERIZE should be removed, result is const broadcast to shape
|
||||
self.assertNotEqual(result.op, Ops.BUFFERIZE)
|
||||
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
|
||||
self.assertIn(42.0, const_vals)
|
||||
|
||||
def test_const_bufferize_with_multiple_ranges(self):
|
||||
"""Test CONST.BUFFERIZE with multiple ranges is also folded."""
|
||||
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
|
||||
c = UOp.const(dtypes.float, 3.14)
|
||||
r1 = UOp.range(3, 0)
|
||||
r2 = UOp.range(4, 1)
|
||||
bufferize_with_ranges = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
|
||||
self.assertEqual(len(bufferize_with_ranges.src), 3) # const + 2 ranges
|
||||
|
||||
result = graph_rewrite(bufferize_with_ranges, pm_const_buffer_folding, name='test')
|
||||
# BUFFERIZE should be removed
|
||||
self.assertNotEqual(result.op, Ops.BUFFERIZE)
|
||||
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
|
||||
self.assertIn(3.14, const_vals)
|
||||
|
||||
class TestUOpTags(unittest.TestCase):
|
||||
def test_inc_by_one(self):
|
||||
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)
|
||||
|
|
@ -115,8 +115,8 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
|
||||
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
|
||||
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
# requires lst be toposorted. like graph rewrite, but for lines
|
||||
|
|
|
|||
|
|
@ -123,12 +123,12 @@ load_store_folding = PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
|
||||
])
|
||||
|
||||
# *** correct load/store ***
|
||||
|
|
@ -345,6 +345,6 @@ pm_add_loads = PatternMatcher([
|
|||
(UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else
|
||||
idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -147,8 +147,9 @@ class Buffer:
|
|||
def deallocate(self):
|
||||
assert hasattr(self, '_buf'), "buffer must be allocated to deallocate"
|
||||
if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
|
||||
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
||||
if GlobalCounters is not None and not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
if self._base is None:
|
||||
if GlobalCounters is not None and not self.device.startswith("DISK") and (self.options is None or self.options.external_ptr is None):
|
||||
GlobalCounters.mem_used -= self.nbytes
|
||||
if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", self.trace_num))
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
elif self._base is not None: self._base.allocated_views -= 1
|
||||
|
|
@ -263,7 +264,7 @@ class LRUAllocator(Allocator, Generic[DeviceType]):
|
|||
for opaque in opaques: super().free(opaque, sz, options)
|
||||
opaques.clear()
|
||||
def free(self, opaque:Any, size:int, options:BufferSpec|None=None):
|
||||
if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
|
||||
if LRU and (options is None or (not options.nolru and options.external_ptr is None)): self.cache[(size, options)].append(opaque)
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
# **************** for Compiled Devices ****************
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class BufferCopy(Runner):
|
|||
getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
|
||||
if disk_supports_fast_copyout and hasattr(dest.allocator, 'copy_from_disk') and src.nbytes >= 4096:
|
||||
dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
|
||||
elif (src.device.startswith("DISK") or src.device.startswith("TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
elif isinstance(src.device, str) and src.device.startswith(("DISK", "TINYFS")) and hasattr(dest.allocator, '_as_buffer'):
|
||||
# fast(ish) path, uses readinto in diskbuffers
|
||||
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,21 +1,25 @@
|
|||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
# ScheduleItem = tuple[AST, buffer UOps, metadata, fixedvars, bound_ranges]
|
||||
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], dict[str, int], tuple[UOp, ...]]
|
||||
|
||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
|
|
@ -26,81 +30,77 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
s = _unwrap_src(s)
|
||||
if s.op is Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.BUFFER, Ops.BIND}:
|
||||
pass # a BUFFER is already realized, BINDs are handled in complete_create_schedule_with_vars
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
case Ops.MSELECT | Ops.MSTACK:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
case Ops.BUFFER | Ops.BIND:
|
||||
pass # BUFFER is already realized, BIND is outer range (handled via bound_ranges below)
|
||||
case _:
|
||||
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, MSELECT, MSTACK, or BIND, not {s.op}")
|
||||
|
||||
with cpu_profile(TracingKey("linearize schedule")):
|
||||
queue: deque[UOp] = deque()
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queue.append(k)
|
||||
|
||||
schedule: list[tuple|UOp] = []
|
||||
schedule: list[ScheduleItem|UOp] = [] # ScheduleItem for kernels, UOp for RANGE/END
|
||||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, k.arg.metadata, {}, bound_ranges))
|
||||
schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
with cpu_profile(TracingKey("expand ranges")):
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr = 0
|
||||
in_ranges: dict[UOp, int] = {}
|
||||
range_ptrs: dict[UOp, int] = {}
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if isinstance(si, UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
ast, buf_uops, metadata, fixedvars, bound_ranges = si
|
||||
fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
with cpu_profile(TracingKey("unroll outer ranges")):
|
||||
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule)
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
def unroll_outer_ranges(schedule:list[ScheduleItem|UOp]) -> tuple[list[ExecItem], list[UOp]]:
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
|
||||
while sched_ptr < len(schedule):
|
||||
if isinstance(si := schedule[sched_ptr], UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
ast, buf_uops, metadata, _, bound_ranges = si
|
||||
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
return pre_schedule, buf_uops_list
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
if b.op is Ops.BUFFER:
|
||||
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1]))
|
||||
else:
|
||||
# TODO: flip args in CONST
|
||||
assert b.op is Ops.CONST
|
||||
ctx[0][b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx[0]))))
|
||||
# both BUFFER and CONST have src=(UNIQUE, DEVICE), replace UNIQUE with LUNIQUE
|
||||
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1]))
|
||||
return ret
|
||||
|
||||
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
|
||||
|
|
@ -110,10 +110,8 @@ def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
|
|||
return ctx[0].setdefault(b, b.replace(src=(b.src[0],)))
|
||||
|
||||
pm_pre_sched_cache = PatternMatcher([
|
||||
# replace input buffers
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# remove unique consts
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
||||
# replace UNIQUE with LUNIQUE for cache key normalization
|
||||
(UPat((Ops.BUFFER, Ops.CONST), src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
|
||||
])
|
||||
|
|
@ -126,8 +124,8 @@ def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
|||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
||||
# restore LUNIQUE back to UNIQUE
|
||||
(UPat((Ops.BUFFER, Ops.CONST), src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
|
||||
# restore BIND value stripped in pm_pre_sched_cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
])
|
||||
|
|
@ -175,8 +173,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
pre_schedule, combined_sink = sc_ret
|
||||
|
||||
# replace all the LUNIQUEs with UNIQUEs (single graph_rewrite for everything)
|
||||
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
|
||||
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite combined")
|
||||
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
|
||||
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
tensor_map_sink, buf_uops_sink = combined.src
|
||||
tm_src = tensor_map_sink.src
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
|
@ -205,4 +203,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
return tensor_map, schedule, var_vals if schedule else {}
|
||||
|
||||
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
|
||||
return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ pm_gradient = PatternMatcher([
|
|||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# gradient on CALL is a custom function
|
||||
(UPat(Ops.CALL, name="k"), lambda ctx, k: (None,)+k.arg(ctx, k)),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -204,6 +204,10 @@ CCACHE = ContextVar("CCACHE", 1)
|
|||
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
|
||||
# set to 0 to disable the scheduler cache
|
||||
SCACHE = ContextVar("SCACHE", 1)
|
||||
# allow use of atomics for embedding backward
|
||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||
# allow use of assembly for gemm
|
||||
ASM_GEMM = ContextVar("ASM_GEMM", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import math
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import prod, make_tuple, flatten
|
||||
from tinygrad.helpers import prod, make_tuple, flatten, USE_ATOMICS
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
||||
class BatchNorm:
|
||||
|
|
@ -304,6 +304,46 @@ class RMSNorm:
|
|||
x = self._norm(x.float()).cast(x.dtype)
|
||||
return x if self.weight is None else x * self.weight
|
||||
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, Ops
|
||||
def _embedding_bwd(grad_emb:UOp, call:UOp) -> tuple:
|
||||
weight, idx = call.src[1:]
|
||||
# for multi-device: unshard inputs to one device
|
||||
if isinstance(weight.device, tuple):
|
||||
assert weight.axis is None, "sharded weights on Embedding not supported with USE_ATOMICS"
|
||||
grad_emb = grad_emb.copy_to_device(weight.device)
|
||||
idx = idx.copy_to_device(weight.device)
|
||||
# weight is replicated, grad_weight should match
|
||||
grad_weight_uop = Tensor.empty(weight.shape, dtype=dtypes.float, device=weight.device).uop
|
||||
|
||||
# TODO: how do we remove this dumb kernel and use Tensor.zeros?
|
||||
def _zero_kernel(out:UOp) -> UOp:
|
||||
i = UOp.range(out.size, 0)
|
||||
return out.flatten()[i].store(0).end(i).sink(arg=KernelInfo(name="zero"))
|
||||
grad_weight_uop = grad_weight_uop.custom_kernel(fxn=_zero_kernel)[0]
|
||||
|
||||
# TODO: do we have a universal helper for this?
|
||||
device = call.device.split(":")[0] if not isinstance(call.device, tuple) else call.device[0].split(":")[0]
|
||||
|
||||
# this is the real atomic kernel
|
||||
def _embedding_bwd_kernel(grad_weight:UOp, grad_emb:UOp, idx:UOp) -> UOp:
|
||||
idx_flat, grad_emb_flat = idx.flatten(), grad_emb.reshape((idx.size, grad_weight.shape[-1]))
|
||||
i = UOp.range(grad_emb_flat.shape[0], 0) # batch_size * sequence_length
|
||||
j = UOp.range(grad_emb_flat.shape[1], 1) # embed_size
|
||||
token_id = idx_flat[i].clip(0, grad_weight.shape[0]-1).cast(dtypes.index)
|
||||
# atomic scatter-add: grad_weight[token_id, j] += grad_emb_flat[i, j]
|
||||
if device in ("CPU", "NULL"): atomic_arg = "__atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED);"
|
||||
elif device == "AMD": atomic_arg = "__hip_atomic_fetch_add({0}, {1}, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);"
|
||||
else: raise NotImplementedError(f"no atomics for device {device}")
|
||||
atomic = UOp(Ops.CUSTOM, dtypes.void, (grad_weight.index(token_id, j, ptr=True), grad_emb_flat[i, j].cast(dtypes.float)), arg = atomic_arg)
|
||||
return atomic.end(i, j).sink(arg=KernelInfo(name="embedding_bwd", opts_to_apply=()))
|
||||
grad_weight_uop = grad_weight_uop.custom_kernel(grad_emb, idx, fxn=_embedding_bwd_kernel)[0]
|
||||
|
||||
return (grad_weight_uop.cast(weight.dtype), None)
|
||||
|
||||
def _embedding_fwd(weight:Tensor, idx:Tensor) -> Tensor:
|
||||
arange = Tensor.arange(weight.shape[0], requires_grad=False, device=weight.device)
|
||||
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(weight, 0).sum(-2, dtype=weight.dtype)
|
||||
|
||||
class Embedding:
|
||||
"""
|
||||
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
|
@ -316,12 +356,12 @@ class Embedding:
|
|||
```
|
||||
"""
|
||||
def __init__(self, vocab_size:int, embed_size:int):
|
||||
self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||
arange = Tensor.arange(self.weight.shape[0], requires_grad=False, device=self.weight.device)
|
||||
return (arange == idx.unsqueeze(-1)).unsqueeze(-1).where(self.weight, 0).sum(-2, dtype=self.weight.dtype)
|
||||
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
|
||||
return _embedding_fwd(self.weight, idx)
|
||||
|
||||
class LSTMCell:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -46,10 +46,10 @@ base_rewrite = PatternMatcher([
|
|||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
|
||||
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
# TODO: look for left-associative
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
|
|
|
|||
|
|
@ -528,7 +528,7 @@ class AMDCopyQueue(HWQueue):
|
|||
# USB devices run in single-step mode, so they can't overrun the queue.
|
||||
total_bytes = (tail_blit_dword * 4 if rem_packet_cnt == 0 else -sdma_queue.put_value % sdma_queue.ring.nbytes) + rem_packet_cnt * 4
|
||||
assert total_bytes < sdma_queue.ring.nbytes, "SDMA queue overrun"
|
||||
while not dev.is_usb() and sdma_queue.put_value + total_bytes - sdma_queue.read_ptr > sdma_queue.ring.nbytes: pass
|
||||
while not dev.is_usb() and sdma_queue.put_value + total_bytes - sdma_queue.read_ptr[0] > sdma_queue.ring.nbytes: pass
|
||||
|
||||
start_idx = (sdma_queue.put_value % sdma_queue.ring.nbytes) // 4
|
||||
sdma_queue.ring[start_idx : start_idx + tail_blit_dword] = array.array('I', cmds[:tail_blit_dword])
|
||||
|
|
@ -640,24 +640,21 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
|
|||
@dataclass
|
||||
class AMDQueueDesc:
|
||||
ring: MMIOInterface
|
||||
read_ptrs: list[MMIOInterface]
|
||||
write_ptrs: list[MMIOInterface]
|
||||
doorbells: list[MMIOInterface]
|
||||
read_ptr: MMIOInterface
|
||||
write_ptr: MMIOInterface
|
||||
doorbell: MMIOInterface
|
||||
put_value: int = 0
|
||||
|
||||
@property
|
||||
def read_ptr(self): return min(p[0] for p in self.read_ptrs)
|
||||
|
||||
def signal_doorbell(self, dev, doorbell_value:int|None=None):
|
||||
try:
|
||||
for write_ptr in self.write_ptrs: write_ptr[0] = self.put_value
|
||||
self.write_ptr[0] = self.put_value
|
||||
|
||||
# Ensure all prior writes are visible to the GPU.
|
||||
System.memory_barrier()
|
||||
|
||||
# Flush hdp if queue is in dev mem.
|
||||
if dev.is_am() and not dev.is_usb(): dev.iface.dev_impl.gmc.flush_hdp()
|
||||
for doorbell in self.doorbells: doorbell[0] = self.put_value if doorbell_value is None else doorbell_value
|
||||
self.doorbell[0] = self.put_value if doorbell_value is None else doorbell_value
|
||||
except Exception as e:
|
||||
dev.error_state = e
|
||||
raise
|
||||
|
|
@ -776,9 +773,9 @@ class KFDIface:
|
|||
self.doorbells_base = queue.doorbell_offset & (~0x1fff) # doorbell is two pages
|
||||
self.doorbells = cast(FileIOInterface, KFDIface.kfd).mmap(0, 0x2000, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, self.doorbells_base)
|
||||
|
||||
return AMDQueueDesc(ring=MMIOInterface(ring.va_addr, ring.size, fmt='I'), read_ptrs=[MMIOInterface(queue.read_pointer_address, 8, fmt='Q')],
|
||||
write_ptrs=[MMIOInterface(queue.write_pointer_address, 8, fmt='Q')],
|
||||
doorbells=[MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q')])
|
||||
return AMDQueueDesc(ring=MMIOInterface(ring.va_addr, ring.size, fmt='I'), read_ptr=MMIOInterface(queue.read_pointer_address, 8, fmt='Q'),
|
||||
write_ptr=MMIOInterface(queue.write_pointer_address, 8, fmt='Q'),
|
||||
doorbell=MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q'))
|
||||
|
||||
def sleep(self, tm:int) -> bool:
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
|
||||
|
|
@ -857,8 +854,8 @@ class PCIIface(PCIIfaceBase):
|
|||
wptr_addr=gart.va_addr+wptr, eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size,
|
||||
idx=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
|
||||
|
||||
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
|
||||
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)
|
||||
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbell=self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q'),
|
||||
read_ptr=gart.cpu_view().view(offset=rptr, size=8, fmt='Q'), write_ptr=gart.cpu_view().view(offset=wptr, size=8, fmt='Q'), put_value=pv)
|
||||
|
||||
def sleep(self, timeout) -> bool:
|
||||
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):
|
||||
|
|
|
|||
|
|
@ -69,11 +69,11 @@ class CUDAAllocator(LRUAllocator['CUDADevice']):
|
|||
if options.external_ptr: return cuda.CUdeviceptr_v2(options.external_ptr)
|
||||
if options.host: return init_c_var(ctypes.c_void_p, lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0x01)))
|
||||
return init_c_var(cuda.CUdeviceptr, lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
|
||||
@suppress_finalizing
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
try:
|
||||
if options.host: check(cuda.cuMemFreeHost(opaque))
|
||||
else: check(cuda.cuMemFree_v2(opaque))
|
||||
except (TypeError, AttributeError): pass
|
||||
if options.external_ptr: return
|
||||
if options.host: check(cuda.cuMemFreeHost(opaque))
|
||||
else: check(cuda.cuMemFree_v2(opaque))
|
||||
def _copyin(self, dest, src:memoryview):
|
||||
check(cuda.cuCtxSetCurrent(self.dev.context))
|
||||
host_mem = self.alloc(len(src), BufferSpec(host=True))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform, sys
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, PROFILE, ProfileRangeEvent, cpu_profile, unwrap
|
||||
import subprocess, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, PROFILE, ProfileRangeEvent, cpu_profile, unwrap, suppress_finalizing
|
||||
import tinygrad.runtime.support.objc as objc
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent, CompilerSet, CompilerPair
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
|
@ -167,8 +167,9 @@ class MetalAllocator(LRUAllocator[MetalDevice]):
|
|||
ret.retain = False
|
||||
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
return MetalBuffer(ret, size)
|
||||
@suppress_finalizing
|
||||
def _free(self, opaque:MetalBuffer, options):
|
||||
if not sys.is_finalizing(): opaque.buf.release
|
||||
if not options.external_ptr: opaque.buf.release
|
||||
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
||||
dest_dev.synchronize()
|
||||
src_command_buffer = src_dev.mtl_queue.commandBuffer().retained()
|
||||
|
|
|
|||
|
|
@ -828,3 +828,5 @@ class NVDevice(HCQCompiled[NVSignal]):
|
|||
self.iface.rm_control(self.profiler, nv_gpu.NVB0CC_CTRL_CMD_PMA_STREAM_UPDATE_GET_PUT,
|
||||
nv_gpu.struct_NVB0CC_CTRL_PMA_STREAM_UPDATE_GET_PUT_PARAMS(bytesConsumed=params.bytesAvailable))
|
||||
return pma_data
|
||||
|
||||
def device_props(self): return {'arch': self.arch, 'sm_version': self.sm_version}
|
||||
|
|
|
|||
|
|
@ -399,9 +399,11 @@ class QCOMDevice(HCQCompiled):
|
|||
raise RuntimeError("Failed to map external pointer to GPU memory") from e
|
||||
|
||||
def _gpu_free(self, mem:HCQBuffer):
|
||||
if mem.meta[0] is None: return
|
||||
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta[0].id)
|
||||
if mem.meta[1]: FileIOInterface.munmap(mem.va_addr, mem.meta[0].mmapsize)
|
||||
if mem.meta[0] is None: return # external (gpu) ptr
|
||||
if not mem.meta[1]: kgsl.IOCTL_KGSL_SHAREDMEM_FREE(self.fd, gpuaddr=mem.meta[0].gpuaddr) # external (cpu) ptr
|
||||
else:
|
||||
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta[0].id)
|
||||
FileIOInterface.munmap(mem.va_addr, mem.meta[0].mmapsize)
|
||||
|
||||
def _ensure_stack_size(self, sz):
|
||||
if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@ def import_asic_regs(prefix:str, version:tuple[int, ...], cls=AMDReg) -> dict[st
|
|||
return x
|
||||
def _download_file(ver, suff) -> str:
|
||||
dir_prefix = {"osssys": "oss"}.get(prefix, prefix)
|
||||
fetch_name, file_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h", f"{prefix}_{'_'.join(map(str, version))}_{suff}.h"
|
||||
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=file_name, subdir="asic_regs")
|
||||
fetch_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h"
|
||||
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=fetch_name, subdir="asic_regs")
|
||||
|
||||
for ver in fixup_ip_version(prefix, version):
|
||||
try: offs, sh_masks = _extract_regs(_download_file(ver, "offset")), _extract_regs(_download_file(ver, "sh_mask"))
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ def assign_multi(dest:UOp, src:UOp):
|
|||
return dest.src[0].assign(src.src[0]).multi(src.axis)
|
||||
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis)
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
|
|
@ -218,6 +218,7 @@ multi_pm = PatternMatcher([
|
|||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def split_reduceop(reduce:UOp, x:UOp):
|
|||
|
||||
mop_cleanup = PatternMatcher([
|
||||
# merge adjacent RESHAPES, safe because they are not tagged
|
||||
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"),
|
||||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
])
|
||||
|
||||
|
|
@ -68,15 +68,29 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
|
|||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
if [x.arg for x in params] != list(range(len(params))): raise RuntimeError(f"params not in order: {[x.arg for x in params]}")
|
||||
if len(params) != len(args): raise TypeError(f"expected {len(params)} args, got {len(args)}")
|
||||
for i, (p, a) in enumerate(zip(params, args)):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict(zip(params, args))).rtag(c.tag)
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
|
|
@ -233,7 +247,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
# dont bufferize an arange
|
||||
(UPat.any((r:=UPat(dtype=dtypes.index).cast()).named("src"), r.eq(UPat()).named("src")).f(Ops.BUFFERIZE,
|
||||
allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
# no buffers for const
|
||||
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
|
||||
# indexing a const is a const
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
|
|
@ -252,24 +266,24 @@ pm_remove_bufferize = PatternMatcher([
|
|||
])
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
shape = b.shape
|
||||
size = prod(shape)
|
||||
if not (isinstance(b.device, str) and b.device.startswith(("DISK", "TINYFS"))): return b
|
||||
shape = b.shape
|
||||
size = prod(shape)
|
||||
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
|
||||
return b
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
|
||||
])
|
||||
|
||||
|
|
@ -502,12 +516,11 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
# gather the metadata
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
for u in ret.toposort():
|
||||
# TODO: this can be wrong if there's multiple of these
|
||||
if u.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}:
|
||||
ret = u
|
||||
break
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
|
||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||
else:
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
|
||||
|
||||
|
|
@ -521,11 +534,14 @@ split_kernels = PatternMatcher([
|
|||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:list[UOp], x:UOp):
|
||||
if x.tag is not None: return None
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None or x in ctx[1]: return None
|
||||
if x.tag is None and x.op is Ops.CALL:
|
||||
# don't tag anything in a CALL
|
||||
for u in x.src[0].toposort(): ctx[1].add(u)
|
||||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx.append(x)
|
||||
return x.replace(tag=(len(ctx)-1,))
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
|
|
@ -551,7 +567,7 @@ replace_contiguous = PatternMatcher([
|
|||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_mops+earliest_rewrites+replace_contiguous, ctx={}, name="earliest rewrites")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
|
|||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
|
|
@ -232,6 +232,16 @@ class Tensor(OpMixin):
|
|||
|
||||
# ***** data handlers ****
|
||||
|
||||
def as_param(self, slot:int):
|
||||
if self.uop.axis is not None:
|
||||
multi_shape = tuple([s//len(self.device) if i==self.uop.axis else s for i,s in enumerate(self.shape)])
|
||||
param = UOp.param(slot, self.dtype, multi_shape, self.device).multi(self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
||||
|
|
@ -275,13 +285,13 @@ class Tensor(OpMixin):
|
|||
self.uop = x.uop
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
# broadcast x
|
||||
|
|
@ -1268,13 +1278,12 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst) -> None:
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
self.realize()._getitem(indices).assign(v)
|
||||
return
|
||||
# NOTE: check that setitem target is valid first
|
||||
if isinstance(v, get_args(PyConst)): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
self.realize()
|
||||
if not self.uop.is_contiguous(): raise RuntimeError("setitem target needs to be contiguous")
|
||||
|
|
@ -2422,6 +2431,9 @@ class Tensor(OpMixin):
|
|||
```
|
||||
"""
|
||||
if IMAGE: return self.image_dot(w, dtype)
|
||||
if ASM_GEMM:
|
||||
from extra.gemm.asm.cdna.gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(self, w): return asm_gemm(self, w)
|
||||
x, dx, dw = self, self.ndim, w.ndim
|
||||
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
|
||||
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class Ops(FastEnum):
|
|||
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
PARAM = auto(); CALL = auto()
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Callable
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, DType, promo_lattice
|
||||
from tinygrad.dtype import dtypes, DType, promo_lattice, truncate
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import flatten, polyN
|
||||
from tinygrad.uop import GroupOp
|
||||
|
|
@ -353,8 +353,9 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
|||
case Ops.IDIV | Ops.MOD:
|
||||
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
|
||||
if dt == dtypes.int:
|
||||
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dt, a0, a1))[0], a0).bitcast(dtypes.uint), a_neg.where(n[1], a1).bitcast(dtypes.uint)
|
||||
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dt, b0, b1))[0], b0).bitcast(dtypes.uint), b_neg.where(n[1], b1).bitcast(dtypes.uint)
|
||||
ua0, ua1, ub0, ub1 = a0.bitcast(dtypes.uint), a1.bitcast(dtypes.uint), b0.bitcast(dtypes.uint), b1.bitcast(dtypes.uint)
|
||||
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dtypes.uint, ua0, ua1))[0], ua0), a_neg.where(n[1], ua1)
|
||||
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dtypes.uint, ub0, ub1))[0], ub0), b_neg.where(n[1], ub1)
|
||||
q, r = (z:=UOp.const(dtypes.uint, 0), z), (z, z)
|
||||
for i in range(63, -1, -1):
|
||||
r = l2i(Ops.SHL, dtypes.uint, *r, UOp.const(dtypes.uint, 1), z)
|
||||
|
|
@ -364,8 +365,9 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
|||
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
|
||||
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
|
||||
if dt == dtypes.int:
|
||||
nq, nr = l2i(Ops.NEG, dt, q0:=q[0].bitcast(dt), q1:=q[1].bitcast(dt)), l2i(Ops.NEG, dt, r0:=r[0].bitcast(dt), r1:=r[1].bitcast(dt))
|
||||
return (a_neg.where(nr[0], r0), a_neg.where(nr[1], r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq[0], q0), (a_neg^b_neg).where(nq[1], q1))
|
||||
(nq0, nq1), (nr0, nr1) = l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *q)), l2i(Ops.BITCAST, dt, *l2i(Ops.NEG, dtypes.uint, *r))
|
||||
(q0, q1), (r0, r1) = l2i(Ops.BITCAST, dt, *q), l2i(Ops.BITCAST, dt, *r)
|
||||
return (a_neg.where(nr0, r0), a_neg.where(nr1, r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq0, q0), (a_neg^b_neg).where(nq1, q1))
|
||||
return (r[0].bitcast(dt), r[1].bitcast(dt)) if op == Ops.MOD else (q[0].bitcast(dt), q[1].bitcast(dt))
|
||||
case Ops.CMPLT: return (a1 < b1) | ((a1.eq(b1)) & (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)))
|
||||
case Ops.CMPEQ: return a0.eq(b0) & a1.eq(b1)
|
||||
|
|
@ -447,5 +449,5 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcenden
|
|||
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
|
||||
None if x.tag is None else x.replace(dtype=l2i_dt[x.dtype], src=(l2i_idx(idx, x.tag),)))]
|
||||
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
None if x.tag is None else UOp.const(l2i_dt[x.dtype], (x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))]
|
||||
None if x.tag is None else UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))]
|
||||
return PatternMatcher(pat)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
|
|||
if pad is not None: ret += " " * (pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
def shape_to_shape_arg(arg:tuple[sint, ...]) -> UOp:
|
||||
if len(arg) == 0: return UOp(Ops.VECTORIZE, dtypes.index.vec(0))
|
||||
elif all(isinstance(x, int) for x in arg): return UOp.const(dtypes.index.vec(len(arg)), cast(tuple[int, ...], arg))
|
||||
else: return UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg))
|
||||
|
||||
def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in lst:
|
||||
|
|
@ -205,7 +210,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
return None
|
||||
|
||||
|
|
@ -224,9 +229,13 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
case Ops.ENCDEC: return self.arg[0]
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.PARAM:
|
||||
# NOTE: copied from marg
|
||||
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
return None
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
|
|
@ -436,7 +445,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
# NOTE: b is ConstType, not ConstLike, so UOps and tuples aren't allowed
|
||||
assert not isinstance(b, (UOp, tuple)), "unique const only works on numbers"
|
||||
ret = UOp.const(dtype, b, device)
|
||||
return ret.replace(src=ret.src + (UOp.unique(None if unique is True else unique),))
|
||||
return ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
|
|
@ -558,11 +567,7 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
|
||||
case Ops.PERMUTE | Ops.FLIP: src_args = []
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
usrcs = []
|
||||
for arg in src_args:
|
||||
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
|
||||
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
|
||||
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
|
||||
usrcs = [shape_to_shape_arg(arg) for arg in src_args]
|
||||
if len(usrcs) == 0: ret = UOp(op, self.dtype, (self,), arg)
|
||||
else: ret = UOp(op, self.dtype, (self,)+UOp.sink(*usrcs).simplify().src)
|
||||
# for all movement ops, we check shape property to validity check the movement op
|
||||
|
|
@ -824,6 +829,13 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
|
||||
return self.src[0].after(self.store(val).end(*argfix(end)))
|
||||
|
||||
# TODO: this should replace placeholder
|
||||
@staticmethod
|
||||
def param(slot:int, dtype:DType, shape:tuple[sint, ...]|None=None, device=None):
|
||||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
|
|
@ -1372,8 +1384,8 @@ def render_marg(ctx,x:UOp):
|
|||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
|
||||
lambda x,d,u: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
|
||||
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
|
|
|
|||
|
|
@ -88,14 +88,11 @@ _tensor_spec = PatternMatcher([
|
|||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
||||
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND, Ops.CONTIGUOUS))), lambda: True),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
|
@ -113,7 +110,7 @@ _tensor_spec = PatternMatcher([
|
|||
|
||||
# device or unique
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat((Ops.LUNIQUE, Ops.UNIQUE)))), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE))), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
|
|
@ -141,7 +138,11 @@ _tensor_spec = PatternMatcher([
|
|||
|
||||
# Tensor range bind / store
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat(Ops.RANGE)), arg=None), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True)
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat())), lambda: True),
|
||||
|
||||
# allow CALL/PARAM
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
])+movement_ops+shared_spec
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
|
|
@ -347,7 +348,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
|
|||
base = unwrap(p.base)
|
||||
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"], )
|
||||
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
|
||||
rctx = decode(data, {p.name:disasm})
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# * INST waves
|
||||
wave_insts:dict[str, dict[str, dict]] = {}
|
||||
|
|
@ -555,11 +556,11 @@ def get_render(query:str) -> dict:
|
|||
pc_to_inst = data["disasm"]
|
||||
start_pc = None
|
||||
rows:dict[int, dict] = {}
|
||||
for pc, (inst,_) in pc_to_inst.items():
|
||||
if start_pc is None: start_pc = pc
|
||||
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
|
||||
for e in w.unpack_insts():
|
||||
if start_pc is None: start_pc = e.pc
|
||||
if (inst:=rows.get(e.pc)) is None:
|
||||
rows[e.pc] = inst = {"pc":e.pc-start_pc, "inst":pc_to_inst[e.pc][0], "hit_count":0, "dur":0, "stall":0, "type":str(e.typ).split("_")[-1],
|
||||
"hits":{"cols":inst_columns, "rows":[]}}
|
||||
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
|
||||
inst["hit_count"] += 1
|
||||
inst["dur"] += e.dur
|
||||
inst["stall"] += e.stall
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue