Implemented sqrt-induced-gaussian for W2-Projection
This commit is contained in:
parent
fcd9953b37
commit
508ebf51f0
@ -206,6 +206,21 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
return mean_actions, chol
|
return mean_actions, chol
|
||||||
|
|
||||||
|
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||||
|
"""
|
||||||
|
Create the distribution given its parameters (mean, cov_sqrt)
|
||||||
|
|
||||||
|
:param mean_actions:
|
||||||
|
:param cov_sqrt:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cov = cov_sqrt.T @ cov_sqrt
|
||||||
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
|
self.cov_sqrt = cov_sqrt
|
||||||
|
|
||||||
|
return self.proba_distribution(mean_actions, chol, latent_pi)
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
Create the distribution given its parameters (mean, chol)
|
Create the distribution given its parameters (mean, chol)
|
||||||
|
@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
NatureCNN,
|
NatureCNN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
|
||||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||||
|
|
||||||
|
|
||||||
@ -286,6 +288,17 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
"""
|
"""
|
||||||
mean_actions = self.action_net(latent_pi)
|
mean_actions = self.action_net(latent_pi)
|
||||||
|
|
||||||
|
if isinstance(self.projection, WassersteinProjectionLayer):
|
||||||
|
if isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
|
cov_sqrt = self.chol_net(latent_pi)
|
||||||
|
dist = self.action_dist.proba_distribution_from_sqrt(
|
||||||
|
mean_actions, cov_sqrt, latent_pi)
|
||||||
|
self.chol = dist.chol
|
||||||
|
return dist
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)')
|
||||||
|
|
||||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||||
|
Loading…
Reference in New Issue
Block a user