mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
More WIP
This commit is contained in:
parent
3d4bbc3a35
commit
04041b1342
3 changed files with 30 additions and 96 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -35,8 +35,9 @@ extra/datasets/COCO/
|
|||
extra/datasets/audio*
|
||||
extra/weights
|
||||
venv
|
||||
examples/**/net.*[js,json]
|
||||
examples/**/**/net*.*[js,json]
|
||||
examples/**/*.safetensors
|
||||
examples/webgpu/stable_diffusion/*.*[js,mjs]
|
||||
node_modules
|
||||
package.json
|
||||
package-lock.json
|
||||
|
|
|
|||
|
|
@ -58,12 +58,12 @@ if __name__ == "__main__":
|
|||
]
|
||||
|
||||
for model in model_parts:
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, Device.DEFAULT.lower(), *model.input)
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, Device.DEFAULT.lower(), *model.input, model_name=model.name)
|
||||
dirname = Path(__file__).parent
|
||||
weight_loc = (dirname / f"net_{model.name}.safetensors").as_posix()
|
||||
safe_save(state, weight_loc)
|
||||
if model.name == "diffusor":
|
||||
convert_f32_to_f16(weight_loc, (dirname / f"net_diffuso_f16.safetensors").as_posix())
|
||||
convert_f32_to_f16(weight_loc, (dirname / f"net_diffusor_f16.safetensors").as_posix())
|
||||
|
||||
with open(dirname / f"net_{model.name}.js", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
text_file.write(prg)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<style>
|
||||
/* General Reset */
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
|
|
@ -163,12 +162,17 @@
|
|||
|
||||
<script type="module">
|
||||
import ClipTokenizer from './clip_tokenizer.js';
|
||||
import textModel from './net_textModel.js';
|
||||
import diffusor from './net_diffusor.js';
|
||||
import decoder from './net_decoder.js';
|
||||
import f16tof32 from './net_f16tof32.js';
|
||||
|
||||
window.clipTokenizer = new ClipTokenizer();
|
||||
window.textModel = textModel;
|
||||
window.diffusor = diffusor;
|
||||
window.decoder = decoder;
|
||||
window.f16tof32 = f16tof32;
|
||||
</script>
|
||||
<script src="./net_text.js"></script>
|
||||
<script src="./net_diffusor.js"></script>
|
||||
<script src="./net_f16tof32.js"></script>
|
||||
<script src="./net_decoder.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
|
||||
|
|
@ -359,123 +363,52 @@
|
|||
return res.arrayBuffer();
|
||||
};
|
||||
|
||||
const getAndDecompressF16Safetensors = async (device, progress) => {
|
||||
const getAndDecompressF16Safetensors = async (device, progress, f16safeTensor) => {
|
||||
let totalLoaded = 0;
|
||||
let totalSize = 0;
|
||||
let partSize = {};
|
||||
|
||||
const getPart = async(key) => {
|
||||
let part = await readTensorFromDb(db, key);
|
||||
|
||||
if (part) {
|
||||
console.log(`Cache hit: ${key}`);
|
||||
return Promise.resolve(part.content);
|
||||
} else {
|
||||
console.log(`Cache miss: ${key}`);
|
||||
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
|
||||
}
|
||||
}
|
||||
|
||||
const progressCallback = (part, loaded, total) => {
|
||||
totalLoaded += loaded;
|
||||
|
||||
if (!partSize[part]) {
|
||||
totalSize += total;
|
||||
partSize[part] = true;
|
||||
}
|
||||
|
||||
progress(totalLoaded, totalSize);
|
||||
};
|
||||
|
||||
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
|
||||
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
|
||||
|
||||
// Combine everything except for text model, since that's already f32
|
||||
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
||||
if (index < 4) {
|
||||
return acc + buffer.byteLength;
|
||||
} else {
|
||||
return acc;
|
||||
}
|
||||
}, 0
|
||||
);
|
||||
|
||||
combinedBuffer = new Uint8Array(totalLength);
|
||||
let offset = 0;
|
||||
buffers.forEach((buffer, index) => {
|
||||
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
|
||||
if (index < 4) {
|
||||
combinedBuffer.set(new Uint8Array(buffer), offset);
|
||||
offset += buffer.byteLength;
|
||||
buffer = null;
|
||||
}
|
||||
});
|
||||
|
||||
let textModelU8 = new Uint8Array(buffers[4]);
|
||||
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
||||
|
||||
const textModelOffset = 3772703308;
|
||||
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
||||
const metadataLength = Number(new DataView(f16safeTensor.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(f16safeTensor.subarray(8, 8 + metadataLength)));
|
||||
|
||||
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
||||
const allToDecomp = f16safeTensor.byteLength - (8 + metadataLength);
|
||||
const decodeChunkSize = 8388480;
|
||||
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
||||
|
||||
console.log(allToDecomp + " bytes to decompress");
|
||||
console.log("Will be decompressed in " + numChunks+ " chunks");
|
||||
|
||||
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
|
||||
let parts = [];
|
||||
|
||||
for (let offsets of partOffsets) {
|
||||
parts.push(new Uint8Array(offsets.end-offsets.start));
|
||||
}
|
||||
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
||||
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
|
||||
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
|
||||
|
||||
f32safeTensor = new Uint8Array(allToDecomp*2);
|
||||
f32safeTensor.set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
||||
f32safeTensor.set(f16safeTensor.subarray(8, 8 + metadataLength), 8);
|
||||
|
||||
let start = Date.now();
|
||||
let cursor = 0;
|
||||
|
||||
for (let i = 0; i < numChunks; i++) {
|
||||
progress(i, numChunks);
|
||||
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
||||
let metaOffset = 8 + metadataLength;
|
||||
let chunkStartF16 = metaOffset + (decodeChunkSize * i);
|
||||
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
||||
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
||||
let chunk = f16safeTensor.subarray(chunkStartF16, chunkEndF16);
|
||||
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
|
||||
let result = await f16decomp(uint32Chunk);
|
||||
let resultUint8 = new Uint8Array(result.buffer);
|
||||
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
||||
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
||||
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
|
||||
|
||||
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
|
||||
parts[cursor].set(resultUint8, offsetInPart);
|
||||
} else {
|
||||
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
|
||||
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
|
||||
|
||||
cursor++;
|
||||
|
||||
if (cursor < parts.length) {
|
||||
let nextPartOffset = spaceLeftInCurrentPart;
|
||||
let nextPartLength = resultUint8.length - nextPartOffset;
|
||||
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
|
||||
}
|
||||
}
|
||||
|
||||
let f32offset = metaOffset + (decodeChunkSize * i * 2);
|
||||
f32safeTensor.set(resultUint8, f32offset);
|
||||
resultUint8 = null;
|
||||
result = null;
|
||||
}
|
||||
|
||||
combinedBuffer = null;
|
||||
f16safeTensor = null;
|
||||
|
||||
let end = Date.now();
|
||||
console.log("Decoding took: " + ((end - start) / 1000) + " s");
|
||||
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
|
||||
|
||||
return parts;
|
||||
return f32safeTensor;
|
||||
};
|
||||
|
||||
const loadNet = async () => {
|
||||
|
|
@ -495,7 +428,7 @@
|
|||
let models = ["textModel", "diffusor", "decoder"];
|
||||
|
||||
nets = await timer(() => Promise.all([
|
||||
textModel().setup(device, safetensorParts),
|
||||
textModel.load(device, safetensorParts),
|
||||
diffusor().setup(device, safetensorParts),
|
||||
decoder().setup(device, safetensorParts)
|
||||
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue