Implemented sqrt-induced-gaussian for W2-Projection

This commit is contained in:
Dominik Moritz Roth 2022-08-06 14:46:42 +02:00
parent fcd9953b37
commit 508ebf51f0
2 changed files with 28 additions and 0 deletions

View File

@ -206,6 +206,21 @@ class UniversalGaussianDistribution(SB3_Distribution):
return mean_actions, chol
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
"""
Create the distribution given its parameters (mean, cov_sqrt)
:param mean_actions:
:param cov_sqrt:
:return:
"""
cov = cov_sqrt.T @ cov_sqrt
chol = th.linalg.cholesky(cov)
self.cov_sqrt = cov_sqrt
return self.proba_distribution(mean_actions, chol, latent_pi)
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
"""
Create the distribution given its parameters (mean, chol)

View File

@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import (
NatureCNN,
)
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
@ -286,6 +288,17 @@ class ActorCriticPolicy(BasePolicy):
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.projection, WassersteinProjectionLayer):
if isinstance(self.action_dist, UniversalGaussianDistribution):
cov_sqrt = self.chol_net(latent_pi)
dist = self.action_dist.proba_distribution_from_sqrt(
mean_actions, cov_sqrt, latent_pi)
self.chol = dist.chol
return dist
else:
raise Exception(
'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)')
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):