From 520dc98eb52f1744c9aea1d311e503f6ef926b3f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 10 Aug 2022 11:54:52 +0200 Subject: [PATCH] Implemented SDE --- .../distributions/distributions.py | 40 +++++++++++++++++-- metastable_baselines/ppo/policies.py | 11 ++++- metastable_baselines/ppo/ppo.py | 2 +- metastable_baselines/sac/policies.py | 21 ++++++---- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 35ae753..b66cea8 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -160,8 +160,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution = None self.gaussian_actions = None - if use_sde: - raise Exception('SDE is not yet implemented') + self.use_sde = use_sde assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' @@ -214,6 +213,9 @@ class UniversalGaussianDistribution(SB3_Distribution): chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength, self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type) + if self.use_sde: + self.sample_weights(self.action_dim) + return mean_actions, chol def _sqrt_to_chol(self, cov_sqrt): @@ -246,7 +248,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution.cov_sqrt = cov_sqrt return self - def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: nn.Module) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, chol) @@ -254,7 +256,9 @@ class UniversalGaussianDistribution(SB3_Distribution): :param chol: :return: """ - # TODO: latent_pi is for SDE, implement. + if self.use_sde: + self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() + # TODO: Change variance of dist to include sde-spread if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: self.distribution = Independent(Normal(mean_actions, chol), 1) @@ -300,6 +304,12 @@ class UniversalGaussianDistribution(SB3_Distribution): self.gaussian_actions = sample return self.prob_squashing_type.apply(sample) + def sample_sde(self) -> th.Tensor: + noise = self.get_noise(self._latent_sde) + actions = self.distribution.mean + noise + self.gaussian_actions = actions + return self.prob_squashing_type.apply(actions) + def mode(self) -> th.Tensor: mode = self.distribution.mean self.gaussian_actions = mode @@ -323,6 +333,28 @@ class UniversalGaussianDistribution(SB3_Distribution): log_prob = self.log_prob(actions, self.gaussian_actions) return actions, log_prob + def sample_weights(self, num_dims, batch_size=1): + self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims)) + # Reparametrization trick to pass gradients + self.exploration_mat = self.weights_dist.rsample() + # Pre-compute matrices in case of parallel exploration + self.exploration_matrices = self.weights_dist.rsample((batch_size,)) + + def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: + latent_sde = latent_sde if self.learn_features else latent_sde.detach() + # # TODO: Good idea? + latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) + chol = self.distribution.scale_tril + # Default case: only one exploration matrix + if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): + return th.mm(chol, th.mm(latent_sde, self.exploration_mat)) + # Use batch matrix multiplication for efficient computation + # (batch_size, n_features) -> (batch_size, 1, n_features) + latent_sde = latent_sde.unsqueeze(dim=1) + # (batch_size, 1, n_actions) + noise = th.bmm(chol, th.bmm(latent_sde, self.exploration_matrices)) + return noise.squeeze(dim=1) + class CholNet(nn.Module): def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType): diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index d42dc7f..096de75 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import ( NatureCNN, ) +from stable_baselines3.common.preprocessing import get_action_dim + from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer from ..distributions import UniversalGaussianDistribution, make_proba_distribution @@ -196,7 +198,14 @@ class ActorCriticPolicy(BasePolicy): assert isinstance( self.action_dist, StateDependentNoiseDistribution) or isinstance( self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE" - self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + if isinstance( + self.action_dist, StateDependentNoiseDistribution): + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + if isinstance( + self.action_dist, UniversalGaussianDistribution): + self.action_dist.sample_weights( + get_action_dim(self.action_space), batch_size=n_envs) def _build_mlp_extractor(self) -> None: """ diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 1e854aa..ceed889 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -105,7 +105,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[Dict[str, Any]] = {}, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 47515dd..82eaac4 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -157,16 +157,23 @@ class Actor(BasePolicy): StateDependentNoiseDistribution), msg return self.chol - def reset_noise(self, batch_size: int = 1) -> None: + def reset_noise(self, n_envs: int = 1) -> None: """ - Sample new weights for the exploration matrix, when using gSDE. + Sample new weights for the exploration matrix. - :param batch_size: + :param n_envs: """ - msg = "reset_noise() is only available when using gSDE" - assert isinstance(self.action_dist, - StateDependentNoiseDistribution), msg - self.action_dist.sample_weights(self.chol, batch_size=batch_size) + assert isinstance( + self.action_dist, StateDependentNoiseDistribution) or isinstance( + self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE" + if isinstance( + self.action_dist, StateDependentNoiseDistribution): + self.action_dist.sample_weights(self.chol, batch_size=n_envs) + + if isinstance( + self.action_dist, UniversalGaussianDistribution): + self.action_dist.sample_weights( + get_action_dim(self.action_space), batch_size=n_envs) def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """