mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
simpler-sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4bce5be8c | ||
|
|
4396ac96b5 | ||
|
|
04041b1342 | ||
|
|
3d4bbc3a35 | ||
|
|
52f4547303 |
4 changed files with 64 additions and 304 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -35,8 +35,9 @@ extra/datasets/COCO/
|
|||
extra/datasets/audio*
|
||||
extra/weights
|
||||
venv
|
||||
examples/**/net.*[js,json]
|
||||
examples/**/**/net*.*[js,json]
|
||||
examples/**/*.safetensors
|
||||
examples/webgpu/stable_diffusion/*.*[js,mjs]
|
||||
node_modules
|
||||
package.json
|
||||
package-lock.json
|
||||
|
|
|
|||
|
|
@ -1,74 +1,28 @@
|
|||
import os
|
||||
from extra.export_model import compile_net, jit_model, dtype_to_js_type
|
||||
from extra.export_model import export_model
|
||||
from extra.f16_decompress import u32_to_f16
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_save, torch_load, load_state_dict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from typing import NamedTuple, Any, List
|
||||
import requests
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
def convert_f32_to_f16(input_file, output_file):
|
||||
with open(input_file, 'rb') as f:
|
||||
metadata_length_bytes = f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
|
||||
metadata_json_bytes = f.read(metadata_length)
|
||||
float32_values = np.fromfile(f, dtype=np.float32)
|
||||
values = np.fromfile(f, dtype=np.float32)
|
||||
|
||||
first_text_model_offset = 3772703308
|
||||
num_elements = int((first_text_model_offset)/4)
|
||||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||
rest_float32_values = float32_values[num_elements:]
|
||||
f16_values = values.astype(np.float16)
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
f.write(metadata_length_bytes)
|
||||
f.write(metadata_json_bytes)
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, data_start, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
for k in metadata:
|
||||
# safetensor is in fp16, except for text moel
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
|
||||
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
|
||||
for k in metadata:
|
||||
offset = metadata[k]['data_offsets'][0]
|
||||
|
||||
if offset == text_model_offset:
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(data_start+offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+data_start)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+data_start:])
|
||||
|
||||
return part_end_offsets
|
||||
f16_values.tofile(f)
|
||||
|
||||
def fetch_dep(file, url):
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
|
|
@ -77,162 +31,29 @@ def fetch_dep(file, url):
|
|||
if __name__ == "__main__":
|
||||
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
|
||||
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
input: List[Tensor] = []
|
||||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
|
||||
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)
|
||||
model_parts = [
|
||||
("textModel", [Tensor.randn(1, 77)], model.cond_stage_model.transformer.text_model),
|
||||
("diffusor", [
|
||||
Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64),
|
||||
Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)
|
||||
], model),
|
||||
("decoder", [Tensor.randn(1,4,64,64)], model.decode),
|
||||
("f16tof32", [Tensor.randn(2097120, dtype=dtypes.uint32)], u32_to_f16)
|
||||
]
|
||||
|
||||
prg = ""
|
||||
for model in model_parts:
|
||||
prg, inp_sizes, out_sizes, state = export_model(model[2], Device.DEFAULT.lower(), *model[1], model_name=model[0])
|
||||
dirname = Path(__file__).parent
|
||||
weight_loc = (dirname / f"net_{model[0]}.safetensors").as_posix()
|
||||
safe_save(state, weight_loc)
|
||||
if model[0] == "diffusor":
|
||||
convert_f32_to_f16(weight_loc, (dirname / f"net_diffusor_f16.safetensors").as_posix())
|
||||
|
||||
def fixup_code(code, key):
|
||||
code = code.replace(key, 'main')\
|
||||
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
|
||||
.replace("@group(0) @binding(0)", "")\
|
||||
.replace("INFINITY", "inf(1.0)")
|
||||
|
||||
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
|
||||
return code
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
|
||||
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
{kernel_code}
|
||||
|
||||
return {{
|
||||
"setup": async (device, safetensor) => {{
|
||||
const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
|
||||
|
||||
{exported_bufs}
|
||||
|
||||
{gpu_write_bufs}
|
||||
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
|
||||
{input_writer}
|
||||
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
|
||||
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
for step in sub_steps:
|
||||
print(f'Executing step={step.name}')
|
||||
prg += compile_step(model, step)
|
||||
|
||||
if step.name == "diffusor":
|
||||
if args.remoteweights:
|
||||
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
|
||||
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
|
||||
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
|
||||
base_url = "."
|
||||
|
||||
prekernel = f"""
|
||||
window.MODEL_BASE_URL= "{base_url}";
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
|
||||
let selectedPart = 0;
|
||||
let counter = 0;
|
||||
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
|
||||
let correctedOffsets = tensorMetadata.data_offsets;
|
||||
let prev_offset = 0;
|
||||
|
||||
for (let start of partStartOffsets) {{
|
||||
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
|
||||
|
||||
if (tensorMetadata.data_offsets[0] < start) {{
|
||||
selectedPart = counter;
|
||||
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
|
||||
break;
|
||||
}}
|
||||
|
||||
counter++;
|
||||
}}
|
||||
|
||||
return safetensorParts[selectedPart].subarray(...correctedOffsets);
|
||||
}}
|
||||
|
||||
const getWeight = (safetensors, key) => {{
|
||||
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
|
||||
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};"""
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prekernel + prg)
|
||||
with open(dirname / f"net_{model[0]}.js", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<style>
|
||||
/* General Reset */
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
|
|
@ -163,9 +162,17 @@
|
|||
|
||||
<script type="module">
|
||||
import ClipTokenizer from './clip_tokenizer.js';
|
||||
import textModel from './net_textModel.js';
|
||||
import diffusor from './net_diffusor.js';
|
||||
import decoder from './net_decoder.js';
|
||||
import f16tof32 from './net_f16tof32.js';
|
||||
|
||||
window.clipTokenizer = new ClipTokenizer();
|
||||
window.textModel = textModel;
|
||||
window.diffusor = diffusor;
|
||||
window.decoder = decoder;
|
||||
window.f16tof32 = f16tof32;
|
||||
</script>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
|
||||
|
|
@ -356,123 +363,51 @@
|
|||
return res.arrayBuffer();
|
||||
};
|
||||
|
||||
const getAndDecompressF16Safetensors = async (device, progress) => {
|
||||
const decompressf16Safetensor = async (device, progress, f16safeTensor) => {
|
||||
let totalLoaded = 0;
|
||||
let totalSize = 0;
|
||||
let partSize = {};
|
||||
|
||||
const getPart = async(key) => {
|
||||
let part = await readTensorFromDb(db, key);
|
||||
|
||||
if (part) {
|
||||
console.log(`Cache hit: ${key}`);
|
||||
return Promise.resolve(part.content);
|
||||
} else {
|
||||
console.log(`Cache miss: ${key}`);
|
||||
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
|
||||
}
|
||||
}
|
||||
|
||||
const progressCallback = (part, loaded, total) => {
|
||||
totalLoaded += loaded;
|
||||
|
||||
if (!partSize[part]) {
|
||||
totalSize += total;
|
||||
partSize[part] = true;
|
||||
}
|
||||
|
||||
progress(totalLoaded, totalSize);
|
||||
};
|
||||
|
||||
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
|
||||
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
|
||||
|
||||
// Combine everything except for text model, since that's already f32
|
||||
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
||||
if (index < 4) {
|
||||
return acc + buffer.byteLength;
|
||||
} else {
|
||||
return acc;
|
||||
}
|
||||
}, 0
|
||||
);
|
||||
|
||||
combinedBuffer = new Uint8Array(totalLength);
|
||||
let offset = 0;
|
||||
buffers.forEach((buffer, index) => {
|
||||
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
|
||||
if (index < 4) {
|
||||
combinedBuffer.set(new Uint8Array(buffer), offset);
|
||||
offset += buffer.byteLength;
|
||||
buffer = null;
|
||||
}
|
||||
});
|
||||
|
||||
let textModelU8 = new Uint8Array(buffers[4]);
|
||||
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
||||
|
||||
const textModelOffset = 3772703308;
|
||||
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
||||
const metadataLength = Number(new DataView(f16safeTensor.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(f16safeTensor.subarray(8, 8 + metadataLength)));
|
||||
|
||||
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
||||
const allToDecomp = f16safeTensor.byteLength - (8 + metadataLength);
|
||||
const decodeChunkSize = 8388480;
|
||||
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
||||
|
||||
console.log(allToDecomp + " bytes to decompress");
|
||||
console.log("Will be decompressed in " + numChunks+ " chunks");
|
||||
|
||||
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
|
||||
let parts = [];
|
||||
|
||||
for (let offsets of partOffsets) {
|
||||
parts.push(new Uint8Array(offsets.end-offsets.start));
|
||||
}
|
||||
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
||||
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
|
||||
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
|
||||
|
||||
f32safeTensor = new Uint8Array(allToDecomp*2);
|
||||
f32safeTensor.set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
||||
f32safeTensor.set(f16safeTensor.subarray(8, 8 + metadataLength), 8);
|
||||
|
||||
let start = Date.now();
|
||||
let cursor = 0;
|
||||
|
||||
for (let i = 0; i < numChunks; i++) {
|
||||
progress(i, numChunks);
|
||||
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
||||
let metaOffset = 8 + metadataLength;
|
||||
let chunkStartF16 = metaOffset + (decodeChunkSize * i);
|
||||
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
||||
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
||||
let chunk = f16safeTensor.subarray(chunkStartF16, chunkEndF16);
|
||||
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
|
||||
let result = await f16decomp(uint32Chunk);
|
||||
let resultUint8 = new Uint8Array(result.buffer);
|
||||
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
||||
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
||||
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
|
||||
|
||||
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
|
||||
parts[cursor].set(resultUint8, offsetInPart);
|
||||
} else {
|
||||
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
|
||||
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
|
||||
|
||||
cursor++;
|
||||
|
||||
if (cursor < parts.length) {
|
||||
let nextPartOffset = spaceLeftInCurrentPart;
|
||||
let nextPartLength = resultUint8.length - nextPartOffset;
|
||||
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
|
||||
}
|
||||
}
|
||||
|
||||
let f32offset = metaOffset + (decodeChunkSize * i * 2);
|
||||
f32safeTensor.set(resultUint8, f32offset);
|
||||
resultUint8 = null;
|
||||
result = null;
|
||||
}
|
||||
|
||||
combinedBuffer = null;
|
||||
f16safeTensor = null;
|
||||
|
||||
let end = Date.now();
|
||||
console.log("Decoding took: " + ((end - start) / 1000) + " s");
|
||||
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
|
||||
|
||||
return parts;
|
||||
return f32safeTensor;
|
||||
};
|
||||
|
||||
const loadNet = async () => {
|
||||
|
|
@ -484,18 +419,15 @@
|
|||
}
|
||||
|
||||
const device = await getDevice();
|
||||
f16decomp = await f16tof32().setup(device, safetensorParts),
|
||||
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
|
||||
|
||||
modelDlTitle.innerHTML = "Compiling model"
|
||||
|
||||
let models = ["textModel", "diffusor", "decoder"];
|
||||
let netText = await textModel.load(device, "./net_textModel.safetensors");
|
||||
let netDiffusor = await diffusor.load(device, "./net_diffusor_f16.safetensors");
|
||||
let netDecoder = await decoder.load(device, "./net_decoder.safetensors");
|
||||
let funcF16Decomp = await f16tof32.load(device);
|
||||
|
||||
nets = await timer(() => Promise.all([
|
||||
textModel().setup(device, safetensorParts),
|
||||
diffusor().setup(device, safetensorParts),
|
||||
decoder().setup(device, safetensorParts)
|
||||
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
|
||||
decompressf16Safetensor(device, progress, diffusor.getWeights());
|
||||
|
||||
progress(1, 1);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,6 +94,7 @@ def export_model_webgpu(functions, statements, bufs, weight_names, input_names,
|
|||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
const {exported_name} = (() => {{
|
||||
let weights = null;
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}};
|
||||
|
|
@ -146,7 +147,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor
|
|||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
weights = safetensor;
|
||||
const metadata = safetensor ? getTensorMetadata(safetensor) : null;
|
||||
const infinityBuf = createInfinityUniformBuf(device);
|
||||
|
||||
{layouts}
|
||||
|
|
@ -184,8 +186,12 @@ const setupNet = async (device, safetensor) => {{
|
|||
return {output_return};
|
||||
}}
|
||||
}}
|
||||
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
|
||||
return {{ load }};
|
||||
const load = async (device, weight_path) =>
|
||||
{{
|
||||
const buffer = weight_path ? await fetch(weight_path).then(x => x.arrayBuffer()) : null;
|
||||
return setupNet(device, buffer ? new Uint8Array(buffer) : null);
|
||||
}}
|
||||
return {{ load, getWeights: () => weights }};
|
||||
}})();
|
||||
export default {exported_name};
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue