Implemented sqrt-induced-gaussian for W2-Projection
This commit is contained in:
parent
fcd9953b37
commit
508ebf51f0
@ -206,6 +206,21 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
return mean_actions, chol
|
||||
|
||||
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||
"""
|
||||
Create the distribution given its parameters (mean, cov_sqrt)
|
||||
|
||||
:param mean_actions:
|
||||
:param cov_sqrt:
|
||||
:return:
|
||||
"""
|
||||
cov = cov_sqrt.T @ cov_sqrt
|
||||
chol = th.linalg.cholesky(cov)
|
||||
|
||||
self.cov_sqrt = cov_sqrt
|
||||
|
||||
return self.proba_distribution(mean_actions, chol, latent_pi)
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||
"""
|
||||
Create the distribution given its parameters (mean, chol)
|
||||
|
@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import (
|
||||
NatureCNN,
|
||||
)
|
||||
|
||||
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||
|
||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||
|
||||
|
||||
@ -286,6 +288,17 @@ class ActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.projection, WassersteinProjectionLayer):
|
||||
if isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||
cov_sqrt = self.chol_net(latent_pi)
|
||||
dist = self.action_dist.proba_distribution_from_sqrt(
|
||||
mean_actions, cov_sqrt, latent_pi)
|
||||
self.chol = dist.chol
|
||||
return dist
|
||||
else:
|
||||
raise Exception(
|
||||
'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)')
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
|
Loading…
Reference in New Issue
Block a user