From ffbf2b3fe574de7d05a5813daf1a3fcee67bc4b4 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 27 Jan 2023 13:34:28 +0100 Subject: [PATCH] Allow reduced latent sde dim --- metastable_baselines/distributions/distributions.py | 1 + metastable_baselines/ppo/policies.py | 4 +++- metastable_baselines/sac/policies.py | 7 ++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 6c40e00..95fb7f6 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -348,6 +348,7 @@ class UniversalGaussianDistribution(SB3_Distribution): def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: latent_sde = latent_sde if self.learn_features else latent_sde.detach() + latent_sde = latent_sde[..., -self.latent_sde_dim:] latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) # Default case: only one exploration matrix if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 74e0ae6..4de69f9 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -100,6 +100,7 @@ class ActorCriticPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, dist_kwargs: Optional[Dict[str, Any]] = None, sqrt_induced_gaussian: bool = False, + latent_dim_sde=None, ): if optimizer_kwargs is None: @@ -155,6 +156,7 @@ class ActorCriticPolicy(BasePolicy): self.dist_kwargs = dist_kwargs self.sqrt_induced_gaussian = sqrt_induced_gaussian + self.latent_dim_sde = latent_dim_sde # Action distribution self.action_dist = make_proba_distribution( @@ -244,7 +246,7 @@ class ActorCriticPolicy(BasePolicy): latent_dim=latent_dim_pi) elif isinstance(self.action_dist, UniversalGaussianDistribution): self.action_net, self.chol_net = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp( + latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp( self.log_std_init) ) else: diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 9540123..435a5c8 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -69,6 +69,7 @@ class Actor(BasePolicy): clip_mean: float = 2.0, normalize_images: bool = True, dist_kwargs={}, + latent_dim_sde=None, ): super().__init__( observation_space, @@ -90,6 +91,8 @@ class Actor(BasePolicy): self.full_std = full_std self.clip_mean = clip_mean + self.latent_dim_sde = latent_dim_sde + if sde_net_arch is not None: warnings.warn( "sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) @@ -111,7 +114,7 @@ class Actor(BasePolicy): self.action_dist = UniversalGaussianDistribution( action_dim, **dist_kwargs) self.mu_net, self.chol_net = self.action_dist.proba_distribution_net( - latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp( + latent_dim=last_layer_dim, latent_sde_dim=self.latent_dim_sde or last_layer_dim, std_init=math.exp( self.log_std_init) ) # self.action_dist = StateDependentNoiseDistribution( @@ -274,6 +277,7 @@ class SACPolicy(BasePolicy): n_critics: int = 2, share_features_extractor: bool = True, dist_kwargs={}, + latent_dim_sde=None, ): super().__init__( observation_space, @@ -316,6 +320,7 @@ class SACPolicy(BasePolicy): } self.actor_kwargs.update(sde_kwargs) + self.actor.kwargs.update({'latent_dim_sde': latent_dim_sde}) self.critic_kwargs = self.net_args.copy() self.critic_kwargs.update( {