mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
initial support for sharding
This commit is contained in:
parent
b36301f382
commit
e3670813b8
4 changed files with 50 additions and 7 deletions
|
|
@ -1,6 +1,13 @@
|
|||
from examples.mlperf.metrics import dice_score
|
||||
|
||||
def cross_entropy_loss(x, y, reduction='mean', label_smoothing=0.0):
|
||||
divisor = y.shape[1]
|
||||
y = (1 - label_smoothing)*y + label_smoothing / divisor
|
||||
if reduction == "none": return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction == "sum": return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
|
||||
def dice_ce_loss(pred, tgt):
|
||||
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt)
|
||||
ce = cross_entropy_loss(pred, tgt)
|
||||
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
||||
return (dice + ce) / 2
|
||||
|
|
|
|||
|
|
@ -30,8 +30,27 @@ def word_error_rate(x, y):
|
|||
return float(scores) / words, float(scores), words
|
||||
|
||||
def one_hot(arr, num_classes=3, channel_axis=1):
|
||||
def _unshard_reshape_shard(x):
|
||||
from tinygrad.helpers import getenv
|
||||
if (gpus:=getenv("GPUS")) > 1:
|
||||
from tinygrad import Device
|
||||
x = x.to(Device.DEFAULT)
|
||||
x = x.reshape(-1)
|
||||
return x.shard_([f"GPU:{i}" for i in range(gpus)])
|
||||
return x
|
||||
|
||||
if len(arr.shape) >= 5: arr = arr.squeeze(dim=channel_axis)
|
||||
res = Tensor.eye(num_classes)[arr.reshape(-1)]
|
||||
arr_reshape = _unshard_reshape_shard(arr)
|
||||
res = Tensor.eye(num_classes)
|
||||
from tinygrad.helpers import getenv
|
||||
arr_reshape = arr
|
||||
if (gpus:=getenv("GPUS")) > 1:
|
||||
from tinygrad import Device
|
||||
arr_reshape = arr_reshape.to(Device.DEFAULT)
|
||||
arr_reshape = arr_reshape.reshape(-1)
|
||||
arr_reshape.shard_([f"GPU:{i}" for i in range(gpus)])
|
||||
res.shard_([f"GPU:{i}" for i in range(gpus)])
|
||||
res = res[arr_reshape]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.permute((0, 4, 1, 2, 3)).cast(dtypes.float)
|
||||
return arr
|
||||
|
|
@ -48,6 +67,12 @@ def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-
|
|||
target_sum = target.sum(axis=reduce_axis)
|
||||
prediction_sum = prediction.sum(axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||||
from tinygrad.helpers import getenv
|
||||
if (gpus:=getenv("GPUS")) > 1:
|
||||
from tinygrad import Device
|
||||
result = result.to(Device.DEFAULT)
|
||||
result = result[0]
|
||||
result.shard_([f"GPU:{i}" for i in range(gpus)])
|
||||
return result[0]
|
||||
|
||||
def normalize_string(s):
|
||||
|
|
|
|||
|
|
@ -13,25 +13,29 @@ def train_unet3d():
|
|||
from examples.mlperf.losses import dice_ce_loss
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, get_train_files
|
||||
from tinygrad import dtypes
|
||||
from tinygrad import TinyJit
|
||||
from tinygrad import dtypes, TinyJit
|
||||
import tinygrad.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
epochs = getenv("NUM_EPOCHS", 4000)
|
||||
bs = getenv("BS", 2)
|
||||
lr = getenv("LR", 0.8)
|
||||
momentum = getenv("MOMENTUM", 0.9)
|
||||
lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 200)
|
||||
lr_warmup_init_lr = getenv("LR_WARMUP_INIT_LR", 0.0001)
|
||||
|
||||
model = UNet3D()
|
||||
optim = nn.optim.SGD(nn.state.get_parameters(model), lr=1.0, momentum=0.9, nesterov=True)
|
||||
params = nn.state.get_parameters(model)
|
||||
if (gpus:=getenv("GPUS")) > 1:
|
||||
for p in params: p.shard_([f"GPU:{i}" for i in range(gpus)]).realize()
|
||||
optim = nn.optim.SGD(params, lr=lr, momentum=momentum, nesterov=True)
|
||||
|
||||
def _lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
|
||||
scale = current_epoch / warmup_epochs
|
||||
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale]))
|
||||
|
||||
@TinyJit
|
||||
# TODO: enable jit when it is supported with multitensor
|
||||
# @TinyJit
|
||||
def _train_step(x, y):
|
||||
y_hat = model(x)
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
|
|
@ -41,13 +45,16 @@ def train_unet3d():
|
|||
optim.step()
|
||||
|
||||
return loss.realize()
|
||||
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
if epoch <= lr_warmup_epochs and lr_warmup_epochs > 0:
|
||||
_lr_warm_up(optim, lr_warmup_init_lr, lr, epoch, lr_warmup_epochs)
|
||||
|
||||
for x, y in (t:=tqdm(iterate(val=False, shuffle=True, bs=bs), desc=f"[Epoch {epoch}]", total=len(get_train_files()))):
|
||||
x, y = Tensor(x, dtype=dtypes.float32), Tensor(y, dtype=dtypes.uint8)
|
||||
if (gpus:=getenv("GPUS")) > 1:
|
||||
x.shard_([f"GPU:{i}" for i in range(gpus)], axis=0)
|
||||
y.shard_([f"GPU:{i}" for i in range(gpus)], axis=0)
|
||||
loss = _train_step(x, y)
|
||||
t.set_description(f"[Epoch {epoch}][Loss: {loss.item():.3f}]")
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ class UNet3D:
|
|||
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
||||
obj.assign(v.numpy())
|
||||
|
||||
def shard(self):
|
||||
for p in nn.state.get_parameters(self):
|
||||
p.shard_(("GPU:0", "GPU:1")).realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue