Implemented prior conditioned annealing (untested)

This commit is contained in:
Dominik Moritz Roth 2023-04-25 17:05:34 +02:00
parent 09159774d9
commit 76ea3a6326
4 changed files with 694 additions and 6 deletions

View File

@ -15,7 +15,7 @@ from stable_baselines3.common.distributions import (
BernoulliDistribution, BernoulliDistribution,
CategoricalDistribution, CategoricalDistribution,
MultiCategoricalDistribution, MultiCategoricalDistribution,
# StateDependentNoiseDistribution, # StateDependentNoiseDistribution,
) )
from stable_baselines3.common.distributions import DiagGaussianDistribution from stable_baselines3.common.distributions import DiagGaussianDistribution
@ -23,6 +23,8 @@ from ..misc.tensor_ops import fill_triangular
from ..misc.tanhBijector import TanhBijector from ..misc.tanhBijector import TanhBijector
from ..misc import givens from ..misc import givens
from pca import PCA_Distribution
class Strength(Enum): class Strength(Enum):
NONE = 0 NONE = 0
@ -96,7 +98,7 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
def make_proba_distribution( def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None action_space: gym.spaces.Space, use_sde: bool = False, use_pca=False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> SB3_Distribution: ) -> SB3_Distribution:
""" """
Return an instance of Distribution for the correct type of action space Return an instance of Distribution for the correct type of action space
@ -114,7 +116,10 @@ def make_proba_distribution(
if isinstance(action_space, gym.spaces.Box): if isinstance(action_space, gym.spaces.Box):
assert len( assert len(
action_space.shape) == 1, "Error: the action space must be a vector" action_space.shape) == 1, "Error: the action space must be a vector"
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs) if use_pca:
return PCA_Distribution(get_action_dim(action_space), **dist_kwargs)
else:
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, gym.spaces.Discrete): elif isinstance(action_space, gym.spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs) return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiDiscrete): elif isinstance(action_space, gym.spaces.MultiDiscrete):
@ -632,4 +637,5 @@ class CholNet(nn.Module):
return '<CholNet />' return '<CholNet />'
AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution] AnyDistribution = Union[SB3_Distribution,
UniversalGaussianDistribution, PCA_Distribution]

View File

