mirror of
https://github.com/shitagaki-lab/see-through.git
synced 2026-05-05 19:58:57 +00:00
23 lines
600 B
Python
23 lines
600 B
Python
import logging
|
|
|
|
import torch
|
|
|
|
|
|
def get_device_type() -> str:
|
|
if torch.backends.mps.is_available():
|
|
return "mps"
|
|
elif torch.cuda.is_available():
|
|
return "cuda"
|
|
else:
|
|
logging.warning("No GPU found, using CPU instead")
|
|
return "cpu"
|
|
|
|
|
|
device_type = get_device_type()
|
|
DEVICE = torch.device(device_type)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|