dppo/model/diffusion/exact_likelihood.py
2024-09-03 21:03:27 -04:00

189 lines
5.7 KiB
Python

"""
Solving probabilistic ODE for exact likelihood, from https://github.com/yang-song/score_sde_pytorch
"""
import torch
import numpy as np
from torchdiffeq import odeint
# adjoint can reduce memory, but not faster
# from torchdiffeq import odeint_adjoint as odeint
from model.diffusion.sde_lib import get_score_fn
def get_likelihood_fn(
sde,
hutchinson_type="Rademacher",
rtol=1e-5,
atol=1e-5,
method="RK45",
steps=10, # should not matter, only t_eval
step_size=1e-3,
eps=1e-5,
continuous=False,
probability_flow=False,
predict_epsilon=False,
num_epsilon=1,
):
"""Create a function to compute the unbiased log-likelihood estimate of a given data point.
Args:
sde: A `sde_lib.SDE` object that represents the forward SDE.
inverse_scaler: The inverse data normalizer.
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
method: A `str`. The algorithm for the black-box ODE solver.
See documentation for `scipy.integrate.solve_ivp`.
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.
Returns:
A function that a batch of data points and returns the log-likelihoods in bits/dim,
the latent code, and the number of function evaluations cost by computation.
"""
def drift_fn(
model,
x,
t,
**kwargs,
):
"""The drift function of the reverse-time SDE."""
score_fn = get_score_fn(
sde,
model,
continuous=continuous,
predict_epsilon=predict_epsilon,
)
# Probability flow ODE is a special case of Reverse SDE
rsde = sde.reverse(score_fn, probability_flow=probability_flow)
sde_out = rsde.sde(x, t, **kwargs)[0]
return sde_out
def div_fn(
model,
x,
t,
noise,
create_graph=False,
**kwargs,
):
with torch.enable_grad():
x.requires_grad_(True)
fn_eps = torch.sum(drift_fn(model, x, t, **kwargs) * noise)
grad_fn_eps = torch.autograd.grad(
fn_eps,
x,
create_graph=create_graph,
)[0]
if not create_graph:
x.requires_grad_(False)
return torch.sum(grad_fn_eps * noise, dim=(1, 2))
def likelihood_fn(
model,
model_ft,
data,
denoising_steps,
ft_denoising_steps,
cond,
**kwargs,
):
"""Compute an unbiased estimate to the log-likelihood in bits/dim.
Args:
model: A score model.
data: A PyTorch tensor. B x horizon x transition_dim
Returns:
logprob: B
"""
shape = data.shape
B, H, A = shape
device = data.device
# sample epsilon
if hutchinson_type == "Gaussian":
epsilon = torch.randn(size=(B * num_epsilon, H, A), device=device)
elif hutchinson_type == "Rademacher":
epsilon = (
torch.randint(
low=0, high=2, size=(B * num_epsilon, H, A), device=device
).float()
* 2
- 1.0
)
else:
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")
# repeat for expectation
cond_eps = cond.repeat_interleave(num_epsilon, dim=0)
def ode_func(t, x):
x = x[:, :-1]
vec_t = torch.full(
(x.shape[0],),
torch.round(t * denoising_steps),
device=x.device,
dtype=int,
)
if torch.round(t * denoising_steps) <= ft_denoising_steps:
model_fn = model_ft
else:
model_fn = model
x = x.view(shape) # B x horizon x transition_dim
drift = drift_fn(
model_fn,
x,
vec_t,
cond=cond,
**kwargs,
).reshape(B, -1)
# repeat for expectation
x = x.repeat_interleave(num_epsilon, dim=0)
vec_t = vec_t.repeat_interleave(num_epsilon)
logp_grad = div_fn(
model,
x,
vec_t,
epsilon,
create_graph=True,
cond=cond_eps,
**kwargs,
)[:, None].reshape(B, num_epsilon, -1)
logp_grad = logp_grad.mean(dim=1) # expectation over epsilon
return torch.cat(
[drift, logp_grad], dim=-1
) # Concatenate along the feature dimension
# flatten data
data = data.view(shape[0], -1)
init = torch.hstack(
(data, torch.zeros((shape[0], 1)).to(data.dtype).to(device))
)
t_eval = torch.linspace(eps, sde.T, steps=steps).to(device) # eval points
solution = odeint(
ode_func,
init,
t_eval,
method=method,
rtol=rtol,
atol=atol,
options={"step_size": step_size},
# args=(model, epsilon),
) # steps x batch x 3
zp = solution[-1] # batch x 3
z = zp[:, :-1].view(shape)
delta_logp = zp[:, -1]
prior_logp = sde.prior_logp(z)
N = torch.prod(torch.tensor(shape[1:]))
# print("prior:", prior_logp / (np.log(2) * N))
# print("delta:", delta_logp / (np.log(2) * N))
logprob = (prior_logp + delta_logp) / (np.log(2) * N)
return logprob
return likelihood_fn