add Q4_K GGUF quantization support

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz 2025-12-15 10:04:26 -05:00
commit 9ea9c40be3
3 changed files with 11 additions and 2 deletions

View file

@ -58,6 +58,7 @@ class TestGGUF(unittest.TestCase):
def test_dequantization_q4_0(self): self._test_dequantization(ggml.GGML_TYPE_Q4_0)
def test_dequantization_q4_1(self): self._test_dequantization(ggml.GGML_TYPE_Q4_1)
def test_dequantization_q8_0(self): self._test_dequantization(ggml.GGML_TYPE_Q8_0)
def test_dequantization_q4_k(self): self._test_dequantization(ggml.GGML_TYPE_Q4_K)
def test_dequantization_q6_k(self): self._test_dequantization(ggml.GGML_TYPE_Q6_K)
def test_dequantization_mxfp4(self):
MXFP4 = 39

View file

@ -182,6 +182,7 @@ class Transformer:
models = {
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
"llama3.2:1b-q4": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
"llama3.2:3b": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
"llama3.2:3b-f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
"llama3.1:8b": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",

View file

@ -308,7 +308,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14), MXFP4 (id: 39)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q4_K (id: 12), Q6_K (id: 14), MXFP4 (id: 39)
"""
# https://github.com/ggerganov/ggml/blob/323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e/include/ggml.h#L356
@ -322,13 +322,20 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
return t.unsqueeze(-1).expand((*t.shape,8//b)).idiv(shift_tensor).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34), 39: (32, 17) }.get(ggml_type)) is not None:
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 8: (32, 34), 12: (256, 144), 14: (256, 210), 39: (32, 17) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 12: # Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128)
d, dmin = (blocks[:,i:i+2].bitcast(dtypes.float16).cast(dtypes.float32).unsqueeze(-1) for i in [0, 2])
s = blocks[:,4:16] # 12 bytes: 6-bit scales[0-3], 6-bit mins[0-3], high bits[4-7]
sc = s[:,0:4].bitwise_and(63).cat(s[:,8:12].bitwise_and(0xF).bitwise_or(s[:,0:4].rshift(6).lshift(4)), dim=-1)
mn = s[:,4:8].bitwise_and(63).cat(s[:,8:12].rshift(4).bitwise_or(s[:,4:8].rshift(6).lshift(4)), dim=-1)
q = Tensor.stack((qs:=blocks[:,16:144].reshape(-1,4,32)).bitwise_and(0xF), qs.rshift(4), dim=2).reshape(-1,8,32).cast(dtypes.float32)
return (d * sc.unsqueeze(-1) * q - dmin * mn.unsqueeze(-1)).flatten(-2)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))