This commit is contained in:
Ahmed Harmouche 2024-12-13 12:23:02 +01:00
commit 04041b1342
3 changed files with 30 additions and 96 deletions

3
.gitignore vendored
View file

@ -35,8 +35,9 @@ extra/datasets/COCO/
extra/datasets/audio*
extra/weights
venv
examples/**/net.*[js,json]
examples/**/**/net*.*[js,json]
examples/**/*.safetensors
examples/webgpu/stable_diffusion/*.*[js,mjs]
node_modules
package.json
package-lock.json

View file

@ -58,12 +58,12 @@ if __name__ == "__main__":
]
for model in model_parts:
prg, inp_sizes, out_sizes, state = export_model(model, Device.DEFAULT.lower(), *model.input)
prg, inp_sizes, out_sizes, state = export_model(model, Device.DEFAULT.lower(), *model.input, model_name=model.name)
dirname = Path(__file__).parent
weight_loc = (dirname / f"net_{model.name}.safetensors").as_posix()
safe_save(state, weight_loc)
if model.name == "diffusor":
convert_f32_to_f16(weight_loc, (dirname / f"net_diffuso_f16.safetensors").as_posix())
convert_f32_to_f16(weight_loc, (dirname / f"net_diffusor_f16.safetensors").as_posix())
with open(dirname / f"net_{model.name}.js", "w") as text_file:
text_file.write(prg)
text_file.write(prg)

View file

@ -6,7 +6,6 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>tinygrad has WebGPU</title>
<style>
/* General Reset */
* {
margin: 0;
padding: 0;
@ -163,12 +162,17 @@
<script type="module">
import ClipTokenizer from './clip_tokenizer.js';
import textModel from './net_textModel.js';
import diffusor from './net_diffusor.js';
import decoder from './net_decoder.js';
import f16tof32 from './net_f16tof32.js';
window.clipTokenizer = new ClipTokenizer();
window.textModel = textModel;
window.diffusor = diffusor;
window.decoder = decoder;
window.f16tof32 = f16tof32;
</script>
<script src="./net_text.js"></script>
<script src="./net_diffusor.js"></script>
<script src="./net_f16tof32.js"></script>
<script src="./net_decoder.js"></script>
</head>
<body>
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
@ -359,123 +363,52 @@
return res.arrayBuffer();
};
const getAndDecompressF16Safetensors = async (device, progress) => {
const getAndDecompressF16Safetensors = async (device, progress, f16safeTensor) => {
let totalLoaded = 0;
let totalSize = 0;
let partSize = {};
const getPart = async(key) => {
let part = await readTensorFromDb(db, key);
if (part) {
console.log(`Cache hit: ${key}`);
return Promise.resolve(part.content);
} else {
console.log(`Cache miss: ${key}`);
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
}
}
const progressCallback = (part, loaded, total) => {
totalLoaded += loaded;
if (!partSize[part]) {
totalSize += total;
partSize[part] = true;
}
progress(totalLoaded, totalSize);
};
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
// Combine everything except for text model, since that's already f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
if (index < 4) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
});
let textModelU8 = new Uint8Array(buffers[4]);
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
const textModelOffset = 3772703308;
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
const metadataLength = Number(new DataView(f16safeTensor.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(f16safeTensor.subarray(8, 8 + metadataLength)));
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
const allToDecomp = f16safeTensor.byteLength - (8 + metadataLength);
const decodeChunkSize = 8388480;
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
console.log(allToDecomp + " bytes to decompress");
console.log("Will be decompressed in " + numChunks+ " chunks");
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
let parts = [];
for (let offsets of partOffsets) {
parts.push(new Uint8Array(offsets.end-offsets.start));
}
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
f32safeTensor = new Uint8Array(allToDecomp*2);
f32safeTensor.set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
f32safeTensor.set(f16safeTensor.subarray(8, 8 + metadataLength), 8);
let start = Date.now();
let cursor = 0;
for (let i = 0; i < numChunks; i++) {
progress(i, numChunks);
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
let metaOffset = 8 + metadataLength;
let chunkStartF16 = metaOffset + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
let chunk = f16safeTensor.subarray(chunkStartF16, chunkEndF16);
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
let result = await f16decomp(uint32Chunk);
let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
parts[cursor].set(resultUint8, offsetInPart);
} else {
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
cursor++;
if (cursor < parts.length) {
let nextPartOffset = spaceLeftInCurrentPart;
let nextPartLength = resultUint8.length - nextPartOffset;
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
}
}
let f32offset = metaOffset + (decodeChunkSize * i * 2);
f32safeTensor.set(resultUint8, f32offset);
resultUint8 = null;
result = null;
}
combinedBuffer = null;
f16safeTensor = null;
let end = Date.now();
console.log("Decoding took: " + ((end - start) / 1000) + " s");
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
return parts;
return f32safeTensor;
};
const loadNet = async () => {
@ -495,7 +428,7 @@
let models = ["textModel", "diffusor", "decoder"];
nets = await timer(() => Promise.all([
textModel().setup(device, safetensorParts),
textModel.load(device, safetensorParts),
diffusor().setup(device, safetensorParts),
decoder().setup(device, safetensorParts)
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")