Compare commits

...

5 commits

Author SHA1 Message Date
Ahmed Harmouche
d4bce5be8c can't allocate this at once, need parts back 2024-12-13 13:51:45 +01:00
Ahmed Harmouche
4396ac96b5 No NamedTuple in compile.py 2024-12-13 12:38:59 +01:00
Ahmed Harmouche
04041b1342 More WIP 2024-12-13 12:23:02 +01:00
Ahmed Harmouche
3d4bbc3a35 WIP 2024-12-13 11:54:14 +01:00
Ahmed Harmouche
52f4547303 WIP 2024-12-13 11:54:14 +01:00
4 changed files with 64 additions and 304 deletions

3
.gitignore vendored
View file

@ -35,8 +35,9 @@ extra/datasets/COCO/
extra/datasets/audio* extra/datasets/audio*
extra/weights extra/weights
venv venv
examples/**/net.*[js,json] examples/**/**/net*.*[js,json]
examples/**/*.safetensors examples/**/*.safetensors
examples/webgpu/stable_diffusion/*.*[js,mjs]
node_modules node_modules
package.json package.json
package-lock.json package-lock.json

View file

@ -1,74 +1,28 @@
import os import os
from extra.export_model import compile_net, jit_model, dtype_to_js_type from extra.export_model import export_model
from extra.f16_decompress import u32_to_f16 from extra.f16_decompress import u32_to_f16
from examples.stable_diffusion import StableDiffusion from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict from tinygrad.nn.state import safe_save, torch_load, load_state_dict
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad import Device, dtypes from tinygrad import Device, dtypes
from tinygrad.helpers import fetch from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
import requests import requests
import argparse
import numpy as np import numpy as np
from pathlib import Path
def convert_f32_to_f16(input_file, output_file): def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f: with open(input_file, 'rb') as f:
metadata_length_bytes = f.read(8) metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False) metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_json_bytes = f.read(metadata_length) metadata_json_bytes = f.read(metadata_length)
float32_values = np.fromfile(f, dtype=np.float32) values = np.fromfile(f, dtype=np.float32)
first_text_model_offset = 3772703308 f16_values = values.astype(np.float16)
num_elements = int((first_text_model_offset)/4)
front_float16_values = float32_values[:num_elements].astype(np.float16)
rest_float32_values = float32_values[num_elements:]
with open(output_file, 'wb') as f: with open(output_file, 'wb') as f:
f.write(metadata_length_bytes) f.write(metadata_length_bytes)
f.write(metadata_json_bytes) f.write(metadata_json_bytes)
front_float16_values.tofile(f) f16_values.tofile(f)
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, data_start, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
for k in metadata:
# safetensor is in fp16, except for text moel
if (metadata[k]["data_offsets"][0] < text_model_offset):
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
last_offset = 0
part_end_offsets = []
for k in metadata:
offset = metadata[k]['data_offsets'][0]
if offset == text_model_offset:
break
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(data_start+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+data_start)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+data_start:])
return part_end_offsets
def fetch_dep(file, url): def fetch_dep(file, url):
with open(file, "w", encoding="utf-8") as f: with open(file, "w", encoding="utf-8") as f:
@ -77,162 +31,29 @@ def fetch_dep(file, url):
if __name__ == "__main__": if __name__ == "__main__":
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js") fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs") fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
args = parser.parse_args()
Device.DEFAULT = "WEBGPU"
Device.DEFAULT = "WEBGPU"
Tensor.no_grad = True Tensor.no_grad = True
model = StableDiffusion() model = StableDiffusion()
# load in weights
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False) load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
class Step(NamedTuple): model_parts = [
name: str = "" ("textModel", [Tensor.randn(1, 77)], model.cond_stage_model.transformer.text_model),
input: List[Tensor] = [] ("diffusor", [
forward: Any = None Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64),
Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)
sub_steps = [ ], model),
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model), ("decoder", [Tensor.randn(1,4,64,64)], model.decode),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model), ("f16tof32", [Tensor.randn(2097120, dtype=dtypes.uint32)], u32_to_f16)
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)
] ]
prg = "" for model in model_parts:
prg, inp_sizes, out_sizes, state = export_model(model[2], Device.DEFAULT.lower(), *model[1], model_name=model[0])
dirname = Path(__file__).parent
weight_loc = (dirname / f"net_{model[0]}.safetensors").as_posix()
safe_save(state, weight_loc)
if model[0] == "diffusor":
convert_f32_to_f16(weight_loc, (dirname / f"net_diffusor_f16.safetensors").as_posix())
def fixup_code(code, key): with open(dirname / f"net_{model[0]}.js", "w") as text_file:
code = code.replace(key, 'main')\ text_file.write(prg)
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
.replace("@group(0) @binding(0)", "")\
.replace("INFINITY", "inf(1.0)")
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
return code
def compile_step(model, step: Step):
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
{exported_bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
const kernels = [{kernel_names}];
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
const commandEncoder = device.createCommandEncoder();
{input_writer}
{kernel_calls}
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
}}
}}
"""
for step in sub_steps:
print(f'Executing step={step.name}')
prg += compile_step(model, step)
if step.name == "diffusor":
if args.remoteweights:
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
base_url = "."
prekernel = f"""
window.MODEL_BASE_URL= "{base_url}";
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
let selectedPart = 0;
let counter = 0;
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
let correctedOffsets = tensorMetadata.data_offsets;
let prev_offset = 0;
for (let start of partStartOffsets) {{
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
if (tensorMetadata.data_offsets[0] < start) {{
selectedPart = counter;
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
break;
}}
counter++;
}}
return safetensorParts[selectedPart].subarray(...correctedOffsets);
}}
const getWeight = (safetensors, key) => {{
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
}}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};"""
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
text_file.write(prekernel + prg)

View file

@ -6,7 +6,6 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>tinygrad has WebGPU</title> <title>tinygrad has WebGPU</title>
<style> <style>
/* General Reset */
* { * {
margin: 0; margin: 0;
padding: 0; padding: 0;
@ -163,9 +162,17 @@
<script type="module"> <script type="module">
import ClipTokenizer from './clip_tokenizer.js'; import ClipTokenizer from './clip_tokenizer.js';
import textModel from './net_textModel.js';
import diffusor from './net_diffusor.js';
import decoder from './net_decoder.js';
import f16tof32 from './net_f16tof32.js';
window.clipTokenizer = new ClipTokenizer(); window.clipTokenizer = new ClipTokenizer();
window.textModel = textModel;
window.diffusor = diffusor;
window.decoder = decoder;
window.f16tof32 = f16tof32;
</script> </script>
<script src="./net.js"></script>
</head> </head>
<body> <body>
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1> <h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
@ -356,123 +363,51 @@
return res.arrayBuffer(); return res.arrayBuffer();
}; };
const getAndDecompressF16Safetensors = async (device, progress) => { const decompressf16Safetensor = async (device, progress, f16safeTensor) => {
let totalLoaded = 0; let totalLoaded = 0;
let totalSize = 0; let totalSize = 0;
let partSize = {}; let partSize = {};
const getPart = async(key) => {
let part = await readTensorFromDb(db, key);
if (part) {
console.log(`Cache hit: ${key}`);
return Promise.resolve(part.content);
} else {
console.log(`Cache miss: ${key}`);
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
}
}
const progressCallback = (part, loaded, total) => {
totalLoaded += loaded;
if (!partSize[part]) {
totalSize += total;
partSize[part] = true;
}
progress(totalLoaded, totalSize);
};
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
// Combine everything except for text model, since that's already f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
if (index < 4) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
});
let textModelU8 = new Uint8Array(buffers[4]);
document.getElementById("modelDlTitle").innerHTML = "Decompressing model"; document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
const textModelOffset = 3772703308; const metadataLength = Number(new DataView(f16safeTensor.buffer).getBigUint64(0, true));
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true)); const metadata = JSON.parse(new TextDecoder("utf8").decode(f16safeTensor.subarray(8, 8 + metadataLength)));
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength); const allToDecomp = f16safeTensor.byteLength - (8 + metadataLength);
const decodeChunkSize = 8388480; const decodeChunkSize = 8388480;
const numChunks = Math.ceil(allToDecomp/decodeChunkSize); const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
console.log(allToDecomp + " bytes to decompress"); console.log(allToDecomp + " bytes to decompress");
console.log("Will be decompressed in " + numChunks+ " chunks"); console.log("Will be decompressed in " + numChunks+ " chunks");
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}]; f32safeTensor = new Uint8Array(allToDecomp*2);
let parts = []; f32safeTensor.set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
f32safeTensor.set(f16safeTensor.subarray(8, 8 + metadataLength), 8);
for (let offsets of partOffsets) {
parts.push(new Uint8Array(offsets.end-offsets.start));
}
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
let start = Date.now(); let start = Date.now();
let cursor = 0;
for (let i = 0; i < numChunks; i++) { for (let i = 0; i < numChunks; i++) {
progress(i, numChunks); progress(i, numChunks);
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i); let metaOffset = 8 + metadataLength;
let chunkStartF16 = metaOffset + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize; let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16); let chunk = f16safeTensor.subarray(chunkStartF16, chunkEndF16);
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4); let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
let result = await f16decomp(uint32Chunk); let result = await f16decomp(uint32Chunk);
let resultUint8 = new Uint8Array(result.buffer); let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2); let f32offset = metaOffset + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength; f32safeTensor.set(resultUint8, f32offset);
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
parts[cursor].set(resultUint8, offsetInPart);
} else {
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
cursor++;
if (cursor < parts.length) {
let nextPartOffset = spaceLeftInCurrentPart;
let nextPartLength = resultUint8.length - nextPartOffset;
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
}
}
resultUint8 = null; resultUint8 = null;
result = null; result = null;
} }
combinedBuffer = null; f16safeTensor = null;
let end = Date.now(); let end = Date.now();
console.log("Decoding took: " + ((end - start) / 1000) + " s"); console.log("Decoding took: " + ((end - start) / 1000) + " s");
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk"); console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
return parts; return f32safeTensor;
}; };
const loadNet = async () => { const loadNet = async () => {
@ -484,18 +419,15 @@
} }
const device = await getDevice(); const device = await getDevice();
f16decomp = await f16tof32().setup(device, safetensorParts),
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
modelDlTitle.innerHTML = "Compiling model" modelDlTitle.innerHTML = "Compiling model"
let models = ["textModel", "diffusor", "decoder"]; let netText = await textModel.load(device, "./net_textModel.safetensors");
let netDiffusor = await diffusor.load(device, "./net_diffusor_f16.safetensors");
let netDecoder = await decoder.load(device, "./net_decoder.safetensors");
let funcF16Decomp = await f16tof32.load(device);
nets = await timer(() => Promise.all([ decompressf16Safetensor(device, progress, diffusor.getWeights());
textModel().setup(device, safetensorParts),
diffusor().setup(device, safetensorParts),
decoder().setup(device, safetensorParts)
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
progress(1, 1); progress(1, 1);

View file

@ -94,6 +94,7 @@ def export_model_webgpu(functions, statements, bufs, weight_names, input_names,
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f""" return f"""
const {exported_name} = (() => {{ const {exported_name} = (() => {{
let weights = null;
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{ const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets); return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}}; }};
@ -146,7 +147,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor
{kernel_code} {kernel_code}
const setupNet = async (device, safetensor) => {{ const setupNet = async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor); weights = safetensor;
const metadata = safetensor ? getTensorMetadata(safetensor) : null;
const infinityBuf = createInfinityUniformBuf(device); const infinityBuf = createInfinityUniformBuf(device);
{layouts} {layouts}
@ -184,8 +186,12 @@ const setupNet = async (device, safetensor) => {{
return {output_return}; return {output_return};
}} }}
}} }}
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }} const load = async (device, weight_path) =>
return {{ load }}; {{
const buffer = weight_path ? await fetch(weight_path).then(x => x.arrayBuffer()) : null;
return setupNet(device, buffer ? new Uint8Array(buffer) : null);
}}
return {{ load, getWeights: () => weights }};
}})(); }})();
export default {exported_name}; export default {exported_name};
""" """