Enabled w2 (can now get sqrt from dist)
This commit is contained in:
parent
508ebf51f0
commit
802094a50f
@ -20,13 +20,14 @@ def get_mean_and_chol(p: AnyDistribution, expand=False):
|
|||||||
|
|
||||||
|
|
||||||
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
||||||
raise Exception('Not yet implemented...')
|
if isinstance(p, UniversalGaussianDistribution):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if not hasattr(p, 'cov_sqrt'):
|
||||||
return p.mean, p.stddev
|
raise Exception(
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
|
||||||
return p.mean, p.scale_tril
|
else:
|
||||||
elif isinstance(p, SB3_Distribution):
|
mean, chol = get_mean_and_chol(p)
|
||||||
return get_mean_and_chol(p.distribution)
|
sqrt_cov = p.cov_sqrt
|
||||||
|
return mean, sqrt_cov
|
||||||
else:
|
else:
|
||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
|
||||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||||
|
from ..misc.distTools import get_mean_and_chol
|
||||||
|
|
||||||
|
|
||||||
class ActorCriticPolicy(BasePolicy):
|
class ActorCriticPolicy(BasePolicy):
|
||||||
@ -293,7 +294,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
cov_sqrt = self.chol_net(latent_pi)
|
cov_sqrt = self.chol_net(latent_pi)
|
||||||
dist = self.action_dist.proba_distribution_from_sqrt(
|
dist = self.action_dist.proba_distribution_from_sqrt(
|
||||||
mean_actions, cov_sqrt, latent_pi)
|
mean_actions, cov_sqrt, latent_pi)
|
||||||
self.chol = dist.chol
|
mean, chol = get_mean_and_chol(dist, expand=False)
|
||||||
|
self.chol = chol
|
||||||
return dist
|
return dist
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
Loading…
Reference in New Issue
Block a user