diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index bff2443..879a8ba 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -206,6 +206,21 @@ class UniversalGaussianDistribution(SB3_Distribution): return mean_actions, chol + def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": + """ + Create the distribution given its parameters (mean, cov_sqrt) + + :param mean_actions: + :param cov_sqrt: + :return: + """ + cov = cov_sqrt.T @ cov_sqrt + chol = th.linalg.cholesky(cov) + + self.cov_sqrt = cov_sqrt + + return self.proba_distribution(mean_actions, chol, latent_pi) + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, chol) diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 63a99c9..f476b94 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, ) +from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer + from ..distributions import UniversalGaussianDistribution, make_proba_distribution @@ -286,6 +288,17 @@ class ActorCriticPolicy(BasePolicy): """ mean_actions = self.action_net(latent_pi) + if isinstance(self.projection, WassersteinProjectionLayer): + if isinstance(self.action_dist, UniversalGaussianDistribution): + cov_sqrt = self.chol_net(latent_pi) + dist = self.action_dist.proba_distribution_from_sqrt( + mean_actions, cov_sqrt, latent_pi) + self.chol = dist.chol + return dist + else: + raise Exception( + 'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)') + if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) elif isinstance(self.action_dist, CategoricalDistribution):