mirror of
https://github.com/jtydhr88/ComfyUI-See-through.git
synced 2026-05-05 20:00:40 +00:00
965 lines
42 KiB
Python
965 lines
42 KiB
Python
import os
|
|
import sys
|
|
import random
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
print("[SeeThrough] nodes.py: starting imports...", flush=True)
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
import folder_paths
|
|
import comfy.model_management as mm
|
|
import comfy.utils
|
|
|
|
|
|
def _log_vram(label):
|
|
"""Log current GPU VRAM usage for profiling."""
|
|
if torch.cuda.is_available():
|
|
alloc = torch.cuda.memory_allocated() / (1024 ** 3)
|
|
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
|
|
print(f"[SeeThrough VRAM] {label}: allocated={alloc:.2f}GB, reserved={reserved:.2f}GB", flush=True)
|
|
|
|
print("[SeeThrough] nodes.py: comfy imports OK", flush=True)
|
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
SEETHROUGH_ROOT_DIR = os.path.join(CURRENT_DIR, "see-through")
|
|
SEETHROUGH_COMMON_DIR = os.path.join(SEETHROUGH_ROOT_DIR, "common")
|
|
|
|
print(f"[SeeThrough] CURRENT_DIR = {CURRENT_DIR}", flush=True)
|
|
print(f"[SeeThrough] SEETHROUGH_COMMON_DIR = {SEETHROUGH_COMMON_DIR}", flush=True)
|
|
print(f"[SeeThrough] common dir exists = {os.path.isdir(SEETHROUGH_COMMON_DIR)}", flush=True)
|
|
|
|
# Mock pycocotools if not installed (only used for mask RLE, not needed here)
|
|
try:
|
|
import pycocotools # noqa: F401
|
|
print("[SeeThrough] pycocotools found", flush=True)
|
|
except ImportError:
|
|
print("[SeeThrough] pycocotools not found, installing mock...", flush=True)
|
|
import types as _types
|
|
_mock_pycocotools = _types.ModuleType("pycocotools")
|
|
_mock_mask = _types.ModuleType("pycocotools.mask")
|
|
_mock_pycocotools.mask = _mock_mask
|
|
sys.modules["pycocotools"] = _mock_pycocotools
|
|
sys.modules["pycocotools.mask"] = _mock_mask
|
|
|
|
if SEETHROUGH_COMMON_DIR not in sys.path:
|
|
sys.path.insert(0, SEETHROUGH_COMMON_DIR)
|
|
print(f"[SeeThrough] Added to sys.path: {SEETHROUGH_COMMON_DIR}", flush=True)
|
|
|
|
if SEETHROUGH_ROOT_DIR not in sys.path:
|
|
sys.path.insert(1, SEETHROUGH_ROOT_DIR)
|
|
print(f"[SeeThrough] Added to sys.path: {SEETHROUGH_ROOT_DIR}", flush=True)
|
|
|
|
_st_conflict_backup = {}
|
|
for _prefix in ("utils", "modules"):
|
|
for _key in list(sys.modules.keys()):
|
|
if _key == _prefix or _key.startswith(_prefix + "."):
|
|
_st_conflict_backup[_key] = sys.modules.pop(_key)
|
|
if _st_conflict_backup:
|
|
print(f"[SeeThrough] Temporarily removed {len(_st_conflict_backup)} conflicting sys.modules entries: "
|
|
f"{list(_st_conflict_backup.keys())[:10]}{'...' if len(_st_conflict_backup) > 10 else ''}", flush=True)
|
|
|
|
print("[SeeThrough] Importing see-through modules...", flush=True)
|
|
import cv2
|
|
from safetensors.torch import load_file
|
|
|
|
from modules.layerdiffuse.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline
|
|
from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel
|
|
from modules.layerdiffuse.vae import TransparentVAE
|
|
from modules.marigold import MarigoldDepthPipeline
|
|
from utils.cv import center_square_pad_resize, img_alpha_blending, smart_resize
|
|
from utils.torchcv import cluster_inpaint_part
|
|
|
|
print("[SeeThrough] All see-through imports OK", flush=True)
|
|
|
|
for _key, _mod in _st_conflict_backup.items():
|
|
if _key not in sys.modules:
|
|
sys.modules[_key] = _mod
|
|
del _st_conflict_backup
|
|
|
|
DEFAULT_LAYERDIFF_REPO = "layerdifforg/seethroughv0.0.2_layerdiff3d"
|
|
DEFAULT_LAYERDIFF_NF4_REPO = "24yearsold/seethroughv0.0.2_layerdiff3d_nf4"
|
|
DEFAULT_DEPTH_REPO = "layerdifforg/seethroughv0.0.1_marigold"
|
|
DEFAULT_DEPTH_NF4_REPO = "24yearsold/seethroughv0.0.1_marigold_nf4"
|
|
|
|
QUANT_MODES = ["none", "nf4"]
|
|
|
|
_NF4_REPO_MAP = {
|
|
DEFAULT_LAYERDIFF_REPO: DEFAULT_LAYERDIFF_NF4_REPO,
|
|
DEFAULT_DEPTH_REPO: DEFAULT_DEPTH_NF4_REPO,
|
|
}
|
|
|
|
VALID_BODY_PARTS_V2 = [
|
|
"hair", "headwear", "face", "eyes", "eyewear", "ears", "earwear",
|
|
"nose", "mouth", "neck", "neckwear", "topwear", "handwear",
|
|
"bottomwear", "legwear", "footwear", "tail", "wings", "objects",
|
|
]
|
|
|
|
SEETHROUGH_MODELS_DIR = os.path.join(folder_paths.models_dir, "SeeThrough")
|
|
os.makedirs(SEETHROUGH_MODELS_DIR, exist_ok=True)
|
|
|
|
|
|
class SeeThrough_LayersData:
|
|
"""Output of GenerateLayers: raw RGBA layers + preprocessing info."""
|
|
def __init__(self, layer_dict, fullpage, input_img, resolution, pad_size, pad_pos):
|
|
self.layer_dict = layer_dict # tag -> RGBA numpy (resolution x resolution)
|
|
self.fullpage = fullpage # center-padded input (resolution x resolution, RGBA)
|
|
self.input_img = input_img # original input (RGBA)
|
|
self.resolution = resolution
|
|
self.pad_size = pad_size
|
|
self.pad_pos = pad_pos
|
|
self.scale = pad_size[0] / resolution
|
|
|
|
|
|
class SeeThrough_LayersDepthData:
|
|
"""Output of GenerateDepth: layers + per-tag depth maps."""
|
|
def __init__(self, layer_dict, depth_dict, fullpage, resolution):
|
|
self.layer_dict = layer_dict # tag -> RGBA numpy
|
|
self.depth_dict = depth_dict # tag -> float32 depth [0,1]
|
|
self.fullpage = fullpage
|
|
self.resolution = resolution
|
|
|
|
|
|
def _cast_non_quantized_params(model, dtype):
|
|
"""Cast non-quantized float parameters to dtype, leaving bitsandbytes 4-bit params and integer tensors untouched."""
|
|
try:
|
|
import bitsandbytes as bnb
|
|
bnb_types = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)
|
|
except ImportError:
|
|
bnb_types = ()
|
|
for module in model.modules():
|
|
if bnb_types and isinstance(module, bnb_types):
|
|
continue
|
|
for param in module.parameters(recurse=False):
|
|
if param.data.is_floating_point():
|
|
param.data = param.data.to(dtype=dtype)
|
|
for buf in module.buffers(recurse=False):
|
|
if buf.data.is_floating_point():
|
|
buf.data = buf.data.to(dtype=dtype)
|
|
|
|
|
|
def seed_everything(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def _scan_model_dirs():
|
|
"""Recursively find diffusers model dirs (containing model_index.json) up to 2 levels deep.
|
|
Returns relative paths from SEETHROUGH_MODELS_DIR (e.g. 'foo' or 'org/foo')."""
|
|
found = []
|
|
if not os.path.isdir(SEETHROUGH_MODELS_DIR):
|
|
return found
|
|
if os.path.isfile(os.path.join(SEETHROUGH_MODELS_DIR, "model_index.json")):
|
|
found.append(".")
|
|
for name in sorted(os.listdir(SEETHROUGH_MODELS_DIR)):
|
|
d1 = os.path.join(SEETHROUGH_MODELS_DIR, name)
|
|
if not os.path.isdir(d1):
|
|
continue
|
|
if os.path.isfile(os.path.join(d1, "model_index.json")):
|
|
found.append(name)
|
|
continue
|
|
for sub in sorted(os.listdir(d1)):
|
|
d2 = os.path.join(d1, sub)
|
|
if os.path.isdir(d2) and os.path.isfile(os.path.join(d2, "model_index.json")):
|
|
found.append(f"{name}/{sub}")
|
|
return found
|
|
|
|
|
|
def _resolve_model_path(model_name):
|
|
if model_name == ".":
|
|
return SEETHROUGH_MODELS_DIR
|
|
local = os.path.join(SEETHROUGH_MODELS_DIR, model_name)
|
|
if os.path.isdir(local):
|
|
return local
|
|
# When model_name is an org/repo style ID (e.g. "layerdifforg/seethroughv0.0.1_marigold"),
|
|
# check if just the repo part exists locally (git clone creates dirs without the org prefix).
|
|
basename = model_name.split("/")[-1]
|
|
local_basename = os.path.join(SEETHROUGH_MODELS_DIR, basename)
|
|
if basename != model_name and os.path.isdir(local_basename):
|
|
return local_basename
|
|
return model_name
|
|
|
|
|
|
def _label_lr_split(labels, stats, id1, id2):
|
|
label1 = (labels == id1).astype(np.uint8) * 255
|
|
label2 = (labels == id2).astype(np.uint8) * 255
|
|
stats1, stats2 = stats[id1], stats[id2]
|
|
x1 = stats[id1][0] + stats[id1][2] / 2
|
|
x2 = stats[id2][0] + stats[id2][2] / 2
|
|
if x2 < x1:
|
|
return label2, label1, stats2, stats1
|
|
return label1, label2, stats1, stats2
|
|
|
|
|
|
def _process_cuts(img, depth, src_xyxy, tgt_bbox, mask=None):
|
|
tx1, ty1, tx2, ty2 = tgt_bbox[:4]
|
|
tx2 += tx1
|
|
ty2 += ty1
|
|
img = img[ty1:ty2, tx1:tx2].copy()
|
|
depth = depth[ty1:ty2, tx1:tx2]
|
|
depth_median = 1.0
|
|
if mask is not None:
|
|
mask = (mask[ty1:ty2, tx1:tx2].copy() > 15).astype(np.uint8)
|
|
ksize = 1
|
|
element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * ksize + 1, 2 * ksize + 1), (ksize, ksize))
|
|
mask = cv2.dilate(mask, element)
|
|
img[..., -1] *= mask
|
|
depth = 1 - (1 - depth) * mask
|
|
if np.any(mask):
|
|
depth_median = float(np.median(depth[mask > 0]))
|
|
fxyxy = [tx1 + src_xyxy[0], ty1 + src_xyxy[1], tx2 + src_xyxy[0], ty2 + src_xyxy[1]]
|
|
return img, depth, fxyxy, depth_median
|
|
|
|
|
|
def _part_lr_split(tag, part_info):
|
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
|
part_info["mask"].astype(np.uint8) * 255, connectivity=8)
|
|
tag2pinfo = {}
|
|
if len(stats) > 2:
|
|
stats = np.array(stats)
|
|
stats_order = np.argsort(stats[..., -1])[::-1][1:]
|
|
arml_mask, armr_mask, statsl, statsr = _label_lr_split(labels, stats, stats_order[0], stats_order[1])
|
|
img, depth, xyxy, dm = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsl, mask=arml_mask)
|
|
tag2pinfo[f"{tag}-r"] = {"img": img, "xyxy": xyxy, "depth": depth, "depth_median": dm, "tag": f"{tag}-r"}
|
|
img, depth, xyxy, dm = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsr, mask=armr_mask)
|
|
tag2pinfo[f"{tag}-l"] = {"img": img, "xyxy": xyxy, "depth": depth, "depth_median": dm, "tag": f"{tag}-l"}
|
|
else:
|
|
tag2pinfo[tag] = part_info
|
|
return tag2pinfo
|
|
|
|
|
|
def _tag_lr_split(tag, tag2pinfo):
|
|
if tag in tag2pinfo:
|
|
tag2pinfo.update(_part_lr_split(tag, tag2pinfo.pop(tag)))
|
|
|
|
|
|
def _compute_depth_median(part_dict):
|
|
img = part_dict.pop("img")
|
|
part_dict.pop("mask", None)
|
|
depth = part_dict.pop("depth")
|
|
mask = img[..., -1] > 10
|
|
depth_median = float(np.median(depth[mask])) if np.any(mask) else 1.0
|
|
nz = cv2.findNonZero(mask.astype(np.uint8))
|
|
if nz is not None:
|
|
xywh = cv2.boundingRect(nz)
|
|
cx1, cy1 = int(xywh[0]), int(xywh[1])
|
|
cx2, cy2 = cx1 + int(xywh[2]), cy1 + int(xywh[3])
|
|
depth = depth[cy1:cy2, cx1:cx2]
|
|
img = img[cy1:cy2, cx1:cx2]
|
|
if "xyxy" in part_dict:
|
|
ox, oy = part_dict["xyxy"][0], part_dict["xyxy"][1]
|
|
part_dict["xyxy"] = [ox + cx1, oy + cy1, ox + cx2, oy + cy2]
|
|
else:
|
|
part_dict["xyxy"] = [cx1, cy1, cx2, cy2]
|
|
depth = np.clip(depth, 0, 1) * 255
|
|
depth = np.round(depth).astype(np.uint8)
|
|
part_dict["depth_median"] = depth_median
|
|
part_dict["img"] = img
|
|
part_dict["depth"] = depth
|
|
return part_dict
|
|
|
|
|
|
def _crop_head(img, xywh):
|
|
x, y, w, h = xywh
|
|
ih, iw = img.shape[:2]
|
|
x1, y1, x2, y2 = x, y, x + w, y + h
|
|
if w < iw // 2:
|
|
px = min(iw - x - w, x, w // 5)
|
|
x1 = min(max(x - px, 0), iw)
|
|
x2 = min(max(x + w + px, 0), iw)
|
|
if h < ih // 2:
|
|
py = min(ih - y - h, y, h // 5)
|
|
y2 = min(max(y + h + py, 0), ih)
|
|
y1 = min(max(y - py, 0), ih)
|
|
return img[y1:y2, x1:x2], (x1, y1, x2, y2)
|
|
|
|
|
|
def _make_preview(tag2pinfo, resolution):
|
|
drawables = list(tag2pinfo.values())
|
|
if drawables:
|
|
blended = img_alpha_blending(drawables, premultiplied=False, final_size=(resolution, resolution))
|
|
else:
|
|
blended = np.zeros((resolution, resolution, 4), dtype=np.uint8)
|
|
preview = blended[..., :3].astype(np.float32) / 255.0
|
|
return torch.from_numpy(preview).unsqueeze(0)
|
|
|
|
|
|
class SeeThrough_LoadLayerDiffModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
local_models = _scan_model_dirs()
|
|
model_list = local_models + [DEFAULT_LAYERDIFF_REPO, DEFAULT_LAYERDIFF_NF4_REPO]
|
|
return {
|
|
"required": {
|
|
"model": (model_list, {"default": DEFAULT_LAYERDIFF_REPO,
|
|
"tooltip": "HuggingFace repo ID or local model folder in models/SeeThrough/"}),
|
|
},
|
|
"optional": {
|
|
"vae_ckpt": ("STRING", {"default": "",
|
|
"tooltip": "Optional path to a custom VAE checkpoint (.safetensors)"}),
|
|
"unet_ckpt": ("STRING", {"default": "",
|
|
"tooltip": "Optional path to a custom UNet checkpoint"}),
|
|
"quant_mode": (QUANT_MODES, {"default": "none",
|
|
"tooltip": "Quantization mode: 'none' for bf16, 'nf4' for 4-bit NormalFloat quantization (~8GB VRAM). Requires bitsandbytes."}),
|
|
"cache_tag_embeds": ("BOOLEAN", {"default": True,
|
|
"tooltip": "Pre-compute and cache tag embeddings, then unload text encoders to save VRAM"}),
|
|
"group_offload": ("BOOLEAN", {"default": False,
|
|
"tooltip": "Enable group offload to reduce peak VRAM (~10GB) at cost of ~1.5x slower speed"}),
|
|
"auto_download": ("BOOLEAN", {"default": True,
|
|
"tooltip": "If model is not found locally, download from HuggingFace. Disable to force local-only and error out instead of downloading."}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SEETHROUGH_LAYERDIFF_MODEL",)
|
|
RETURN_NAMES = ("layerdiff_model",)
|
|
FUNCTION = "load_model"
|
|
CATEGORY = "SeeThrough"
|
|
|
|
def load_model(self, model, vae_ckpt="", unet_ckpt="", quant_mode="none", cache_tag_embeds=True, group_offload=False, auto_download=True):
|
|
use_nf4 = quant_mode == "nf4"
|
|
dtype = torch.bfloat16
|
|
|
|
if use_nf4 and model in _NF4_REPO_MAP:
|
|
model = _NF4_REPO_MAP[model]
|
|
print(f"[SeeThrough] quant_mode=nf4: auto-switched to NF4 repo {model}", flush=True)
|
|
|
|
pretrained = _resolve_model_path(model)
|
|
is_local = os.path.isdir(pretrained)
|
|
local_only = is_local or not auto_download
|
|
|
|
print(f"[SeeThrough] Loading LayerDiff model from: {pretrained} "
|
|
f"(quant_mode={quant_mode}, local={is_local}, local_files_only={local_only})", flush=True)
|
|
trans_vae = TransparentVAE.from_pretrained(pretrained, subfolder="trans_vae", local_files_only=local_only)
|
|
|
|
if unet_ckpt:
|
|
print(f"[SeeThrough] Loading custom UNet from: {unet_ckpt}", flush=True)
|
|
unet = UNetFrameConditionModel.from_pretrained(unet_ckpt)
|
|
else:
|
|
unet = UNetFrameConditionModel.from_pretrained(pretrained, subfolder="unet", local_files_only=local_only)
|
|
|
|
pipeline = KDiffusionStableDiffusionXLPipeline.from_pretrained(
|
|
pretrained, trans_vae=trans_vae, unet=unet, scheduler=None, local_files_only=local_only)
|
|
|
|
if vae_ckpt:
|
|
print(f"[SeeThrough] Loading custom VAE from: {vae_ckpt}", flush=True)
|
|
td_sd, vae_sd = {}, {}
|
|
sd = load_file(vae_ckpt)
|
|
for k, v in sd.items():
|
|
if k.startswith("trans_decoder."):
|
|
td_sd[k[len("trans_decoder."):]] = v
|
|
elif k.startswith("vae."):
|
|
vae_sd[k.replace("vae.", "")] = v
|
|
if vae_sd:
|
|
pipeline.vae.load_state_dict(vae_sd)
|
|
if td_sd:
|
|
pipeline.trans_vae.decoder.load_state_dict(td_sd)
|
|
|
|
pipeline.vae.to(dtype=dtype)
|
|
pipeline.trans_vae.to(dtype=dtype)
|
|
|
|
if use_nf4:
|
|
print("[SeeThrough] NF4 mode: casting non-quantized parameters to bf16", flush=True)
|
|
_cast_non_quantized_params(pipeline.unet, dtype)
|
|
_cast_non_quantized_params(pipeline.text_encoder, dtype)
|
|
_cast_non_quantized_params(pipeline.text_encoder_2, dtype)
|
|
else:
|
|
pipeline.unet.to(dtype=dtype)
|
|
pipeline.text_encoder.to(dtype=dtype)
|
|
pipeline.text_encoder_2.to(dtype=dtype)
|
|
|
|
pipeline._st_group_offload = False
|
|
if group_offload:
|
|
if hasattr(pipeline, 'enable_group_offload'):
|
|
print("[SeeThrough] Enabling group offload for LayerDiff pipeline", flush=True)
|
|
pipeline.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
pipeline._st_group_offload = True
|
|
else:
|
|
print("[SeeThrough] WARNING: group_offload requires diffusers >= 0.37.0, skipping. Please upgrade: pip install diffusers>=0.37.0", flush=True)
|
|
|
|
if cache_tag_embeds:
|
|
if not pipeline._st_group_offload:
|
|
device = mm.get_torch_device()
|
|
if not use_nf4:
|
|
pipeline.text_encoder.to(device)
|
|
pipeline.text_encoder_2.to(device)
|
|
print("[SeeThrough] Caching tag embeddings and unloading text encoders...", flush=True)
|
|
pipeline.cache_tag_embeds(unload_textencoders=True)
|
|
_log_vram("After cache_tag_embeds (text encoders unloaded)")
|
|
|
|
_log_vram("LayerDiff model loaded (CPU)")
|
|
print(f"[SeeThrough] LayerDiff model loaded (quant_mode={quant_mode})", flush=True)
|
|
return (pipeline,)
|
|
|
|
class SeeThrough_LoadDepthModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
local_models = _scan_model_dirs()
|
|
model_list = local_models + [DEFAULT_DEPTH_REPO, DEFAULT_DEPTH_NF4_REPO]
|
|
return {
|
|
"required": {
|
|
"model": (model_list, {"default": DEFAULT_DEPTH_REPO,
|
|
"tooltip": "HuggingFace repo ID or local model folder in models/SeeThrough/"}),
|
|
},
|
|
"optional": {
|
|
"quant_mode": (QUANT_MODES, {"default": "none",
|
|
"tooltip": "Quantization mode: 'none' for bf16, 'nf4' for 4-bit NormalFloat quantization. Requires bitsandbytes."}),
|
|
"cache_tag_embeds": ("BOOLEAN", {"default": True,
|
|
"tooltip": "Pre-compute empty text embedding and unload text encoder to save VRAM"}),
|
|
"group_offload": ("BOOLEAN", {"default": False,
|
|
"tooltip": "Enable group offload to reduce peak VRAM at cost of slower speed"}),
|
|
"auto_download": ("BOOLEAN", {"default": True,
|
|
"tooltip": "If model is not found locally, download from HuggingFace. Disable to force local-only and error out instead of downloading."}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SEETHROUGH_DEPTH_MODEL",)
|
|
RETURN_NAMES = ("depth_model",)
|
|
FUNCTION = "load_model"
|
|
CATEGORY = "SeeThrough"
|
|
|
|
def load_model(self, model, quant_mode="none", cache_tag_embeds=True, group_offload=False, auto_download=True):
|
|
use_nf4 = quant_mode == "nf4"
|
|
dtype = torch.bfloat16
|
|
|
|
if use_nf4 and model in _NF4_REPO_MAP:
|
|
model = _NF4_REPO_MAP[model]
|
|
print(f"[SeeThrough] quant_mode=nf4: auto-switched to NF4 repo {model}", flush=True)
|
|
|
|
pretrained = _resolve_model_path(model)
|
|
is_local = os.path.isdir(pretrained)
|
|
local_only = is_local or not auto_download
|
|
|
|
print(f"[SeeThrough] Loading Marigold depth model from: {pretrained} "
|
|
f"(quant_mode={quant_mode}, local={is_local}, local_files_only={local_only})", flush=True)
|
|
unet = UNetFrameConditionModel.from_pretrained(pretrained, subfolder="unet", local_files_only=local_only)
|
|
pipeline = MarigoldDepthPipeline.from_pretrained(pretrained, unet=unet, local_files_only=local_only)
|
|
|
|
if use_nf4:
|
|
print("[SeeThrough] NF4 mode: casting non-quantized parameters to bf16", flush=True)
|
|
pipeline.vae.to(dtype=dtype)
|
|
_cast_non_quantized_params(pipeline.unet, dtype)
|
|
_cast_non_quantized_params(pipeline.text_encoder, dtype)
|
|
else:
|
|
pipeline.to(dtype=dtype)
|
|
|
|
pipeline._st_group_offload = False
|
|
if group_offload:
|
|
if hasattr(pipeline, 'enable_group_offload'):
|
|
print("[SeeThrough] Enabling group offload for Marigold pipeline", flush=True)
|
|
pipeline.enable_group_offload('cuda', num_blocks_per_group=1)
|
|
pipeline._st_group_offload = True
|
|
else:
|
|
print("[SeeThrough] WARNING: group_offload requires diffusers >= 0.37.0, skipping. Please upgrade: pip install diffusers>=0.37.0", flush=True)
|
|
|
|
if cache_tag_embeds:
|
|
if not pipeline._st_group_offload:
|
|
device = mm.get_torch_device()
|
|
if not use_nf4:
|
|
pipeline.text_encoder.to(device)
|
|
print("[SeeThrough] Caching empty text embedding and unloading text encoder...", flush=True)
|
|
pipeline.cache_tag_embeds(unload_textencoders=True)
|
|
_log_vram("After Marigold cache_tag_embeds (text encoder unloaded)")
|
|
|
|
_log_vram("Depth model loaded (CPU)")
|
|
print(f"[SeeThrough] Depth model loaded (quant_mode={quant_mode})", flush=True)
|
|
return (pipeline,)
|
|
|
|
class SeeThrough_GenerateLayers:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"layerdiff_model": ("SEETHROUGH_LAYERDIFF_MODEL",),
|
|
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32 - 1}),
|
|
"resolution": ("INT", {"default": 1280, "min": 512, "max": 2048, "step": 64}),
|
|
"num_inference_steps": ("INT", {"default": 30, "min": 1, "max": 100}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SEETHROUGH_LAYERS", "IMAGE")
|
|
RETURN_NAMES = ("layers", "preview")
|
|
FUNCTION = "generate"
|
|
CATEGORY = "SeeThrough"
|
|
|
|
def generate(self, image, layerdiff_model, seed=42, resolution=1280, num_inference_steps=30):
|
|
pipeline = layerdiff_model
|
|
device = mm.get_torch_device()
|
|
offload = torch.device("cpu")
|
|
seed_everything(seed)
|
|
|
|
# Convert ComfyUI IMAGE to numpy RGBA
|
|
img_np = (image[0].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
if img_np.shape[-1] == 3:
|
|
img_np = np.concatenate([img_np, np.full((*img_np.shape[:2], 1), 255, dtype=np.uint8)], axis=-1)
|
|
input_img = img_np.copy()
|
|
|
|
fullpage, pad_size, pad_pos = center_square_pad_resize(input_img, resolution, return_pad_info=True)
|
|
scale = pad_size[0] / resolution
|
|
|
|
tag_version = pipeline.unet.get_tag_version()
|
|
layer_dict = {}
|
|
|
|
print(f"[SeeThrough] GenerateLayers: tag_version={tag_version}, resolution={resolution}, steps={num_inference_steps}", flush=True)
|
|
_log_vram("GenerateLayers start")
|
|
|
|
is_group_offload = getattr(pipeline, '_st_group_offload', False)
|
|
|
|
has_cached = len(pipeline._cached_prompt_embeds) > 0
|
|
if has_cached:
|
|
print("[SeeThrough] Using cached tag embeddings", flush=True)
|
|
encode_fn = pipeline.encode_cropped_prompt_77tokens_cached
|
|
else:
|
|
if not is_group_offload:
|
|
pipeline.text_encoder.to(device)
|
|
pipeline.text_encoder_2.to(device)
|
|
_log_vram("Text encoders loaded to GPU")
|
|
encode_fn = pipeline.encode_cropped_prompt_77tokens
|
|
|
|
if tag_version == "v2":
|
|
prompt_embeds, pooled_prompt_embeds = encode_fn(VALID_BODY_PARTS_V2)
|
|
elif tag_version == "v3":
|
|
body_tags = ["front hair", "back hair", "head", "neck", "neckwear",
|
|
"topwear", "handwear", "bottomwear", "legwear", "footwear",
|
|
"tail", "wings", "objects"]
|
|
head_tags = ["headwear", "face", "irides", "eyebrow", "eyewhite",
|
|
"eyelash", "eyewear", "ears", "earwear", "nose", "mouth"]
|
|
body_embeds, body_pooled = encode_fn(body_tags)
|
|
head_embeds, head_pooled = encode_fn(head_tags)
|
|
else:
|
|
raise ValueError(f"Unknown tag version: {tag_version}")
|
|
|
|
if not has_cached and not is_group_offload:
|
|
pipeline.text_encoder.to(offload)
|
|
pipeline.text_encoder_2.to(offload)
|
|
_log_vram("Text encoders offloaded to CPU")
|
|
|
|
if not is_group_offload:
|
|
pipeline.unet.to(device)
|
|
pipeline.vae.to(device)
|
|
pipeline.trans_vae.to(device)
|
|
mm.soft_empty_cache()
|
|
_log_vram("UNet+VAE on GPU, ready for diffusion")
|
|
|
|
rng = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
print("[SeeThrough] Text encoded, text encoders offloaded to CPU", flush=True)
|
|
|
|
if tag_version == "v2":
|
|
out = pipeline(strength=1.0, num_inference_steps=num_inference_steps, batch_size=1,
|
|
generator=rng, guidance_scale=1.0,
|
|
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
|
fullpage=fullpage)
|
|
_log_vram("v2 diffusion complete")
|
|
for rst, tag in zip(out.images, VALID_BODY_PARTS_V2):
|
|
layer_dict[tag] = rst
|
|
|
|
elif tag_version == "v3":
|
|
out = pipeline(strength=1.0, num_inference_steps=num_inference_steps, batch_size=1,
|
|
generator=rng, guidance_scale=1.0,
|
|
prompt_embeds=body_embeds, pooled_prompt_embeds=body_pooled,
|
|
fullpage=fullpage, group_index=0)
|
|
_log_vram("v3 body diffusion complete")
|
|
for rst, tag in zip(out.images, body_tags):
|
|
layer_dict[tag] = rst
|
|
|
|
head_img = out.images[2]
|
|
nz = cv2.findNonZero((head_img[..., -1] > 15).astype(np.uint8))
|
|
if nz is not None:
|
|
hx0, hy0, hw, hh = cv2.boundingRect(nz)
|
|
hx = int(hx0 * scale) - pad_pos[0]
|
|
hy = int(hy0 * scale) - pad_pos[1]
|
|
input_head, (hx1, hy1, hx2, hy2) = _crop_head(input_img, [hx, hy, int(hw * scale), int(hh * scale)])
|
|
hx1 = int(hx1 / scale + pad_pos[0] / scale)
|
|
hy1 = int(hy1 / scale + pad_pos[1] / scale)
|
|
ih, iw = input_head.shape[:2]
|
|
input_head, head_pad_size, head_pad_pos = center_square_pad_resize(input_head, resolution, return_pad_info=True)
|
|
|
|
out = pipeline(strength=1.0, num_inference_steps=num_inference_steps, batch_size=1,
|
|
generator=rng, guidance_scale=1.0,
|
|
prompt_embeds=head_embeds, pooled_prompt_embeds=head_pooled,
|
|
fullpage=input_head, group_index=1)
|
|
_log_vram("v3 head diffusion complete")
|
|
|
|
canvas = np.zeros((resolution, resolution, 4), dtype=np.uint8)
|
|
coords = np.array([head_pad_pos[1], head_pad_pos[1] + ih, head_pad_pos[0], head_pad_pos[0] + iw])
|
|
py1, py2, px1, px2 = (coords / scale).astype(np.int64)
|
|
scale_size = (int(head_pad_size[0] / scale), int(head_pad_size[1] / scale))
|
|
|
|
for rst, tag in zip(out.images, head_tags):
|
|
rst = smart_resize(rst, scale_size)[py1:py2, px1:px2]
|
|
full = canvas.copy()
|
|
full[hy1:hy1 + rst.shape[0], hx1:hx1 + rst.shape[1]] = rst
|
|
layer_dict[tag] = full
|
|
|
|
if not is_group_offload:
|
|
pipeline.unet.to(offload)
|
|
pipeline.vae.to(offload)
|
|
pipeline.trans_vae.to(offload)
|
|
mm.soft_empty_cache()
|
|
_log_vram("GenerateLayers offloaded to CPU")
|
|
print(f"[SeeThrough] GenerateLayers complete: {len(layer_dict)} layers", flush=True)
|
|
|
|
layers_data = SeeThrough_LayersData(layer_dict, fullpage, input_img, resolution, pad_size, pad_pos)
|
|
|
|
preview_dict = {}
|
|
for tag, img in layer_dict.items():
|
|
mask = img[..., -1] > 10
|
|
if np.any(mask):
|
|
preview_dict[tag] = {"img": img, "xyxy": [0, 0, resolution, resolution]}
|
|
preview = _make_preview(preview_dict, resolution)
|
|
|
|
return (layers_data, preview)
|
|
|
|
|
|
class SeeThrough_GenerateDepth:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"layers": ("SEETHROUGH_LAYERS",),
|
|
"depth_model": ("SEETHROUGH_DEPTH_MODEL",),
|
|
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32 - 1}),
|
|
},
|
|
"optional": {
|
|
"resolution_depth": ("INT", {"default": -1, "min": -1, "max": 2048, "step": 64,
|
|
"tooltip": "Resolution for depth inference. -1 uses the same resolution as layers. Lower values save VRAM and speed up inference."}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SEETHROUGH_LAYERS_DEPTH", "IMAGE")
|
|
RETURN_NAMES = ("layers_depth", "preview")
|
|
FUNCTION = "generate"
|
|
CATEGORY = "SeeThrough"
|
|
|
|
def generate(self, layers, depth_model, seed=42, resolution_depth=-1):
|
|
layer_dict = layers.layer_dict
|
|
fullpage = layers.fullpage
|
|
resolution = layers.resolution
|
|
marigold = depth_model
|
|
device = mm.get_torch_device()
|
|
offload = torch.device("cpu")
|
|
is_group_offload = getattr(marigold, '_st_group_offload', False)
|
|
|
|
print("[SeeThrough] GenerateDepth: running Marigold...", flush=True)
|
|
_log_vram("GenerateDepth start")
|
|
|
|
empty_array = np.zeros((resolution, resolution, 4), dtype=np.uint8)
|
|
blended_alpha = np.zeros((resolution, resolution), dtype=np.float32)
|
|
compose_list = {"eyes": ["eyewhite", "irides", "eyelash", "eyebrow"],
|
|
"hair": ["back hair", "front hair"]}
|
|
|
|
img_list = []
|
|
for tag in VALID_BODY_PARTS_V2:
|
|
if tag in layer_dict:
|
|
tag_arr = layer_dict[tag].copy()
|
|
tag_arr[..., -1][tag_arr[..., -1] < 15] = 0
|
|
img_list.append(tag_arr)
|
|
else:
|
|
img_list.append(empty_array.copy())
|
|
|
|
compose_dict = {}
|
|
for c, clist in compose_list.items():
|
|
imlist, taglist = [], []
|
|
for t in clist:
|
|
if t in layer_dict:
|
|
tag_arr = layer_dict[t].copy()
|
|
tag_arr[..., -1][tag_arr[..., -1] < 15] = 0
|
|
imlist.append(tag_arr)
|
|
taglist.append(t)
|
|
if imlist:
|
|
composed = img_alpha_blending(imlist, premultiplied=False)
|
|
img_list[VALID_BODY_PARTS_V2.index(c)] = composed
|
|
compose_dict[c] = {"taglist": taglist, "imlist": imlist}
|
|
|
|
for img in img_list:
|
|
blended_alpha += img[..., -1].astype(np.float32) / 255
|
|
blended_alpha = np.clip(blended_alpha, 0, 1) * 255
|
|
blended_alpha = blended_alpha.astype(np.uint8)
|
|
|
|
fullpage_for_depth = fullpage.copy()
|
|
fullpage_for_depth[..., -1] = blended_alpha
|
|
img_list.append(fullpage_for_depth)
|
|
|
|
src_h, src_w = resolution, resolution
|
|
if resolution_depth > 0 and resolution_depth != resolution:
|
|
depth_res = resolution_depth
|
|
depth_res = max(64, (depth_res // 8) * 8)
|
|
img_list_input = [smart_resize(img, (depth_res, depth_res)) for img in img_list]
|
|
print(f"[SeeThrough] Depth inference at resolution {depth_res} (layers at {resolution})", flush=True)
|
|
else:
|
|
img_list_input = img_list
|
|
depth_res = resolution
|
|
|
|
if not is_group_offload:
|
|
marigold.unet.to(device)
|
|
marigold.vae.to(device)
|
|
mm.soft_empty_cache()
|
|
_log_vram("Marigold on GPU")
|
|
print("[SeeThrough] Marigold pipeline moved to GPU", flush=True)
|
|
|
|
seed_everything(seed)
|
|
pipe_out = marigold(color_map=None, show_progress_bar=False, img_list=img_list_input)
|
|
_log_vram("Marigold inference complete")
|
|
depth_pred = pipe_out.depth_tensor.to(device="cpu", dtype=torch.float32).numpy()
|
|
|
|
if depth_res != resolution:
|
|
depth_pred = np.stack([smart_resize(d, (src_h, src_w)) for d in depth_pred])
|
|
|
|
if not is_group_offload:
|
|
marigold.unet.to(offload)
|
|
marigold.vae.to(offload)
|
|
mm.soft_empty_cache()
|
|
_log_vram("GenerateDepth offloaded to CPU")
|
|
|
|
depth_dict = {}
|
|
for ii, tag in enumerate(VALID_BODY_PARTS_V2):
|
|
depth = depth_pred[ii]
|
|
if tag in compose_dict:
|
|
mask_accum = blended_alpha > 256 # all-False
|
|
for t, im in zip(compose_dict[tag]["taglist"][::-1], compose_dict[tag]["imlist"][::-1]):
|
|
mask_local = im[..., -1] > 15
|
|
mask_invis = np.bitwise_and(mask_accum, mask_local)
|
|
depth_local = np.full((resolution, resolution), fill_value=1.0, dtype=np.float32)
|
|
depth_local[mask_local] = depth[mask_local]
|
|
if np.any(mask_invis):
|
|
vis = np.bitwise_and(mask_local, np.bitwise_not(mask_invis))
|
|
if np.any(vis):
|
|
depth_local[mask_invis] = np.median(depth[vis])
|
|
mask_accum = np.bitwise_or(mask_accum, mask_local)
|
|
depth_dict[t] = depth_local
|
|
else:
|
|
depth_dict[tag] = np.clip(depth, 0, 1).astype(np.float32)
|
|
|
|
print(f"[SeeThrough] GenerateDepth complete: {len(depth_dict)} depth maps, Marigold offloaded to CPU", flush=True)
|
|
|
|
result = SeeThrough_LayersDepthData(layer_dict, depth_dict, fullpage, resolution)
|
|
|
|
# Preview: blend with depth info
|
|
preview_dict = {}
|
|
for tag in layer_dict:
|
|
img = layer_dict[tag]
|
|
if tag in depth_dict and np.any(img[..., -1] > 10):
|
|
preview_dict[tag] = {"img": img, "depth": depth_dict[tag], "xyxy": [0, 0, resolution, resolution]}
|
|
preview = _make_preview(preview_dict, resolution)
|
|
|
|
return (result, preview)
|
|
|
|
class SeeThrough_PostProcess:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"layers_depth": ("SEETHROUGH_LAYERS_DEPTH",),
|
|
"tblr_split": ("BOOLEAN", {"default": True,
|
|
"tooltip": "Split symmetric parts (eyes, ears, handwear) into left/right"}),
|
|
"use_lama": ("BOOLEAN", {"default": True,
|
|
"tooltip": "Use LaMa inpainting for hair splitting (better quality). Falls back to OpenCV if disabled."}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SEETHROUGH_PARTS", "IMAGE")
|
|
RETURN_NAMES = ("parts", "preview")
|
|
FUNCTION = "process"
|
|
CATEGORY = "SeeThrough"
|
|
|
|
def process(self, layers_depth, tblr_split=True, use_lama=True):
|
|
layer_dict = layers_depth.layer_dict
|
|
depth_dict = layers_depth.depth_dict
|
|
fullpage = layers_depth.fullpage
|
|
resolution = layers_depth.resolution
|
|
|
|
print("[SeeThrough] PostProcess: splitting & clustering...", flush=True)
|
|
|
|
# Build tag2pinfo
|
|
tag2pinfo = {}
|
|
for tag in layer_dict:
|
|
img = layer_dict[tag]
|
|
if tag not in depth_dict:
|
|
continue
|
|
depth = depth_dict[tag]
|
|
mask = img[..., -1] > 10
|
|
if not np.any(mask):
|
|
continue
|
|
tag2pinfo[tag] = {"img": img, "depth": depth, "xyxy": [0, 0, resolution, resolution],
|
|
"mask": mask, "tag": tag}
|
|
|
|
# Eye splitting (v2 composite 'eyes')
|
|
if "eyes" in tag2pinfo:
|
|
part_info = tag2pinfo.pop("eyes")
|
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
|
part_info["mask"].astype(np.uint8) * 255, connectivity=8)
|
|
if len(stats) > 2:
|
|
stats_arr = np.array(stats)
|
|
if len(stats_arr[..., -1]) >= 5:
|
|
stats_order = np.argsort(stats_arr[..., -1])[::-1][1:]
|
|
eyel_mask, eyer_mask, statsl, statsr = _label_lr_split(labels, stats_arr, stats_order[0], stats_order[1])
|
|
img, depth, xyxy, _ = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsl)
|
|
tag2pinfo["eyer"] = {"img": img, "xyxy": xyxy, "depth": depth}
|
|
img, depth, xyxy, _ = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsr)
|
|
tag2pinfo["eyel"] = {"img": img, "xyxy": xyxy, "depth": depth}
|
|
if len(stats_order) >= 4:
|
|
browl_mask, browr_mask, statsl, statsr = _label_lr_split(labels, stats_arr, stats_order[2], stats_order[3])
|
|
img, depth, xyxy, _ = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsl)
|
|
tag2pinfo["browr"] = {"img": img, "xyxy": xyxy, "depth": depth}
|
|
img, depth, xyxy, _ = _process_cuts(part_info["img"], part_info["depth"], part_info["xyxy"], statsr)
|
|
tag2pinfo["browl"] = {"img": img, "xyxy": xyxy, "depth": depth}
|
|
else:
|
|
tag2pinfo["eyes"] = part_info
|
|
else:
|
|
tag2pinfo["eyes"] = part_info
|
|
|
|
# Left-right splitting
|
|
if tblr_split:
|
|
_tag_lr_split("handwear", tag2pinfo)
|
|
for eye_tag in ["eyewhite", "irides", "eyelash", "eyebrow"]:
|
|
_tag_lr_split(eye_tag, tag2pinfo)
|
|
_tag_lr_split("ears", tag2pinfo)
|
|
|
|
if "hair" in tag2pinfo:
|
|
part_info = tag2pinfo.pop("hair")
|
|
try:
|
|
inpaint_mode = "lama" if use_lama else "cv2"
|
|
parts = cluster_inpaint_part(inpaint=inpaint_mode, **part_info)
|
|
parts.sort(key=lambda x: x["depth_median"])
|
|
tag2pinfo["hairf"] = parts[0]
|
|
tag2pinfo["hairb"] = parts[1]
|
|
except Exception as e:
|
|
print(f"[SeeThrough] Hair clustering failed: {e}, keeping as-is", flush=True)
|
|
tag2pinfo["hair"] = part_info
|
|
|
|
# Nose/mouth color restoration
|
|
for restore_tag in ("nose", "mouth"):
|
|
if restore_tag in tag2pinfo:
|
|
pinfo = tag2pinfo[restore_tag]
|
|
src_h, src_w = pinfo["img"].shape[:2]
|
|
fp_h, fp_w = fullpage.shape[:2]
|
|
if src_h == fp_h and src_w == fp_w:
|
|
pinfo["img"][..., :3] = fullpage[..., :3]
|
|
else:
|
|
x1, y1 = pinfo["xyxy"][0], pinfo["xyxy"][1]
|
|
crop = fullpage[y1:min(y1 + src_h, fp_h), x1:min(x1 + src_w, fp_w), :3]
|
|
pinfo["img"][:crop.shape[0], :crop.shape[1], :3] = crop
|
|
|
|
# Crop + depth_median
|
|
for tag in list(tag2pinfo.keys()):
|
|
pinfo = tag2pinfo[tag]
|
|
if "img" in pinfo and "depth" in pinfo:
|
|
_compute_depth_median(pinfo)
|
|
pinfo["tag"] = tag
|
|
|
|
# Depth ordering adjustments
|
|
if "face" in tag2pinfo:
|
|
face_dm = tag2pinfo["face"]["depth_median"]
|
|
for t in ["nose", "mouth", "eyes", "eyel", "eyer"]:
|
|
if t in tag2pinfo and tag2pinfo[t]["depth_median"] > face_dm:
|
|
tag2pinfo[t]["depth_median"] = face_dm - 0.001
|
|
for t in ["earr", "earl", "ears"]:
|
|
if t in tag2pinfo:
|
|
tag2pinfo[t]["depth_median"] = face_dm + 0.001
|
|
|
|
frame_size = fullpage.shape[:2]
|
|
parts_data = {"tag2pinfo": tag2pinfo, "frame_size": frame_size}
|
|
|
|
print(f"[SeeThrough] PostProcess complete: {len(tag2pinfo)} layers", flush=True)
|
|
for tag, pinfo in sorted(tag2pinfo.items(), key=lambda x: x[1].get("depth_median", 1)):
|
|
dm = pinfo.get("depth_median", "?")
|
|
print(f" - {tag}: depth_median={dm:.4f}" if isinstance(dm, float) else f" - {tag}", flush=True)
|
|
|
|
preview = _make_preview(tag2pinfo, resolution)
|
|
return (parts_data, preview)
|
|
|
|
|
|
class SeeThrough_SavePSD:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"parts": ("SEETHROUGH_PARTS",),
|
|
"filename_prefix": ("STRING", {"default": "seethrough"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("info_file",)
|
|
FUNCTION = "save"
|
|
CATEGORY = "SeeThrough"
|
|
OUTPUT_NODE = True
|
|
|
|
def save(self, parts, filename_prefix="seethrough"):
|
|
from PIL import Image
|
|
import json
|
|
|
|
tag2pinfo = parts["tag2pinfo"]
|
|
frame_size = parts["frame_size"]
|
|
canvas_h, canvas_w = frame_size
|
|
|
|
output_dir = folder_paths.get_output_directory()
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
uid = str(uuid.uuid4())[:8]
|
|
|
|
sorted_tags = sorted(tag2pinfo.keys(), key=lambda t: tag2pinfo[t].get("depth_median", 1), reverse=True)
|
|
|
|
layer_info_list = []
|
|
for tag in sorted_tags:
|
|
pinfo = tag2pinfo[tag]
|
|
img = pinfo.get("img")
|
|
depth = pinfo.get("depth")
|
|
if img is None:
|
|
continue
|
|
|
|
xyxy = pinfo.get("xyxy", [0, 0, img.shape[1], img.shape[0]])
|
|
x1, y1, x2, y2 = [int(v) for v in xyxy]
|
|
|
|
layer_filename = f"{filename_prefix}_{ts}_{uid}_{tag}.png"
|
|
Image.fromarray(img).save(os.path.join(output_dir, layer_filename))
|
|
|
|
entry = {"name": tag, "filename": layer_filename,
|
|
"left": x1, "top": y1, "right": x2, "bottom": y2,
|
|
"depth_median": float(pinfo.get("depth_median", 1))}
|
|
|
|
if depth is not None:
|
|
depth_filename = f"{filename_prefix}_{ts}_{uid}_{tag}_depth.png"
|
|
if depth.ndim == 2:
|
|
Image.fromarray(depth, mode="L").save(os.path.join(output_dir, depth_filename))
|
|
else:
|
|
Image.fromarray(depth).save(os.path.join(output_dir, depth_filename))
|
|
entry["depth_filename"] = depth_filename
|
|
|
|
layer_info_list.append(entry)
|
|
|
|
info_filename = f"{filename_prefix}_{ts}_{uid}_layers.json"
|
|
info_path = os.path.join(output_dir, info_filename)
|
|
with open(info_path, "w", encoding="utf-8") as f:
|
|
json.dump({"prefix": filename_prefix, "timestamp": f"{ts}_{uid}",
|
|
"layers": layer_info_list, "width": int(canvas_w), "height": int(canvas_h)}, f, indent=2)
|
|
|
|
log_path = os.path.join(output_dir, "seethrough_psd_info.log")
|
|
with open(log_path, "w") as f:
|
|
f.write(info_filename)
|
|
|
|
print(f"[SeeThrough] {len(layer_info_list)} layers saved. Use 'Download PSD' button to generate PSD.", flush=True)
|
|
return (info_path,)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"SeeThrough_LoadLayerDiffModel": SeeThrough_LoadLayerDiffModel,
|
|
"SeeThrough_LoadDepthModel": SeeThrough_LoadDepthModel,
|
|
"SeeThrough_GenerateLayers": SeeThrough_GenerateLayers,
|
|
"SeeThrough_GenerateDepth": SeeThrough_GenerateDepth,
|
|
"SeeThrough_PostProcess": SeeThrough_PostProcess,
|
|
"SeeThrough_SavePSD": SeeThrough_SavePSD,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"SeeThrough_LoadLayerDiffModel": "SeeThrough Load LayerDiff Model",
|
|
"SeeThrough_LoadDepthModel": "SeeThrough Load Depth Model",
|
|
"SeeThrough_GenerateLayers": "SeeThrough Generate Layers",
|
|
"SeeThrough_GenerateDepth": "SeeThrough Generate Depth",
|
|
"SeeThrough_PostProcess": "SeeThrough Post Process",
|
|
"SeeThrough_SavePSD": "SeeThrough Save PSD",
|
|
}
|