Enabled w2 (can now get sqrt from dist)

This commit is contained in:
Dominik Moritz Roth 2022-08-06 14:54:59 +02:00
parent 508ebf51f0
commit 802094a50f
2 changed files with 11 additions and 8 deletions

View File

@ -20,13 +20,14 @@ def get_mean_and_chol(p: AnyDistribution, expand=False):
def get_mean_and_sqrt(p: UniversalGaussianDistribution): def get_mean_and_sqrt(p: UniversalGaussianDistribution):
raise Exception('Not yet implemented...') if isinstance(p, UniversalGaussianDistribution):
if isinstance(p, th.distributions.Normal): if not hasattr(p, 'cov_sqrt'):
return p.mean, p.stddev raise Exception(
elif isinstance(p, th.distributions.MultivariateNormal): 'Distribution was not induced from sqrt. On-demand calculation is not supported.')
return p.mean, p.scale_tril else:
elif isinstance(p, SB3_Distribution): mean, chol = get_mean_and_chol(p)
return get_mean_and_chol(p.distribution) sqrt_cov = p.cov_sqrt
return mean, sqrt_cov
else: else:
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')

View File

@ -38,6 +38,7 @@ from stable_baselines3.common.torch_layers import (
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
from ..distributions import UniversalGaussianDistribution, make_proba_distribution from ..distributions import UniversalGaussianDistribution, make_proba_distribution
from ..misc.distTools import get_mean_and_chol
class ActorCriticPolicy(BasePolicy): class ActorCriticPolicy(BasePolicy):
@ -293,7 +294,8 @@ class ActorCriticPolicy(BasePolicy):
cov_sqrt = self.chol_net(latent_pi) cov_sqrt = self.chol_net(latent_pi)
dist = self.action_dist.proba_distribution_from_sqrt( dist = self.action_dist.proba_distribution_from_sqrt(
mean_actions, cov_sqrt, latent_pi) mean_actions, cov_sqrt, latent_pi)
self.chol = dist.chol mean, chol = get_mean_and_chol(dist, expand=False)
self.chol = chol
return dist return dist
else: else:
raise Exception( raise Exception(