see-through/annotators/lang_sam/models/utils.py

23 lines
600 B
Python

import logging
import torch
def get_device_type() -> str:
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
logging.warning("No GPU found, using CPU instead")
return "cpu"
device_type = get_device_type()
DEVICE = torch.device(device_type)
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True