remove memory peak for quantized llama (#1720)

This commit is contained in:
nimlgen 2023-08-30 23:32:30 +03:00 committed by GitHub
commit b5cf274da3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -262,8 +262,8 @@ class AbsmaxQuantizedLinear:
if 'feed_forward' in name or ('attention.w') in name or name == 'output.weight':
scale = v.abs().max(axis=1) / 127.0
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
new_tensors[name] = int8_weight.realize()
new_tensors[name.replace('weight', 'scale')] = scale.realize()
new_tensors[name] = int8_weight
new_tensors[name.replace('weight', 'scale')] = scale
else:
new_tensors[name] = v
return new_tensors
@ -287,6 +287,7 @@ class LLaMa:
if quantize:
weights = AbsmaxQuantizedLinear.quantize(weights)
for _,v in weights.items(): v.realize()
load_state_dict(model, weights, strict=False)
return LLaMa(model, sp_model)