initial support for sharding

This commit is contained in:
Francis Lata 2024-01-14 23:05:19 +00:00
commit e3670813b8
4 changed files with 50 additions and 7 deletions

View file

@ -1,6 +1,13 @@
from examples.mlperf.metrics import dice_score
def cross_entropy_loss(x, y, reduction='mean', label_smoothing=0.0):
divisor = y.shape[1]
y = (1 - label_smoothing)*y + label_smoothing / divisor
if reduction == "none": return -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction == "sum": return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
def dice_ce_loss(pred, tgt):
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt)
ce = cross_entropy_loss(pred, tgt)
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2

View file

@ -30,8 +30,27 @@ def word_error_rate(x, y):
return float(scores) / words, float(scores), words
def one_hot(arr, num_classes=3, channel_axis=1):
def _unshard_reshape_shard(x):
from tinygrad.helpers import getenv
if (gpus:=getenv("GPUS")) > 1:
from tinygrad import Device
x = x.to(Device.DEFAULT)
x = x.reshape(-1)
return x.shard_([f"GPU:{i}" for i in range(gpus)])
return x
if len(arr.shape) >= 5: arr = arr.squeeze(dim=channel_axis)
res = Tensor.eye(num_classes)[arr.reshape(-1)]
arr_reshape = _unshard_reshape_shard(arr)
res = Tensor.eye(num_classes)
from tinygrad.helpers import getenv
arr_reshape = arr
if (gpus:=getenv("GPUS")) > 1:
from tinygrad import Device
arr_reshape = arr_reshape.to(Device.DEFAULT)
arr_reshape = arr_reshape.reshape(-1)
arr_reshape.shard_([f"GPU:{i}" for i in range(gpus)])
res.shard_([f"GPU:{i}" for i in range(gpus)])
res = res[arr_reshape]
arr = res.reshape(list(arr.shape) + [num_classes])
arr = arr.permute((0, 4, 1, 2, 3)).cast(dtypes.float)
return arr
@ -48,6 +67,12 @@ def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-
target_sum = target.sum(axis=reduce_axis)
prediction_sum = prediction.sum(axis=reduce_axis)
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
from tinygrad.helpers import getenv
if (gpus:=getenv("GPUS")) > 1:
from tinygrad import Device
result = result.to(Device.DEFAULT)
result = result[0]
result.shard_([f"GPU:{i}" for i in range(gpus)])
return result[0]
def normalize_string(s):

View file

@ -13,25 +13,29 @@ def train_unet3d():
from examples.mlperf.losses import dice_ce_loss
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, get_train_files
from tinygrad import dtypes
from tinygrad import TinyJit
from tinygrad import dtypes, TinyJit
import tinygrad.nn as nn
from tqdm import tqdm
epochs = getenv("NUM_EPOCHS", 4000)
bs = getenv("BS", 2)
lr = getenv("LR", 0.8)
momentum = getenv("MOMENTUM", 0.9)
lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 200)
lr_warmup_init_lr = getenv("LR_WARMUP_INIT_LR", 0.0001)
model = UNet3D()
optim = nn.optim.SGD(nn.state.get_parameters(model), lr=1.0, momentum=0.9, nesterov=True)
params = nn.state.get_parameters(model)
if (gpus:=getenv("GPUS")) > 1:
for p in params: p.shard_([f"GPU:{i}" for i in range(gpus)]).realize()
optim = nn.optim.SGD(params, lr=lr, momentum=momentum, nesterov=True)
def _lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
scale = current_epoch / warmup_epochs
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale]))
@TinyJit
# TODO: enable jit when it is supported with multitensor
# @TinyJit
def _train_step(x, y):
y_hat = model(x)
loss = dice_ce_loss(y_hat, y)
@ -41,13 +45,16 @@ def train_unet3d():
optim.step()
return loss.realize()
for epoch in range(1, epochs + 1):
if epoch <= lr_warmup_epochs and lr_warmup_epochs > 0:
_lr_warm_up(optim, lr_warmup_init_lr, lr, epoch, lr_warmup_epochs)
for x, y in (t:=tqdm(iterate(val=False, shuffle=True, bs=bs), desc=f"[Epoch {epoch}]", total=len(get_train_files()))):
x, y = Tensor(x, dtype=dtypes.float32), Tensor(y, dtype=dtypes.uint8)
if (gpus:=getenv("GPUS")) > 1:
x.shard_([f"GPU:{i}" for i in range(gpus)], axis=0)
y.shard_([f"GPU:{i}" for i in range(gpus)], axis=0)
loss = _train_step(x, y)
t.set_description(f"[Epoch {epoch}][Loss: {loss.item():.3f}]")

View file

@ -54,6 +54,10 @@ class UNet3D:
assert obj.shape == v.shape, (k, obj.shape, v.shape)
obj.assign(v.numpy())
def shard(self):
for p in nn.state.get_parameters(self):
p.shard_(("GPU:0", "GPU:1")).realize()
if __name__ == "__main__":
mdl = UNet3D()
mdl.load_from_pretrained()