32 lines
878 B
Python
32 lines
878 B
Python
"""
|
|
From Diffuser https://github.com/jannerm/diffuser
|
|
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
|
|
"""
|
|
cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
|
"""
|
|
steps = timesteps + 1
|
|
x = np.linspace(0, steps, steps)
|
|
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
|
|
return torch.tensor(betas_clipped, dtype=dtype)
|
|
|
|
|
|
def extract(a, t, x_shape):
|
|
b, *_ = t.shape
|
|
out = a.gather(-1, t)
|
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
|
|
|
|
|
def make_timesteps(batch_size, i, device):
|
|
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
|
return t
|