From 6c7fc371161e39da1c40699dbf51c4f634653cee Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:37:01 +0200 Subject: [PATCH] Tidying --- metastable_projections/misc/distTools.py | 134 ++++++++++++++++++ metastable_projections/misc/norm.py | 31 ++++ .../projections}/__init__.py | 0 .../projections}/base_projection_layer.py | 0 .../projections}/frob_projection_layer.py | 0 .../projections}/kl_projection_layer.py | 0 .../projections}/w2_projection_layer.py | 0 setup.py | 4 +- 8 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 metastable_projections/misc/distTools.py create mode 100644 metastable_projections/misc/norm.py rename {projections => metastable_projections/projections}/__init__.py (100%) rename {projections => metastable_projections/projections}/base_projection_layer.py (100%) rename {projections => metastable_projections/projections}/frob_projection_layer.py (100%) rename {projections => metastable_projections/projections}/kl_projection_layer.py (100%) rename {projections => metastable_projections/projections}/w2_projection_layer.py (100%) diff --git a/metastable_projections/misc/distTools.py b/metastable_projections/misc/distTools.py new file mode 100644 index 0000000..da8b6ce --- /dev/null +++ b/metastable_projections/misc/distTools.py @@ -0,0 +1,134 @@ +import torch as th + +from stable_baselines3.common.distributions import Distribution as SB3_Distribution + +from ..distributions import UniversalGaussianDistribution, AnyDistribution + + +def get_mean_and_chol(p: AnyDistribution, expand=False): + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): + if expand: + return p.mean, th.diag_embed(p.stddev) + else: + return p.mean, p.stddev + elif isinstance(p, th.distributions.MultivariateNormal): + return p.mean, p.scale_tril + elif isinstance(p, SB3_Distribution): + return get_mean_and_chol(p.distribution, expand=expand) + else: + raise Exception('Dist-Type not implemented') + + +def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): + if not hasattr(p, 'cov_sqrt'): + raise Exception( + 'Distribution was not induced from sqrt. On-demand calculation is not supported.') + else: + mean, chol = get_mean_and_chol(p, expand=False) + sqrt_cov = p.cov_sqrt + if mean.shape[0] != sqrt_cov.shape[0]: + shape = list(sqrt_cov.shape) + shape[0] = mean.shape[0] + shape = tuple(shape) + sqrt_cov = sqrt_cov.expand(shape) + if expand and len(sqrt_cov.shape) <= 2: + sqrt_cov = th.diag_embed(sqrt_cov) + return mean, sqrt_cov + + +def get_cov(p: AnyDistribution): + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): + return th.diag_embed(p.variance) + elif isinstance(p, th.distributions.MultivariateNormal): + return p.covariance_matrix + elif isinstance(p, SB3_Distribution): + return get_cov(p.distribution) + else: + raise Exception('Dist-Type not implemented') + + +def has_diag_cov(p: AnyDistribution, numerical_check=False): + if isinstance(p, SB3_Distribution): + return has_diag_cov(p.distribution, numerical_check=numerical_check) + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): + return True + if not numerical_check: + return False + # Check if matrix is diag + cov = get_cov(p) + return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1)), th.zeros_like(cov)) + + +def is_contextual(p: AnyDistribution): + # TODO: Implement for UniveralGaussianDist + return False + + +def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=False): + if check_diag and not has_diag_cov(p, numerical_check=numerical_check): + raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') + return th.diagonal(get_cov(p), dim1=-2, dim2=-1) + + +def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): + if isinstance(orig_p, UniversalGaussianDistribution): + return orig_p.new_dist_like_me(mean, chol) + elif isinstance(orig_p, th.distributions.Normal): + if orig_p.stddev.shape != chol.shape: + chol = th.diagonal(chol, dim1=1, dim2=2) + return th.distributions.Normal(mean, chol) + elif isinstance(orig_p, th.distributions.Independent): + if orig_p.stddev.shape != chol.shape: + chol = th.diagonal(chol, dim1=1, dim2=2) + return th.distributions.Independent(th.distributions.Normal(mean, chol), 1) + elif isinstance(orig_p, th.distributions.MultivariateNormal): + return th.distributions.MultivariateNormal(mean, scale_tril=chol) + elif isinstance(orig_p, SB3_Distribution): + p = orig_p.distribution + if isinstance(p, th.distributions.Normal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Normal(mean, chol) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, chol), 1) + elif isinstance(p, th.distributions.MultivariateNormal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.MultivariateNormal( + mean, scale_tril=chol) + else: + raise Exception('Dist-Type not implemented (of sb3 dist)') + return p_out + else: + raise Exception('Dist-Type not implemented') + + +def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): + chol = _sqrt_to_chol(cov_sqrt) + + new = new_dist_like(orig_p, mean, chol) + + new.cov_sqrt = cov_sqrt + if hasattr(new, 'distribution'): + new.distribution.cov_sqrt = cov_sqrt + + return new + + +def _sqrt_to_chol(cov_sqrt): + vec = False + if len(cov_sqrt.shape) == 2: + vec = True + + if vec: + cov_sqrt = th.diag_embed(cov_sqrt) + + cov = th.bmm(cov_sqrt.mT, cov_sqrt) + cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) + + chol = th.linalg.cholesky(cov) + + if vec: + chol = th.diagonal(chol, dim1=-2, dim2=-1) + + return chol diff --git a/metastable_projections/misc/norm.py b/metastable_projections/misc/norm.py new file mode 100644 index 0000000..894451b --- /dev/null +++ b/metastable_projections/misc/norm.py @@ -0,0 +1,31 @@ +import torch as th +from torch.distributions.multivariate_normal import _batch_mahalanobis + + +def mahalanobis_alt(u, v, std): + """ + Stolen from Fabian's Code (Public Version) + + """ + delta = u - v + return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) + + +def mahalanobis(u, v, chol): + delta = u - v + return _batch_mahalanobis(chol, delta) + + +def frob_sq(diff, is_spd=False): + # If diff is spd, we can use a (probably) more performant algorithm + if is_spd: + return _frob_sq_spd(diff) + return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) + + +def _frob_sq_spd(diff): + return _batch_trace(diff @ diff) + + +def _batch_trace(x): + return th.diagonal(x, dim1=-2, dim2=-1).sum(-1) diff --git a/projections/__init__.py b/metastable_projections/projections/__init__.py similarity index 100% rename from projections/__init__.py rename to metastable_projections/projections/__init__.py diff --git a/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py similarity index 100% rename from projections/base_projection_layer.py rename to metastable_projections/projections/base_projection_layer.py diff --git a/projections/frob_projection_layer.py b/metastable_projections/projections/frob_projection_layer.py similarity index 100% rename from projections/frob_projection_layer.py rename to metastable_projections/projections/frob_projection_layer.py diff --git a/projections/kl_projection_layer.py b/metastable_projections/projections/kl_projection_layer.py similarity index 100% rename from projections/kl_projection_layer.py rename to metastable_projections/projections/kl_projection_layer.py diff --git a/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py similarity index 100% rename from projections/w2_projection_layer.py rename to metastable_projections/projections/w2_projection_layer.py diff --git a/setup.py b/setup.py index 85dd8dc..4334c37 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ from setuptools import setup, find_packages setup( - name='metastable-baselines', + name='metastable-projections', version='1.0.0', # url='https://github.com/mypackage.git', # author='Author Name', # author_email='author@gmail.com', # description='Description of my package', packages=['.'], - install_requires=['gym', 'stable_baselines3'], + install_requires=['torch', 'stable_baselines3'], )