mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
292 lines
No EOL
13 KiB
Python
292 lines
No EOL
13 KiB
Python
from typing import List, Tuple
|
|
|
|
from torch import nn
|
|
import torch
|
|
from einops import rearrange
|
|
|
|
from .sam.modeling import Sam
|
|
from .sam.modeling.mask_decoder import MLP
|
|
from .sam import sam_model_registry
|
|
from .extend_sam import BaseExtendSam, BaseMaskDecoderAdapter, MaskDecoder
|
|
|
|
|
|
class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
|
|
|
|
def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True):
|
|
super(SemMaskDecoderAdapter, self).__init__(sam_mask_decoder, fix)
|
|
|
|
self.class_num = class_num
|
|
self.is_hq = self.sam_mask_decoder.is_hq
|
|
self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens
|
|
|
|
transformer_dim = self.sam_mask_decoder.transformer_dim
|
|
|
|
self.output_hypernetworks_mlps = nn.ModuleList(
|
|
[
|
|
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
for _ in range(self.class_num)
|
|
]
|
|
)
|
|
if init_from_sam:
|
|
target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict()
|
|
for ii in range(class_num):
|
|
self.output_hypernetworks_mlps[ii].load_state_dict(target_sd)
|
|
del self.sam_mask_decoder.output_hypernetworks_mlps
|
|
|
|
if self.is_hq:
|
|
self.hf_mlps = nn.ModuleList(
|
|
[
|
|
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
for _ in range(self.class_num)
|
|
]
|
|
)
|
|
if init_from_sam:
|
|
target_sd = self.sam_mask_decoder.hf_mlp.state_dict()
|
|
for ii in range(class_num):
|
|
self.hf_mlps[ii].load_state_dict(target_sd)
|
|
del self.sam_mask_decoder.hf_mlp
|
|
# input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token]
|
|
# num_mask_tokens: 4 + 1
|
|
self.hf_token_idx = self.num_mask_tokens
|
|
|
|
def forward(
|
|
self,
|
|
image_embeddings: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
sparse_prompt_embeddings: torch.Tensor,
|
|
dense_prompt_embeddings: torch.Tensor,
|
|
multimask_output: bool = True,
|
|
hq_token_only: bool = False,
|
|
interm_embeddings: torch.Tensor = None,
|
|
mask_scale=1,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Predict masks given image and prompt embeddings.
|
|
|
|
Arguments:
|
|
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
|
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
|
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
|
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
|
multimask_output (bool): Whether to return multiple masks or a single
|
|
mask.
|
|
|
|
Returns:
|
|
torch.Tensor: batched predicted masks
|
|
torch.Tensor: batched predictions of mask quality
|
|
"""
|
|
hq_features = None
|
|
|
|
# token processing
|
|
if self.is_hq:
|
|
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
|
hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features)
|
|
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight, self.sam_mask_decoder.hf_token.weight]
|
|
else:
|
|
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight]
|
|
output_tokens = torch.cat(output_tokens, dim=0)
|
|
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
|
# tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim)
|
|
# Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel
|
|
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
|
src = src + dense_prompt_embeddings
|
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
|
src = src.to(dtype=pos_src.dtype)
|
|
tokens = tokens.to(dtype=pos_src.dtype)
|
|
b, c, h, w = src.shape
|
|
|
|
# Run the transformer
|
|
hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens)
|
|
iou_token_out = hs[:, 0, :]
|
|
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
|
|
|
# Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks
|
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src)
|
|
|
|
hyper_in_list: List[torch.Tensor] = []
|
|
hyper_hq_list: List[torch.Tensor] = []
|
|
for i in range(self.class_num):
|
|
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :]))
|
|
if self.is_hq:
|
|
hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :]))
|
|
|
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
b, c, h, w = upscaled_embedding_sam.shape
|
|
masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
|
|
iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out)
|
|
|
|
if self.is_hq:
|
|
hyper_hq = torch.stack(hyper_hq_list, dim=1)
|
|
upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features
|
|
masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
|
|
if hq_token_only:
|
|
masks = masks_hq
|
|
else:
|
|
masks = masks_sam + masks_hq
|
|
|
|
else:
|
|
masks = masks_sam
|
|
|
|
# Generate mask quality predictions
|
|
# Prepare output
|
|
return masks, iou_pred
|
|
|
|
|
|
class SemMaskDecoderAdapterTokenVariant(BaseMaskDecoderAdapter):
|
|
|
|
def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True):
|
|
super(SemMaskDecoderAdapterTokenVariant, self).__init__(sam_mask_decoder, fix)
|
|
|
|
self.class_num = class_num
|
|
self.is_hq = self.sam_mask_decoder.is_hq
|
|
# self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens
|
|
|
|
|
|
|
|
transformer_dim = self.sam_mask_decoder.transformer_dim
|
|
self.sem_mask_tokens = nn.Embedding(class_num, transformer_dim)
|
|
|
|
self.output_hypernetworks_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
if init_from_sam:
|
|
target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict()
|
|
self.output_hypernetworks_mlp.load_state_dict(target_sd)
|
|
target_sd = self.sam_mask_decoder.mask_tokens.state_dict()
|
|
target_sd = {'weight': target_sd['weight'][[1]].repeat(class_num, 1)}
|
|
self.sem_mask_tokens.load_state_dict(target_sd)
|
|
pass
|
|
|
|
del self.sam_mask_decoder.mask_tokens
|
|
del self.sam_mask_decoder.output_hypernetworks_mlps
|
|
|
|
if self.is_hq:
|
|
self.hq_mask_tokens = nn.Embedding(class_num, transformer_dim)
|
|
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
if init_from_sam:
|
|
target_sd = self.sam_mask_decoder.hf_mlp.state_dict()
|
|
self.hf_mlp.load_state_dict(target_sd)
|
|
target_sd = self.sam_mask_decoder.hf_token.state_dict()
|
|
target_sd = {'weight': target_sd['weight'].repeat(class_num, 1)}
|
|
self.hq_mask_tokens.load_state_dict(target_sd)
|
|
del self.sam_mask_decoder.hf_mlp
|
|
del self.sam_mask_decoder.hf_token
|
|
# input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token]
|
|
# num_mask_tokens: 4 + 1
|
|
# self.hf_token_idx = self.num_mask_tokens
|
|
|
|
def forward(
|
|
self,
|
|
image_embeddings: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
sparse_prompt_embeddings: torch.Tensor,
|
|
dense_prompt_embeddings: torch.Tensor,
|
|
multimask_output: bool = True,
|
|
hq_token_only: bool = False,
|
|
interm_embeddings: torch.Tensor = None,
|
|
mask_scale=1,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Predict masks given image and prompt embeddings.
|
|
|
|
Arguments:
|
|
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
|
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
|
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
|
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
|
multimask_output (bool): Whether to return multiple masks or a single
|
|
mask.
|
|
|
|
Returns:
|
|
torch.Tensor: batched predicted masks
|
|
torch.Tensor: batched predictions of mask quality
|
|
"""
|
|
hq_features = None
|
|
|
|
# token processing
|
|
if self.is_hq:
|
|
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
|
hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features)
|
|
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight, self.hq_mask_tokens.weight]
|
|
else:
|
|
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight]
|
|
output_tokens = torch.cat(output_tokens, dim=0)
|
|
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
|
# tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim)
|
|
# Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel
|
|
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
|
src = src + dense_prompt_embeddings
|
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
|
src = src.to(dtype=pos_src.dtype)
|
|
tokens = tokens.to(dtype=pos_src.dtype)
|
|
b, c, h, w = src.shape
|
|
|
|
# Run the transformer
|
|
hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens)
|
|
iou_token_out = hs[:, 0, :]
|
|
mask_tokens_out = hs[:, 1 : (1 + self.class_num), :]
|
|
|
|
# Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks
|
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src)
|
|
|
|
hyper_in = self.output_hypernetworks_mlp(rearrange(mask_tokens_out, 'b c d -> (b c) d'))
|
|
hyper_in = rearrange(hyper_in, '(b c) d -> b c d', b=b)
|
|
# for i in range(self.class_num):
|
|
# hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :]))
|
|
if self.is_hq:
|
|
# hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :]))
|
|
hyper_hq = self.hf_mlp(rearrange(hs[:, 1 + self.class_num: (1 + 2 * self.class_num), :], 'b c d -> (b c) d'))
|
|
hyper_hq = rearrange(hyper_hq, '(b c) d -> b c d', b=b)
|
|
|
|
# hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
b, c, h, w = upscaled_embedding_sam.shape
|
|
masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
|
|
iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out)
|
|
|
|
if self.is_hq:
|
|
# hyper_hq = torch.stack(hyper_hq_list, dim=1)
|
|
upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features
|
|
masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
|
|
if hq_token_only:
|
|
masks = masks_hq
|
|
else:
|
|
masks = masks_sam + masks_hq
|
|
|
|
else:
|
|
masks = masks_sam
|
|
|
|
# Generate mask quality predictions
|
|
# Prepare output
|
|
return masks, iou_pred
|
|
|
|
|
|
class SemanticSam(BaseExtendSam):
|
|
|
|
def __init__(self,
|
|
class_num,
|
|
sam: Sam = None,
|
|
fix_img_en=False,
|
|
fix_prompt_en=False,
|
|
fix_mask_de=False,
|
|
model_type: str = 'h_hq',
|
|
mask_decoder='mlp_variant',
|
|
**kwargs):
|
|
|
|
init_from_sam = sam is not None
|
|
if sam is None:
|
|
build_sam = sam_model_registry[model_type]['build']
|
|
sam = build_sam()
|
|
|
|
super().__init__(sam=sam, fix_img_en=fix_img_en, fix_mask_de=fix_mask_de, fix_prompt_en=fix_prompt_en)
|
|
sam_mask_decoder = self.mask_adapter.sam_mask_decoder
|
|
del self.mask_adapter
|
|
if mask_decoder == 'mlp_variant':
|
|
self.mask_adapter = SemMaskDecoderAdapter(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam)
|
|
elif mask_decoder == 'token_variant':
|
|
self.mask_adapter = SemMaskDecoderAdapterTokenVariant(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam)
|
|
else:
|
|
raise Exception(f'Invalid mask decoder: {mask_decoder}') |