From 28d0c609bc3247904a81f8ebba8153fec27dd352 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 14 Aug 2022 20:09:10 +0200 Subject: [PATCH] Fixed SDE: sampling had dimension mismatches --- .../distributions/distributions.py | 20 +++++++++++++------ metastable_baselines/ppo/policies.py | 3 +-- metastable_baselines/sac/policies.py | 3 +-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index c6a3cb2..7521051 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -249,7 +249,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution.cov_sqrt = cov_sqrt return self - def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: nn.Module) -> "UniversalGaussianDistribution": + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: th.Tensor) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, chol) @@ -300,12 +300,18 @@ class UniversalGaussianDistribution(SB3_Distribution): return self.distribution.entropy() def sample(self) -> th.Tensor: + if self.use_sde: + return self._sample_sde() + else: + return self._sample_normal() + + def _sample_normal(self) -> th.Tensor: # Reparametrization trick to pass gradients sample = self.distribution.rsample() self.gaussian_actions = sample return self.prob_squashing_type.apply(sample) - def sample_sde(self) -> th.Tensor: + def _sample_sde(self) -> th.Tensor: noise = self.get_noise(self._latent_sde) actions = self.distribution.mean + noise self.gaussian_actions = actions @@ -334,7 +340,8 @@ class UniversalGaussianDistribution(SB3_Distribution): log_prob = self.log_prob(actions, self.gaussian_actions) return actions, log_prob - def sample_weights(self, num_dims, batch_size=1): + def sample_weights(self, batch_size=1): + num_dims = (self.latent_sde_dim, self.action_dim) self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims)) # Reparametrization trick to pass gradients self.exploration_mat = self.weights_dist.rsample() @@ -345,15 +352,16 @@ class UniversalGaussianDistribution(SB3_Distribution): latent_sde = latent_sde if self.learn_features else latent_sde.detach() # # TODO: Good idea? latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) - chol = self.distribution.scale_tril # Default case: only one exploration matrix if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): - return th.mm(chol, th.mm(latent_sde, self.exploration_mat)) + chol = th.diag_embed(self.distribution.stddev) + return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] + chol = self.distribution.scale_tril # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) latent_sde = latent_sde.unsqueeze(dim=1) # (batch_size, 1, n_actions) - noise = th.bmm(chol, th.bmm(latent_sde, self.exploration_matrices)) + noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol) return noise.squeeze(dim=1) diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index fe36e49..eb9d137 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -203,8 +203,7 @@ class ActorCriticPolicy(BasePolicy): if isinstance( self.action_dist, UniversalGaussianDistribution): - self.action_dist.sample_weights( - get_action_dim(self.action_space), batch_size=n_envs) + self.action_dist.sample_weights(batch_size=n_envs) def _build_mlp_extractor(self) -> None: """ diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index b816025..530702e 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -185,8 +185,7 @@ class Actor(BasePolicy): if isinstance( self.action_dist, UniversalGaussianDistribution): - self.action_dist.sample_weights( - get_action_dim(self.action_space), batch_size=n_envs) + self.action_dist.sample_weights(batch_size=n_envs) def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """