Fixed SDE: sampling had dimension mismatches

This commit is contained in:
Dominik Moritz Roth 2022-08-14 20:09:10 +02:00
parent e1c59cffd0
commit 28d0c609bc
3 changed files with 16 additions and 10 deletions

View File

@ -249,7 +249,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.distribution.cov_sqrt = cov_sqrt self.distribution.cov_sqrt = cov_sqrt
return self return self
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: nn.Module) -> "UniversalGaussianDistribution": def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: th.Tensor) -> "UniversalGaussianDistribution":
""" """
Create the distribution given its parameters (mean, chol) Create the distribution given its parameters (mean, chol)
@ -300,12 +300,18 @@ class UniversalGaussianDistribution(SB3_Distribution):
return self.distribution.entropy() return self.distribution.entropy()
def sample(self) -> th.Tensor: def sample(self) -> th.Tensor:
if self.use_sde:
return self._sample_sde()
else:
return self._sample_normal()
def _sample_normal(self) -> th.Tensor:
# Reparametrization trick to pass gradients # Reparametrization trick to pass gradients
sample = self.distribution.rsample() sample = self.distribution.rsample()
self.gaussian_actions = sample self.gaussian_actions = sample
return self.prob_squashing_type.apply(sample) return self.prob_squashing_type.apply(sample)
def sample_sde(self) -> th.Tensor: def _sample_sde(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde) noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise actions = self.distribution.mean + noise
self.gaussian_actions = actions self.gaussian_actions = actions
@ -334,7 +340,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
log_prob = self.log_prob(actions, self.gaussian_actions) log_prob = self.log_prob(actions, self.gaussian_actions)
return actions, log_prob return actions, log_prob
def sample_weights(self, num_dims, batch_size=1): def sample_weights(self, batch_size=1):
num_dims = (self.latent_sde_dim, self.action_dim)
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims)) self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
# Reparametrization trick to pass gradients # Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample() self.exploration_mat = self.weights_dist.rsample()
@ -345,15 +352,16 @@ class UniversalGaussianDistribution(SB3_Distribution):
latent_sde = latent_sde if self.learn_features else latent_sde.detach() latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# # TODO: Good idea? # # TODO: Good idea?
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
chol = self.distribution.scale_tril
# Default case: only one exploration matrix # Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return th.mm(chol, th.mm(latent_sde, self.exploration_mat)) chol = th.diag_embed(self.distribution.stddev)
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
chol = self.distribution.scale_tril
# Use batch matrix multiplication for efficient computation # Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features) # (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1) latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions) # (batch_size, 1, n_actions)
noise = th.bmm(chol, th.bmm(latent_sde, self.exploration_matrices)) noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
return noise.squeeze(dim=1) return noise.squeeze(dim=1)

View File

@ -203,8 +203,7 @@ class ActorCriticPolicy(BasePolicy):
if isinstance( if isinstance(
self.action_dist, UniversalGaussianDistribution): self.action_dist, UniversalGaussianDistribution):
self.action_dist.sample_weights( self.action_dist.sample_weights(batch_size=n_envs)
get_action_dim(self.action_space), batch_size=n_envs)
def _build_mlp_extractor(self) -> None: def _build_mlp_extractor(self) -> None:
""" """

View File

@ -185,8 +185,7 @@ class Actor(BasePolicy):
if isinstance( if isinstance(
self.action_dist, UniversalGaussianDistribution): self.action_dist, UniversalGaussianDistribution):
self.action_dist.sample_weights( self.action_dist.sample_weights(batch_size=n_envs)
get_action_dim(self.action_space), batch_size=n_envs)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
""" """