mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
72 lines
No EOL
1.9 KiB
Python
72 lines
No EOL
1.9 KiB
Python
from functools import partial
|
|
|
|
import torch
|
|
|
|
from .modeling import ImageEncoderViT, PromptEncoder
|
|
from .modeling.tiny_vit_sam import TinyViT
|
|
|
|
model_type_registry = dict(
|
|
vit_l = dict(
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
global_attn_indexes=[5, 11, 17, 23]
|
|
),
|
|
|
|
vit_h = dict(
|
|
embed_dim=1280,
|
|
depth=32,
|
|
num_heads=16,
|
|
global_attn_indexes=[7, 15, 23, 31],
|
|
),
|
|
|
|
vit_b = dict(
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
global_attn_indexes=[2, 5, 8, 11],),
|
|
)
|
|
|
|
def build_image_encoder(model_type: str):
|
|
|
|
if model_type == 'vit_t':
|
|
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
|
|
embed_dims=[64, 128, 160, 320],
|
|
depths=[2, 2, 6, 2],
|
|
num_heads=[2, 4, 5, 10],
|
|
window_sizes=[7, 7, 14, 7],
|
|
mlp_ratio=4.,
|
|
drop_rate=0.,
|
|
drop_path_rate=0.0,
|
|
use_checkpoint=False,
|
|
mbconv_expand_ratio=4.0,
|
|
local_conv_size=3,
|
|
layer_lr_decay=0.8
|
|
)
|
|
else:
|
|
assert model_type in model_type_registry
|
|
image_encoder = ImageEncoderViT(
|
|
img_size=1024,
|
|
mlp_ratio=4,
|
|
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
|
patch_size=16,
|
|
qkv_bias=True,
|
|
use_rel_pos=True,
|
|
window_size=14,
|
|
out_chans=256,
|
|
**model_type_registry[model_type]
|
|
)
|
|
|
|
return image_encoder
|
|
|
|
def build_prompt_encoder(image_size = 1024, vit_patch_size = 16):
|
|
image_embedding_size = image_size // vit_patch_size
|
|
prompt_encoder=PromptEncoder(
|
|
embed_dim=256,
|
|
image_embedding_size=(image_embedding_size, image_embedding_size),
|
|
input_image_size=(image_size, image_size),
|
|
mask_in_chans=16,
|
|
)
|
|
return prompt_encoder
|
|
|
|
|