mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama: update local amax implementation after ParamArgs change (#16446)
* local amax failing test * update _local_abs_max_fxn
This commit is contained in:
parent
6795c2d5c9
commit
29b47a0057
2 changed files with 14 additions and 3 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import shape_to_shape_arg
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
|
||||
|
||||
FP8_MAX = 448.0
|
||||
|
|
@ -11,7 +12,7 @@ NUM_WG, THREADS_PER_WG = 1024, 256
|
|||
@functools.cache
|
||||
def _local_abs_max_fxn(x_p, device):
|
||||
x = Tensor(x_p, device=device)
|
||||
inner = Tensor(x.uop.src[0]) if x.uop.op is Ops.MULTI else x
|
||||
inner = Tensor(x.uop.replace(src=(shape_to_shape_arg(x.uop.shard_shape),), arg=replace(x.uop.arg, axis=None))) if x.uop.axis is not None else x
|
||||
return (inner.abs().max(),)
|
||||
|
||||
def local_abs_max(x:Tensor) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad import Tensor, Device, dtypes, Context, GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8
|
||||
from extra.llama_kernels.fused_ce import fused_ce_loss
|
||||
from extra.llama_kernels import local_abs_max
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed, quantize_fp8_scalar
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
|
|
@ -82,5 +83,14 @@ class TestQuantizeFP8(unittest.TestCase):
|
|||
assert fp8.uop.shape == x.uop.shape
|
||||
assert new_amax.shape == ()
|
||||
|
||||
class TestLocalAmax(unittest.TestCase):
|
||||
def test_multi_tensor_local_shard_amax(self):
|
||||
devices = ("CPU:0", "CPU:1")
|
||||
x = Tensor.arange(16, device=devices[0]).reshape(4, 4).cast(dtypes.float).contiguous().realize().shard(devices, axis=0).realize()
|
||||
GlobalCounters.reset()
|
||||
out = (x * local_abs_max(x)).contiguous().realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 4)
|
||||
self.assertEqual(out.tolist(), [[0., 7., 14., 21.], [28., 35., 42., 49.], [120., 135., 150., 165.], [180., 195., 210., 225.]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue