Work on Contextual Covariances
This commit is contained in:
parent
aacacebfc4
commit
e09950b30c
@ -11,6 +11,7 @@ from stable_baselines3.common.distributions import sum_independent_dims
|
|||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||||
|
|
||||||
|
from ..misc.fakeModule import FakeModule
|
||||||
|
|
||||||
# TODO: Full Cov Parameter
|
# TODO: Full Cov Parameter
|
||||||
# TODO: Contextual Cov
|
# TODO: Contextual Cov
|
||||||
@ -22,6 +23,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
|||||||
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
||||||
# TODO: (Support Squased Dists (tanh))
|
# TODO: (Support Squased Dists (tanh))
|
||||||
|
|
||||||
|
|
||||||
class Strength(Enum):
|
class Strength(Enum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
SCALAR = 1
|
SCALAR = 1
|
||||||
@ -65,7 +67,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
self.distribution = None
|
self.distribution = None
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
|
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||||
"""
|
"""
|
||||||
Create the layers and parameter that represent the distribution:
|
Create the layers and parameter that represent the distribution:
|
||||||
one output will be the mean of the Gaussian, the other parameter will be the
|
one output will be the mean of the Gaussian, the other parameter will be the
|
||||||
@ -73,50 +75,71 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:return:
|
:return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: Rename pseudo_cov to pseudo_chol
|
||||||
|
|
||||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||||
|
|
||||||
if self.par_strength == Strength.NONE:
|
if self.par_strength == Strength.NONE:
|
||||||
if self.cov_strength == Strength.NONE:
|
if self.cov_strength == Strength.NONE:
|
||||||
pseudo_cov = th.ones(self.action_dim) * log_std_init
|
pseudo_cov_par = th.ones(self.action_dim) * log_std_init
|
||||||
elif self.cov_strength == Strength.SCALAR:
|
elif self.cov_strength == Strength.SCALAR:
|
||||||
pseudo_cov = th.ones(self.action_dim) * \
|
pseudo_cov_par = th.ones(self.action_dim) * \
|
||||||
nn.Parameter(log_std_init, requires_grad=True)
|
nn.Parameter(log_std_init, requires_grad=True)
|
||||||
elif self.cov_strength == Strength.DIAG:
|
elif self.cov_strength == Strength.DIAG:
|
||||||
pseudo_cov = nn.Parameter(
|
pseudo_cov_par = nn.Parameter(
|
||||||
th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
||||||
elif self.cov_strength == Strength.FULL:
|
elif self.cov_strength == Strength.FULL:
|
||||||
# Off-axis init?
|
# TODO: This won't work, need to ensure SPD!
|
||||||
pseudo_cov = nn.Parameter(
|
# TODO: Off-axis init?
|
||||||
|
pseudo_cov_par = nn.Parameter(
|
||||||
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
||||||
|
pseudo_cov = FakeModule(pseudo_cov_par)
|
||||||
elif self.par_strength == self.cov_strength:
|
elif self.par_strength == self.cov_strength:
|
||||||
if self.par_strength == Strength.NONE:
|
if self.par_strength == Strength.NONE:
|
||||||
pseudo_cov = th.ones(self.action_dim)
|
pseudo_cov = FakeModule(th.ones(self.action_dim))
|
||||||
elif self.par_strength == Strength.SCALAR:
|
elif self.par_strength == Strength.SCALAR:
|
||||||
|
# TODO: Does it work like this? Test!
|
||||||
std = nn.Linear(latent_dim, 1)
|
std = nn.Linear(latent_dim, 1)
|
||||||
pseudo_cov = th.ones(self.action_dim) * std
|
pseudo_cov = th.ones(self.action_dim) * std
|
||||||
elif self.par_strength == Strength.DIAG:
|
elif self.par_strength == Strength.DIAG:
|
||||||
pseudo_cov = nn.Linear(latent_dim, self.action_dim)
|
pseudo_cov = nn.Linear(latent_dim, self.action_dim)
|
||||||
elif self.par_strength == Strength.FULL:
|
elif self.par_strength == Strength.FULL:
|
||||||
raise Exception("Don't know how to implement yet...")
|
pseudo_cov = self._parameterize_full(latent_dim)
|
||||||
elif self.par_strength > self.cov_strength:
|
elif self.par_strength > self.cov_strength:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'The parameterization can not be stronger than the actual covariance.')
|
'The parameterization can not be stronger than the actual covariance.')
|
||||||
else:
|
else:
|
||||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||||
factor = nn.Linear(latent_dim, 1)
|
pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim)
|
||||||
par_cov = th.ones(self.action_dim) * \
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
nn.Parameter(1, requires_grad=True)
|
pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim)
|
||||||
pseudo_cov = par_cov * factor[0]
|
|
||||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'That does not even make any sense...')
|
'That does not even make any sense...')
|
||||||
else:
|
else:
|
||||||
|
raise Exception("This Exception can't happen (I think)")
|
||||||
|
|
||||||
|
return mean_actions, pseudo_cov
|
||||||
|
|
||||||
|
def _parameterize_full(self, latent_dim):
|
||||||
|
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||||
|
|
||||||
return mean_actions, pseudo_cov
|
def _parameterize_hybrid_from_diag(self, latent_dim):
|
||||||
|
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
|
||||||
|
raise Exception(
|
||||||
|
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||||
|
|
||||||
|
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||||
|
factor = nn.Linear(latent_dim, 1)
|
||||||
|
par_cov = th.ones(self.action_dim) * \
|
||||||
|
nn.Parameter(1, requires_grad=True)
|
||||||
|
pseudo_cov = par_cov * factor[0]
|
||||||
|
return pseudo_cov
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
|
def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
|
20
metastable_baselines/misc/fakeModule.py
Normal file
20
metastable_baselines/misc/fakeModule.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class FakeModule(nn.Module):
|
||||||
|
"""
|
||||||
|
A torch.nn Module, that drops the input and returns a tensor given at initialization.
|
||||||
|
Gradients can pass through this Module and affect the given tensor.
|
||||||
|
"""
|
||||||
|
# In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule
|
||||||
|
|
||||||
|
def __init__(self, tensor):
|
||||||
|
super().__init__()
|
||||||
|
self.tensor = tensor
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.tensor
|
||||||
|
|
||||||
|
def string(self):
|
||||||
|
return '<FakeModule: '+str(self.tensor)+'>'
|
Loading…
Reference in New Issue
Block a user