Fixing bugs with w2 and sqrt_induced_gaussian
This commit is contained in:
parent
802094a50f
commit
75d73049b4
@ -184,6 +184,16 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
def new_dist_like_me_from_sqrt(self, mean: th.Tensor, cov_sqrt: th.Tensor):
|
||||||
|
chol = self._sqrt_to_chol(cov_sqrt)
|
||||||
|
|
||||||
|
new = self.new_dist_like_me(mean, chol)
|
||||||
|
|
||||||
|
new.cov_sqrt = cov_sqrt
|
||||||
|
new.distribution.cov_sqrt = cov_sqrt
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||||
"""
|
"""
|
||||||
Create the layers and parameter that represent the distribution:
|
Create the layers and parameter that represent the distribution:
|
||||||
@ -206,6 +216,22 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
return mean_actions, chol
|
return mean_actions, chol
|
||||||
|
|
||||||
|
def _sqrt_to_chol(self, cov_sqrt):
|
||||||
|
vec = False
|
||||||
|
if len(cov_sqrt.shape) == 2:
|
||||||
|
vec = True
|
||||||
|
|
||||||
|
if vec:
|
||||||
|
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||||
|
|
||||||
|
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||||
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
|
if vec:
|
||||||
|
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
return chol
|
||||||
|
|
||||||
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
Create the distribution given its parameters (mean, cov_sqrt)
|
Create the distribution given its parameters (mean, cov_sqrt)
|
||||||
@ -214,12 +240,11 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param cov_sqrt:
|
:param cov_sqrt:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
cov = cov_sqrt.T @ cov_sqrt
|
|
||||||
chol = th.linalg.cholesky(cov)
|
|
||||||
|
|
||||||
self.cov_sqrt = cov_sqrt
|
self.cov_sqrt = cov_sqrt
|
||||||
|
chol = self._sqrt_to_chol(cov_sqrt)
|
||||||
return self.proba_distribution(mean_actions, chol, latent_pi)
|
self.proba_distribution(mean_actions, chol, latent_pi)
|
||||||
|
self.distribution.cov_sqrt = cov_sqrt
|
||||||
|
return self
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
|
@ -19,17 +19,16 @@ def get_mean_and_chol(p: AnyDistribution, expand=False):
|
|||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False):
|
||||||
if isinstance(p, UniversalGaussianDistribution):
|
if not hasattr(p, 'cov_sqrt'):
|
||||||
if not hasattr(p, 'cov_sqrt'):
|
raise Exception(
|
||||||
raise Exception(
|
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
|
||||||
'Distribution was not induced from sqrt. On-demand calculation is not supported.')
|
|
||||||
else:
|
|
||||||
mean, chol = get_mean_and_chol(p)
|
|
||||||
sqrt_cov = p.cov_sqrt
|
|
||||||
return mean, sqrt_cov
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Dist-Type not implemented')
|
mean, chol = get_mean_and_chol(p, expand=False)
|
||||||
|
sqrt_cov = p.cov_sqrt
|
||||||
|
if expand and len(sqrt_cov.shape) == 2:
|
||||||
|
sqrt_cov = th.diag_embed(sqrt_cov)
|
||||||
|
return mean, sqrt_cov
|
||||||
|
|
||||||
|
|
||||||
def get_cov(p: AnyDistribution):
|
def get_cov(p: AnyDistribution):
|
||||||
@ -97,3 +96,32 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
|||||||
return p_out
|
return p_out
|
||||||
else:
|
else:
|
||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
|
def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor):
|
||||||
|
chol = _sqrt_to_chol(cov_sqrt)
|
||||||
|
|
||||||
|
new = new_dist_like(orig_p, mean, chol)
|
||||||
|
|
||||||
|
new.cov_sqrt = cov_sqrt
|
||||||
|
if hasattr(new, 'distribution'):
|
||||||
|
new.distribution.cov_sqrt = cov_sqrt
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def _sqrt_to_chol(cov_sqrt):
|
||||||
|
vec = False
|
||||||
|
if len(cov_sqrt.shape) == 2:
|
||||||
|
vec = True
|
||||||
|
|
||||||
|
if vec:
|
||||||
|
cov_sqrt = th.diag_embed(cov_sqrt)
|
||||||
|
|
||||||
|
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
|
||||||
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
|
if vec:
|
||||||
|
chol = th.diagonal(chol, dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
return chol
|
||||||
|
@ -99,6 +99,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
sqrt_induced_gaussian=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
@ -152,6 +153,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
self.use_sde = use_sde
|
self.use_sde = use_sde
|
||||||
self.dist_kwargs = dist_kwargs
|
self.dist_kwargs = dist_kwargs
|
||||||
|
|
||||||
|
self.sqrt_induced_gaussian = sqrt_induced_gaussian
|
||||||
|
|
||||||
# Action distribution
|
# Action distribution
|
||||||
self.action_dist = make_proba_distribution(
|
self.action_dist = make_proba_distribution(
|
||||||
action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||||
@ -289,18 +292,6 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
"""
|
"""
|
||||||
mean_actions = self.action_net(latent_pi)
|
mean_actions = self.action_net(latent_pi)
|
||||||
|
|
||||||
if isinstance(self.projection, WassersteinProjectionLayer):
|
|
||||||
if isinstance(self.action_dist, UniversalGaussianDistribution):
|
|
||||||
cov_sqrt = self.chol_net(latent_pi)
|
|
||||||
dist = self.action_dist.proba_distribution_from_sqrt(
|
|
||||||
mean_actions, cov_sqrt, latent_pi)
|
|
||||||
mean, chol = get_mean_and_chol(dist, expand=False)
|
|
||||||
self.chol = chol
|
|
||||||
return dist
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
'Need to use UniversalGaussianDistribution to use WassersteinProjection (uses sqrt-induced-cov)')
|
|
||||||
|
|
||||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||||
@ -315,9 +306,17 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||||
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
chol = self.chol_net(latent_pi)
|
if self.sqrt_induced_gaussian:
|
||||||
self.chol = chol
|
cov_sqrt = self.chol_net(latent_pi)
|
||||||
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
dist = self.action_dist.proba_distribution_from_sqrt(
|
||||||
|
mean_actions, cov_sqrt, latent_pi)
|
||||||
|
mean, chol = get_mean_and_chol(dist, expand=False)
|
||||||
|
self.chol = chol
|
||||||
|
return dist
|
||||||
|
else:
|
||||||
|
chol = self.chol_net(latent_pi)
|
||||||
|
self.chol = chol
|
||||||
|
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action distribution")
|
raise ValueError("Invalid action distribution")
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import BaseCallback
|
|||||||
from stable_baselines3.common.utils import obs_as_tensor
|
from stable_baselines3.common.utils import obs_as_tensor
|
||||||
from stable_baselines3.common.vec_env import VecNormalize
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
from ..misc.distTools import new_dist_like
|
from ..misc.distTools import new_dist_like, new_dist_like_from_sqrt
|
||||||
|
|
||||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
@ -133,7 +133,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
tensorboard_log=tensorboard_log,
|
tensorboard_log=tensorboard_log,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs |
|
||||||
|
{'sqrt_induced_gaussian': isinstance(
|
||||||
|
projection, WassersteinProjectionLayer)},
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
device=device,
|
device=device,
|
||||||
create_eval_env=create_eval_env,
|
create_eval_env=create_eval_env,
|
||||||
@ -245,8 +247,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
p_dist = p.distribution
|
p_dist = p.distribution
|
||||||
q_dist = new_dist_like(
|
if isinstance(self.projection, WassersteinProjectionLayer):
|
||||||
p_dist, rollout_data.means, rollout_data.chols)
|
q_dist = new_dist_like_from_sqrt(
|
||||||
|
p_dist, rollout_data.means, rollout_data.chols)
|
||||||
|
else:
|
||||||
|
q_dist = new_dist_like(
|
||||||
|
p_dist, rollout_data.means, rollout_data.chols)
|
||||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
if isinstance(p_dist, th.distributions.Normal):
|
if isinstance(p_dist, th.distributions.Normal):
|
||||||
# Normal uses a weird mapping from dimensions into batch_shape
|
# Normal uses a weird mapping from dimensions into batch_shape
|
||||||
|
Loading…
Reference in New Issue
Block a user