From a5309e0fb8977c975cadeb5c2a5abb7b0bf059b4 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 16 Jan 2024 15:25:34 +0100 Subject: [PATCH] Clean up and make cpp_projections optional --- metastable_projections/__init__.py | 1 + .../projections/__init__.py | 10 +++- .../projections/base_projection_layer.py | 47 +++++++++++++++++-- .../projections/identity_projection_layer.py | 5 ++ .../projections/kl_projection_layer.py | 5 +- .../projections/w2_projection_layer.py | 20 ++++++++ 6 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 metastable_projections/__init__.py create mode 100644 metastable_projections/projections/identity_projection_layer.py diff --git a/metastable_projections/__init__.py b/metastable_projections/__init__.py new file mode 100644 index 0000000..f95bd25 --- /dev/null +++ b/metastable_projections/__init__.py @@ -0,0 +1 @@ +from .projections import * \ No newline at end of file diff --git a/metastable_projections/projections/__init__.py b/metastable_projections/projections/__init__.py index e309e39..ad37886 100644 --- a/metastable_projections/projections/__init__.py +++ b/metastable_projections/projections/__init__.py @@ -1,5 +1,13 @@ #TODO: License or such from .base_projection_layer import BaseProjectionLayer +from .identity_projection_layer import IdentityProjectionLayer from .frob_projection_layer import FrobeniusProjectionLayer -from .kl_projection_layer import KLProjectionLayer from .w2_projection_layer import WassersteinProjectionLayer + +try: + import cpp_projection +except ModuleNotFoundError: + print('[MSB] ITPAL is not installed; KL projections not avaible.') + from .base_projection_layer import ExceptionProjectionLayer as KLProjectionLayer +else: + from .kl_projection_layer import KLProjectionLayer \ No newline at end of file diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index 31950d7..6da8269 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -9,7 +9,6 @@ from ..misc.distTools import * class BaseProjectionLayer(object): - def __init__(self, mean_bound: float = 0.03, cov_bound: float = 1e-3, @@ -30,10 +29,8 @@ class BaseProjectionLayer(object): self.entropy_first = entropy_first self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection - def __call__(self, p, q, step, *args, **kwargs): - # TODO: self.entropy_schedule(self.initial_entropy, self.target_entropy, self.temperature, step) * p[0].new_ones(p[0].shape[0]) - entropy_bound = 'lol' - return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=entropy_bound, **kwargs) + def __call__(self, p, q, **kwargs): + return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=None, **kwargs) @final def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs): @@ -71,6 +68,22 @@ class BaseProjectionLayer(object): return self.entropy_proj(new_p, beta) + def project_from_rollouts(self, dist, rollout_data, **kwargs): + """ + Hook for implementing the specific trust region projection + Args: + p: current distribution + q: old distribution + eps: mean trust region bound + eps_cov: covariance trust region bound + **kwargs: + + Returns: + projected_dist, old_dist (from rollouts) + """ + old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps) + return self(dist, old_distribution, **kwargs), old_distribution + def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): """ Hook for implementing the specific trust region projection @@ -108,6 +121,23 @@ class BaseProjectionLayer(object): """ return kl_divergence(p, q) + def new_dist_like(orig_p, mean, cov_cholesky): + assert isinstance(orig_p, Distribution) + p = orig_p.distribution + if isinstance(p, th.distributions.Normal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Normal(mean, cov_cholesky) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, cov_cholesky), 1) + elif isinstance(p, th.distributions.MultivariateNormal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.MultivariateNormal( + mean, scale_tril=cov_cholesky) + else: + raise Exception('Dist-Type not implemented (of sb3 dist)') + return p_out def entropy_inequality_projection(p: th.distributions.Normal, beta: Union[float, th.Tensor]): @@ -219,3 +249,10 @@ def mean_equality_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tens proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16) return proj_mean + + +class ExceptionProjectionLayer(BaseProjectionLayer): + def __init__(self, + *args, **kwargs + ): + raise Exception('To be able to use KL projections, ITPAL must be installed: https://github.com/ALRhub/ITPAL (Private Repo).') \ No newline at end of file diff --git a/metastable_projections/projections/identity_projection_layer.py b/metastable_projections/projections/identity_projection_layer.py new file mode 100644 index 0000000..62216db --- /dev/null +++ b/metastable_projections/projections/identity_projection_layer.py @@ -0,0 +1,5 @@ +from .base_projection_layer import BaseProjectionLayer + +class IdentityProjectionLayer(BaseProjectionLayer): + def project_from_rollouts(self, dist, rollout_data, **kwargs): + return dist, dist diff --git a/metastable_projections/projections/kl_projection_layer.py b/metastable_projections/projections/kl_projection_layer.py index d50367d..cc38126 100644 --- a/metastable_projections/projections/kl_projection_layer.py +++ b/metastable_projections/projections/kl_projection_layer.py @@ -138,7 +138,10 @@ class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function): batch_shape, dim) ctx.proj = p_op - proj_std = p_op.forward(eps, old_std_np, std_np) + try: + proj_std = p_op.forward(eps, old_std_np, std_np) + except: + proj_std = std_np return cov.new(proj_std) diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index b2065ab..7c99f59 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -9,6 +9,8 @@ from .base_projection_layer import BaseProjectionLayer, mean_projection from ..misc.norm import mahalanobis, _batch_trace from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov +from stable_baselines3.common.distributions import Distribution + class WassersteinProjectionLayer(BaseProjectionLayer): """ @@ -92,6 +94,24 @@ class WassersteinProjectionLayer(BaseProjectionLayer): return kl_loss * self.trust_region_coeff + def new_dist_like(orig_p, mean, cov_sqrt): + assert isinstance(orig_p, Distribution) + p = orig_p.distribution + if isinstance(p, th.distributions.Normal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Normal(mean, cov_sqrt) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, cov_sqrt), 1) + elif isinstance(p, th.distributions.MultivariateNormal): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.MultivariateNormal( + mean, scale_tril=cov_sqrt) + else: + raise Exception('Dist-Type not implemented (of sb3 dist)') + return p_out + def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]: """