These casts should only happen if these are supported (#7644)

This commit is contained in:
Ahmed Harmouche 2024-11-12 00:56:50 +01:00 committed by GitHub
commit 9c63c3d8ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,6 @@
from tinygrad import Tensor, dtypes
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
from tinygrad.device import is_dtype_supported
from typing import Optional, Union, List, Any, Tuple
import math
@ -9,7 +9,8 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
class ResBlock:
def __init__(self, channels:int, emb_channels:int, out_channels:int):
@ -222,16 +223,17 @@ class UNetModel:
]
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
t_emb = timestep_embedding(tms, self.model_ch)
emb = t_emb.sequential(self.time_embed)
if y is not None:
assert y.shape[0] == x.shape[0]
emb = emb + y.sequential(self.label_emb[0])
emb = emb.cast(dtypes.float16)
ctx = ctx.cast(dtypes.float16)
x = x .cast(dtypes.float16)
if is_dtype_supported(dtypes.float16):
emb = emb.cast(dtypes.float16)
ctx = ctx.cast(dtypes.float16)
x = x .cast(dtypes.float16)
def run(x:Tensor, bb) -> Tensor:
if isinstance(bb, ResBlock): x = bb(x, emb)