2022-06-26 16:39:37 +02:00
|
|
|
import torch as th
|
|
|
|
|
|
|
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
from ..distributions import UniversalGaussianDistribution, AnyDistribution
|
2022-07-09 14:45:35 +02:00
|
|
|
|
2022-06-26 16:39:37 +02:00
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
2022-07-15 15:03:51 +02:00
|
|
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
2022-06-26 18:14:12 +02:00
|
|
|
if expand:
|
|
|
|
return p.mean, th.diag_embed(p.stddev)
|
|
|
|
else:
|
|
|
|
return p.mean, p.stddev
|
2022-06-26 16:39:37 +02:00
|
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
|
|
return p.mean, p.scale_tril
|
|
|
|
elif isinstance(p, SB3_Distribution):
|
2022-06-29 12:44:13 +02:00
|
|
|
return get_mean_and_chol(p.distribution, expand=expand)
|
2022-06-26 16:39:37 +02:00
|
|
|
else:
|
|
|
|
raise Exception('Dist-Type not implemented')
|
|
|
|
|
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
2022-06-30 20:40:30 +02:00
|
|
|
raise Exception('Not yet implemented...')
|
|
|
|
if isinstance(p, th.distributions.Normal):
|
|
|
|
return p.mean, p.stddev
|
|
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
|
|
return p.mean, p.scale_tril
|
|
|
|
elif isinstance(p, SB3_Distribution):
|
|
|
|
return get_mean_and_chol(p.distribution)
|
|
|
|
else:
|
|
|
|
raise Exception('Dist-Type not implemented')
|
|
|
|
|
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
def get_cov(p: AnyDistribution):
|
2022-07-15 15:03:51 +02:00
|
|
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
2022-06-26 18:14:12 +02:00
|
|
|
return th.diag_embed(p.variance)
|
2022-06-26 16:39:37 +02:00
|
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
|
|
return p.covariance_matrix
|
|
|
|
elif isinstance(p, SB3_Distribution):
|
|
|
|
return get_cov(p.distribution)
|
|
|
|
else:
|
|
|
|
raise Exception('Dist-Type not implemented')
|
|
|
|
|
|
|
|
|
2022-07-15 18:46:17 +02:00
|
|
|
def has_diag_cov(p: AnyDistribution, numerical_check=False):
|
2022-06-29 12:44:13 +02:00
|
|
|
if isinstance(p, SB3_Distribution):
|
|
|
|
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
2022-07-15 15:03:51 +02:00
|
|
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
2022-06-29 12:44:13 +02:00
|
|
|
return True
|
|
|
|
if not numerical_check:
|
|
|
|
return False
|
|
|
|
# Check if matrix is diag
|
|
|
|
cov = get_cov(p)
|
2022-07-15 18:46:17 +02:00
|
|
|
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1)), th.zeros_like(cov))
|
2022-06-29 12:44:13 +02:00
|
|
|
|
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
def is_contextual(p: AnyDistribution):
|
2022-07-01 11:52:14 +02:00
|
|
|
# TODO: Implement for UniveralGaussianDist
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2022-07-15 18:46:17 +02:00
|
|
|
def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=False):
|
|
|
|
if check_diag and not has_diag_cov(p, numerical_check=numerical_check):
|
2022-06-29 12:44:13 +02:00
|
|
|
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
|
|
|
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
|
|
|
|
|
|
|
|
2022-07-13 19:38:20 +02:00
|
|
|
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
2022-07-09 14:45:35 +02:00
|
|
|
if isinstance(orig_p, UniversalGaussianDistribution):
|
2022-07-15 15:03:51 +02:00
|
|
|
return orig_p.new_dist_like_me(mean, chol)
|
2022-07-09 14:45:35 +02:00
|
|
|
elif isinstance(orig_p, th.distributions.Normal):
|
2022-06-26 18:14:12 +02:00
|
|
|
if orig_p.stddev.shape != chol.shape:
|
|
|
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
2022-06-26 16:39:37 +02:00
|
|
|
return th.distributions.Normal(mean, chol)
|
2022-07-15 15:03:51 +02:00
|
|
|
elif isinstance(orig_p, th.distributions.Independent):
|
|
|
|
if orig_p.stddev.shape != chol.shape:
|
|
|
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
|
|
|
return th.distributions.Independent(th.distributions.Normal(mean, chol), 1)
|
2022-06-26 16:39:37 +02:00
|
|
|
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
|
|
|
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
|
|
|
elif isinstance(orig_p, SB3_Distribution):
|
|
|
|
p = orig_p.distribution
|
|
|
|
if isinstance(p, th.distributions.Normal):
|
|
|
|
p_out = orig_p.__class__(orig_p.action_dim)
|
|
|
|
p_out.distribution = th.distributions.Normal(mean, chol)
|
2022-07-15 15:03:51 +02:00
|
|
|
elif isinstance(p, th.distributions.Independent):
|
|
|
|
p_out = orig_p.__class__(orig_p.action_dim)
|
|
|
|
p_out.distribution = th.distributions.Independent(
|
|
|
|
th.distributions.Normal(mean, chol), 1)
|
2022-06-26 16:39:37 +02:00
|
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
|
|
p_out = orig_p.__class__(orig_p.action_dim)
|
|
|
|
p_out.distribution = th.distributions.MultivariateNormal(
|
|
|
|
mean, scale_tril=chol)
|
|
|
|
else:
|
|
|
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
|
|
|
return p_out
|
|
|
|
else:
|
|
|
|
raise Exception('Dist-Type not implemented')
|