Fixed SDE: sampling had dimension mismatches
This commit is contained in:
parent
e1c59cffd0
commit
28d0c609bc
@ -249,7 +249,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.distribution.cov_sqrt = cov_sqrt
|
self.distribution.cov_sqrt = cov_sqrt
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: nn.Module) -> "UniversalGaussianDistribution":
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: th.Tensor) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
Create the distribution given its parameters (mean, chol)
|
Create the distribution given its parameters (mean, chol)
|
||||||
|
|
||||||
@ -300,12 +300,18 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
return self.distribution.entropy()
|
return self.distribution.entropy()
|
||||||
|
|
||||||
def sample(self) -> th.Tensor:
|
def sample(self) -> th.Tensor:
|
||||||
|
if self.use_sde:
|
||||||
|
return self._sample_sde()
|
||||||
|
else:
|
||||||
|
return self._sample_normal()
|
||||||
|
|
||||||
|
def _sample_normal(self) -> th.Tensor:
|
||||||
# Reparametrization trick to pass gradients
|
# Reparametrization trick to pass gradients
|
||||||
sample = self.distribution.rsample()
|
sample = self.distribution.rsample()
|
||||||
self.gaussian_actions = sample
|
self.gaussian_actions = sample
|
||||||
return self.prob_squashing_type.apply(sample)
|
return self.prob_squashing_type.apply(sample)
|
||||||
|
|
||||||
def sample_sde(self) -> th.Tensor:
|
def _sample_sde(self) -> th.Tensor:
|
||||||
noise = self.get_noise(self._latent_sde)
|
noise = self.get_noise(self._latent_sde)
|
||||||
actions = self.distribution.mean + noise
|
actions = self.distribution.mean + noise
|
||||||
self.gaussian_actions = actions
|
self.gaussian_actions = actions
|
||||||
@ -334,7 +340,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
log_prob = self.log_prob(actions, self.gaussian_actions)
|
log_prob = self.log_prob(actions, self.gaussian_actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
def sample_weights(self, num_dims, batch_size=1):
|
def sample_weights(self, batch_size=1):
|
||||||
|
num_dims = (self.latent_sde_dim, self.action_dim)
|
||||||
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
|
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
|
||||||
# Reparametrization trick to pass gradients
|
# Reparametrization trick to pass gradients
|
||||||
self.exploration_mat = self.weights_dist.rsample()
|
self.exploration_mat = self.weights_dist.rsample()
|
||||||
@ -345,15 +352,16 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||||
# # TODO: Good idea?
|
# # TODO: Good idea?
|
||||||
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
||||||
chol = self.distribution.scale_tril
|
|
||||||
# Default case: only one exploration matrix
|
# Default case: only one exploration matrix
|
||||||
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
||||||
return th.mm(chol, th.mm(latent_sde, self.exploration_mat))
|
chol = th.diag_embed(self.distribution.stddev)
|
||||||
|
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
|
||||||
|
chol = self.distribution.scale_tril
|
||||||
# Use batch matrix multiplication for efficient computation
|
# Use batch matrix multiplication for efficient computation
|
||||||
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
||||||
latent_sde = latent_sde.unsqueeze(dim=1)
|
latent_sde = latent_sde.unsqueeze(dim=1)
|
||||||
# (batch_size, 1, n_actions)
|
# (batch_size, 1, n_actions)
|
||||||
noise = th.bmm(chol, th.bmm(latent_sde, self.exploration_matrices))
|
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
|
||||||
return noise.squeeze(dim=1)
|
return noise.squeeze(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,8 +203,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self.action_dist, UniversalGaussianDistribution):
|
self.action_dist, UniversalGaussianDistribution):
|
||||||
self.action_dist.sample_weights(
|
self.action_dist.sample_weights(batch_size=n_envs)
|
||||||
get_action_dim(self.action_space), batch_size=n_envs)
|
|
||||||
|
|
||||||
def _build_mlp_extractor(self) -> None:
|
def _build_mlp_extractor(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -185,8 +185,7 @@ class Actor(BasePolicy):
|
|||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self.action_dist, UniversalGaussianDistribution):
|
self.action_dist, UniversalGaussianDistribution):
|
||||||
self.action_dist.sample_weights(
|
self.action_dist.sample_weights(batch_size=n_envs)
|
||||||
get_action_dim(self.action_space), batch_size=n_envs)
|
|
||||||
|
|
||||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user