@ -0,0 +1,677 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum
import gym
import torch as th
from torch import nn
from torch.distributions import Normal, Independent, MultivariateNormal
from math import pi
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.distributions import sum_independent_dims
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
MultiCategoricalDistribution,
# StateDependentNoiseDistribution,
)
from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.tensor_ops import fill_triangular
from ..misc.tanhBijector import TanhBijector
from ..misc import givens
class Strength(Enum):
NONE = 0
SCALAR = 1
DIAG = 2
FULL = 3
class ParametrizationType(Enum):
NONE = 0
CHOL = 1
SPHERICAL_CHOL = 2
EIGEN = 3
EIGEN_RAW = 4
class EnforcePositiveType(Enum):
# This need to be implemented in this ugly fashion,
# because cloudpickle does not like more complex enums
NONE = 0
SOFTPLUS = 1
ABS = 2
RELU = 3
LOG = 4
def apply(self, x):
# aaaaaa
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
class ProbSquashingType(Enum):
NONE = 0
TANH = 1
def apply(self, x):
return [nn.Identity(), th.tanh][self.value](x)
def apply_inv(self, x):
return [nn.Identity(), TanhBijector.inverse][self.value](x)
def cast_to_enum(inp, Class):
if isinstance(inp, Enum):
return inp
else:
return Class[inp]
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
allowedEPTs = allowedEPTs or EnforcePositiveType
allowedParStrength = allowedParStrength or Strength
allowedCovStrength = allowedCovStrength or Strength
allowedPTs = allowedPTs or ParametrizationType
allowedPSTs = allowedPSTs or ProbSquashingType
for ps in allowedParStrength:
for cs in allowedCovStrength:
if ps.value > cs.value:
continue
if cs == Strength.NONE:
yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE)
else:
for ept in allowedEPTs:
if cs == Strength.FULL:
for pt in allowedPTs:
if pt != ParametrizationType.NONE:
yield (ps, cs, ept, pt)
else:
yield (ps, cs, ept, ParametrizationType.NONE)
def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> SB3_Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
dist_kwargs['use_sde'] = use_sde
if isinstance(action_space, gym.spaces.Box):
assert len(
action_space.shape) == 1, "Error: the action space must be a vector"
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, gym.spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
class UniversalGaussianDistribution(SB3_Distribution):
"""
Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions.
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-3, sde_learn_features=False, sde_latent_softmax=False, use_hybrid=False, hybrid_rex_fac=0.5, use_pca=False):
super(UniversalGaussianDistribution, self).__init__()
self.action_dim = action_dim
self.par_strength = cast_to_enum(neural_strength, Strength)
self.cov_strength = cast_to_enum(cov_strength, Strength)
self.par_type = cast_to_enum(
parameterization_type, ParametrizationType)
self.enforce_positive_type = cast_to_enum(
enforce_positive_type, EnforcePositiveType)
self.prob_squashing_type = cast_to_enum(
prob_squashing_type, ProbSquashingType)
self.epsilon = epsilon
self.distribution = None
self.gaussian_actions = None
self.use_sde = use_sde
self.use_hybrid = use_hybrid
self.hybrid_rex_fac = hybrid_rex_fac
self.learn_features = sde_learn_features
self.sde_latent_softmax = sde_latent_softmax
self.use_pca = use_pca
if self.use_hybrid:
assert self.use_sde, 'use_sde has to be set to use use_hybrid'
assert (self.par_type != ParametrizationType.NONE) == (
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception(
'You need to specify an enforce_positive_type for spherical_cholesky')
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
p = self.distribution
if isinstance(p, Independent):
if p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
np = Independent(Normal(mean, chol), 1)
elif isinstance(p, MultivariateNormal):
np = MultivariateNormal(mean, scale_tril=chol)
new = UniversalGaussianDistribution(self.action_dim, use_sde=self.use_sde, neural_strength=self.par_strength, cov_strength=self.cov_strength,
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type, epsilon=self.epsilon, sde_learn_features=self.learn_features)
new.distribution = np
return new
def new_dist_like_me_from_sqrt(self, mean: th.Tensor, cov_sqrt: th.Tensor):
chol = self._sqrt_to_chol(cov_sqrt)
new = self.new_dist_like_me(mean, chol)
new.cov_sqrt = cov_sqrt
new.distribution.cov_sqrt = cov_sqrt
return new
def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param std_init: Initial value for the standard deviation
:return: We return two nn.Modules (mean, chol). chol can be a vector if the full chol would be a diagonal.
"""
assert std_init >= 0.0, "std can not be initialized to a negative value."
self.latent_sde_dim = latent_sde_dim
mean_actions = nn.Linear(latent_dim, self.action_dim)
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon)
if self.use_sde:
self.sample_weights(self.action_dim)
return mean_actions, chol
def _sqrt_to_chol(self, cov_sqrt):
vec = self.cov_strength != Strength.FULL
batch_dims = len(cov_sqrt.shape) - 2 + 1*vec
if vec:
cov_sqrt = th.diag_embed(cov_sqrt)
if batch_dims == 0:
cov = th.mm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1])*(self.epsilon)
else:
cov = th.bmm(cov_sqrt.mT, cov_sqrt)
cov += th.eye(cov.shape[-1]).expand(cov.shape)*(self.epsilon)
chol = th.linalg.cholesky(cov)
if vec:
chol = th.diagonal(chol, dim1=-2, dim2=-1)
return chol
def proba_distribution_from_sqrt(self, mean_actions: th.Tensor, cov_sqrt: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
"""
Create the distribution given its parameters (mean, cov_sqrt)
:param mean_actions:
:param cov_sqrt:
:return:
"""
self.cov_sqrt = cov_sqrt
chol = self._sqrt_to_chol(cov_sqrt)
self.proba_distribution(mean_actions, chol, latent_pi)
self.distribution.cov_sqrt = cov_sqrt
return self
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: th.Tensor) -> "UniversalGaussianDistribution":
"""
Create the distribution given its parameters (mean, chol)
:param mean_actions:
:param chol:
:return:
"""
if self.use_sde:
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# TODO: Change variance of dist to include sde-spread
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
self.distribution = Independent(Normal(mean_actions, chol), 1)
elif self.cov_strength in [Strength.FULL]:
self.distribution = MultivariateNormal(
mean_actions, scale_tril=chol)
if self.distribution == None:
raise Exception('Unable to create torch distribution')
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
if self.prob_squashing_type == ProbSquashingType.NONE:
log_prob = self.distribution.log_prob(actions)
return log_prob
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = self.prob_squashing_type.apply_inv(actions)
log_prob = self.distribution.log_prob(gaussian_actions)
if self.prob_squashing_type == ProbSquashingType.TANH:
log_prob -= th.sum(th.log(1 - actions **
2 + self.epsilon), dim=1)
return log_prob
raise Exception()
def entropy(self) -> th.Tensor:
# TODO: This will return incorrect results when using prob-squashing
return self.distribution.entropy()
def _init_pca(self):
pass
def _apply_pca(self, mu, chol):
return mu, chol
def sample(self) -> th.Tensor:
if self.use_hybrid:
return self._sample_hybrid()
elif self.use_sde:
return self._sample_sde()
else:
return self._sample_normal()
def _standard_normal(shape, dtype, device):
if th._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return th.normal(th.zeros(shape, dtype=dtype, device=device),
th.ones(shape, dtype=dtype, device=device))
return th.empty(shape, dtype=dtype, device=device).normal_()
def _batch_mv(bmat, bvec):
return th.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
def _rsample(self, mu=None, chol=None):
if mu == None:
mu = self.distribution.loc
if chol == None:
if isinstance(self.distribution, Independent):
chol = self.distribution.scale
elif isinstance(self.distribution, MultivariateNormal):
chol = self.distribution._unbroadcasted_scale_tril
if self.use_pca:
assert isinstance(
self.distribution, Independent), 'PCA not avaible for full covariances'
mu, chol = self._apply_pca(mu, chol)
sample_shape = th.size()
shape = self.distribution._extended_shape(sample_shape)
if isinstance(self.distribution, Independent):
eps = self._standard_normal(
shape, dtype=self.loc.dtype, device=self.loc.device)
return mu + eps * chol
elif isinstance(self.distribution, MultivariateNormal):
eps = self._standard_normal(
shape, dtype=self.loc.dtype, device=self.loc.device)
return mu + self._batch_mv(chol, eps)
def _sample_normal(self) -> th.Tensor:
# Reparametrization trick to pass gradients
sample = self.distribution.rsample()
self.gaussian_actions = sample
return self.prob_squashing_type.apply(sample)
def _sample_sde(self) -> th.Tensor:
# More Reparametrization trick to pass gradients
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
self.gaussian_actions = actions
return self.prob_squashing_type.apply(actions)
def _sample_hybrid(self) -> th.Tensor:
f = self.hybrid_rex_fac
rex_sample = self.distribution.rsample()
noise = self.get_noise(self._latent_sde)
sde_sample = self.distribution.mean + noise
actions = rex_sample*f + sde_sample*(1-f)
self.gaussian_actions = actions
return self.prob_squashing_type.apply(actions)
def mode(self) -> th.Tensor:
mode = self.distribution.mean
self.gaussian_actions = mode
return self.prob_squashing_type.apply(mode)
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_sde=None) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_sde=latent_sde)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde=None) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(
mean_actions, log_std, latent_sde=latent_sde)
log_prob = self.log_prob(actions, self.gaussian_actions)
return actions, log_prob
def sample_weights(self, batch_size=1):
num_dims = (self.latent_sde_dim, self.action_dim)
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
latent_sde = latent_sde[..., -self.latent_sde_dim:]
if self.sde_latent_softmax:
latent_sde = latent_sde.softmax(-1)
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
chol = th.diag_embed(self.distribution.stddev)
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
p = self.distribution
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
chol = th.diag_embed(self.distribution.stddev)
elif isinstance(p, th.distributions.MultivariateNormal):
chol = p.scale_tril
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
return noise.squeeze(dim=1)
class CholNet(nn.Module):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon):
super().__init__()
self.latent_dim = latent_dim
self.action_dim = action_dim
self.par_strength = par_strength
self.cov_strength = cov_strength
self.par_type = par_type
self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = prob_squashing_type
self.epsilon = epsilon
self._flat_chol_len = action_dim * (action_dim + 1) // 2
if self.par_type == ParametrizationType.CHOL:
self._full_params_len = self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
self._full_params_len = self._flat_chol_len
elif self.par_type == ParametrizationType.EIGEN:
self._full_params_len = self.action_dim * 2
elif self.par_type == ParametrizationType.EIGEN_RAW:
self._full_params_len = self.action_dim * 2
self._givens_rotator = givens.Rotation(action_dim)
self._givens_ident = th.eye(action_dim)
# Yes, this is ugly.
# But I don't know how this mess could be elegantly abstracted away...
if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE:
self.chol = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR:
self.param = nn.Parameter(
th.Tensor([std_init]), requires_grad=True)
elif self.cov_strength == Strength.DIAG:
self.params = nn.Parameter(
th.ones(self.action_dim) * std_init, requires_grad=True)
elif self.cov_strength == Strength.FULL:
# TODO: Init Off-axis differently?
self.params = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.SCALAR:
self.std = nn.Linear(latent_dim, 1)
elif self.par_strength == Strength.DIAG:
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
elif self.par_strength == Strength.FULL:
self.params = nn.Linear(latent_dim, self._full_params_len)
elif self.par_strength.value > self.cov_strength.value:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
self.factor = nn.Linear(latent_dim, 1)
self.param = nn.Parameter(
th.ones(self.action_dim), requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
if self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception(
'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.')
self.stds = nn.Linear(latent_dim, self.action_dim)
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
# TODO: Init Non-zero?
self.params = nn.Parameter(
th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True)
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
self.factor = nn.Linear(latent_dim, 1)
# TODO: Init Off-axis differently?
self.params = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
else:
raise Exception("This Exception can't happen")
def forward(self, x: th.Tensor) -> th.Tensor:
# Ugly mess pt.2:
if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE:
return self.chol
elif self.cov_strength == Strength.SCALAR:
return self._ensure_positive_func(
th.ones(self.action_dim) * self.param[0])
elif self.cov_strength == Strength.DIAG:
return self._ensure_positive_func(self.params)
elif self.cov_strength == Strength.FULL:
return self._parameterize_full(self.params)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.SCALAR:
std = self.std(x)
diag_chol = th.ones(self.action_dim) * std
return self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.DIAG:
diag_chol = self.diag_chol(x)
return self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.FULL:
params = self.params(x)
return self._parameterize_full(params)
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
factor = self.factor(x)[0]
diag_chol = self._ensure_positive_func(
self.param * factor)
return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
stds = self._ensure_positive_func(self.stds(x))
smol = self._parameterize_full(self.params)
big = self.padder(smol)
pearson_cor_chol = big + th.eye(stds.shape[-1])
pearson_cor = (pearson_cor_chol.T @
pearson_cor_chol)
if len(stds.shape) > 1:
# batched operation, we need to expand
pearson_cor = pearson_cor.expand(
(stds.shape[0],)+pearson_cor.shape)
stds = stds.unsqueeze(2)
cov = stds.mT * pearson_cor * stds
chol = th.linalg.cholesky(cov)
return chol
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
# TODO: Maybe possible to improve speed and stability by multiplying with factor in cholesky-form.
factor = self._ensure_positive_func(self.factor(x))
par_chol = self._parameterize_full(self.params)
cov = (par_chol.T @ par_chol)
if len(factor) > 1:
factor = factor.unsqueeze(2)
cov = cov * factor
chol = th.linalg.cholesky(cov)
return chol
raise Exception()
def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params)
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._chol_from_flat_sphe_chol(params)
elif self.par_type == ParametrizationType.EIGEN:
return self._chol_from_givens_params(params, True)
elif self.par_type == ParametrizationType.EIGEN_RAW:
return self._chol_from_givens_params(params, False)
raise Exception()
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Make efficient more
# Note:
# We must should ensure:
# S[i,1] > 0 where i = 1..n
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
vec = self.cov_strength != Strength.FULL
batch_dims = len(sphe_chol.shape) - 2 + 1*vec
batch = batch_dims != 0
batch_shape = sphe_chol.shape[:batch_dims]
batch_shape_scalar = batch_shape + (1,)
S = sphe_chol
n = sphe_chol.shape[-1]
L = th.zeros_like(sphe_chol)
for i in range(n):
#t = 1
t = th.Tensor([1])[0]
if batch:
t = t.expand(batch_shape_scalar)
#s = ''
for j in range(i+1):
#maybe_cos = 1
maybe_cos = th.Tensor([1])[0]
if batch:
maybe_cos = maybe_cos.expand(batch_shape_scalar)
#s_maybe_cos = ''
if i != j and j < n-1 and i < n:
if batch:
maybe_cos = th.cos(th.tanh(S[:, i, j+1])*pi)
else:
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
#s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
if batch:
L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T
else:
L[i, j] = S[i, 0] * t * maybe_cos
# print('[L_'+str(i+1)+']_'+str(j+1) +
# '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
if j <= i and j < n-1 and i < n:
if batch:
tc = t.clone()
t = (tc.T * th.sin(th.tanh(S[:, i, j+1])*pi)).T
else:
t *= th.sin(th.tanh(S[i, j+1])*pi)
#s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
return L
def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x) + self.epsilon
def _ensure_diagonal_positive(self, chol):
if len(chol.shape) == 1:
# If our chol is a vector (representing a diagonal chol)
return self._ensure_positive_func(chol)
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + chol.triu(1)
def _chol_from_givens_params(self, params, bijection=False):
theta, eigenv = params[:,
:self.action_dim], params[:, self.action_dim:]
eigenv = self._ensure_positive_func(eigenv)
if bijection:
eigenv = th.cumsum(eigenv, -1)
# reverse order, oh well...
Q = th.zeros((theta.shape[0], self.action_dim,
self.action_dim), device=eigenv.device)
for b in range(theta.shape[0]):
self._givens_rotator.theta = theta[b]
Q[b] = self._givens_rotator(self._givens_ident)
Qinv = Q.transpose(dim0=-2, dim1=-1)
cov = Q * th.diag(eigenv) * Qinv
chol = th.linalg.cholesky(cov)
return chol
def string(self):
return '<CholNet />'
AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution]

