From 508ebf51f0fc6c05b50b89278fe1ba4148687d7c Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 6 Aug 2022 14:46:42 +0200 Subject: [PATCH] Implemented sqrt-induced-gaussian for W2-Projection --- .../distributions/distributions.py | 15 +++++++++++++++ metastable_baselines/ppo/policies.py | 13 +++++++++++++ 2 files changed, 28 insertions(+) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index bff2443..879a8ba 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -206,6 +206,21 @@ class UniversalGaussianDistribution(SB3_Distribution): return mean_actions, chol + def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": + """ + Create the distribution given its parameters (mean, cov_sqrt) + + :param mean_actions: + :param cov_sqrt: + :return: + """ + cov = cov_sqrt.T @ cov_sqrt + chol = th.linalg.cholesky(cov) + + self.cov_sqrt = cov_sqrt + + return self.proba_distribution(mean_actions, chol, latent_pi) + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, chol) diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 63a99c9..f476b94 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, ) +from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer + from ..distributions import UniversalGaussianDistribution, make_proba_distribution @@ -286,6 +288,17 @@ class ActorCriticPolicy(BasePolicy): """ mean_actions = self.action_net(latent_pi) + if isinstance(self.projection, WassersteinProjectionLayer): + if isinstance(self.action_dist, UniversalGaussianDistribution): + cov_sqrt = self.chol_net(latent_pi) + dist = self.action_dist.proba_distribution_from_sqrt( + mean_actions, cov_sqrt, latent_pi) + self.chol = dist.chol + return dist + else: + raise Exception( + 'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)') + if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) elif isinstance(self.action_dist, CategoricalDistribution):