Allow reduced latent sde dim
This commit is contained in:
parent
f421dc2ab5
commit
ffbf2b3fe5
@ -348,6 +348,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
||||
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||
latent_sde = latent_sde[..., -self.latent_sde_dim:]
|
||||
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
||||
# Default case: only one exploration matrix
|
||||
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
||||
|
@ -100,6 +100,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sqrt_induced_gaussian: bool = False,
|
||||
latent_dim_sde=None,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
@ -155,6 +156,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
self.dist_kwargs = dist_kwargs
|
||||
|
||||
self.sqrt_induced_gaussian = sqrt_induced_gaussian
|
||||
self.latent_dim_sde = latent_dim_sde
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(
|
||||
@ -244,7 +246,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp(
|
||||
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
|
||||
self.log_std_init)
|
||||
)
|
||||
else:
|
||||
|
@ -69,6 +69,7 @@ class Actor(BasePolicy):
|
||||
clip_mean: float = 2.0,
|
||||
normalize_images: bool = True,
|
||||
dist_kwargs={},
|
||||
latent_dim_sde=None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
@ -90,6 +91,8 @@ class Actor(BasePolicy):
|
||||
self.full_std = full_std
|
||||
self.clip_mean = clip_mean
|
||||
|
||||
self.latent_dim_sde = latent_dim_sde
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn(
|
||||
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||
@ -111,7 +114,7 @@ class Actor(BasePolicy):
|
||||
self.action_dist = UniversalGaussianDistribution(
|
||||
action_dim, **dist_kwargs)
|
||||
self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp(
|
||||
latent_dim=last_layer_dim, latent_sde_dim=self.latent_dim_sde or last_layer_dim, std_init=math.exp(
|
||||
self.log_std_init)
|
||||
)
|
||||
# self.action_dist = StateDependentNoiseDistribution(
|
||||
@ -274,6 +277,7 @@ class SACPolicy(BasePolicy):
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
dist_kwargs={},
|
||||
latent_dim_sde=None,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
@ -316,6 +320,7 @@ class SACPolicy(BasePolicy):
|
||||
}
|
||||
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
self.actor.kwargs.update({'latent_dim_sde': latent_dim_sde})
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update(
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user