Allow reduced latent sde dim

This commit is contained in:
Dominik Moritz Roth 2023-01-27 13:34:28 +01:00
parent f421dc2ab5
commit ffbf2b3fe5
3 changed files with 10 additions and 2 deletions

View File

@ -348,6 +348,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
latent_sde = latent_sde[..., -self.latent_sde_dim:]
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):

View File

@ -100,6 +100,7 @@ class ActorCriticPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None,
dist_kwargs: Optional[Dict[str, Any]] = None,
sqrt_induced_gaussian: bool = False,
latent_dim_sde=None,
):
if optimizer_kwargs is None:
@ -155,6 +156,7 @@ class ActorCriticPolicy(BasePolicy):
self.dist_kwargs = dist_kwargs
self.sqrt_induced_gaussian = sqrt_induced_gaussian
self.latent_dim_sde = latent_dim_sde
# Action distribution
self.action_dist = make_proba_distribution(
@ -244,7 +246,7 @@ class ActorCriticPolicy(BasePolicy):
latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, UniversalGaussianDistribution):
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp(
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
self.log_std_init)
)
else:

View File

@ -69,6 +69,7 @@ class Actor(BasePolicy):
clip_mean: float = 2.0,
normalize_images: bool = True,
dist_kwargs={},
latent_dim_sde=None,
):
super().__init__(
observation_space,
@ -90,6 +91,8 @@ class Actor(BasePolicy):
self.full_std = full_std
self.clip_mean = clip_mean
self.latent_dim_sde = latent_dim_sde
if sde_net_arch is not None:
warnings.warn(
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
@ -111,7 +114,7 @@ class Actor(BasePolicy):
self.action_dist = UniversalGaussianDistribution(
action_dim, **dist_kwargs)
self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp(
latent_dim=last_layer_dim, latent_sde_dim=self.latent_dim_sde or last_layer_dim, std_init=math.exp(
self.log_std_init)
)
# self.action_dist = StateDependentNoiseDistribution(
@ -274,6 +277,7 @@ class SACPolicy(BasePolicy):
n_critics: int = 2,
share_features_extractor: bool = True,
dist_kwargs={},
latent_dim_sde=None,
):
super().__init__(
observation_space,
@ -316,6 +320,7 @@ class SACPolicy(BasePolicy):
}
self.actor_kwargs.update(sde_kwargs)
self.actor.kwargs.update({'latent_dim_sde': latent_dim_sde})
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{