see-through/training/tests/test_layerdiff.py
Miaomiao Li 9ded69536e public: release training scripts, configs, and data pipeline (V3)
Co-Authored-By: dmMaze <beneathlimbo@gmail.com>
2026-04-14 15:16:51 +08:00

495 lines
No EOL
20 KiB
Python

from numpy._typing._array_like import NDArray
import os
import os.path as osp
import sys
from typing import Any
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from omegaconf import OmegaConf
import numpy as np
from pathlib import Path
from tqdm import tqdm
import click
import cv2
import torch
from omegaconf import OmegaConf
from PIL import Image
from utils.torch_utils import seed_everything
from utils.cv import checkerboard, checkerboard_vis, batch_load_masks, pad_rgb
from utils.io_utils import find_all_imgs, pil_pad_square, pil_ensure_rgb, imglist2imgrid, json2dict, dict2json
@click.group()
def cli():
"""layer diff test scripts.
"""
@cli.command('test_layerdiff')
@click.option('--config')
@click.option('--rank_to_worldsize', default=None)
def test_layerdiff(config, rank_to_worldsize):
import random
from safetensors.torch import load_file
from utils.cv import visualize_rgba, center_square_pad_resize
from utils.torch_utils import init_model_from_pretrained, tensor2img, seed_everything
from modules.layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder, vae_encode
from modules.layerdiffuse.utils import patch_transvae_sd, patch_unet_convin
from modules.layerdiffuse.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline
from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel
from live2d.scrap_model import VALID_BODY_PARTS_V2
from utils.io_utils import load_exec_list
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps, rescale_noise_cfg
from train.dataset_layerdiff import LayerDiffDataset
from utils.torch_utils import img2tensor, tensor2img
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from diffusers import (
AutoencoderKL,
EulerDiscreteScheduler,
UNet2DConditionModel,
)
config = OmegaConf.load(config)
if rank_to_worldsize is not None:
config.rank_to_worldsize = rank_to_worldsize
seed_list = config.get('seed_list', None)
device = 'cuda'
dtype = torch.bfloat16
args = config
save_dir = config.save_dir
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
seed = config.get('seed', 1)
seed_everything(seed)
is_3d = args.get('is_3d', False)
pretrained_model_name_or_path = args.pretrained_model_name_or_path
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path,
subfolder="vae",
revision=None,
variant=None,
)
scheduler = EulerDiscreteScheduler.from_pretrained(
pretrained_model_name_or_path,
subfolder="scheduler",
)
trans_vae_encoder: TransparentVAEEncoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEEncoder, weights_name='ld_diffusers_sdxl_vae_transparent_encoder.safetensors',
patch_state_dict_func=patch_transvae_sd
)
trans_vae_decoder: TransparentVAEDecoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEDecoder, weights_name='ld_diffusers_sdxl_vae_transparent_decoder.safetensors',
patch_state_dict_func=patch_transvae_sd
)
tokenizer_one = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
use_fast=False,
)
tokenizer_two = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
use_fast=False,
)
text_encoder_one = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)
if config.vae_ckpt:
sd = load_file(config.vae_ckpt)
td_sd = {}
vae_sd = {}
for k, v in sd.items():
if k.startswith('trans_decoder.'):
td_sd[k.lstrip('trans_decoder.')] = v
elif k.startswith('vae.'):
vae_sd[k.replace('vae.', '')] = v
trans_vae_decoder.load_state_dict(td_sd)
vae.load_state_dict(vae_sd)
vae = vae.to(device=device, dtype=torch.float32)
trans_vae_encoder = trans_vae_encoder.to(device=device, dtype=torch.float32)
trans_vae_decoder = trans_vae_decoder.to(device=device, dtype=torch.float32)
else:
trans_vae_encoder: TransparentVAEEncoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEEncoder, weights_name='ld_diffusers_sdxl_vae_transparent_encoder.safetensors',
patch_state_dict_func=patch_transvae_sd
).to(device=device, dtype=torch.float32)
trans_vae_decoder: TransparentVAEDecoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEDecoder, weights_name='ld_diffusers_sdxl_vae_transparent_decoder.safetensors',
patch_state_dict_func=patch_transvae_sd
).to(device=device, dtype=torch.float32)
vae = AutoencoderKL.from_pretrained(
"cagliostrolab/animagine-xl-4.0",
subfolder="vae").to(device=device, dtype=torch.float32)
unet_cls = UNetFrameConditionModel if is_3d else UNet2DConditionModel
unet = unet_cls.from_pretrained(
**OmegaConf.to_container(config.pretrained_unet)
)
unet = unet.to(device=device, dtype=dtype)
unet.eval()
weight_dtype = unet.dtype
rng = torch.Generator(device=unet.device).manual_seed(seed)
pipeline: KDiffusionStableDiffusionXLPipeline = KDiffusionStableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder_one,
tokenizer=tokenizer_one,
text_encoder_2=text_encoder_two,
tokenizer_2=tokenizer_two,
unet=unet,
scheduler=scheduler
)
target_tag_list = VALID_BODY_PARTS_V2
exec_list = []
rank_to_worldsize = config.get('rank_to_worldsize', None)
create_sub_dir = config.get('create_sub_dir', False)
if isinstance(config.exec_list, str):
if osp.isdir(config.exec_list):
exec_list = find_all_imgs(config.exec_list, abs_path=True)
else:
exec_list = load_exec_list(config.exec_list, rank_to_worldsize=rank_to_worldsize)
else:
for p in config.exec_list:
exec_list += load_exec_list(p, rank_to_worldsize=rank_to_worldsize)
if config.shuffle:
random.shuffle(exec_list)
mask_cond = config.get('mask_cond', False)
if seed_list is not None:
num_data = len(exec_list)
exec_list = exec_list * len(seed_list)
for ii, p in enumerate(tqdm(exec_list)):
if mask_cond:
mask_list = batch_load_masks(p + '_masks.json')
if create_sub_dir:
# saved = osp.join(save_dir, osp.basename(osp.dirname(srcp)), srcname)
saved = osp.join(save_dir, osp.basename(osp.dirname(p)), osp.splitext(osp.basename(p))[0])
else:
saved = osp.join(save_dir, osp.splitext(osp.basename(p))[0])
current_seed = None
if seed_list is not None:
current_seed = seed_list[ii // num_data]
saved += f'_seed{current_seed}'
seed_everything(current_seed)
os.makedirs(saved, exist_ok=True)
fullpage = center_square_pad_resize(np.array(Image.open(p).convert('RGBA')), 1024)
Image.fromarray(fullpage).save(osp.join(saved, 'src_img.png'))
page_alpha = img2tensor(fullpage[..., -1] / 255., device=vae.device, dtype=vae.dtype)[0][..., None]
fullpage = fullpage[..., :3]
c_concat = np.concatenate([np.full_like(fullpage[..., :1], fill_value=255), fullpage], axis=2)
c_concat = img2tensor(c_concat, normalize=True)
if trans_vae_decoder.model.config.in_channels == 6:
rgb_cond = c_concat[:, 1:].to(device=vae.device, dtype=vae.dtype)
else:
rgb_cond = None
empty_tensor = torch.zeros_like(c_concat)
c_concat = vae_encode(vae, trans_vae_encoder, c_concat, use_offset=False).to(device=device, dtype=dtype)
if unet.config.in_channels == 12:
c_concat = torch.cat([c_concat, vae_encode(vae, trans_vae_encoder, empty_tensor).to(device=device, dtype=dtype)], dim=1)
vis_list = [fullpage]
# target_tag_list = ['hair']
res_list = []
for tag in target_tag_list:
if is_3d:
tag = target_tag_list
if mask_cond:
visible_mask = mask_list[VALID_BODY_PARTS_V2.index(tag)].astype(np.uint8) * 255
visible_mask = center_square_pad_resize(visible_mask, 1024)
visible_mask = visible_mask[..., None]
cond_tag_img = fullpage * (visible_mask.astype(np.float32) / 255.)
# cond_tag_img[..., [-1]] = visible_mask.astype(np.float32)
cond_tag_img = np.concatenate([cond_tag_img, visible_mask], axis=2)
cond_tag_img: NDArray[Any] = np.concatenate([visible_mask, (pad_rgb(cond_tag_img) * 255).astype(np.uint8)], axis=2)
cond_tag_img = img2tensor(img=cond_tag_img, normalize=True, mean=[0., 0., 0., 0.], std=[255., 255., 255., 255.], dim_order='chw')
c_concat[:, -4:] = vae_encode(vae, trans_vae_encoder, cond_tag_img[None].to(dtype=vae.dtype, device=vae.device), use_offset=False).to(dtype=weight_dtype)
latents = pipeline(
strength=1.0,
num_inference_steps=config.num_inference_steps,
batch_size=1,
generator=rng,
guidance_scale=args.guidance_scale, c_concat=c_concat,
prompt=tag,
negative_prompt=''
).images
if latents.ndim == 5:
latents = latents[0]
latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
for latent in latents:
latent = latent[None]
# latent = scheduler.add_noise(latent, torch.randn_like(latent), timesteps=torch.tensor([1], device=latent.device))
result_list, vis_list_batch = trans_vae_decoder(vae, latent, rgb_cond=rgb_cond, mask=page_alpha)
vis_list += vis_list_batch
res_list += result_list
if is_3d:
break
for rst, tag in zip(res_list, target_tag_list):
savename = osp.join(saved, f'{tag}.png')
Image.fromarray(rst).save(savename)
@cli.command('test_layerdiff_vae')
@click.option('--config')
@click.option('--save_dir', default=None)
@click.option('--shuffle', is_flag=True, default=False)
def test_layerdiff_vae(config, save_dir, shuffle):
from train.dataset_layerdiff import LayerDiffDataset
from utils.torch_utils import img2tensor, tensor2img
from utils.torch_utils import init_model_from_pretrained, tensor2img, seed_everything
from modules.layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder, vae_encode
from modules.layerdiffuse.utils import patch_unet_convin, patch_transvae_sd
from diffusers import AutoencoderKL
seed_everything(1)
if save_dir is None:
save_dir = osp.join(osp.dirname(config),
osp.splitext(osp.basename(config))[0] + '_vaedecodevis')
os.makedirs(save_dir, exist_ok=True)
config = OmegaConf.load(config)
shuffle = True
dataset = LayerDiffDataset(**OmegaConf.to_container(config.dataset_args))
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=shuffle,
batch_size=config.get('batch_size', 1),
num_workers=config.get('dataloader_num_workers', 0),
drop_last=True
)
vae = AutoencoderKL.from_pretrained(
'cagliostrolab/animagine-xl-4.0',
subfolder="vae",
)
trans_vae_encoder: TransparentVAEEncoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEEncoder, weights_name='ld_diffusers_sdxl_vae_transparent_encoder.safetensors',
patch_state_dict_func=patch_transvae_sd
)
trans_vae_decoder: TransparentVAEDecoder = init_model_from_pretrained(
'lllyasviel/LayerDiffuse_Diffusers',
TransparentVAEDecoder, weights_name='ld_diffusers_sdxl_vae_transparent_decoder.safetensors',
patch_state_dict_func=patch_transvae_sd
)
device = 'cuda'
vae.requires_grad_(False)
trans_vae_encoder.requires_grad_(False)
trans_vae_decoder.requires_grad_(False)
vae.to(device, dtype=torch.float32)
trans_vae_decoder.to(device, dtype=torch.float32)
trans_vae_encoder.to(device, dtype=torch.float32)
def _cvt_img(img):
img_np = tensor2img(img, mean=dataset.pixel_mean, std=dataset.pixel_std, denormalize=True)
img_np = np.concatenate([img_np[..., 1:], img_np[..., [0, 0, 0]]])
return Image.fromarray(img_np)
for batch in dataloader:
img = batch['img'][0]
cond_fullpage = batch['cond_fullpage'][0]
cond_tag_img = batch['cond_tag_img'][0]
srcp = batch['srcp'][0]
caption = batch['caption'][0]
savename = osp.join(save_dir, osp.splitext(osp.basename(srcp))[0])
latent_scaled = vae_encode(vae, trans_vae_encoder, img[None].to(device=device, dtype=vae.dtype), use_offset=True)
latents = latent_scaled.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor
result_list, vis_list_batch = trans_vae_decoder(vae, latents)
# trans_vae_decoder()
Image.fromarray(vis_list_batch[0]).save(savename + '_vaedecoded.png')
_cvt_img(img).save(savename + '_img.png')
_cvt_img(cond_fullpage).save(savename + '_cond_fullpage.png')
_cvt_img(cond_tag_img).save(savename + '_cond_tag_img.png')
with open(savename + '.txt', 'w', encoding='utf8') as f:
f.write(caption)
@cli.command('marigold_blend')
@click.option('--config')
@click.option('--rank_to_worldsize', default=None)
def marigold_blend(config, rank_to_worldsize):
from modules.marigold import MarigoldDepthPipeline
from modules.marigold.marigold_depth_pipeline import encode_empty_text, encode_argb_list
from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel
from live2d.scrap_model import VALID_BODY_PARTS_V2
from utils.cv import img_alpha_blending, center_square_pad_resize, argb2rgba
from utils.torch_utils import img2tensor, tensor2img
from utils.io_utils import load_exec_list
# model = MarigoldDepthPipeline.from
config = OmegaConf.load(config)
if rank_to_worldsize is not None:
config.rank_to_worldsize = rank_to_worldsize
device = 'cuda'
dtype = torch.bfloat16
unet = UNetFrameConditionModel.from_pretrained(config.depth_unet)
pipe = MarigoldDepthPipeline.from_pretrained('prs-eth/marigold-depth-v1-1', unet=unet)
pipe.empty_text_embed = encode_empty_text()
pipe.to(device=device, dtype=dtype)
# srcp = 'workspace/datasets/live2d_bodysamples_chunk3/sekai_jxl_crop____08shizuku_archery_t03-NONE-8.png'
# srcp = 'workspace/datasets/live2d_bodysamples_chunk3/ARUZ_jxl_crop____aierdeliqi_4-NONE-5.png'
# part_dir = 'local_configs/layerdiff3d_womix20k_vaewgan1024_10k'
save_dir = config.save_dir
exec_list = config.exec_list
rank_to_worldsize = config.get('rank_to_worldsize', None)
create_sub_dir = config.get('create_sub_dir', False)
exec_list = load_exec_list(exec_list, rank_to_worldsize=rank_to_worldsize)
seed_list = config.get('seed_list', None)
if seed_list is not None:
nlist = []
for p in exec_list:
nlist += [p] * len(seed_list)
exec_list = nlist
for i, srcp in enumerate(tqdm(exec_list)):
srcname = osp.basename(osp.splitext(srcp)[0])
img_list = []
vis_list = []
caption_list = []
exist_list = []
empty_array = np.zeros((1024, 1024, 4), dtype=np.uint8)
blended_alpha = np.zeros((1024, 1024), dtype=np.float32)
fullpage = fullpage = center_square_pad_resize(np.array(Image.open(srcp).convert('RGBA')), 1024)
if create_sub_dir:
saved = osp.join(save_dir, osp.basename(osp.dirname(srcp)), srcname)
else:
saved = osp.join(save_dir, srcname)
if seed_list is not None:
saved = saved + '_seed'+str(seed_list[i % len(seed_list)])
for tag in VALID_BODY_PARTS_V2:
tagp = osp.join(saved, f'{tag}.png')
if osp.exists(tagp):
exist_list.append(True)
caption_list.append(tag)
tag_arr = np.array(Image.open(tagp))
tag_arr[..., -1][tag_arr[..., -1] < 15] = 0
blended_alpha += tag_arr[..., -1].astype(np.float32) / 255
img_list.append(tag_arr)
else:
img_list.append(empty_array)
exist_list.append(False)
blended_alpha = np.clip(blended_alpha, 0, 1) * 255
blended_alpha = blended_alpha.astype(np.uint8)
fullpage[..., -1] = blended_alpha
img_list.append(fullpage)
vae = pipe.vae
def _np_transform(img):
img = np.concatenate([img[..., 3:], img[..., :3]], axis=2).astype(np.float32) / 255.
img = img2tensor(img=img, dim_order='chw')
return img
img_list_tensor = torch.stack([_np_transform(img) for img in img_list])
cond_full_page = img_list_tensor[-1][None]
ncls = img_list_tensor.shape[0]
with torch.no_grad():
rgb_latent = [encode_argb_list(vae, img[None, None].to(device=device, dtype=dtype), pad_argb=True, dtype=dtype) for img in img_list_tensor]
rgb_latent = torch.cat(rgb_latent, dim=1)
rgb_cond_latent = encode_argb_list(vae, cond_full_page[None], pad_argb=True, dtype=dtype)
rgb_latent = torch.cat(
[rgb_cond_latent.expand(-1, ncls, -1, -1, -1), rgb_latent], dim=2
)
pipe_out = pipe(
# tensor2img(img, 'pil', denormalize=True, mean=127.5, std=127.5),
color_map=None,
cond_latent=rgb_latent[0],
)
depth_pred: np.ndarray = pipe_out.depth_tensor
depth_pred = depth_pred.to(device='cpu', dtype=torch.float32).numpy()
drawables = [{'img': argb2rgba(tensor2img(img, denormalize=True)), 'depth': depth} for img, depth in zip(img_list_tensor, depth_pred)]
drawables = drawables[:-1]
blended = img_alpha_blending(drawables)
# vis_list.append(checkerboard_vis(img=blended))
infop = osp.join(saved, 'info.json')
if osp.exists(infop):
info = json2dict(infop)
else:
info = {'parts': {}}
parts = info['parts']
for ii, depth in enumerate(depth_pred[:-1]):
depth_max, depth_min = depth.max(), depth.min()
depth = np.clip((depth - depth_min) / (depth_max - depth_min + 1e-7) * 255, 0, 255).astype(np.uint8)
# depth = depth[..., None][..., [-1] * 3].copy()
tag = VALID_BODY_PARTS_V2[ii]
parts_info = parts.get(tag, {})
savep = osp.join(saved, f'{tag}_depth.png')
Image.fromarray(depth).save(savep)
vis_list.append(depth)
parts[tag] = parts_info
parts_info['depth_max'] = depth_max
parts_info['depth_min'] = depth_min
dict2json(info, infop)
Image.fromarray(blended).save(osp.join(saved, 'reconstruction.png'))
# for tagimg in img_list[:-1]:
# vis_list.append(checkerboard_vis(tagimg))
# caption_list.insert(0, 'reconstruction')
# caption_list.insert(0, 'input')
# vis_list.insert(0, fullpage[..., :3])
# imglist2imgrid_with_tags(vis_list, caption_list, fix_size=512, output_type='pil').save(osp.join(save_dir, osp.basename(f'{srcname}.png')))
# # blended = img_alpha_blending(drawables)
# img_list.append()
# Image.fromarray(np.concat([blended, fullpage])).save('local_tst.png')
if __name__ == '__main__':
cli()