mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
These casts should only happen if these are supported (#7644)
This commit is contained in:
parent
a88a15c7e8
commit
9c63c3d8ab
1 changed files with 8 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
|
||||
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Optional, Union, List, Any, Tuple
|
||||
import math
|
||||
|
||||
|
|
@ -9,7 +9,8 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
|||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
||||
|
|
@ -222,16 +223,17 @@ class UNetModel:
|
|||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
|
||||
t_emb = timestep_embedding(tms, self.model_ch)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
if y is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue