diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 6fcbea1..47515dd 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Type, Union +import math + import gym import torch as th from torch import nn @@ -18,6 +20,8 @@ from stable_baselines3.common.torch_layers import ( ) from stable_baselines3.common.type_aliases import Schedule +from ..distributions import UniversalGaussianDistribution + # CAP the standard deviation of the actor LOG_STD_MAX = 2 LOG_STD_MIN = -20 @@ -64,6 +68,7 @@ class Actor(BasePolicy): use_expln: bool = False, clip_mean: float = 2.0, normalize_images: bool = True, + dist_kwargs={}, ): super().__init__( observation_space, @@ -86,7 +91,8 @@ class Actor(BasePolicy): self.clip_mean = clip_mean if sde_net_arch is not None: - warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + warnings.warn( + "sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) action_dim = get_action_dim(self.action_space) latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) @@ -94,20 +100,29 @@ class Actor(BasePolicy): last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim if self.use_sde: + # TODO: Port to UGD self.action_dist = StateDependentNoiseDistribution( action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True ) - self.mu, self.log_std = self.action_dist.proba_distribution_net( + self.mu_net, self.chol_net = self.action_dist.proba_distribution_net( latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] if clip_mean > 0.0: - self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) + self.mu = nn.Sequential(self.mu, nn.Hardtanh( + min_val=-clip_mean, max_val=clip_mean)) else: - self.action_dist = SquashedDiagGaussianDistribution(action_dim) - self.mu = nn.Linear(last_layer_dim, action_dim) - self.log_std = nn.Linear(last_layer_dim, action_dim) + self.action_dist = UniversalGaussianDistribution( + action_dim, **dist_kwargs) + self.mu_net, self.chol_net = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp( + self.log_std_init) + ) + + #self.action_dist = SquashedDiagGaussianDistribution(action_dim) + #self.mu = nn.Linear(last_layer_dim, action_dim) + #self.log_std = nn.Linear(last_layer_dim, action_dim) def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters() @@ -138,8 +153,9 @@ class Actor(BasePolicy): :return: """ msg = "get_std() is only available when using gSDE" - assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg - return self.action_dist.get_std(self.log_std) + assert isinstance(self.action_dist, + StateDependentNoiseDistribution), msg + return self.chol def reset_noise(self, batch_size: int = 1) -> None: """ @@ -148,8 +164,9 @@ class Actor(BasePolicy): :param batch_size: """ msg = "reset_noise() is only available when using gSDE" - assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg - self.action_dist.sample_weights(self.log_std, batch_size=batch_size) + assert isinstance(self.action_dist, + StateDependentNoiseDistribution), msg + self.action_dist.sample_weights(self.chol, batch_size=batch_size) def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: """ @@ -161,25 +178,25 @@ class Actor(BasePolicy): """ features = self.extract_features(obs) latent_pi = self.latent_pi(features) - mean_actions = self.mu(latent_pi) + mean_actions = self.mu_net(latent_pi) if self.use_sde: - return mean_actions, self.log_std, dict(latent_sde=latent_pi) + return mean_actions, self.chol, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) - log_std = self.log_std(latent_pi) + chol = self.chol_net(latent_pi) # Original Implementation to cap the standard deviation - log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) - return mean_actions, log_std, {} + self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + return mean_actions, self.chol, {} def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: - mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + mean_actions, chol, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed - return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) + return self.action_dist.actions_from_params(mean_actions, chol, deterministic=deterministic, **kwargs) def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - mean_actions, log_std, kwargs = self.get_action_dist_params(obs) + mean_actions, chol, kwargs = self.get_action_dist_params(obs) # return action and associated log prob - return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) + return self.action_dist.log_prob_from_params(mean_actions, chol, **kwargs) def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: return self(observation, deterministic) @@ -236,6 +253,7 @@ class SACPolicy(BasePolicy): optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = True, + dist_kwargs={}, ): super().__init__( observation_space, @@ -267,7 +285,8 @@ class SACPolicy(BasePolicy): self.actor_kwargs = self.net_args.copy() if sde_net_arch is not None: - warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + warnings.warn( + "sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) sde_kwargs = { "use_sde": use_sde, @@ -275,6 +294,7 @@ class SACPolicy(BasePolicy): "use_expln": use_expln, "clip_mean": clip_mean, } + self.actor_kwargs.update(sde_kwargs) self.critic_kwargs = self.net_args.copy() self.critic_kwargs.update( @@ -289,17 +309,22 @@ class SACPolicy(BasePolicy): self.critic, self.critic_target = None, None self.share_features_extractor = share_features_extractor + self.dist_kwargs = dist_kwargs + self._build(lr_schedule) def _build(self, lr_schedule: Schedule) -> None: self.actor = self.make_actor() - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class( + self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) if self.share_features_extractor: - self.critic = self.make_critic(features_extractor=self.actor.features_extractor) + self.critic = self.make_critic( + features_extractor=self.actor.features_extractor) # Do not optimize the shared features extractor with the critic loss # otherwise, there are gradient computation issues - critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + critic_parameters = [param for name, param in self.critic.named_parameters( + ) if "features_extractor" not in name] else: # Create a separate features extractor for the critic # this requires more memory and computation @@ -310,7 +335,8 @@ class SACPolicy(BasePolicy): self.critic_target = self.make_critic(features_extractor=None) self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class( + critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) # Target networks should always be in eval mode self.critic_target.set_training_mode(False) @@ -327,7 +353,8 @@ class SACPolicy(BasePolicy): use_expln=self.actor_kwargs["use_expln"], clip_mean=self.actor_kwargs["clip_mean"], n_critics=self.critic_kwargs["n_critics"], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + # dummy lr schedule, not needed for loading policy alone + lr_schedule=self._dummy_schedule, optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, @@ -345,11 +372,13 @@ class SACPolicy(BasePolicy): self.actor.reset_noise(batch_size=batch_size) def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: - actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) - return Actor(**actor_kwargs).to(self.device) + actor_kwargs = self._update_features_extractor( + self.actor_kwargs, features_extractor) + return Actor(**actor_kwargs, dist_kwargs=self.dist_kwargs).to(self.device) def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: - critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) + critic_kwargs = self._update_features_extractor( + self.critic_kwargs, features_extractor) return ContinuousCritic(**critic_kwargs).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: