mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
221 lines
7.3 KiB
Python
221 lines
7.3 KiB
Python
import os
|
|
|
|
from torch.hub import download_url_to_file
|
|
import torch.nn as nn
|
|
import torch
|
|
from diffusers import UNet2DConditionModel
|
|
from diffusers.configuration_utils import FrozenDict
|
|
|
|
|
|
def patch_transvae_sd(model, state_dict):
|
|
return {'model.' + k: v for k, v in state_dict.items()}
|
|
|
|
|
|
def module_dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
|
|
def module_device(self):
|
|
return next(self.parameters()).device
|
|
|
|
|
|
def conv_add_channels(new_c: int, conv: nn.Conv2d, prepend=False):
|
|
|
|
new_conv = nn.Conv2d(new_c + conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None)
|
|
|
|
sd = conv.state_dict()
|
|
ks = conv.kernel_size[0]
|
|
if prepend:
|
|
sd['weight'] = torch.cat([torch.zeros((conv.out_channels, new_c, ks, ks)), sd['weight']], dim=1)
|
|
else:
|
|
sd['weight'] = torch.cat([sd['weight'], torch.zeros((conv.out_channels, new_c, ks, ks))], dim=1)
|
|
|
|
new_conv.load_state_dict(sd, strict=True)
|
|
new_conv.to(device=module_device(conv), dtype=module_dtype(conv))
|
|
|
|
return new_conv
|
|
|
|
|
|
def update_net_config(net: UNet2DConditionModel, key: str, value):
|
|
new_config = dict(net.config)
|
|
new_config[key] = value
|
|
net._internal_dict = FrozenDict(new_config)
|
|
|
|
|
|
def patch_unet_convin(unet: UNet2DConditionModel, target_in_channels, prepend=False):
|
|
'''
|
|
add new channels to unet.conv_in, weights init to zeros
|
|
'''
|
|
|
|
new_added_conv_channels = target_in_channels - unet.config.in_channels
|
|
if new_added_conv_channels < 1:
|
|
return
|
|
new_conv = conv_add_channels(new_added_conv_channels, unet.conv_in, prepend=prepend)
|
|
del unet.conv_in
|
|
unet.conv_in = new_conv
|
|
update_net_config(unet, "in_channels", new_conv.in_channels)
|
|
|
|
|
|
|
|
def download_model(url, local_path):
|
|
if os.path.exists(local_path):
|
|
return local_path
|
|
|
|
temp_path = local_path + '.tmp'
|
|
download_url_to_file(url=url, dst=temp_path)
|
|
os.rename(temp_path, local_path)
|
|
return local_path
|
|
|
|
|
|
def load_frozen_patcher(filename, state_dict, strength):
|
|
patch_dict = {}
|
|
for k, w in state_dict.items():
|
|
model_key, patch_type, weight_index = k.split('::')
|
|
if model_key not in patch_dict:
|
|
patch_dict[model_key] = {}
|
|
if patch_type not in patch_dict[model_key]:
|
|
patch_dict[model_key][patch_type] = [None] * 16
|
|
patch_dict[model_key][patch_type][int(weight_index)] = w
|
|
|
|
patch_flat = {}
|
|
for model_key, v in patch_dict.items():
|
|
for patch_type, weight_list in v.items():
|
|
patch_flat[model_key] = (patch_type, weight_list)
|
|
|
|
add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
|
return
|
|
|
|
|
|
def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False):
|
|
lora_identifier = (filename, strength_patch, strength_model, online_mode)
|
|
this_patches = {}
|
|
|
|
p = set()
|
|
model_keys = set(k for k, _ in self.model.named_parameters())
|
|
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_keys:
|
|
p.add(k)
|
|
current_patches = this_patches.get(key, [])
|
|
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
|
this_patches[key] = current_patches
|
|
|
|
self.lora_patches[lora_identifier] = this_patches
|
|
return p
|
|
|
|
|
|
# class LoraLoader:
|
|
# def __init__(self, model):
|
|
# self.model = model
|
|
# self.backup = {}
|
|
# self.online_backup = []
|
|
# self.loaded_hash = str([])
|
|
|
|
# @torch.inference_mode()
|
|
# def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False):
|
|
# hashes = str(list(lora_patches.keys()))
|
|
|
|
# if hashes == self.loaded_hash and not force_refresh:
|
|
# return
|
|
|
|
# # Merge Patches
|
|
|
|
# all_patches = {}
|
|
|
|
# for (_, _, _, online_mode), patches in lora_patches.items():
|
|
# for key, current_patches in patches.items():
|
|
# all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
|
|
|
|
# # Initialize
|
|
|
|
# memory_management.signal_empty_cache = True
|
|
|
|
# parameter_devices = get_parameter_devices(self.model)
|
|
|
|
# # Restore
|
|
|
|
# for m in set(self.online_backup):
|
|
# del m.forge_online_loras
|
|
|
|
# self.online_backup = []
|
|
|
|
# for k, w in self.backup.items():
|
|
# if not isinstance(w, torch.nn.Parameter):
|
|
# # In very few cases
|
|
# w = torch.nn.Parameter(w, requires_grad=False)
|
|
|
|
# utils.set_attr_raw(self.model, k, w)
|
|
|
|
# self.backup = {}
|
|
|
|
# set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
|
|
# # Patch
|
|
|
|
# for (key, online_mode), current_patches in all_patches.items():
|
|
# try:
|
|
# parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
|
# assert isinstance(weight, torch.nn.Parameter)
|
|
# except:
|
|
# raise ValueError(f"Wrong LoRA Key: {key}")
|
|
|
|
# if online_mode:
|
|
# if not hasattr(parent_layer, 'forge_online_loras'):
|
|
# parent_layer.forge_online_loras = {}
|
|
|
|
# parent_layer.forge_online_loras[child_key] = current_patches
|
|
# self.online_backup.append(parent_layer)
|
|
# continue
|
|
|
|
# if key not in self.backup:
|
|
# self.backup[key] = weight.to(device=offload_device)
|
|
|
|
# bnb_layer = None
|
|
|
|
# if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
|
|
# bnb_layer = parent_layer
|
|
# from backend.operations_bnb import functional_dequantize_4bit
|
|
# weight = functional_dequantize_4bit(weight)
|
|
|
|
# gguf_cls = getattr(weight, 'gguf_cls', None)
|
|
# gguf_parameter = None
|
|
|
|
# if gguf_cls is not None:
|
|
# gguf_parameter = weight
|
|
# from backend.operations_gguf import dequantize_tensor
|
|
# weight = dequantize_tensor(weight)
|
|
|
|
# try:
|
|
# weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
# except:
|
|
# print('Patching LoRA weights out of memory. Retrying by offloading models.')
|
|
# set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
|
|
# memory_management.soft_empty_cache()
|
|
# weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
|
|
# if bnb_layer is not None:
|
|
# bnb_layer.reload_weight(weight)
|
|
# continue
|
|
|
|
# if gguf_cls is not None:
|
|
# gguf_cls.quantize_pytorch(weight, gguf_parameter)
|
|
# continue
|
|
|
|
# utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
|
|
|
# # End
|
|
|
|
# set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
# self.loaded_hash = hashes
|
|
# return
|