Allow reduced latent sde dim
This commit is contained in:
parent
f421dc2ab5
commit
ffbf2b3fe5
@ -348,6 +348,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
||||||
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||||
|
latent_sde = latent_sde[..., -self.latent_sde_dim:]
|
||||||
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
||||||
# Default case: only one exploration matrix
|
# Default case: only one exploration matrix
|
||||||
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
||||||
|
@ -100,6 +100,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
sqrt_induced_gaussian: bool = False,
|
sqrt_induced_gaussian: bool = False,
|
||||||
|
latent_dim_sde=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
@ -155,6 +156,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
self.dist_kwargs = dist_kwargs
|
self.dist_kwargs = dist_kwargs
|
||||||
|
|
||||||
self.sqrt_induced_gaussian = sqrt_induced_gaussian
|
self.sqrt_induced_gaussian = sqrt_induced_gaussian
|
||||||
|
self.latent_dim_sde = latent_dim_sde
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_proba_distribution(
|
self.action_dist = make_proba_distribution(
|
||||||
@ -244,7 +246,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
latent_dim=latent_dim_pi)
|
latent_dim=latent_dim_pi)
|
||||||
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||||
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, std_init=math.exp(
|
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
|
||||||
self.log_std_init)
|
self.log_std_init)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -69,6 +69,7 @@ class Actor(BasePolicy):
|
|||||||
clip_mean: float = 2.0,
|
clip_mean: float = 2.0,
|
||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
dist_kwargs={},
|
dist_kwargs={},
|
||||||
|
latent_dim_sde=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
@ -90,6 +91,8 @@ class Actor(BasePolicy):
|
|||||||
self.full_std = full_std
|
self.full_std = full_std
|
||||||
self.clip_mean = clip_mean
|
self.clip_mean = clip_mean
|
||||||
|
|
||||||
|
self.latent_dim_sde = latent_dim_sde
|
||||||
|
|
||||||
if sde_net_arch is not None:
|
if sde_net_arch is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
|
||||||
@ -111,7 +114,7 @@ class Actor(BasePolicy):
|
|||||||
self.action_dist = UniversalGaussianDistribution(
|
self.action_dist = UniversalGaussianDistribution(
|
||||||
action_dim, **dist_kwargs)
|
action_dim, **dist_kwargs)
|
||||||
self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
|
self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||||
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp(
|
latent_dim=last_layer_dim, latent_sde_dim=self.latent_dim_sde or last_layer_dim, std_init=math.exp(
|
||||||
self.log_std_init)
|
self.log_std_init)
|
||||||
)
|
)
|
||||||
# self.action_dist = StateDependentNoiseDistribution(
|
# self.action_dist = StateDependentNoiseDistribution(
|
||||||
@ -274,6 +277,7 @@ class SACPolicy(BasePolicy):
|
|||||||
n_critics: int = 2,
|
n_critics: int = 2,
|
||||||
share_features_extractor: bool = True,
|
share_features_extractor: bool = True,
|
||||||
dist_kwargs={},
|
dist_kwargs={},
|
||||||
|
latent_dim_sde=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
observation_space,
|
observation_space,
|
||||||
@ -316,6 +320,7 @@ class SACPolicy(BasePolicy):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.actor_kwargs.update(sde_kwargs)
|
self.actor_kwargs.update(sde_kwargs)
|
||||||
|
self.actor.kwargs.update({'latent_dim_sde': latent_dim_sde})
|
||||||
self.critic_kwargs = self.net_args.copy()
|
self.critic_kwargs = self.net_args.copy()
|
||||||
self.critic_kwargs.update(
|
self.critic_kwargs.update(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user