From 75d73049b42a6c3e868a3f44a48005ebc00b8adc Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 6 Aug 2022 21:25:49 +0200 Subject: [PATCH] Fixing bugs with w2 and sqrt_induced_gaussian --- .../distributions/distributions.py | 35 ++++++++++++-- metastable_baselines/misc/distTools.py | 48 +++++++++++++++---- metastable_baselines/ppo/policies.py | 29 ++++++----- metastable_baselines/ppo/ppo.py | 14 ++++-- 4 files changed, 92 insertions(+), 34 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 879a8ba..35ae753 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -184,6 +184,16 @@ class UniversalGaussianDistribution(SB3_Distribution): return new + def new_dist_like_me_from_sqrt(self, mean: th.Tensor, cov_sqrt: th.Tensor): + chol = self._sqrt_to_chol(cov_sqrt) + + new = self.new_dist_like_me(mean, chol) + + new.cov_sqrt = cov_sqrt + new.distribution.cov_sqrt = cov_sqrt + + return new + def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: """ Create the layers and parameter that represent the distribution: @@ -206,6 +216,22 @@ class UniversalGaussianDistribution(SB3_Distribution): return mean_actions, chol + def _sqrt_to_chol(self, cov_sqrt): + vec = False + if len(cov_sqrt.shape) == 2: + vec = True + + if vec: + cov_sqrt = th.diag_embed(cov_sqrt) + + cov = th.bmm(cov_sqrt.mT, cov_sqrt) + chol = th.linalg.cholesky(cov) + + if vec: + chol = th.diagonal(chol, dim1=-2, dim2=-1) + + return chol + def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, cov_sqrt) @@ -214,12 +240,11 @@ class UniversalGaussianDistribution(SB3_Distribution): :param cov_sqrt: :return: """ - cov = cov_sqrt.T @ cov_sqrt - chol = th.linalg.cholesky(cov) - self.cov_sqrt = cov_sqrt - - return self.proba_distribution(mean_actions, chol, latent_pi) + chol = self._sqrt_to_chol(cov_sqrt) + self.proba_distribution(mean_actions, chol, latent_pi) + self.distribution.cov_sqrt = cov_sqrt + return self def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": """ diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 97e4b4f..4983fc9 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -19,17 +19,16 @@ def get_mean_and_chol(p: AnyDistribution, expand=False): raise Exception('Dist-Type not implemented') -def get_mean_and_sqrt(p: UniversalGaussianDistribution): - if isinstance(p, UniversalGaussianDistribution): - if not hasattr(p, 'cov_sqrt'): - raise Exception( - 'Distribution was not induced from sqrt. On-demand calculation is not supported.') - else: - mean, chol = get_mean_and_chol(p) - sqrt_cov = p.cov_sqrt - return mean, sqrt_cov +def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): + if not hasattr(p, 'cov_sqrt'): + raise Exception( + 'Distribution was not induced from sqrt. On-demand calculation is not supported.') else: - raise Exception('Dist-Type not implemented') + mean, chol = get_mean_and_chol(p, expand=False) + sqrt_cov = p.cov_sqrt + if expand and len(sqrt_cov.shape) == 2: + sqrt_cov = th.diag_embed(sqrt_cov) + return mean, sqrt_cov def get_cov(p: AnyDistribution): @@ -97,3 +96,32 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): return p_out else: raise Exception('Dist-Type not implemented') + + +def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): + chol = _sqrt_to_chol(cov_sqrt) + + new = new_dist_like(orig_p, mean, chol) + + new.cov_sqrt = cov_sqrt + if hasattr(new, 'distribution'): + new.distribution.cov_sqrt = cov_sqrt + + return new + + +def _sqrt_to_chol(cov_sqrt): + vec = False + if len(cov_sqrt.shape) == 2: + vec = True + + if vec: + cov_sqrt = th.diag_embed(cov_sqrt) + + cov = th.bmm(cov_sqrt.mT, cov_sqrt) + chol = th.linalg.cholesky(cov) + + if vec: + chol = th.diagonal(chol, dim1=-2, dim2=-1) + + return chol diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index ea5ed75..a76886e 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -99,6 +99,7 @@ class ActorCriticPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, dist_kwargs: Optional[Dict[str, Any]] = None, + sqrt_induced_gaussian=False, ): if optimizer_kwargs is None: @@ -152,6 +153,8 @@ class ActorCriticPolicy(BasePolicy): self.use_sde = use_sde self.dist_kwargs = dist_kwargs + self.sqrt_induced_gaussian = sqrt_induced_gaussian + # Action distribution self.action_dist = make_proba_distribution( action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) @@ -289,18 +292,6 @@ class ActorCriticPolicy(BasePolicy): """ mean_actions = self.action_net(latent_pi) - if isinstance(self.projection, WassersteinProjectionLayer): - if isinstance(self.action_dist, UniversalGaussianDistribution): - cov_sqrt = self.chol_net(latent_pi) - dist = self.action_dist.proba_distribution_from_sqrt( - mean_actions, cov_sqrt, latent_pi) - mean, chol = get_mean_and_chol(dist, expand=False) - self.chol = chol - return dist - else: - raise Exception( - 'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)') - if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) elif isinstance(self.action_dist, CategoricalDistribution): @@ -315,9 +306,17 @@ class ActorCriticPolicy(BasePolicy): elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) elif isinstance(self.action_dist, UniversalGaussianDistribution): - chol = self.chol_net(latent_pi) - self.chol = chol - return self.action_dist.proba_distribution(mean_actions, chol, latent_pi) + if self.sqrt_induced_gaussian: + cov_sqrt = self.chol_net(latent_pi) + dist = self.action_dist.proba_distribution_from_sqrt( + mean_actions, cov_sqrt, latent_pi) + mean, chol = get_mean_and_chol(dist, expand=False) + self.chol = chol + return dist + else: + chol = self.chol_net(latent_pi) + self.chol = chol + return self.action_dist.proba_distribution(mean_actions, chol, latent_pi) else: raise ValueError("Invalid action distribution") diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 54672f4..1e854aa 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor from stable_baselines3.common.vec_env import VecNormalize -from ..misc.distTools import new_dist_like +from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt from ..projections.base_projection_layer import BaseProjectionLayer from ..projections.frob_projection_layer import FrobeniusProjectionLayer @@ -133,7 +133,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): use_sde=use_sde, sde_sample_freq=sde_sample_freq, tensorboard_log=tensorboard_log, - policy_kwargs=policy_kwargs, + policy_kwargs=policy_kwargs | + {'sqrt_induced_gaussian': isinstance( + projection, WassersteinProjectionLayer)}, verbose=verbose, device=device, create_eval_env=create_eval_env, @@ -245,8 +247,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) p_dist = p.distribution - q_dist = new_dist_like( - p_dist, rollout_data.means, rollout_data.chols) + if isinstance(self.projection, WassersteinProjectionLayer): + q_dist = new_dist_like_from_sqrt( + p_dist, rollout_data.means, rollout_data.chols) + else: + q_dist = new_dist_like( + p_dist, rollout_data.means, rollout_data.chols) proj_p = self.projection(p_dist, q_dist, self._global_steps) if isinstance(p_dist, th.distributions.Normal): # Normal uses a weird mapping from dimensions into batch_shape