tinygrad/examples/webgpu/stable_diffusion/compile.py
2024-12-13 12:38:59 +01:00

59 lines
2.6 KiB
Python

import os
from extra.export_model import export_model
from extra.f16_decompress import u32_to_f16
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import safe_save, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad import Device, dtypes
from tinygrad.helpers import fetch
import requests
import numpy as np
from pathlib import Path
def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f:
metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_json_bytes = f.read(metadata_length)
values = np.fromfile(f, dtype=np.float32)
f16_values = values.astype(np.float16)
with open(output_file, 'wb') as f:
f.write(metadata_length_bytes)
f.write(metadata_json_bytes)
f16_values.tofile(f)
def fetch_dep(file, url):
with open(file, "w", encoding="utf-8") as f:
f.write(requests.get(url).text.replace("https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs", "./bpe_simple_vocab_16e6.mjs"))
if __name__ == "__main__":
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
Device.DEFAULT = "WEBGPU"
Tensor.no_grad = True
model = StableDiffusion()
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
model_parts = [
("textModel", [Tensor.randn(1, 77)], model.cond_stage_model.transformer.text_model),
("diffusor", [
Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64),
Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)
], model),
("decoder", [Tensor.randn(1,4,64,64)], model.decode),
("f16tof32", [Tensor.randn(2097120, dtype=dtypes.uint32)], u32_to_f16)
]
for model in model_parts:
prg, inp_sizes, out_sizes, state = export_model(model[2], Device.DEFAULT.lower(), *model[1], model_name=model[0])
dirname = Path(__file__).parent
weight_loc = (dirname / f"net_{model[0]}.safetensors").as_posix()
safe_save(state, weight_loc)
if model[0] == "diffusor":
convert_f32_to_f16(weight_loc, (dirname / f"net_diffusor_f16.safetensors").as_posix())
with open(dirname / f"net_{model[0]}.js", "w") as text_file:
text_file.write(prg)