From 802094a50f409478d83bd0b46ab071e52735001c Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 6 Aug 2022 14:54:59 +0200 Subject: [PATCH] Enabled w2 (can now get sqrt from dist) --- metastable_baselines/misc/distTools.py | 15 ++++++++------- metastable_baselines/ppo/policies.py | 4 +++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 9ec8c3d..97e4b4f 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -20,13 +20,14 @@ def get_mean_and_chol(p: AnyDistribution, expand=False): def get_mean_and_sqrt(p: UniversalGaussianDistribution): - raise Exception('Not yet implemented...') - if isinstance(p, th.distributions.Normal): - return p.mean, p.stddev - elif isinstance(p, th.distributions.MultivariateNormal): - return p.mean, p.scale_tril - elif isinstance(p, SB3_Distribution): - return get_mean_and_chol(p.distribution) + if isinstance(p, UniversalGaussianDistribution): + if not hasattr(p, 'cov_sqrt'): + raise Exception( + 'Distribution was not induced from sqrt. On-demand calculation is not supported.') + else: + mean, chol = get_mean_and_chol(p) + sqrt_cov = p.cov_sqrt + return mean, sqrt_cov else: raise Exception('Dist-Type not implemented') diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index f476b94..ea5ed75 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -38,6 +38,7 @@ from stable_baselines3.common.torch_layers import ( from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer from ..distributions import UniversalGaussianDistribution, make_proba_distribution +from ..misc.distTools import get_mean_and_chol class ActorCriticPolicy(BasePolicy): @@ -293,7 +294,8 @@ class ActorCriticPolicy(BasePolicy): cov_sqrt = self.chol_net(latent_pi) dist = self.action_dist.proba_distribution_from_sqrt( mean_actions, cov_sqrt, latent_pi) - self.chol = dist.chol + mean, chol = get_mean_and_chol(dist, expand=False) + self.chol = chol return dist else: raise Exception(