View File

@ -25,7 +25,7 @@ from stable_baselines3.common.torch_layers import (
MlpExtractor, MlpExtractor,
NatureCNN, NatureCNN,
) )
from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.type_aliases import Schedules
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.torch_layers import (
@ -88,6 +88,7 @@ class ActorCriticPolicy(BasePolicy):
activation_fn: Type[nn.Module] = nn.Tanh, activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True, ortho_init: bool = True,
use_sde: bool = False, use_sde: bool = False,
use_pca: bool = False,
std_init: float = 1.0, std_init: float = 1.0,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None, sde_net_arch: Optional[List[int]] = None,
@ -153,6 +154,7 @@ class ActorCriticPolicy(BasePolicy):
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) "sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
self.use_sde = use_sde self.use_sde = use_sde
self.use_pca = use_pca
self.dist_kwargs = dist_kwargs self.dist_kwargs = dist_kwargs
self.sqrt_induced_gaussian = sqrt_induced_gaussian self.sqrt_induced_gaussian = sqrt_induced_gaussian
@ -160,7 +162,7 @@ class ActorCriticPolicy(BasePolicy):
# Action distribution # Action distribution
self.action_dist = make_proba_distribution( self.action_dist = make_proba_distribution(
action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
self._build(lr_schedule) self._build(lr_schedule)

View File

@ -62,6 +62,7 @@ class Actor(BasePolicy):
features_dim: int, features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU, activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False, use_sde: bool = False,
use_pca: bool = False,
log_std_init: float = -3, log_std_init: float = -3,
full_std: bool = True, full_std: bool = True,
sde_net_arch: Optional[List[int]] = None, sde_net_arch: Optional[List[int]] = None,
@ -79,6 +80,8 @@ class Actor(BasePolicy):
squash_output=True, squash_output=True,
) )
assert use_pca == False, 'PCA is not implemented for SAC'
# Save arguments to re-create object at loading # Save arguments to re-create object at loading
self.use_sde = use_sde self.use_sde = use_sde
self.sde_features_extractor = None self.sde_features_extractor = None