mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove memory peak for quantized llama (#1720)
This commit is contained in:
parent
e4eb5d55c7
commit
b5cf274da3
1 changed files with 3 additions and 2 deletions
|
|
@ -262,8 +262,8 @@ class AbsmaxQuantizedLinear:
|
|||
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight.realize()
|
||||
new_tensors[name.replace('weight', 'scale')] = scale.realize()
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
|
@ -287,6 +287,7 @@ class LLaMa:
|
|||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
for _,v in weights.items(): v.realize()
|
||||
load_state_dict(model, weights, strict=False)
|
||||
|
||||
return LLaMa(model, sp_model)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue