diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 51af0c7..afa8d86 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -143,7 +143,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6, sde_learn_features=False, full_sde=None): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim self.par_strength = cast_to_enum(neural_strength, Strength) @@ -161,6 +161,10 @@ class UniversalGaussianDistribution(SB3_Distribution): self.gaussian_actions = None self.use_sde = use_sde + self.learn_features = sde_learn_features + + if full_sde != None: + print('[!] Argument full_sde is only provided to remain compatible with vanilla SB3 PPO. It does not serve any function!') assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 096de75..fe36e49 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -140,10 +140,9 @@ class ActorCriticPolicy(BasePolicy): # Keyword arguments for gSDE distribution if use_sde: add_dist_kwargs = { - "full_std": full_std, - "squash_output": squash_output, - "use_expln": use_expln, - "learn_features": False, + 'use_sde': True, + # "use_expln": use_expln, + # "learn_features": False, } for k in add_dist_kwargs: dist_kwargs[k] = add_dist_kwargs[k] diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 82eaac4..b816025 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -100,18 +100,31 @@ class Actor(BasePolicy): last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim if self.use_sde: - # TODO: Port to UGD - self.action_dist = StateDependentNoiseDistribution( - action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True - ) + add_dist_kwargs = { + 'use_sde': True, + # "use_expln": use_expln, + # "learn_features": False, + } + for k in add_dist_kwargs: + dist_kwargs[k] = add_dist_kwargs[k] + + self.action_dist = UniversalGaussianDistribution( + action_dim, **dist_kwargs) self.mu_net, self.chol_net = self.action_dist.proba_distribution_net( - latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init + latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp( + self.log_std_init) ) + # self.action_dist = StateDependentNoiseDistribution( + # action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True + # ) + # self.mu_net, self.chol_net = self.action_dist.proba_distribution_net( + # latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init + # ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] - if clip_mean > 0.0: - self.mu = nn.Sequential(self.mu, nn.Hardtanh( - min_val=-clip_mean, max_val=clip_mean)) + # if clip_mean > 0.0: + # self.mu = nn.Sequential(self.mu, nn.Hardtanh( + # min_val=-clip_mean, max_val=clip_mean)) else: self.action_dist = UniversalGaussianDistribution( action_dim, **dist_kwargs) @@ -120,9 +133,9 @@ class Actor(BasePolicy): self.log_std_init) ) - #self.action_dist = SquashedDiagGaussianDistribution(action_dim) - #self.mu = nn.Linear(last_layer_dim, action_dim) - #self.log_std = nn.Linear(last_layer_dim, action_dim) + # self.action_dist = SquashedDiagGaussianDistribution(action_dim) + # self.mu = nn.Linear(last_layer_dim, action_dim) + # self.log_std = nn.Linear(last_layer_dim, action_dim) def _get_constructor_parameters(self) -> Dict[str, Any]: data = super()._get_constructor_parameters()