mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
452 lines
19 KiB
Python
452 lines
19 KiB
Python
"""Quantized inference for See-through full pipeline (layerdiff body -> head -> marigold depth -> PSD).
|
|
|
|
Supports NF4 (default, 4-bit) and bf16 (baseline) modes. HF repos are auto-selected
|
|
based on quant_mode. Builds pipelines directly without using inference_utils global singletons.
|
|
|
|
Usage (from repo root):
|
|
python inference/scripts/inference_psd_quantized.py --srcp image.png --save_to_psd
|
|
python inference/scripts/inference_psd_quantized.py --quant_mode none --no_group_offload
|
|
"""
|
|
|
|
import os.path as osp
|
|
import argparse
|
|
import sys
|
|
import os
|
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
|
|
default_n_threads = 8
|
|
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
|
|
os.environ['MKL_NUM_THREADS'] = f"{default_n_threads}"
|
|
os.environ['OMP_NUM_THREADS'] = f"{default_n_threads}"
|
|
|
|
import json
|
|
import time
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from modules.layerdiffuse.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline
|
|
from modules.layerdiffuse.vae import TransparentVAE
|
|
from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel
|
|
from modules.marigold import MarigoldDepthPipeline
|
|
from utils.cv import center_square_pad_resize, smart_resize, img_alpha_blending
|
|
from utils.torch_utils import seed_everything
|
|
from utils.io_utils import json2dict, dict2json
|
|
from utils.inference_utils import further_extr
|
|
from utils.cv import validate_resolution
|
|
|
|
|
|
VALID_BODY_PARTS_V2 = [
|
|
'hair', 'headwear', 'face', 'eyes', 'eyewear', 'ears', 'earwear', 'nose', 'mouth',
|
|
'neck', 'neckwear', 'topwear', 'handwear', 'bottomwear', 'legwear', 'footwear',
|
|
'tail', 'wings', 'objects'
|
|
]
|
|
|
|
|
|
def build_layerdiff_pipeline(args):
|
|
"""Build the LayerDiff3D pipeline with appropriate quantization."""
|
|
quant_mode = args.quant_mode
|
|
|
|
if quant_mode == 'none':
|
|
# bf16 baseline: load from original repo
|
|
repo = args.repo_id_layerdiff
|
|
trans_vae = TransparentVAE.from_pretrained(repo, subfolder='trans_vae')
|
|
unet = UNetFrameConditionModel.from_pretrained(repo, subfolder='unet')
|
|
pipeline = KDiffusionStableDiffusionXLPipeline.from_pretrained(
|
|
repo, trans_vae=trans_vae, unet=unet, scheduler=None)
|
|
if args.cpu_offload:
|
|
pipeline.vae.to(dtype=torch.bfloat16)
|
|
pipeline.trans_vae.to(dtype=torch.bfloat16)
|
|
pipeline.unet.to(dtype=torch.bfloat16)
|
|
pipeline.text_encoder.to(dtype=torch.bfloat16)
|
|
pipeline.text_encoder_2.to(dtype=torch.bfloat16)
|
|
pipeline.enable_model_cpu_offload()
|
|
else:
|
|
pipeline.vae.to(dtype=torch.bfloat16, device='cuda')
|
|
pipeline.trans_vae.to(dtype=torch.bfloat16, device='cuda')
|
|
pipeline.unet.to(dtype=torch.bfloat16, device='cuda')
|
|
pipeline.text_encoder.to(dtype=torch.bfloat16, device='cuda')
|
|
pipeline.text_encoder_2.to(dtype=torch.bfloat16, device='cuda')
|
|
if getattr(args, 'group_offload', False):
|
|
pipeline.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
# Cache tag embeddings and unload text encoders to save VRAM
|
|
pipeline.cache_tag_embeds()
|
|
else:
|
|
# NF4: load from pre-quantized repo (auto-selected by REPO_MAP)
|
|
repo = args.repo_id_layerdiff
|
|
unet = UNetFrameConditionModel.from_pretrained(repo, subfolder='unet')
|
|
|
|
trans_vae = TransparentVAE.from_pretrained(repo, subfolder='trans_vae') # always bf16
|
|
pipeline = KDiffusionStableDiffusionXLPipeline.from_pretrained(
|
|
repo, trans_vae=trans_vae, unet=unet, scheduler=None)
|
|
|
|
if args.cpu_offload:
|
|
# VAE + TransparentVAE to bf16; quantized components handled by bnb
|
|
pipeline.vae.to(dtype=torch.bfloat16)
|
|
pipeline.trans_vae.to(dtype=torch.bfloat16)
|
|
pipeline.enable_model_cpu_offload()
|
|
else:
|
|
pipeline.vae.to(dtype=torch.bfloat16, device='cuda')
|
|
pipeline.trans_vae.to(dtype=torch.bfloat16, device='cuda')
|
|
# Don't manually .to(cuda) quantized components -- bnb handles device placement
|
|
if getattr(args, 'group_offload', False):
|
|
pipeline.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
# Cache tag embeddings and unload text encoders to save VRAM
|
|
pipeline.cache_tag_embeds()
|
|
|
|
return pipeline
|
|
|
|
|
|
def build_marigold_pipeline(args):
|
|
"""Build the Marigold depth pipeline with appropriate quantization."""
|
|
quant_mode = args.quant_mode
|
|
|
|
if quant_mode == 'none':
|
|
repo = args.repo_id_depth
|
|
unet = UNetFrameConditionModel.from_pretrained(repo, subfolder='unet')
|
|
marigold_pipe = MarigoldDepthPipeline.from_pretrained(repo, unet=unet)
|
|
if args.cpu_offload:
|
|
marigold_pipe.to(dtype=torch.bfloat16)
|
|
marigold_pipe.enable_model_cpu_offload()
|
|
else:
|
|
marigold_pipe.to(device='cuda', dtype=torch.bfloat16)
|
|
if getattr(args, 'group_offload', False):
|
|
marigold_pipe.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
marigold_pipe.cache_tag_embeds()
|
|
else:
|
|
# NF4: load from pre-quantized repo (auto-selected by REPO_MAP)
|
|
repo = args.repo_id_depth
|
|
unet = UNetFrameConditionModel.from_pretrained(repo, subfolder='unet', torch_dtype=torch.bfloat16)
|
|
|
|
marigold_pipe = MarigoldDepthPipeline.from_pretrained(repo, unet=unet, torch_dtype=torch.bfloat16)
|
|
marigold_pipe.vae.to(device='cuda')
|
|
marigold_pipe.unet.to(device='cuda')
|
|
# Text encoder may be quantized (from pre-quantized repo) — only move device, not dtype
|
|
if not getattr(marigold_pipe.text_encoder, 'is_quantized', False) and \
|
|
not getattr(marigold_pipe.text_encoder, 'quantization_method', None):
|
|
marigold_pipe.text_encoder.to(device='cuda')
|
|
if getattr(args, 'group_offload', False):
|
|
marigold_pipe.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
marigold_pipe.cache_tag_embeds()
|
|
|
|
return marigold_pipe
|
|
|
|
|
|
def run_layerdiff(pipeline, imgp, save_dir, seed, num_inference_steps, resolution):
|
|
"""Run LayerDiff3D body + head passes. Replicates inference_utils.py v3 logic exactly."""
|
|
saved = osp.join(save_dir, osp.splitext(osp.basename(imgp))[0])
|
|
os.makedirs(saved, exist_ok=True)
|
|
input_img = np.array(Image.open(imgp).convert('RGBA'))
|
|
fullpage, pad_size, pad_pos = center_square_pad_resize(input_img, resolution, return_pad_info=True)
|
|
scale = pad_size[0] / resolution
|
|
Image.fromarray(fullpage).save(osp.join(saved, 'src_img.png'))
|
|
|
|
rng = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
|
|
|
|
# Body pass
|
|
body_tag_list = ['front hair', 'back hair', 'head', 'neck', 'neckwear', 'topwear', 'handwear', 'bottomwear', 'legwear', 'footwear', 'tail', 'wings', 'objects']
|
|
pipeline_output = pipeline(
|
|
strength=1.0,
|
|
num_inference_steps=num_inference_steps,
|
|
batch_size=1,
|
|
generator=rng,
|
|
guidance_scale=1.0,
|
|
prompt=body_tag_list,
|
|
negative_prompt='',
|
|
fullpage=fullpage,
|
|
group_index=0
|
|
)
|
|
images = pipeline_output.images
|
|
for rst, tag in zip(pipeline_output.images, body_tag_list):
|
|
Image.fromarray(rst).save(osp.join(saved, f'{tag}.png'))
|
|
head_img = images[2]
|
|
|
|
# Head crop
|
|
head_tag_list = ['headwear', 'face', 'irides', 'eyebrow', 'eyewhite', 'eyelash', 'eyewear', 'ears', 'earwear', 'nose', 'mouth']
|
|
hx0, hy0, hw, hh = cv2.boundingRect(cv2.findNonZero((head_img[..., -1] > 15).astype(np.uint8)))
|
|
|
|
hx = int(hx0 * scale) - pad_pos[0]
|
|
hy = int(hy0 * scale) - pad_pos[1]
|
|
hw = int(hw * scale)
|
|
hh = int(hh * scale)
|
|
|
|
def _crop_head(img, xywh):
|
|
x, y, w, h = xywh
|
|
ih, iw = img.shape[:2]
|
|
x1 = x
|
|
y1 = y
|
|
x2 = x + w
|
|
y2 = y + h
|
|
if w < iw // 2:
|
|
px = min(iw - x - w, x, w // 5)
|
|
x1 = min(max(x - px, 0), iw)
|
|
x2 = min(max(x + w + px, 0), iw)
|
|
if h < ih // 2:
|
|
py = min(ih - y - h, y, h // 5)
|
|
y2 = min(max(y + h + py, 0), ih)
|
|
y1 = min(max(y - py, 0), ih)
|
|
return img[y1: y2, x1: x2], (x1, y1, x2, y2)
|
|
|
|
input_head, (hx1, hy1, hx2, hy2) = _crop_head(input_img, [hx, hy, hw, hh])
|
|
hx1 = int(hx1 / scale + pad_pos[0] / scale)
|
|
hy1 = int(hy1 / scale + pad_pos[1] / scale)
|
|
ih, iw = input_head.shape[:2]
|
|
input_head, pad_size, pad_pos = center_square_pad_resize(input_head, resolution, return_pad_info=True)
|
|
Image.fromarray(input_head).save(osp.join(saved, 'src_head.png'))
|
|
|
|
# Head pass
|
|
pipeline_output = pipeline(
|
|
strength=1.0,
|
|
num_inference_steps=num_inference_steps,
|
|
batch_size=1,
|
|
generator=rng,
|
|
guidance_scale=1.0,
|
|
prompt=head_tag_list,
|
|
negative_prompt='',
|
|
fullpage=input_head,
|
|
group_index=1
|
|
)
|
|
canvas = np.zeros((resolution, resolution, 4), dtype=np.uint8)
|
|
|
|
py1, py2, px1, px2 = (np.array([pad_pos[1], pad_pos[1] + ih, pad_pos[0], pad_pos[0] + iw]) / scale).astype(np.int64)
|
|
|
|
scale_size = (int(pad_size[0] / scale), int(pad_size[1] / scale))
|
|
|
|
for rst, tag in zip(pipeline_output.images, head_tag_list):
|
|
rst = smart_resize(rst, scale_size)[py1: py2, px1: px2]
|
|
full = canvas.copy()
|
|
full[hy1: hy1 + rst.shape[0], hx1: hx1 + rst.shape[1]] = rst
|
|
Image.fromarray(full).save(osp.join(saved, f'{tag}.png'))
|
|
|
|
|
|
def run_marigold(marigold_pipe, srcp, save_dir, seed, resolution_depth):
|
|
"""Run Marigold depth estimation. Matches inference_utils.apply_marigold logic.
|
|
|
|
Uses resolution_depth to control Marigold inference resolution. If different from
|
|
source image size, images are resized before depth prediction and depth maps are
|
|
resized back after. All frames processed together (no chunking).
|
|
"""
|
|
srcname = osp.basename(osp.splitext(srcp)[0])
|
|
saved = osp.join(save_dir, srcname)
|
|
|
|
# Read source image to get actual size (matches inference_utils approach)
|
|
src_img_p = osp.join(saved, 'src_img.png')
|
|
fullpage = np.array(Image.open(src_img_p).convert('RGBA'))
|
|
src_h, src_w = fullpage.shape[:2]
|
|
|
|
if isinstance(resolution_depth, int) and resolution_depth == -1:
|
|
resolution_depth = [src_h, src_w]
|
|
resolution_depth = validate_resolution(resolution_depth)
|
|
src_rescaled = resolution_depth[0] != src_h or resolution_depth[1] != src_w
|
|
|
|
img_list = []
|
|
exist_list = []
|
|
empty_array = np.zeros((src_h, src_w, 4), dtype=np.uint8)
|
|
blended_alpha = np.zeros((src_h, src_w), dtype=np.float32)
|
|
|
|
compose_list = {'eyes': ['eyewhite', 'irides', 'eyelash', 'eyebrow'], 'hair': ['back hair', 'front hair']}
|
|
for tag in VALID_BODY_PARTS_V2:
|
|
tagp = osp.join(saved, f'{tag}.png')
|
|
if osp.exists(tagp):
|
|
exist_list.append(True)
|
|
tag_arr = np.array(Image.open(tagp))
|
|
tag_arr[..., -1][tag_arr[..., -1] < 15] = 0
|
|
img_list.append(tag_arr)
|
|
else:
|
|
img_list.append(empty_array)
|
|
exist_list.append(False)
|
|
|
|
compose_dict = {}
|
|
for c, clist in compose_list.items():
|
|
imlist = []
|
|
taglist = []
|
|
for tag in clist:
|
|
p = osp.join(saved, tag + '.png')
|
|
if osp.exists(p):
|
|
tag_arr = np.array(Image.open(p))
|
|
tag_arr[..., -1][tag_arr[..., -1] < 15] = 0
|
|
imlist.append(tag_arr)
|
|
taglist.append(tag)
|
|
if len(imlist) > 0:
|
|
img = img_alpha_blending(imlist, premultiplied=False)
|
|
img_list[VALID_BODY_PARTS_V2.index(c)] = img
|
|
compose_dict[c] = {'taglist': taglist, 'imlist': imlist}
|
|
|
|
for img in img_list:
|
|
blended_alpha += img[..., -1].astype(np.float32) / 255
|
|
|
|
blended_alpha = np.clip(blended_alpha, 0, 1) * 255
|
|
blended_alpha = blended_alpha.astype(np.uint8)
|
|
fullpage[..., -1] = blended_alpha
|
|
img_list.append(fullpage)
|
|
|
|
# Resize to depth resolution if needed
|
|
img_list_input = img_list
|
|
if src_rescaled:
|
|
img_list_input = [smart_resize(img, resolution_depth) for img in img_list]
|
|
|
|
seed_everything(seed)
|
|
pipe_out = marigold_pipe(color_map=None, img_list=img_list_input)
|
|
depth_pred = pipe_out.depth_tensor
|
|
depth_pred = depth_pred.to(device='cpu', dtype=torch.float32).numpy()
|
|
|
|
# Resize depth back to source resolution if needed
|
|
if src_rescaled:
|
|
depth_pred = [smart_resize(d, (src_h, src_w)) for d in depth_pred]
|
|
|
|
drawables = [{'img': img, 'depth': depth} for img, depth in zip(img_list, depth_pred)]
|
|
drawables = drawables[:-1]
|
|
blended = img_alpha_blending(drawables, premultiplied=False)
|
|
|
|
infop = osp.join(saved, 'info.json')
|
|
if osp.exists(infop):
|
|
info = json2dict(infop)
|
|
else:
|
|
info = {'parts': {}}
|
|
|
|
parts = info['parts']
|
|
for ii, depth in enumerate(depth_pred[:-1]):
|
|
depth = (np.clip(depth, 0, 1) * 255).astype(np.uint8)
|
|
tag = VALID_BODY_PARTS_V2[ii]
|
|
if tag in compose_dict:
|
|
mask = blended_alpha > 256
|
|
for t, im in zip(compose_dict[tag]['taglist'][::-1], compose_dict[tag]['imlist'][::-1]):
|
|
mask_local = im[..., -1] > 15
|
|
mask_invis = np.bitwise_and(mask, mask_local)
|
|
depth_local = np.full((src_h, src_w), fill_value=255, dtype=np.uint8)
|
|
depth_local[mask_local] = depth[mask_local]
|
|
if np.any(mask_invis):
|
|
depth_local[mask_invis] = np.median(depth[np.bitwise_and(mask_local, np.bitwise_not(mask_invis))])
|
|
mask = np.bitwise_or(mask, mask_local)
|
|
|
|
parts_info = parts.get(t, {})
|
|
Image.fromarray(depth_local).save(osp.join(saved, f'{t}_depth.png'))
|
|
parts[t] = parts_info
|
|
continue
|
|
|
|
parts_info = parts.get(tag, {})
|
|
Image.fromarray(depth).save(osp.join(saved, f'{tag}_depth.png'))
|
|
parts[tag] = parts_info
|
|
|
|
dict2json(info, infop)
|
|
Image.fromarray(blended).save(osp.join(saved, 'reconstruction.png'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Quantized inference: LayerDiff body+head -> Marigold depth -> PSD"
|
|
)
|
|
parser.add_argument('--srcp', type=str, default='assets/test_image.png', help='input image')
|
|
parser.add_argument('--save_dir', type=str, default='workspace/layerdiff_output')
|
|
parser.add_argument('--seed', type=int, default=42)
|
|
parser.add_argument('--resolution', type=int, default=1280)
|
|
parser.add_argument('--save_to_psd', action='store_true')
|
|
parser.add_argument('--tblr_split', action='store_true',
|
|
help='try split parts (handwear, eyes, etc) into left-right components')
|
|
parser.add_argument('--quant_mode', type=str, default='nf4', choices=['nf4', 'none'],
|
|
help='quantization mode: nf4 (default, 4-bit) or none (bf16 baseline)')
|
|
parser.add_argument('--repo_id_layerdiff', type=str, default=None,
|
|
help='Override LayerDiff3D HF repo (auto-selected based on quant_mode)')
|
|
parser.add_argument('--repo_id_depth', type=str, default=None,
|
|
help='Override Marigold3D HF repo (auto-selected based on quant_mode)')
|
|
parser.add_argument('--cpu_offload', action='store_true', default=False,
|
|
help='enable model CPU offload (default: on)')
|
|
parser.add_argument('--no_cpu_offload', action='store_false', dest='cpu_offload',
|
|
help='disable model CPU offload')
|
|
parser.add_argument('--num_inference_steps', type=int, default=30)
|
|
parser.add_argument('--resolution_depth', type=int, default=768,
|
|
help='Marigold depth inference resolution (default 768; -1 to match layerdiff resolution)')
|
|
parser.add_argument('--group_offload', action='store_true', default=True,
|
|
help='Enable group offload to reduce peak VRAM (default: on)')
|
|
parser.add_argument('--no_group_offload', action='store_false', dest='group_offload',
|
|
help='Disable group offload for faster inference on high-VRAM GPUs')
|
|
args = parser.parse_args()
|
|
|
|
# Auto-select HF repos based on quant_mode
|
|
REPO_MAP = {
|
|
'nf4': {
|
|
'layerdiff': '24yearsold/seethroughv0.0.2_layerdiff3d_nf4',
|
|
'depth': '24yearsold/seethroughv0.0.1_marigold_nf4',
|
|
},
|
|
'none': {
|
|
'layerdiff': 'layerdifforg/seethroughv0.0.2_layerdiff3d',
|
|
'depth': '24yearsold/seethroughv0.0.1_marigold',
|
|
},
|
|
}
|
|
defaults = REPO_MAP[args.quant_mode]
|
|
if args.repo_id_layerdiff is None:
|
|
args.repo_id_layerdiff = defaults['layerdiff']
|
|
if args.repo_id_depth is None:
|
|
args.repo_id_depth = defaults['depth']
|
|
|
|
srcp = args.srcp
|
|
seed = args.seed
|
|
resolution = args.resolution
|
|
num_inference_steps = args.num_inference_steps
|
|
save_dir = args.save_dir
|
|
srcname = osp.basename(osp.splitext(srcp)[0])
|
|
saved = osp.join(save_dir, srcname)
|
|
|
|
print(f"Quantized inference: quant_mode={args.quant_mode}, cpu_offload={args.cpu_offload}")
|
|
print(f" Source image: {srcp}")
|
|
print(f" Save dir: {save_dir}")
|
|
print(f" Resolution: {resolution}, Steps: {num_inference_steps}, Seed: {seed}")
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
total_t0 = time.time()
|
|
|
|
# --- LayerDiff ---
|
|
print('\nBuilding LayerDiff3D pipeline...')
|
|
seed_everything(seed)
|
|
pipeline = build_layerdiff_pipeline(args)
|
|
|
|
print('Running LayerDiff3D (body + head)...')
|
|
layerdiff_t0 = time.time()
|
|
run_layerdiff(pipeline, srcp, save_dir, seed, num_inference_steps, resolution)
|
|
layerdiff_time = time.time() - layerdiff_t0
|
|
print(f' LayerDiff3D done in {layerdiff_time:.1f}s')
|
|
|
|
# Free layerdiff pipeline before loading marigold
|
|
del pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
# --- Marigold ---
|
|
print('\nBuilding Marigold depth pipeline...')
|
|
marigold_pipe = build_marigold_pipeline(args)
|
|
|
|
print('Running Marigold depth...')
|
|
marigold_t0 = time.time()
|
|
run_marigold(marigold_pipe, srcp, save_dir, seed, resolution_depth=args.resolution_depth)
|
|
marigold_time = time.time() - marigold_t0
|
|
print(f' Marigold done in {marigold_time:.1f}s')
|
|
|
|
# Free marigold pipeline before PSD assembly
|
|
del marigold_pipe
|
|
torch.cuda.empty_cache()
|
|
|
|
# --- PSD assembly ---
|
|
print('\nRunning PSD assembly...')
|
|
psd_t0 = time.time()
|
|
further_extr(saved, rotate=False, save_to_psd=args.save_to_psd, tblr_split=args.tblr_split)
|
|
psd_time = time.time() - psd_t0
|
|
print(f' PSD assembly done in {psd_time:.1f}s')
|
|
|
|
total_time = time.time() - total_t0
|
|
|
|
# --- Stats ---
|
|
stats = {
|
|
'quant_mode': args.quant_mode,
|
|
'peak_vram_gb': torch.cuda.max_memory_allocated() / 1024**3,
|
|
'layerdiff_time_s': layerdiff_time,
|
|
'marigold_time_s': marigold_time,
|
|
'psd_time_s': psd_time,
|
|
'total_time_s': total_time,
|
|
}
|
|
print(f'\n{"="*60}')
|
|
print(json.dumps(stats, indent=2))
|
|
print(f'{"="*60}')
|
|
with open(osp.join(saved, 'stats.json'), 'w') as f:
|
|
json.dump(stats, f, indent=2)
|
|
print(f'Stats saved to {osp.join(saved, "stats.json")}')
|