llama: update local amax implementation after ParamArgs change (#16446)

* local amax failing test

* update _local_abs_max_fxn
This commit is contained in:
qazal 2026-05-30 10:55:43 +03:00 committed by GitHub
commit 29b47a0057
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 3 deletions

View file

@ -1,7 +1,8 @@
from __future__ import annotations
import functools, pathlib
from dataclasses import replace
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import shape_to_shape_arg
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
FP8_MAX = 448.0
@ -11,7 +12,7 @@ NUM_WG, THREADS_PER_WG = 1024, 256
@functools.cache
def _local_abs_max_fxn(x_p, device):
x = Tensor(x_p, device=device)
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
inner = Tensor(x.uop.replace(src=(shape_to_shape_arg(x.uop.shard_shape),), arg=replace(x.uop.arg, axis=None))) if x.uop.axis is not None else x
return (inner.abs().max(),)
def local_abs_max(x:Tensor) -> Tensor:

View file

@ -1,8 +1,9 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad import Tensor, Device, dtypes, Context, GlobalCounters
from tinygrad.helpers import getenv
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8
from extra.llama_kernels.fused_ce import fused_ce_loss
from extra.llama_kernels import local_abs_max
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed, quantize_fp8_scalar
from test.helpers import needs_second_gpu
@ -82,5 +83,14 @@ class TestQuantizeFP8(unittest.TestCase):
assert fp8.uop.shape == x.uop.shape
assert new_amax.shape == ()
class TestLocalAmax(unittest.TestCase):
def test_multi_tensor_local_shard_amax(self):
devices = ("CPU:0", "CPU:1")
x = Tensor.arange(16, device=devices[0]).reshape(4, 4).cast(dtypes.float).contiguous().realize().shard(devices, axis=0).realize()
GlobalCounters.reset()
out = (x * local_abs_max(x)).contiguous().realize()
self.assertEqual(GlobalCounters.kernel_count, 4)
self.assertEqual(out.tolist(), [[0., 7., 14., 21.], [28., 35., 42., 49.], [120., 135., 150., 165.], [180., 195., 210., 225.]])
if __name__ == '__main__':
unittest.main()