Implemented SDE
This commit is contained in:
parent
12e422aec7
commit
520dc98eb5
@ -160,8 +160,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.distribution = None
|
self.distribution = None
|
||||||
self.gaussian_actions = None
|
self.gaussian_actions = None
|
||||||
|
|
||||||
if use_sde:
|
self.use_sde = use_sde
|
||||||
raise Exception('SDE is not yet implemented')
|
|
||||||
|
|
||||||
assert (self.par_type != ParametrizationType.NONE) == (
|
assert (self.par_type != ParametrizationType.NONE) == (
|
||||||
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
@ -214,6 +213,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
||||||
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type)
|
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type)
|
||||||
|
|
||||||
|
if self.use_sde:
|
||||||
|
self.sample_weights(self.action_dim)
|
||||||
|
|
||||||
return mean_actions, chol
|
return mean_actions, chol
|
||||||
|
|
||||||
def _sqrt_to_chol(self, cov_sqrt):
|
def _sqrt_to_chol(self, cov_sqrt):
|
||||||
@ -246,7 +248,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.distribution.cov_sqrt = cov_sqrt
|
self.distribution.cov_sqrt = cov_sqrt
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
|
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_sde: nn.Module) -> "UniversalGaussianDistribution":
|
||||||
"""
|
"""
|
||||||
Create the distribution given its parameters (mean, chol)
|
Create the distribution given its parameters (mean, chol)
|
||||||
|
|
||||||
@ -254,7 +256,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param chol:
|
:param chol:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# TODO: latent_pi is for SDE, implement.
|
if self.use_sde:
|
||||||
|
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||||
|
# TODO: Change variance of dist to include sde-spread
|
||||||
|
|
||||||
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
||||||
self.distribution = Independent(Normal(mean_actions, chol), 1)
|
self.distribution = Independent(Normal(mean_actions, chol), 1)
|
||||||
@ -300,6 +304,12 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.gaussian_actions = sample
|
self.gaussian_actions = sample
|
||||||
return self.prob_squashing_type.apply(sample)
|
return self.prob_squashing_type.apply(sample)
|
||||||
|
|
||||||
|
def sample_sde(self) -> th.Tensor:
|
||||||
|
noise = self.get_noise(self._latent_sde)
|
||||||
|
actions = self.distribution.mean + noise
|
||||||
|
self.gaussian_actions = actions
|
||||||
|
return self.prob_squashing_type.apply(actions)
|
||||||
|
|
||||||
def mode(self) -> th.Tensor:
|
def mode(self) -> th.Tensor:
|
||||||
mode = self.distribution.mean
|
mode = self.distribution.mean
|
||||||
self.gaussian_actions = mode
|
self.gaussian_actions = mode
|
||||||
@ -323,6 +333,28 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
log_prob = self.log_prob(actions, self.gaussian_actions)
|
log_prob = self.log_prob(actions, self.gaussian_actions)
|
||||||
return actions, log_prob
|
return actions, log_prob
|
||||||
|
|
||||||
|
def sample_weights(self, num_dims, batch_size=1):
|
||||||
|
self.weights_dist = Normal(th.zeros(num_dims), th.ones(num_dims))
|
||||||
|
# Reparametrization trick to pass gradients
|
||||||
|
self.exploration_mat = self.weights_dist.rsample()
|
||||||
|
# Pre-compute matrices in case of parallel exploration
|
||||||
|
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
|
||||||
|
|
||||||
|
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
||||||
|
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
||||||
|
# # TODO: Good idea?
|
||||||
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
||||||
|
chol = self.distribution.scale_tril
|
||||||
|
# Default case: only one exploration matrix
|
||||||
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
||||||
|
return th.mm(chol, th.mm(latent_sde, self.exploration_mat))
|
||||||
|
# Use batch matrix multiplication for efficient computation
|
||||||
|
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
||||||
|
latent_sde = latent_sde.unsqueeze(dim=1)
|
||||||
|
# (batch_size, 1, n_actions)
|
||||||
|
noise = th.bmm(chol, th.bmm(latent_sde, self.exploration_matrices))
|
||||||
|
return noise.squeeze(dim=1)
|
||||||
|
|
||||||
|
|
||||||
class CholNet(nn.Module):
|
class CholNet(nn.Module):
|
||||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType):
|
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType):
|
||||||
|
@ -35,6 +35,8 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
NatureCNN,
|
NatureCNN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
|
||||||
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
from metastable_baselines.projections.w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
|
||||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||||
@ -196,8 +198,15 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
self.action_dist, StateDependentNoiseDistribution) or isinstance(
|
self.action_dist, StateDependentNoiseDistribution) or isinstance(
|
||||||
self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE"
|
self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE"
|
||||||
|
if isinstance(
|
||||||
|
self.action_dist, StateDependentNoiseDistribution):
|
||||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
self.action_dist, UniversalGaussianDistribution):
|
||||||
|
self.action_dist.sample_weights(
|
||||||
|
get_action_dim(self.action_space), batch_size=n_envs)
|
||||||
|
|
||||||
def _build_mlp_extractor(self) -> None:
|
def _build_mlp_extractor(self) -> None:
|
||||||
"""
|
"""
|
||||||
Create the policy and value networks.
|
Create the policy and value networks.
|
||||||
|
@ -105,7 +105,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
create_eval_env: bool = False,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
|
@ -157,16 +157,23 @@ class Actor(BasePolicy):
|
|||||||
StateDependentNoiseDistribution), msg
|
StateDependentNoiseDistribution), msg
|
||||||
return self.chol
|
return self.chol
|
||||||
|
|
||||||
def reset_noise(self, batch_size: int = 1) -> None:
|
def reset_noise(self, n_envs: int = 1) -> None:
|
||||||
"""
|
"""
|
||||||
Sample new weights for the exploration matrix, when using gSDE.
|
Sample new weights for the exploration matrix.
|
||||||
|
|
||||||
:param batch_size:
|
:param n_envs:
|
||||||
"""
|
"""
|
||||||
msg = "reset_noise() is only available when using gSDE"
|
assert isinstance(
|
||||||
assert isinstance(self.action_dist,
|
self.action_dist, StateDependentNoiseDistribution) or isinstance(
|
||||||
StateDependentNoiseDistribution), msg
|
self.action_dist, UniversalGaussianDistribution), "reset_noise() is only available when using gSDE"
|
||||||
self.action_dist.sample_weights(self.chol, batch_size=batch_size)
|
if isinstance(
|
||||||
|
self.action_dist, StateDependentNoiseDistribution):
|
||||||
|
self.action_dist.sample_weights(self.chol, batch_size=n_envs)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
self.action_dist, UniversalGaussianDistribution):
|
||||||
|
self.action_dist.sample_weights(
|
||||||
|
get_action_dim(self.action_space), batch_size=n_envs)
|
||||||
|
|
||||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user