diff --git a/icon.svg b/icon.svg
new file mode 100644
index 0000000..1658c07
--- /dev/null
+++ b/icon.svg
@@ -0,0 +1,97 @@
+
+
+
+
diff --git a/sb3_trl/__init__.py b/metastable_baselines/__init__.py
similarity index 100%
rename from sb3_trl/__init__.py
rename to metastable_baselines/__init__.py
diff --git a/metastable_baselines/distributions/__init__.py b/metastable_baselines/distributions/__init__.py
new file mode 100644
index 0000000..7f0891a
--- /dev/null
+++ b/metastable_baselines/distributions/__init__.py
@@ -0,0 +1 @@
+#TODO: License or such
diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py
new file mode 100644
index 0000000..b7c32ab
--- /dev/null
+++ b/metastable_baselines/distributions/distributions.py
@@ -0,0 +1,197 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch as th
+from torch import nn
+from torch.distributions import Normal, MultivariateNormal
+
+from stable_baselines3.common.preprocessing import get_action_dim
+
+from stable_baselines3.common.distributions import Distribution as SB3_Distribution
+from stable_baselines3.common.distributions import DiagGaussianDistribution
+
+
+class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution):
+ """
+ Gaussian distribution with diagonal covariance matrix, for continuous actions.
+ Includes contextual parametrization of the covariance matrix.
+
+ :param action_dim: Dimension of the action space.
+ """
+
+ def __init__(self, action_dim: int):
+ super(ContextualCovDiagonalGaussianDistribution, self).__init__()
+
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
+ """
+ Create the layers and parameter that represent the distribution:
+ one output will be the mean of the Gaussian, the other parameter will be the
+ standard deviation (log std in fact to allow negative values)
+
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
+ :param log_std_init: Initial value for the log standard deviation
+ :return:
+ """
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ log_std = nn.Linear(latent_dim, self.action_dim)
+ return mean_actions, log_std
+
+
+class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution):
+ """
+ Gaussian distribution induced by its sqrt(cov), for continuous actions.
+
+ :param action_dim: Dimension of the action space.
+ """
+
+ def __init__(self, action_dim: int):
+ super(DiagGaussianDistribution, self).__init__()
+ self.action_dim = action_dim
+ self.mean_actions = None
+ self.log_std = None
+
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
+ """
+ Create the layers and parameter that represent the distribution:
+ one output will be the mean of the Gaussian, the other parameter will be the
+ standard deviation (log std in fact to allow negative values)
+
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
+ :param log_std_init: Initial value for the log standard deviation
+ :return:
+ """
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ # TODO: allow action dependent std
+ log_std = nn.Linear(latent_dim, (self.action_dim, self.action_dim))
+ return mean_actions, log_std
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
+ """
+ Create the distribution given its parameters (mean, std)
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ action_std = th.ones_like(mean_actions) * log_std.exp()
+ self.distribution = Normal(mean_actions, action_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ """
+ Get the log probabilities of actions according to the distribution.
+ Note that you must first call the ``proba_distribution()`` method.
+
+ :param actions:
+ :return:
+ """
+ log_prob = self.distribution.log_prob(actions)
+ return sum_independent_dims(log_prob)
+
+ def entropy(self) -> th.Tensor:
+ return sum_independent_dims(self.distribution.entropy())
+
+ def sample(self) -> th.Tensor:
+ # Reparametrization trick to pass gradients
+ return self.distribution.rsample()
+
+ def mode(self) -> th.Tensor:
+ return self.distribution.mean
+
+ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(mean_actions, log_std)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Compute the log probability of taking an action
+ given the distribution parameters.
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ actions = self.actions_from_params(mean_actions, log_std)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class DiagGaussianDistribution(SB3_Distribution):
+ """
+ Gaussian distribution with full covariance matrix, for continuous actions.
+
+ :param action_dim: Dimension of the action space.
+ """
+
+ def __init__(self, action_dim: int):
+ super(DiagGaussianDistribution, self).__init__()
+ self.action_dim = action_dim
+ self.mean_actions = None
+ self.log_std = None
+
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
+ """
+ Create the layers and parameter that represent the distribution:
+ one output will be the mean of the Gaussian, the other parameter will be the
+ standard deviation (log std in fact to allow negative values)
+
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
+ :param log_std_init: Initial value for the log standard deviation
+ :return:
+ """
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ # TODO: allow action dependent std
+ log_std = nn.Parameter(th.ones(self.action_dim)
+ * log_std_init, requires_grad=True)
+ return mean_actions, log_std
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
+ """
+ Create the distribution given its parameters (mean, std)
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ action_std = th.ones_like(mean_actions) * log_std.exp()
+ self.distribution = Normal(mean_actions, action_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ """
+ Get the log probabilities of actions according to the distribution.
+ Note that you must first call the ``proba_distribution()`` method.
+
+ :param actions:
+ :return:
+ """
+ log_prob = self.distribution.log_prob(actions)
+ return sum_independent_dims(log_prob)
+
+ def entropy(self) -> th.Tensor:
+ return sum_independent_dims(self.distribution.entropy())
+
+ def sample(self) -> th.Tensor:
+ # Reparametrization trick to pass gradients
+ return self.distribution.rsample()
+
+ def mode(self) -> th.Tensor:
+ return self.distribution.mean
+
+ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(mean_actions, log_std)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Compute the log probability of taking an action
+ given the distribution parameters.
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ actions = self.actions_from_params(mean_actions, log_std)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
diff --git a/sb3_trl/misc/distTools.py b/metastable_baselines/misc/distTools.py
similarity index 86%
rename from sb3_trl/misc/distTools.py
rename to metastable_baselines/misc/distTools.py
index 6496ea8..633992e 100644
--- a/sb3_trl/misc/distTools.py
+++ b/metastable_baselines/misc/distTools.py
@@ -17,6 +17,18 @@ def get_mean_and_chol(p, expand=False):
raise Exception('Dist-Type not implemented')
+def get_mean_and_sqrt(p):
+ raise Exception('Not yet implemented...')
+ if isinstance(p, th.distributions.Normal):
+ return p.mean, p.stddev
+ elif isinstance(p, th.distributions.MultivariateNormal):
+ return p.mean, p.scale_tril
+ elif isinstance(p, SB3_Distribution):
+ return get_mean_and_chol(p.distribution)
+ else:
+ raise Exception('Dist-Type not implemented')
+
+
def get_cov(p):
if isinstance(p, th.distributions.Normal):
return th.diag_embed(p.variance)
diff --git a/sb3_trl/misc/norm.py b/metastable_baselines/misc/norm.py
similarity index 89%
rename from sb3_trl/misc/norm.py
rename to metastable_baselines/misc/norm.py
index 74d1f66..f40d319 100644
--- a/sb3_trl/misc/norm.py
+++ b/metastable_baselines/misc/norm.py
@@ -7,9 +7,9 @@ def mahalanobis_alt(u, v, std):
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
-def mahalanobis(u, v, cov):
+def mahalanobis(u, v, chol):
delta = u - v
- return _batch_mahalanobis(cov, delta)
+ return _batch_mahalanobis(chol, delta)
def frob_sq(diff, is_spd=False):
diff --git a/sb3_trl/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py
similarity index 100%
rename from sb3_trl/misc/rollout_buffer.py
rename to metastable_baselines/misc/rollout_buffer.py
diff --git a/metastable_baselines/projections_orig/__init__.py b/metastable_baselines/projections_orig/__init__.py
new file mode 100644
index 0000000..a578185
--- /dev/null
+++ b/metastable_baselines/projections_orig/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
diff --git a/metastable_baselines/projections_orig/base_projection_layer.py b/metastable_baselines/projections_orig/base_projection_layer.py
new file mode 100644
index 0000000..3a881af
--- /dev/null
+++ b/metastable_baselines/projections_orig/base_projection_layer.py
@@ -0,0 +1,374 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import copy
+import math
+import torch as ch
+from typing import Tuple, Union
+
+from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
+from trust_region_projections.utils.network_utils import get_optimizer
+from trust_region_projections.utils.projection_utils import gaussian_kl, get_entropy_schedule
+from trust_region_projections.utils.torch_utils import generate_minibatches, select_batch, tensorize
+
+
+def entropy_inequality_projection(policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ beta: Union[float, ch.Tensor]):
+ """
+ Projects std to satisfy an entropy INEQUALITY constraint.
+ Args:
+ policy: policy instance
+ p: current distribution
+ beta: target entropy for EACH std or general bound for all stds
+
+ Returns:
+ projected std that satisfies the entropy bound
+ """
+ mean, std = p
+ k = std.shape[-1]
+ batch_shape = std.shape[:-2]
+
+ ent = policy.entropy(p)
+ mask = ent < beta
+
+ # if nothing has to be projected skip computation
+ if (~mask).all():
+ return p
+
+ alpha = ch.ones(batch_shape, dtype=std.dtype, device=std.device)
+ alpha[mask] = ch.exp((beta[mask] - ent[mask]) / k)
+
+ proj_std = ch.einsum('ijk,i->ijk', std, alpha)
+ return mean, ch.where(mask[..., None, None], proj_std, std)
+
+
+def entropy_equality_projection(policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ beta: Union[float, ch.Tensor]):
+ """
+ Projects std to satisfy an entropy EQUALITY constraint.
+ Args:
+ policy: policy instance
+ p: current distribution
+ beta: target entropy for EACH std or general bound for all stds
+
+ Returns:
+ projected std that satisfies the entropy bound
+ """
+ mean, std = p
+ k = std.shape[-1]
+
+ ent = policy.entropy(p)
+ alpha = ch.exp((beta - ent) / k)
+ proj_std = ch.einsum('ijk,i->ijk', std, alpha)
+ return mean, proj_std
+
+
+def mean_projection(mean: ch.Tensor, old_mean: ch.Tensor, maha: ch.Tensor, eps: ch.Tensor):
+ """
+ Projects the mean based on the Mahalanobis objective and trust region.
+ Args:
+ mean: current mean vectors
+ old_mean: old mean vectors
+ maha: Mahalanobis distance between the two mean vectors
+ eps: trust region bound
+
+ Returns:
+ projected mean that satisfies the trust region
+ """
+ batch_shape = mean.shape[:-1]
+ mask = maha > eps
+
+ ################################################################################################################
+ # mean projection maha
+
+ # if nothing has to be projected skip computation
+ if mask.any():
+ omega = ch.ones(batch_shape, dtype=mean.dtype, device=mean.device)
+ omega[mask] = ch.sqrt(maha[mask] / eps) - 1.
+ omega = ch.max(-omega, omega)[..., None]
+
+ m = (mean + omega * old_mean) / (1 + omega + 1e-16)
+ proj_mean = ch.where(mask[..., None], m, mean)
+ else:
+ proj_mean = mean
+
+ return proj_mean
+
+
+class BaseProjectionLayer(object):
+
+ def __init__(self,
+ proj_type: str = "",
+ mean_bound: float = 0.03,
+ cov_bound: float = 1e-3,
+ trust_region_coeff: float = 0.0,
+ scale_prec: bool = True,
+
+ entropy_schedule: Union[None, str] = None,
+ action_dim: Union[None, int] = None,
+ total_train_steps: Union[None, int] = None,
+ target_entropy: float = 0.0,
+ temperature: float = 0.5,
+ entropy_eq: bool = False,
+ entropy_first: bool = False,
+
+ do_regression: bool = False,
+ regression_iters: int = 1000,
+ regression_lr: int = 3e-4,
+ optimizer_type_reg: str = "adam",
+
+ cpu: bool = True,
+ dtype: ch.dtype = ch.float32,
+ ):
+
+ """
+ Base projection layer, which can be used to compute metrics for non-projection approaches.
+ Args:
+ proj_type: Which type of projection to use. None specifies no projection and uses the TRPO objective.
+ mean_bound: projection bound for the step size w.r.t. mean
+ cov_bound: projection bound for the step size w.r.t. covariance matrix
+ trust_region_coeff: Coefficient for projection regularization loss term.
+ scale_prec: If true used mahalanobis distance for projections instead of euclidean with Sigma_old^-1.
+ entropy_schedule: Schedule type for entropy projection, one of 'linear', 'exp', None.
+ action_dim: number of action dimensions to scale exp decay correctly.
+ total_train_steps: total number of training steps to compute appropriate decay over time.
+ target_entropy: projection bound for the entropy of the covariance matrix
+ temperature: temperature decay for exponential entropy bound
+ entropy_eq: Use entropy equality constraints.
+ entropy_first: Project entropy before trust region.
+ do_regression: Conduct additional regression steps after the the policy steps to match projection and policy.
+ regression_iters: Number of regression steps.
+ regression_lr: Regression learning rate.
+ optimizer_type_reg: Optimizer for regression.
+ cpu: Compute on CPU only.
+ dtype: Data type to use, either of float32 or float64. The later might be necessary for higher
+ dimensions in order to learn the full covariance.
+ """
+
+ # projection and bounds
+ self.proj_type = proj_type
+ self.mean_bound = tensorize(mean_bound, cpu=cpu, dtype=dtype)
+ self.cov_bound = tensorize(cov_bound, cpu=cpu, dtype=dtype)
+ self.trust_region_coeff = trust_region_coeff
+ self.scale_prec = scale_prec
+
+ # projection utils
+ assert (action_dim and total_train_steps) if entropy_schedule else True
+ self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
+ self.entropy_schedule = get_entropy_schedule(entropy_schedule, total_train_steps, dim=action_dim)
+ self.target_entropy = tensorize(target_entropy, cpu=cpu, dtype=dtype)
+ self.entropy_first = entropy_first
+ self.entropy_eq = entropy_eq
+ self.temperature = temperature
+ self._initial_entropy = None
+
+ # regression
+ self.do_regression = do_regression
+ self.regression_iters = regression_iters
+ self.lr_reg = regression_lr
+ self.optimizer_type_reg = optimizer_type_reg
+
+ def __call__(self, policy, p: Tuple[ch.Tensor, ch.Tensor], q, step, *args, **kwargs):
+ # entropy_bound = self.policy.entropy(q) - self.target_entropy
+ entropy_bound = self.entropy_schedule(self.initial_entropy, self.target_entropy, self.temperature,
+ step) * p[0].new_ones(p[0].shape[0])
+ return self._projection(policy, p, q, self.mean_bound, self.cov_bound, entropy_bound, **kwargs)
+
+ def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
+ """
+ Hook for implementing the specific trust region projection
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: mean trust region bound
+ eps_cov: covariance trust region bound
+ **kwargs:
+
+ Returns:
+ projected
+ """
+ return p
+
+ # @final
+ def _projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, beta: ch.Tensor, **kwargs):
+ """
+ Template method with hook _trust_region_projection() to encode specific functionality.
+ (Optional) entropy projection is executed before or after as specified by entropy_first.
+ Do not override this. For Python >= 3.8 you can use the @final decorator to enforce not overwriting.
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: mean trust region bound
+ eps_cov: covariance trust region bound
+ beta: entropy bound
+ **kwargs:
+
+ Returns:
+ projected mean, projected std
+ """
+
+ ####################################################################################################################
+ # entropy projection in the beginning
+ if self.entropy_first:
+ p = self.entropy_proj(policy, p, beta)
+
+ ####################################################################################################################
+ # trust region projection for mean and cov bounds
+ proj_mean, proj_std = self._trust_region_projection(policy, p, q, eps, eps_cov, **kwargs)
+
+ ####################################################################################################################
+ # entropy projection in the end
+ if self.entropy_first:
+ return proj_mean, proj_std
+
+ return self.entropy_proj(policy, (proj_mean, proj_std), beta)
+
+ @property
+ def initial_entropy(self):
+ return self._initial_entropy
+
+ @initial_entropy.setter
+ def initial_entropy(self, entropy):
+ if self.initial_entropy is None:
+ self._initial_entropy = entropy
+
+ def trust_region_value(self, policy, p, q):
+ """
+ Computes the KL divergence between two Gaussian distributions p and q.
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ Returns:
+ Mean and covariance part of the trust region metric.
+ """
+ return gaussian_kl(policy, p, q)
+
+ def get_trust_region_loss(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ proj_p: Tuple[ch.Tensor, ch.Tensor]):
+ """
+ Compute the trust region loss to ensure policy output and projection stay close.
+ Args:
+ policy: policy instance
+ proj_p: projected distribution
+ p: predicted distribution from network output
+
+ Returns:
+ trust region loss
+ """
+ p_target = (proj_p[0].detach(), proj_p[1].detach())
+ mean_diff, cov_diff = self.trust_region_value(policy, p, p_target)
+
+ delta_loss = (mean_diff + cov_diff if policy.contextual_std else mean_diff).mean()
+
+ return delta_loss * self.trust_region_coeff
+
+ def compute_metrics(self, policy, p, q) -> dict:
+ """
+ Returns dict with constraint metrics.
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+
+ Returns:
+ dict with constraint metrics
+ """
+ with ch.no_grad():
+ entropy_old = policy.entropy(q)
+ entropy = policy.entropy(p)
+ mean_kl, cov_kl = gaussian_kl(policy, p, q)
+ kl = mean_kl + cov_kl
+
+ mean_diff, cov_diff = self.trust_region_value(policy, p, q)
+
+ combined_constraint = mean_diff + cov_diff
+ entropy_diff = entropy_old - entropy
+
+ return {'kl': kl.detach().mean(),
+ 'constraint': combined_constraint.mean(),
+ 'mean_constraint': mean_diff.mean(),
+ 'cov_constraint': cov_diff.mean(),
+ 'entropy': entropy.mean(),
+ 'entropy_diff': entropy_diff.mean(),
+ 'kl_max': kl.max(),
+ 'constraint_max': combined_constraint.max(),
+ 'mean_constraint_max': mean_diff.max(),
+ 'cov_constraint_max': cov_diff.max(),
+ 'entropy_max': entropy.max(),
+ 'entropy_diff_max': entropy_diff.max()
+ }
+
+ def trust_region_regression(self, policy: AbstractGaussianPolicy, obs: ch.Tensor, q: Tuple[ch.Tensor, ch.Tensor],
+ n_minibatches: int, global_steps: int):
+ """
+ Take additional regression steps to match projection output and policy output.
+ The policy parameters are updated in-place.
+ Args:
+ policy: policy instance
+ obs: collected observations from trajectories
+ q: old distributions
+ n_minibatches: split the rollouts into n_minibatches.
+ global_steps: current number of steps, required for projection
+ Returns:
+ dict with mean of regession loss
+ """
+
+ if not self.do_regression:
+ return {}
+
+ policy_unprojected = copy.deepcopy(policy)
+ optim_reg = get_optimizer(self.optimizer_type_reg, policy_unprojected.parameters(), learning_rate=self.lr_reg)
+ optim_reg.reset()
+
+ reg_losses = obs.new_tensor(0.)
+
+ # get current projected values --> targets for regression
+ p_flat = policy(obs)
+ p_target = self(policy, p_flat, q, global_steps)
+
+ for _ in range(self.regression_iters):
+ batch_indices = generate_minibatches(obs.shape[0], n_minibatches)
+
+ # Minibatches SGD
+ for indices in batch_indices:
+ batch = select_batch(indices, obs, p_target[0], p_target[1])
+ b_obs, b_target_mean, b_target_std = batch
+ proj_p = (b_target_mean.detach(), b_target_std.detach())
+
+ p = policy_unprojected(b_obs)
+
+ # invert scaling with coeff here as we do not have to balance with other losses
+ loss = self.get_trust_region_loss(policy, p, proj_p) / self.trust_region_coeff
+
+ optim_reg.zero_grad()
+ loss.backward()
+ optim_reg.step()
+ reg_losses += loss.detach()
+
+ policy.load_state_dict(policy_unprojected.state_dict())
+
+ if not policy.contextual_std:
+ # set policy with projection value.
+ # In non-contextual cases we have only one cov, so the projection is the same.
+ policy.set_std(p_target[1][0])
+
+ steps = self.regression_iters * (math.ceil(obs.shape[0] / n_minibatches))
+ return {"regression_loss": (reg_losses / steps).detach()}
diff --git a/metastable_baselines/projections_orig/frob_projection_layer.py b/metastable_baselines/projections_orig/frob_projection_layer.py
new file mode 100644
index 0000000..8d338ce
--- /dev/null
+++ b/metastable_baselines/projections_orig/frob_projection_layer.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import torch as ch
+from typing import Tuple
+
+from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
+from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
+from trust_region_projections.utils.projection_utils import gaussian_frobenius
+
+
+class FrobeniusProjectionLayer(BaseProjectionLayer):
+
+ def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
+ """
+ Runs Frobenius projection layer and constructs cholesky of covariance
+
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: (modified) kl bound/ kl bound for mean part
+ eps_cov: (modified) kl bound for cov part
+ beta: (modified) entropy bound
+ **kwargs:
+ Returns: mean, cov cholesky
+ """
+
+ mean, chol = p
+ old_mean, old_chol = q
+ batch_shape = mean.shape[:-1]
+
+ ####################################################################################################################
+ # precompute mean and cov part of frob projection, which are used for the projection.
+ mean_part, cov_part, cov, cov_old = gaussian_frobenius(policy, p, q, self.scale_prec, True)
+
+ ################################################################################################################
+ # mean projection maha/euclidean
+
+ proj_mean = mean_projection(mean, old_mean, mean_part, eps)
+
+ ################################################################################################################
+ # cov projection frobenius
+
+ cov_mask = cov_part > eps_cov
+
+ if cov_mask.any():
+ # alpha = ch.where(fro_norm_sq > eps_cov, ch.sqrt(fro_norm_sq / eps_cov) - 1., ch.tensor(1.))
+ eta = ch.ones(batch_shape, dtype=chol.dtype, device=chol.device)
+ eta[cov_mask] = ch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
+ eta = ch.max(-eta, eta)
+
+ new_cov = (cov + ch.einsum('i,ijk->ijk', eta, cov_old)) / (1. + eta + 1e-16)[..., None, None]
+ proj_chol = ch.where(cov_mask[..., None, None], ch.cholesky(new_cov), chol)
+ else:
+ proj_chol = chol
+
+ return proj_mean, proj_chol
+
+ def trust_region_value(self, policy, p, q):
+ """
+ Computes the Frobenius metric between two Gaussian distributions p and q.
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ Returns:
+ mean and covariance part of Frobenius metric
+ """
+ return gaussian_frobenius(policy, p, q, self.scale_prec)
+
+ def get_trust_region_loss(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ proj_p: Tuple[ch.Tensor, ch.Tensor]):
+
+ mean_diff, _ = self.trust_region_value(policy, p, proj_p)
+ if policy.contextual_std:
+ # Compute MSE here, because we found the Frobenius norm tends to generate values that explode for the cov
+ cov_diff = (p[1] - proj_p[1]).pow(2).sum([-1, -2])
+ delta_loss = (mean_diff + cov_diff).mean()
+ else:
+ delta_loss = mean_diff.mean()
+
+ return delta_loss * self.trust_region_coeff
diff --git a/metastable_baselines/projections_orig/kl_projection_layer.py b/metastable_baselines/projections_orig/kl_projection_layer.py
new file mode 100644
index 0000000..ca5acd5
--- /dev/null
+++ b/metastable_baselines/projections_orig/kl_projection_layer.py
@@ -0,0 +1,101 @@
+import cpp_projection
+import numpy as np
+import torch as ch
+from typing import Any, Tuple
+
+from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
+from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
+from trust_region_projections.utils.projection_utils import gaussian_kl
+from trust_region_projections.utils.torch_utils import get_numpy
+
+
+class KLProjectionLayer(BaseProjectionLayer):
+
+ def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
+ """
+ Runs KL projection layer and constructs cholesky of covariance
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: (modified) kl bound/ kl bound for mean part
+ eps_cov: (modified) kl bound for cov part
+ **kwargs:
+
+ Returns:
+ projected mean, projected cov cholesky
+ """
+ mean, std = p
+ old_mean, old_std = q
+
+ if not policy.contextual_std:
+ # only project first one to reduce number of numerical optimizations
+ std = std[:1]
+ old_std = old_std[:1]
+
+ ################################################################################################################
+ # project mean with closed form
+ mean_part, _ = gaussian_kl(policy, p, q)
+ proj_mean = mean_projection(mean, old_mean, mean_part, eps)
+
+ cov = policy.covariance(std)
+ old_cov = policy.covariance(old_std)
+
+ if policy.is_diag:
+ proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
+ old_cov.diagonal(dim1=-2, dim2=-1),
+ eps_cov)
+ proj_std = proj_cov.sqrt().diag_embed()
+ else:
+ raise NotImplementedError("The KL projection currently does not support full covariance matrices.")
+
+ if not policy.contextual_std:
+ # scale first std back to batchsize
+ proj_std = proj_std.expand(mean.shape[0], -1, -1)
+
+ return proj_mean, proj_std
+
+
+class KLProjectionGradFunctionDiagCovOnly(ch.autograd.Function):
+ projection_op = None
+
+ @staticmethod
+ def get_projection_op(batch_shape, dim, max_eval=100):
+ if not KLProjectionGradFunctionDiagCovOnly.projection_op:
+ KLProjectionGradFunctionDiagCovOnly.projection_op = \
+ cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
+ return KLProjectionGradFunctionDiagCovOnly.projection_op
+
+ @staticmethod
+ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+ std, old_std, eps_cov = args
+
+ batch_shape = std.shape[0]
+ dim = std.shape[-1]
+
+ cov_np = get_numpy(std)
+ old_std = get_numpy(old_std)
+ eps = get_numpy(eps_cov) * np.ones(batch_shape)
+
+ # p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim)
+ # ctx.proj = projection_op
+
+ p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
+ ctx.proj = p_op
+
+ proj_std = p_op.forward(eps, old_std, cov_np)
+
+ return std.new(proj_std)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
+ projection_op = ctx.proj
+ d_std, = grad_outputs
+
+ d_std_np = get_numpy(d_std)
+ d_std_np = np.atleast_2d(d_std_np)
+ df_stds = projection_op.backward(d_std_np)
+ df_stds = np.atleast_2d(df_stds)
+
+ return d_std.new(df_stds), None, None
diff --git a/metastable_baselines/projections_orig/papi_projection.py b/metastable_baselines/projections_orig/papi_projection.py
new file mode 100644
index 0000000..b52db75
--- /dev/null
+++ b/metastable_baselines/projections_orig/papi_projection.py
@@ -0,0 +1,233 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import logging
+
+import copy
+import numpy as np
+import torch as ch
+from typing import Tuple, Union
+
+from trust_region_projections.utils.projection_utils import gaussian_kl
+from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
+from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer
+from trust_region_projections.utils.torch_utils import torch_batched_trace
+
+logger = logging.getLogger("papi_projection")
+
+
+class PAPIProjection(BaseProjectionLayer):
+
+ def __init__(self,
+ proj_type: str = "",
+ mean_bound: float = 0.015,
+ cov_bound: float = 0.0,
+
+ entropy_eq: bool = False,
+ entropy_first: bool = True,
+
+ cpu: bool = True,
+ dtype: ch.dtype = ch.float32,
+ **kwargs
+ ):
+
+ """
+ PAPI projection, which can be used after each training epoch to satisfy the trust regions.
+ Args:
+ proj_type: Which type of projection to use. None specifies no projection and uses the TRPO objective.
+ mean_bound: projection bound for the step size,
+ PAPI only has a joint KL constraint, mean and cov bound are summed for this bound.
+ cov_bound: projection bound for the step size,
+ PAPI only has a joint KL constraint, mean and cov bound are summed for this bound.
+ entropy_eq: Use entropy equality constraints.
+ entropy_first: Project entropy before trust region.
+ cpu: Compute on CPU only.
+ dtype: Data type to use, either of float32 or float64. The later might be necessary for higher
+ dimensions in order to learn the full covariance.
+ """
+
+ assert entropy_first
+ super().__init__(proj_type, mean_bound, cov_bound, 0.0, False, None, None, None, 0.0, 0.0, entropy_eq,
+ entropy_first, cpu, dtype)
+
+ self.last_policies = []
+
+ def __call__(self, policy, p, q, step=0, *args, **kwargs):
+ if kwargs.get("obs"):
+ self._papi_steps(policy, q, **kwargs)
+ else:
+ return p
+
+ def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: Union[ch.Tensor, float],
+ eps_cov: Union[ch.Tensor, float], **kwargs):
+ """
+ runs papi projection layer and constructs sqrt of covariance
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: (modified) kl bound/ kl bound for mean part
+ eps_cov: (modified) kl bound for cov part
+ **kwargs:
+
+ Returns:
+ mean, cov sqrt
+ """
+
+ mean, chol = p
+ old_mean, old_chol = q
+ intermed_mean = kwargs.get('intermed_mean')
+
+ dtype = mean.dtype
+ device = mean.device
+
+ dim = mean.shape[-1]
+
+ ################################################################################################################
+ # Precompute basic matrices
+
+ # Joint bound
+ eps += eps_cov
+
+ I = ch.eye(dim, dtype=dtype, device=device)
+ old_precision = ch.cholesky_solve(I, old_chol)[0]
+ logdet_old = policy.log_determinant(old_chol)
+ cov = policy.covariance(chol)
+
+ ################################################################################################################
+ # compute expected KL
+ maha_part, cov_part = gaussian_kl(policy, p, q)
+ maha_part = maha_part.mean()
+ cov_part = cov_part.mean()
+
+ if intermed_mean is not None:
+ maha_intermediate = 0.5 * policy.maha(intermed_mean, old_mean, old_chol).mean()
+ mm = ch.min(maha_part, maha_intermediate)
+
+ ################################################################################################################
+ # matrix rotation/rescaling projection
+ if maha_part + cov_part > eps + 1e-6:
+ old_cov = policy.covariance(old_chol)
+
+ maha_delta = eps if intermed_mean is None else (eps - mm)
+ eta_rot = maha_delta / ch.max(maha_part + cov_part, ch.tensor(1e-16, dtype=dtype, device=device))
+ new_cov = (1 - eta_rot) * old_cov + eta_rot * cov
+ proj_chol = ch.cholesky(new_cov)
+
+ # recompute covariance part of KL for new chol
+ trace_term = 0.5 * (torch_batched_trace(old_precision @ new_cov) - dim).mean() # rotation difference
+ entropy_diff = 0.5 * (logdet_old - policy.log_determinant(proj_chol)).mean()
+
+ cov_part = trace_term + entropy_diff
+
+ else:
+ proj_chol = chol
+
+ ################################################################################################################
+ # mean interpolation projection
+ if maha_part + cov_part > eps + 1e-6:
+
+ if intermed_mean is not None:
+ a = 0.5 * policy.maha(mean, intermed_mean, old_chol).mean()
+ b = 0.5 * ((mean - intermed_mean) @ old_precision @ (intermed_mean - old_mean).T).mean()
+ c = maha_intermediate - ch.max(eps - cov_part, ch.tensor(0., dtype=dtype, device=device))
+ eta_mean = (-b + ch.sqrt(ch.max(b * b - a * c, ch.tensor(1e-16, dtype=dtype, device=device)))) / \
+ ch.max(a, ch.tensor(1e-16, dtype=dtype, device=device))
+ else:
+ eta_mean = ch.sqrt(
+ ch.max(eps - cov_part, ch.tensor(1e-16, dtype=dtype, device=device)) /
+ ch.max(maha_part, ch.tensor(1e-16, dtype=dtype, device=device)))
+ else:
+ eta_mean = ch.tensor(1., dtype=dtype, device=device)
+
+ return eta_mean, proj_chol
+
+ def _papi_steps(self, policy: AbstractGaussianPolicy, q: Tuple[ch.Tensor, ch.Tensor], obs: ch.Tensor, lr_schedule,
+ lr_schedule_vf=None):
+ """
+ Take PAPI steps after PPO finished its steps. Policy parameters are updated in-place.
+ Args:
+ policy: policy instance
+ q: old distribution
+ obs: collected observations from trajectories
+ lr_schedule: lr schedule for policy
+ lr_schedule_vf: lr schedule for vf
+
+ Returns:
+
+ """
+ assert not policy.contextual_std
+
+ # save latest policy in history
+ self.last_policies.append(copy.deepcopy(policy))
+
+ ################################################################################################################
+ # policy backtracking: out of last n policies and current one find one that satisfies the kl constraint
+
+ intermed_policy = None
+ n_backtracks = 0
+
+ for i, pi in enumerate(reversed(self.last_policies)):
+ p_prime = pi(obs)
+ mean_part, cov_part = pi.kl_divergence(p_prime, q)
+ if (mean_part + cov_part).mean() <= self.mean_bound + self.cov_bound:
+ intermed_policy = pi
+ n_backtracks = i
+ break
+
+ ################################################################################################################
+ # LR update
+
+ # reduce learning rate when appropriate policy not within the last 4 epochs
+ if n_backtracks >= 4 or intermed_policy is None:
+ # Linear learning rate annealing
+ lr_schedule.step()
+ if lr_schedule_vf:
+ lr_schedule_vf.step()
+
+ if intermed_policy is None:
+ # pop last policy and make it current one, as the updated one was poor
+ # do not keep last policy in history, otherwise we could stack the same policy multiple times.
+ if len(self.last_policies) >= 1:
+ policy.load_state_dict(self.last_policies.pop().state_dict())
+ logger.warning(f"No suitable policy found in backtracking of {len(self.last_policies)} policies.")
+ return
+
+ ################################################################################################################
+ # PAPI iterations
+
+ # We assume only non contextual covariances here, therefore we only need to project for one
+ q = (q[0], q[1][:1]) # (means, covs[:1])
+
+ # This is A from Alg. 2 [Akrour et al., 2019]
+ intermed_weight = intermed_policy.get_last_layer().detach().clone()
+ # This is A @ phi(s)
+ intermed_mean = p_prime[0].detach().clone()
+
+ entropy = policy.entropy(q)
+ entropy_bound = obs.new_tensor([-np.inf]) if entropy / self.initial_entropy > 0.5 \
+ else entropy - (self.mean_bound + self.cov_bound)
+
+ for _ in range(20):
+ eta, proj_chol = self._projection(intermed_policy, (p_prime[0], p_prime[1][:1]), q,
+ self.mean_bound, self.cov_bound, entropy_bound,
+ intermed_mean=intermed_mean)
+ intermed_policy.papi_weight_update(eta, intermed_weight)
+ intermed_policy.set_std(proj_chol[0])
+ p_prime = intermed_policy(obs)
+
+ policy.load_state_dict(intermed_policy.state_dict())
diff --git a/metastable_baselines/projections_orig/projection_factory.py b/metastable_baselines/projections_orig/projection_factory.py
new file mode 100644
index 0000000..9c38275
--- /dev/null
+++ b/metastable_baselines/projections_orig/projection_factory.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer
+from trust_region_projections.projections.frob_projection_layer import FrobeniusProjectionLayer
+from trust_region_projections.projections.kl_projection_layer import KLProjectionLayer
+from trust_region_projections.projections.papi_projection import PAPIProjection
+from trust_region_projections.projections.w2_projection_layer import WassersteinProjectionLayer
+
+
+def get_projection_layer(proj_type: str = "", **kwargs) -> BaseProjectionLayer:
+ """
+ Factory to generate the projection layers for all projections.
+ Args:
+ proj_type: One of None/' ', 'ppo', 'papi', 'w2', 'w2_non_com', 'frob', 'kl', or 'entropy'
+ **kwargs: arguments for projection layer
+
+ Returns:
+
+ """
+ if not proj_type or proj_type.isspace() or proj_type.lower() in ["ppo", "sac", "td3", "mpo", "entropy"]:
+ return BaseProjectionLayer(proj_type, **kwargs)
+
+ elif proj_type.lower() == "w2":
+ return WassersteinProjectionLayer(proj_type, **kwargs)
+
+ elif proj_type.lower() == "frob":
+ return FrobeniusProjectionLayer(proj_type, **kwargs)
+
+ elif proj_type.lower() == "kl":
+ return KLProjectionLayer(proj_type, **kwargs)
+
+ elif proj_type.lower() == "papi":
+ # papi has a different approach compared to our projections.
+ # It has to be applied after the training with PPO.
+ return PAPIProjection(proj_type, **kwargs)
+
+ else:
+ raise ValueError(
+ f"Invalid projection type {proj_type}."
+ f" Choose one of None/' ', 'ppo', 'papi', 'w2', 'w2_non_com', 'frob', 'kl', or 'entropy'.")
diff --git a/metastable_baselines/projections_orig/w2_projection_layer.py b/metastable_baselines/projections_orig/w2_projection_layer.py
new file mode 100644
index 0000000..bce87a3
--- /dev/null
+++ b/metastable_baselines/projections_orig/w2_projection_layer.py
@@ -0,0 +1,84 @@
+# Copyright (c) 2021 Robert Bosch GmbH
+# Author: Fabian Otto
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import torch as ch
+from typing import Tuple
+
+from trust_region_projections.models.policy.abstract_gaussian_policy import AbstractGaussianPolicy
+from trust_region_projections.projections.base_projection_layer import BaseProjectionLayer, mean_projection
+from trust_region_projections.utils.projection_utils import gaussian_wasserstein_commutative
+
+
+class WassersteinProjectionLayer(BaseProjectionLayer):
+
+ def _trust_region_projection(self, policy: AbstractGaussianPolicy, p: Tuple[ch.Tensor, ch.Tensor],
+ q: Tuple[ch.Tensor, ch.Tensor], eps: ch.Tensor, eps_cov: ch.Tensor, **kwargs):
+ """
+ Runs commutative Wasserstein projection layer and constructs sqrt of covariance
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ eps: (modified) kl bound/ kl bound for mean part
+ eps_cov: (modified) kl bound for cov part
+ **kwargs:
+
+ Returns:
+ mean, cov sqrt
+ """
+ mean, sqrt = p
+ old_mean, old_sqrt = q
+ batch_shape = mean.shape[:-1]
+
+ ####################################################################################################################
+ # precompute mean and cov part of W2, which are used for the projection.
+ # Both parts differ based on precision scaling.
+ # If activated, the mean part is the maha distance and the cov has a more complex term in the inner parenthesis.
+ mean_part, cov_part = gaussian_wasserstein_commutative(policy, p, q, self.scale_prec)
+
+ ####################################################################################################################
+ # project mean (w/ or w/o precision scaling)
+ proj_mean = mean_projection(mean, old_mean, mean_part, eps)
+
+ ####################################################################################################################
+ # project covariance (w/ or w/o precision scaling)
+
+ cov_mask = cov_part > eps_cov
+
+ if cov_mask.any():
+ # gradient issue with ch.where, it executes both paths and gives NaN gradient.
+ eta = ch.ones(batch_shape, dtype=sqrt.dtype, device=sqrt.device)
+ eta[cov_mask] = ch.sqrt(cov_part[cov_mask] / eps_cov) - 1.
+ eta = ch.max(-eta, eta)
+
+ new_sqrt = (sqrt + ch.einsum('i,ijk->ijk', eta, old_sqrt)) / (1. + eta + 1e-16)[..., None, None]
+ proj_sqrt = ch.where(cov_mask[..., None, None], new_sqrt, sqrt)
+ else:
+ proj_sqrt = sqrt
+
+ return proj_mean, proj_sqrt
+
+ def trust_region_value(self, policy, p, q):
+ """
+ Computes the Wasserstein distance between two Gaussian distributions p and q.
+ Args:
+ policy: policy instance
+ p: current distribution
+ q: old distribution
+ Returns:
+ mean and covariance part of Wasserstein distance
+ """
+ return gaussian_wasserstein_commutative(policy, p, q, scale_prec=self.scale_prec)
\ No newline at end of file
diff --git a/replay.py b/replay.py
index 1c1f484..0a4a110 100755
--- a/replay.py
+++ b/replay.py
@@ -10,7 +10,7 @@ from stable_baselines3 import SAC, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
-from sb3_trl.trl_pg import TRL_PG
+from metastable_baselines.trl_pg import TRL_PG
import columbus
diff --git a/sb3_trl/trl_pg/__init__.py b/sb3_trl/trl_pg/__init__.py
deleted file mode 100644
index 66cf37e..0000000
--- a/sb3_trl/trl_pg/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from sb3_trl.trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
-from sb3_trl.trl_pg.trl_pg import TRL_PG
diff --git a/sb3_trl/trl_pg/policies.py b/sb3_trl/trl_pg/policies.py
deleted file mode 100644
index 8b784eb..0000000
--- a/sb3_trl/trl_pg/policies.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# This file is here just to define MlpPolicy/CnnPolicy
-# that work for TRL_PG
-from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
-
-MlpPolicy = ActorCriticPolicy
-CnnPolicy = ActorCriticCnnPolicy
-MultiInputPolicy = MultiInputActorCriticPolicy
diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py
deleted file mode 100644
index 6209877..0000000
--- a/sb3_trl/trl_pg/trl_pg.py
+++ /dev/null
@@ -1,520 +0,0 @@
-import warnings
-from typing import Any, Dict, Optional, Type, Union, NamedTuple
-
-import numpy as np
-import torch as th
-from gym import spaces
-from torch.nn import functional as F
-
-from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
-from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
-from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
-from stable_baselines3.common.utils import explained_variance, get_schedule_fn
-from stable_baselines3.common.vec_env import VecEnv
-from stable_baselines3.common.buffers import RolloutBuffer
-from stable_baselines3.common.callbacks import BaseCallback
-from stable_baselines3.common.utils import obs_as_tensor
-from stable_baselines3.common.vec_env import VecNormalize
-
-from ..projections.base_projection_layer import BaseProjectionLayer
-from ..projections.frob_projection_layer import FrobeniusProjectionLayer
-from ..projections.w2_projection_layer import WassersteinProjectionLayer
-
-from ..misc.rollout_buffer import GaussianRolloutBuffer, GaussianRolloutBufferSamples
-
-
-class TRL_PG(OnPolicyAlgorithm):
- """
- Differential Trust Region Layer (TRL) for Policy Gradient (PG)
-
- Paper: https://arxiv.org/abs/2101.09207
- Code: This implementation borrows (/steals most) code from SB3's PPO implementation https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
- The implementation of the TRL-specific parts borrows from https://github.com/boschresearch/trust-region-layers/blob/main/trust_region_projections/algorithms/pg/pg.py
-
- :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
- :param env: The environment to learn from (if registered in Gym, can be str)
- :param learning_rate: The learning rate, it can be a function
- of the current progress remaining (from 1 to 0)
- :param n_steps: The number of steps to run for each environment per update
- (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
- NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
- See https://github.com/pytorch/pytorch/issues/29372
- :param batch_size: Minibatch size
- :param n_epochs: Number of epoch when optimizing the surrogate loss
- :param gamma: Discount factor
- :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
- :param clip_range: Clipping parameter, it can be a function of the current progress
- remaining (from 1 to 0).
- :param clip_range_vf: Clipping parameter for the value function,
- it can be a function of the current progress remaining (from 1 to 0).
- This is a parameter specific to the OpenAI implementation. If None is passed (default),
- no clipping will be done on the value function.
- IMPORTANT: this clipping depends on the reward scaling.
- :param normalize_advantage: Whether to normalize or not the advantage
- :param ent_coef: Entropy coefficient for the loss calculation
- :param vf_coef: Value function coefficient for the loss calculation
- :param max_grad_norm: The maximum value for the gradient clipping
- :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
- instead of action noise exploration (default: False)
- :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
- Default: -1 (only sample at the beginning of the rollout)
- :param target_kl: Limit the KL divergence between updates,
- because the clipping is not enough to prevent large update
- # 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
- see issue
- By default, there is no limit on the kl div.
- :param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
- :param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
- :param seed: Seed for the pseudo random generators
- :param device: Device (cpu, cuda, ...) on which the code should be run.
- Setting it to auto, the code will be run on the GPU if possible.
- :param projection: What kind of Projection to use
- :param _init_setup_model: Whether or not to build the network at the creation of the instance
- """
-
- policy_aliases: Dict[str, Type[BasePolicy]] = {
- "MlpPolicy": ActorCriticPolicy,
- "CnnPolicy": ActorCriticCnnPolicy,
- "MultiInputPolicy": MultiInputActorCriticPolicy,
- }
-
- def __init__(
- self,
- policy: Union[str, Type[ActorCriticPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Schedule] = 3e-4,
- n_steps: int = 2048,
- batch_size: int = 64,
- n_epochs: int = 10,
- gamma: float = 0.99,
- gae_lambda: float = 0.95,
- clip_range: Union[float, Schedule] = 0.2,
- clip_range_vf: Union[None, float, Schedule] = None,
- normalize_advantage: bool = True,
- ent_coef: float = 0.0,
- vf_coef: float = 0.5,
- max_grad_norm: float = 0.5,
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- target_kl: Optional[float] = None,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = "auto",
-
- # Different from PPO:
- projection: BaseProjectionLayer = WassersteinProjectionLayer(),
- #projection: BaseProjectionLayer = FrobeniusProjectionLayer(),
- #projection: BaseProjectionLayer = BaseProjectionLayer(),
-
- _init_setup_model: bool = True,
- ):
-
- super().__init__(
- policy,
- env,
- learning_rate=learning_rate,
- n_steps=n_steps,
- gamma=gamma,
- gae_lambda=gae_lambda,
- ent_coef=ent_coef,
- vf_coef=vf_coef,
- max_grad_norm=max_grad_norm,
- use_sde=use_sde,
- sde_sample_freq=sde_sample_freq,
- tensorboard_log=tensorboard_log,
- policy_kwargs=policy_kwargs,
- verbose=verbose,
- device=device,
- create_eval_env=create_eval_env,
- seed=seed,
- _init_setup_model=False,
- supported_action_spaces=(
- spaces.Box,
- # spaces.Discrete,
- # spaces.MultiDiscrete,
- # spaces.MultiBinary,
- ),
- )
-
- # Sanity check, otherwise it will lead to noisy gradient and NaN
- # because of the advantage normalization
- if normalize_advantage:
- assert (
- batch_size > 1
- ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
-
- if self.env is not None:
- # Check that `n_steps * n_envs > 1` to avoid NaN
- # when doing advantage normalization
- buffer_size = self.env.num_envs * self.n_steps
- assert (
- buffer_size > 1
- ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
- # Check that the rollout buffer size is a multiple of the mini-batch size
- untruncated_batches = buffer_size // batch_size
- if buffer_size % batch_size > 0:
- warnings.warn(
- f"You have specified a mini-batch size of {batch_size},"
- f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
- f" after every {untruncated_batches} untruncated mini-batches,"
- f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
- f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
- f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
- )
- self.batch_size = batch_size
- self.n_epochs = n_epochs
- self.clip_range = clip_range
- self.clip_range_vf = clip_range_vf
- self.normalize_advantage = normalize_advantage
- self.target_kl = target_kl
-
- # Different from PPO:
- self.projection = projection
- self._global_steps = 0
-
- if _init_setup_model:
- self._setup_model()
-
- def _setup_model(self) -> None:
- super()._setup_model()
-
- # Initialize schedules for policy/value clipping
- self.clip_range = get_schedule_fn(self.clip_range)
- if self.clip_range_vf is not None:
- if isinstance(self.clip_range_vf, (float, int)):
- assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
-
- self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
-
- # Changed from PPO: We need a bigger RolloutBuffer
- self.rollout_buffer = GaussianRolloutBuffer(
- self.n_steps,
- self.observation_space,
- self.action_space,
- device=self.device,
- gamma=self.gamma,
- gae_lambda=self.gae_lambda,
- n_envs=self.n_envs,
- )
-
- def train(self) -> None:
- """
- Update policy using the currently gathered rollout buffer.
- """
- # Switch to train mode (this affects batch norm / dropout)
- self.policy.set_training_mode(True)
- # Update optimizer learning rate
- self._update_learning_rate(self.policy.optimizer)
- # Compute current clip range
- clip_range = self.clip_range(self._current_progress_remaining)
- # Optional: clip range for the value function
- if self.clip_range_vf is not None:
- clip_range_vf = self.clip_range_vf(
- self._current_progress_remaining)
-
- surrogate_losses = []
- entropy_losses = []
- trust_region_losses = []
- pg_losses, value_losses = [], []
- clip_fractions = []
-
- continue_training = True
-
- # train for n_epochs epochs
- for epoch in range(self.n_epochs):
- approx_kl_divs = []
- # Do a complete pass on the rollout buffer
- for rollout_data in self.rollout_buffer.get(self.batch_size):
- # This is new compared to PPO.
- # Calculating the TR-Projections we need to know the step number
- self._global_steps += 1
-
- actions = rollout_data.actions
- if isinstance(self.action_space, spaces.Discrete):
- # Convert discrete action from float to long
- actions = rollout_data.actions.long().flatten()
-
- # Re-sample the noise matrix because the log_std has changed
- if self.use_sde:
- self.policy.reset_noise(self.batch_size)
-
- # old code for PPO
- # values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
-
- # src in TRL reference code:
- # Stolen from Fabian's Code (Public Version):
- # p = self.policy(rollout_data.observations)
- # proj_p = self.projection(self.policy, p, b_q = (b_old_mean, b_old_std), self._global_step)
- # new_logpacs = self.policy.log_probability(proj_p, b_actions)
-
- # src of evaluate_actions:
- # pol = self.policy
- # features = pol.extract_features(rollout_data.observations)
- # latent_pi, latent_vf = pol.mlp_extractor(features)
- # distribution = pol._get_action_dist_from_latent(latent_pi)
- # log_prob = distribution.log_prob(actions)
- # values = pol.value_net(latent_vf)
- # return values, log_prob, distribution.entropy()
- # entropy = distribution.entropy()
-
- # here we go:
- pol = self.policy
- features = pol.extract_features(rollout_data.observations)
- latent_pi, latent_vf = pol.mlp_extractor(features)
- p = pol._get_action_dist_from_latent(latent_pi)
- p_dist = p.distribution
- # q_means = rollout_data.means
- # if len(rollout_data.stds.shape) == 1: # only diag
- # q_stds = th.diag(rollout_data.stds)
- # else:
- # q_stds = rollout_data.stds
- # q_dist = th.distributions.MultivariateNormal(
- # q_means, q_stds)
- q_dist = th.distributions.Normal(
- rollout_data.means, rollout_data.stds)
- proj_p = self.projection(p_dist, q_dist, self._global_steps)
- log_prob = proj_p.log_prob(actions).sum(dim=1)
- values = self.policy.value_net(latent_vf)
- entropy = proj_p.entropy()
-
- values = values.flatten()
- # Normalize advantage
- advantages = rollout_data.advantages
- if self.normalize_advantage:
- advantages = (advantages - advantages.mean()
- ) / (advantages.std() + 1e-8)
-
- # ratio between old and new policy, should be one at the first iteration
- ratio = th.exp(log_prob - rollout_data.old_log_prob)
-
- # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
- # clipped surrogate loss
- surrogate_loss_1 = advantages * ratio
- surrogate_loss_2 = advantages * \
- th.clamp(ratio, 1 - clip_range, 1 + clip_range)
- surrogate_loss = - \
- th.min(surrogate_loss_1, surrogate_loss_2).mean()
-
- surrogate_losses.append(surrogate_loss.item())
-
- clip_fraction = th.mean(
- (th.abs(ratio - 1) > clip_range).float()).item()
- clip_fractions.append(clip_fraction)
-
- if self.clip_range_vf is None:
- # No clipping
- values_pred = values
- else:
- # Clip the different between old and new value
- # NOTE: this depends on the reward scaling
- values_pred = rollout_data.old_values + th.clamp(
- values - rollout_data.old_values, -clip_range_vf, clip_range_vf
- )
- # Value loss using the TD(gae_lambda) target
- value_loss = F.mse_loss(rollout_data.returns, values_pred)
- value_losses.append(value_loss.item())
-
- # Entropy loss favor exploration
- if entropy is None:
- # Approximate entropy when no analytical form
- entropy_loss = -th.mean(-log_prob)
- else:
- entropy_loss = -th.mean(entropy)
-
- entropy_losses.append(entropy_loss.item())
-
- # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
- trust_region_loss = self.projection.get_trust_region_loss(
- p, proj_p)
-
- trust_region_losses.append(trust_region_loss.item())
-
- policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss
- pg_losses.append(policy_loss.item())
-
- loss = policy_loss + self.vf_coef * value_loss
-
- # Calculate approximate form of reverse KL Divergence for early stopping
- # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
- # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
- # and Schulman blog: http://joschu.net/blog/kl-approx.html
- with th.no_grad():
- log_ratio = log_prob - rollout_data.old_log_prob
- approx_kl_div = th.mean(
- (th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
- approx_kl_divs.append(approx_kl_div)
-
- if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
- continue_training = False
- if self.verbose >= 1:
- print(
- f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
- break
-
- # Optimization step
- self.policy.optimizer.zero_grad()
- loss.backward()
- # Clip grad norm
- th.nn.utils.clip_grad_norm_(
- self.policy.parameters(), self.max_grad_norm)
- self.policy.optimizer.step()
-
- if not continue_training:
- break
-
- self._n_updates += self.n_epochs
- explained_var = explained_variance(
- self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
-
- # Logs
- self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
- self.logger.record("train/entropy_loss", np.mean(entropy_losses))
- self.logger.record("train/trust_region_loss",
- np.mean(trust_region_losses))
- self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
- self.logger.record("train/value_loss", np.mean(value_losses))
- self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
- self.logger.record("train/clip_fraction", np.mean(clip_fractions))
- self.logger.record("train/loss", loss.item())
- self.logger.record("train/explained_variance", explained_var)
- if hasattr(self.policy, "log_std"):
- self.logger.record(
- "train/std", th.exp(self.policy.log_std).mean().item())
-
- self.logger.record("train/n_updates",
- self._n_updates, exclude="tensorboard")
- self.logger.record("train/clip_range", clip_range)
- if self.clip_range_vf is not None:
- self.logger.record("train/clip_range_vf", clip_range_vf)
-
- def learn(
- self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 1,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "TRL_PG",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True,
- ) -> "TRL_PG":
-
- return super().learn(
- total_timesteps=total_timesteps,
- callback=callback,
- log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps,
- )
-
- # This is new compared to PPO.
- # TRL requires us to also save the original mean and std in our rollouts
- def collect_rollouts(
- self,
- env: VecEnv,
- callback: BaseCallback,
- rollout_buffer: RolloutBuffer,
- n_rollout_steps: int,
- ) -> bool:
- """
- Collect experiences using the current policy and fill a ``RolloutBuffer``.
- The term rollout here refers to the model-free notion and should not
- be used with the concept of rollout used in model-based RL or planning.
- :param env: The training environment
- :param callback: Callback that will be called at each step
- (and at the beginning and end of the rollout)
- :param rollout_buffer: Buffer to fill with rollouts
- :param n_steps: Number of experiences to collect per environment
- :return: True if function returned with at least `n_rollout_steps`
- collected, False if callback terminated rollout prematurely.
- """
- assert self._last_obs is not None, "No previous observation was provided"
- # Switch to eval mode (this affects batch norm / dropout)
- self.policy.set_training_mode(False)
-
- n_steps = 0
- rollout_buffer.reset()
- # Sample new weights for the state dependent exploration
- if self.use_sde:
- self.policy.reset_noise(env.num_envs)
-
- callback.on_rollout_start()
-
- while n_steps < n_rollout_steps:
- if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
- # Sample a new noise matrix
- self.policy.reset_noise(env.num_envs)
-
- with th.no_grad():
- # Convert to pytorch tensor or to TensorDict
- obs_tensor = obs_as_tensor(self._last_obs, self.device)
- actions, values, log_probs = self.policy(obs_tensor)
- dist = self.policy.get_distribution(obs_tensor).distribution
- mean, std = dist.mean, dist.stddev
- actions = actions.cpu().numpy()
-
- # Rescale and perform action
- clipped_actions = actions
- # Clip the actions to avoid out of bound error
- if isinstance(self.action_space, spaces.Box):
- clipped_actions = np.clip(
- actions, self.action_space.low, self.action_space.high)
-
- new_obs, rewards, dones, infos = env.step(clipped_actions)
-
- self.num_timesteps += env.num_envs
-
- # Give access to local variables
- callback.update_locals(locals())
- if callback.on_step() is False:
- return False
-
- self._update_info_buffer(infos)
- n_steps += 1
-
- if isinstance(self.action_space, spaces.Discrete):
- # Reshape in case of discrete action
- actions = actions.reshape(-1, 1)
-
- # Handle timeout by bootstraping with value function
- # see GitHub issue #633
- for idx, done in enumerate(dones):
- if (
- done
- and infos[idx].get("terminal_observation") is not None
- and infos[idx].get("TimeLimit.truncated", False)
- ):
- terminal_obs = self.policy.obs_to_tensor(
- infos[idx]["terminal_observation"])[0]
- with th.no_grad():
- terminal_value = self.policy.predict_values(terminal_obs)[
- 0]
- rewards[idx] += self.gamma * terminal_value
-
- rollout_buffer.add(self._last_obs, actions, rewards,
- self._last_episode_starts, values, log_probs, mean, std)
- self._last_obs = new_obs
- self._last_episode_starts = dones
-
- with th.no_grad():
- # Compute value for the last timestep
- values = self.policy.predict_values(
- obs_as_tensor(new_obs, self.device))
-
- rollout_buffer.compute_returns_and_advantage(
- last_values=values, dones=dones)
-
- callback.on_rollout_end()
-
- return True
diff --git a/sb3_trl/trl_sac/__init__.py b/sb3_trl/trl_sac/__init__.py
deleted file mode 100644
index c0e01b7..0000000
--- a/sb3_trl/trl_sac/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from sb3_trl.trl_sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
-from sb3_trl.trl_sac.trl_sac import TRL_SAC
diff --git a/sb3_trl/trl_sac/policies.py b/sb3_trl/trl_sac/policies.py
deleted file mode 100644
index 6fcbea1..0000000
--- a/sb3_trl/trl_sac/policies.py
+++ /dev/null
@@ -1,516 +0,0 @@
-import warnings
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
-
-import gym
-import torch as th
-from torch import nn
-
-from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
-from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
-from stable_baselines3.common.preprocessing import get_action_dim
-from stable_baselines3.common.torch_layers import (
- BaseFeaturesExtractor,
- CombinedExtractor,
- FlattenExtractor,
- NatureCNN,
- create_mlp,
- get_actor_critic_arch,
-)
-from stable_baselines3.common.type_aliases import Schedule
-
-# CAP the standard deviation of the actor
-LOG_STD_MAX = 2
-LOG_STD_MIN = -20
-
-
-class Actor(BasePolicy):
- """
- Actor network (policy) for SAC.
-
- :param observation_space: Obervation space
- :param action_space: Action space
- :param net_arch: Network architecture
- :param features_extractor: Network to extract features
- (a CNN when using images, a nn.Flatten() layer otherwise)
- :param features_dim: Number of features
- :param activation_fn: Activation function
- :param use_sde: Whether to use State Dependent Exploration or not
- :param log_std_init: Initial value for the log standard deviation
- :param full_std: Whether to use (n_features x n_actions) parameters
- for the std instead of only (n_features,) when using gSDE.
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
- :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
- a positive standard deviation (cf paper). It allows to keep variance
- above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
- :param normalize_images: Whether to normalize images or not,
- dividing by 255.0 (True by default)
- """
-
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- normalize_images: bool = True,
- ):
- super().__init__(
- observation_space,
- action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- squash_output=True,
- )
-
- # Save arguments to re-create object at loading
- self.use_sde = use_sde
- self.sde_features_extractor = None
- self.net_arch = net_arch
- self.features_dim = features_dim
- self.activation_fn = activation_fn
- self.log_std_init = log_std_init
- self.sde_net_arch = sde_net_arch
- self.use_expln = use_expln
- self.full_std = full_std
- self.clip_mean = clip_mean
-
- if sde_net_arch is not None:
- warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
-
- action_dim = get_action_dim(self.action_space)
- latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
- self.latent_pi = nn.Sequential(*latent_pi_net)
- last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
-
- if self.use_sde:
- self.action_dist = StateDependentNoiseDistribution(
- action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
- )
- self.mu, self.log_std = self.action_dist.proba_distribution_net(
- latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
- )
- # Avoid numerical issues by limiting the mean of the Gaussian
- # to be in [-clip_mean, clip_mean]
- if clip_mean > 0.0:
- self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean))
- else:
- self.action_dist = SquashedDiagGaussianDistribution(action_dim)
- self.mu = nn.Linear(last_layer_dim, action_dim)
- self.log_std = nn.Linear(last_layer_dim, action_dim)
-
- def _get_constructor_parameters(self) -> Dict[str, Any]:
- data = super()._get_constructor_parameters()
-
- data.update(
- dict(
- net_arch=self.net_arch,
- features_dim=self.features_dim,
- activation_fn=self.activation_fn,
- use_sde=self.use_sde,
- log_std_init=self.log_std_init,
- full_std=self.full_std,
- use_expln=self.use_expln,
- features_extractor=self.features_extractor,
- clip_mean=self.clip_mean,
- )
- )
- return data
-
- def get_std(self) -> th.Tensor:
- """
- Retrieve the standard deviation of the action distribution.
- Only useful when using gSDE.
- It corresponds to ``th.exp(log_std)`` in the normal case,
- but is slightly different when using ``expln`` function
- (cf StateDependentNoiseDistribution doc).
-
- :return:
- """
- msg = "get_std() is only available when using gSDE"
- assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
- return self.action_dist.get_std(self.log_std)
-
- def reset_noise(self, batch_size: int = 1) -> None:
- """
- Sample new weights for the exploration matrix, when using gSDE.
-
- :param batch_size:
- """
- msg = "reset_noise() is only available when using gSDE"
- assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
- self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
-
- def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
- """
- Get the parameters for the action distribution.
-
- :param obs:
- :return:
- Mean, standard deviation and optional keyword arguments.
- """
- features = self.extract_features(obs)
- latent_pi = self.latent_pi(features)
- mean_actions = self.mu(latent_pi)
-
- if self.use_sde:
- return mean_actions, self.log_std, dict(latent_sde=latent_pi)
- # Unstructured exploration (Original implementation)
- log_std = self.log_std(latent_pi)
- # Original Implementation to cap the standard deviation
- log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
- return mean_actions, log_std, {}
-
- def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
- mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
- # Note: the action is squashed
- return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
-
- def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
- mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
- # return action and associated log prob
- return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
-
- def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
- return self(observation, deterministic)
-
-
-class SACPolicy(BasePolicy):
- """
- Policy class (with both actor and critic) for SAC.
-
- :param observation_space: Observation space
- :param action_space: Action space
- :param lr_schedule: Learning rate schedule (could be constant)
- :param net_arch: The specification of the policy and value networks.
- :param activation_fn: Activation function
- :param use_sde: Whether to use State Dependent Exploration or not
- :param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
- :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
- a positive standard deviation (cf paper). It allows to keep variance
- above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
- :param features_extractor_class: Features extractor to use.
- :param features_extractor_kwargs: Keyword arguments
- to pass to the features extractor.
- :param normalize_images: Whether to normalize images or not,
- dividing by 255.0 (True by default)
- :param optimizer_class: The optimizer to use,
- ``th.optim.Adam`` by default
- :param optimizer_kwargs: Additional keyword arguments,
- excluding the learning rate, to pass to the optimizer
- :param n_critics: Number of critic networks to create.
- :param share_features_extractor: Whether to share or not the features extractor
- between the actor and the critic (this saves computation time)
- """
-
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Schedule,
- net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2,
- share_features_extractor: bool = True,
- ):
- super().__init__(
- observation_space,
- action_space,
- features_extractor_class,
- features_extractor_kwargs,
- optimizer_class=optimizer_class,
- optimizer_kwargs=optimizer_kwargs,
- squash_output=True,
- )
-
- if net_arch is None:
- if features_extractor_class == NatureCNN:
- net_arch = []
- else:
- net_arch = [256, 256]
-
- actor_arch, critic_arch = get_actor_critic_arch(net_arch)
-
- self.net_arch = net_arch
- self.activation_fn = activation_fn
- self.net_args = {
- "observation_space": self.observation_space,
- "action_space": self.action_space,
- "net_arch": actor_arch,
- "activation_fn": self.activation_fn,
- "normalize_images": normalize_images,
- }
- self.actor_kwargs = self.net_args.copy()
-
- if sde_net_arch is not None:
- warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
-
- sde_kwargs = {
- "use_sde": use_sde,
- "log_std_init": log_std_init,
- "use_expln": use_expln,
- "clip_mean": clip_mean,
- }
- self.actor_kwargs.update(sde_kwargs)
- self.critic_kwargs = self.net_args.copy()
- self.critic_kwargs.update(
- {
- "n_critics": n_critics,
- "net_arch": critic_arch,
- "share_features_extractor": share_features_extractor,
- }
- )
-
- self.actor, self.actor_target = None, None
- self.critic, self.critic_target = None, None
- self.share_features_extractor = share_features_extractor
-
- self._build(lr_schedule)
-
- def _build(self, lr_schedule: Schedule) -> None:
- self.actor = self.make_actor()
- self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
-
- if self.share_features_extractor:
- self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
- # Do not optimize the shared features extractor with the critic loss
- # otherwise, there are gradient computation issues
- critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
- else:
- # Create a separate features extractor for the critic
- # this requires more memory and computation
- self.critic = self.make_critic(features_extractor=None)
- critic_parameters = self.critic.parameters()
-
- # Critic target should not share the features extractor with critic
- self.critic_target = self.make_critic(features_extractor=None)
- self.critic_target.load_state_dict(self.critic.state_dict())
-
- self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
-
- # Target networks should always be in eval mode
- self.critic_target.set_training_mode(False)
-
- def _get_constructor_parameters(self) -> Dict[str, Any]:
- data = super()._get_constructor_parameters()
-
- data.update(
- dict(
- net_arch=self.net_arch,
- activation_fn=self.net_args["activation_fn"],
- use_sde=self.actor_kwargs["use_sde"],
- log_std_init=self.actor_kwargs["log_std_init"],
- use_expln=self.actor_kwargs["use_expln"],
- clip_mean=self.actor_kwargs["clip_mean"],
- n_critics=self.critic_kwargs["n_critics"],
- lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
- optimizer_class=self.optimizer_class,
- optimizer_kwargs=self.optimizer_kwargs,
- features_extractor_class=self.features_extractor_class,
- features_extractor_kwargs=self.features_extractor_kwargs,
- )
- )
- return data
-
- def reset_noise(self, batch_size: int = 1) -> None:
- """
- Sample new weights for the exploration matrix, when using gSDE.
-
- :param batch_size:
- """
- self.actor.reset_noise(batch_size=batch_size)
-
- def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
- actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
- return Actor(**actor_kwargs).to(self.device)
-
- def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
- critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
- return ContinuousCritic(**critic_kwargs).to(self.device)
-
- def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
- return self._predict(obs, deterministic=deterministic)
-
- def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
- return self.actor(observation, deterministic)
-
- def set_training_mode(self, mode: bool) -> None:
- """
- Put the policy in either training or evaluation mode.
-
- This affects certain modules, such as batch normalisation and dropout.
-
- :param mode: if true, set to training mode, else set to evaluation mode
- """
- self.actor.set_training_mode(mode)
- self.critic.set_training_mode(mode)
- self.training = mode
-
-
-MlpPolicy = SACPolicy
-
-
-class CnnPolicy(SACPolicy):
- """
- Policy class (with both actor and critic) for SAC.
-
- :param observation_space: Observation space
- :param action_space: Action space
- :param lr_schedule: Learning rate schedule (could be constant)
- :param net_arch: The specification of the policy and value networks.
- :param activation_fn: Activation function
- :param use_sde: Whether to use State Dependent Exploration or not
- :param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
- :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
- a positive standard deviation (cf paper). It allows to keep variance
- above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
- :param features_extractor_class: Features extractor to use.
- :param normalize_images: Whether to normalize images or not,
- dividing by 255.0 (True by default)
- :param optimizer_class: The optimizer to use,
- ``th.optim.Adam`` by default
- :param optimizer_kwargs: Additional keyword arguments,
- excluding the learning rate, to pass to the optimizer
- :param n_critics: Number of critic networks to create.
- :param share_features_extractor: Whether to share or not the features extractor
- between the actor and the critic (this saves computation time)
- """
-
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Schedule,
- net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2,
- share_features_extractor: bool = True,
- ):
- super().__init__(
- observation_space,
- action_space,
- lr_schedule,
- net_arch,
- activation_fn,
- use_sde,
- log_std_init,
- sde_net_arch,
- use_expln,
- clip_mean,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs,
- n_critics,
- share_features_extractor,
- )
-
-
-class MultiInputPolicy(SACPolicy):
- """
- Policy class (with both actor and critic) for SAC.
-
- :param observation_space: Observation space
- :param action_space: Action space
- :param lr_schedule: Learning rate schedule (could be constant)
- :param net_arch: The specification of the policy and value networks.
- :param activation_fn: Activation function
- :param use_sde: Whether to use State Dependent Exploration or not
- :param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
- :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
- a positive standard deviation (cf paper). It allows to keep variance
- above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
- :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
- :param features_extractor_class: Features extractor to use.
- :param normalize_images: Whether to normalize images or not,
- dividing by 255.0 (True by default)
- :param optimizer_class: The optimizer to use,
- ``th.optim.Adam`` by default
- :param optimizer_kwargs: Additional keyword arguments,
- excluding the learning rate, to pass to the optimizer
- :param n_critics: Number of critic networks to create.
- :param share_features_extractor: Whether to share or not the features extractor
- between the actor and the critic (this saves computation time)
- """
-
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Schedule,
- net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2,
- share_features_extractor: bool = True,
- ):
- super().__init__(
- observation_space,
- action_space,
- lr_schedule,
- net_arch,
- activation_fn,
- use_sde,
- log_std_init,
- sde_net_arch,
- use_expln,
- clip_mean,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs,
- n_critics,
- share_features_extractor,
- )
diff --git a/sb3_trl/trl_sac/trl_sac.py b/sb3_trl/trl_sac/trl_sac.py
deleted file mode 100644
index 2e884b9..0000000
--- a/sb3_trl/trl_sac/trl_sac.py
+++ /dev/null
@@ -1,324 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
-
-import gym
-import numpy as np
-import torch as th
-from torch.nn import functional as F
-
-from stable_baselines3.common.buffers import ReplayBuffer
-from stable_baselines3.common.noise import ActionNoise
-from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
-from stable_baselines3.common.policies import BasePolicy
-from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
-from stable_baselines3.common.utils import polyak_update
-from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
-
-
-class TRL_SAC(OffPolicyAlgorithm):
- """
- Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
- This implementation is almost a 1:1-copy of the sb3-code for SAC.
- Only minor changes have been made to implement Differential Trust Region Layers
-
- Description from original SAC implementation:
- Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
- This implementation borrows code from original implementation (https://github.com/haarnoja/sac)
- from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo
- (https://github.com/rail-berkeley/softlearning/)
- and from Stable Baselines (https://github.com/hill-a/stable-baselines)
- Paper: https://arxiv.org/abs/1801.01290
- Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
-
- Note: we use double q target and not value target as discussed
- in https://github.com/hill-a/stable-baselines/issues/270
-
- :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
- :param env: The environment to learn from (if registered in Gym, can be str)
- :param learning_rate: learning rate for adam optimizer,
- the same learning rate will be used for all networks (Q-Values, Actor and Value function)
- it can be a function of the current progress remaining (from 1 to 0)
- :param buffer_size: size of the replay buffer
- :param learning_starts: how many steps of the model to collect transitions for before learning starts
- :param batch_size: Minibatch size for each gradient update
- :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
- :param gamma: the discount factor
- :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
- like ``(5, "step")`` or ``(2, "episode")``.
- :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
- Set to ``-1`` means to do as many gradient steps as steps done in the environment
- during the rollout.
- :param action_noise: the action noise type (None by default), this can help
- for hard exploration problem. Cf common.noise for the different action noise type.
- :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
- If ``None``, it will be automatically selected.
- :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
- :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
- at a cost of more complexity.
- See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
- :param ent_coef: Entropy regularization coefficient. (Equivalent to
- inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
- Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
- :param target_update_interval: update the target network every ``target_network_update_freq``
- gradient steps.
- :param target_entropy: target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
- :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
- instead of action noise exploration (default: False)
- :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
- Default: -1 (only sample at the beginning of the rollout)
- :param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
- during the warm up phase (before learning starts)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
- :param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
- :param seed: Seed for the pseudo random generators
- :param device: Device (cpu, cuda, ...) on which the code should be run.
- Setting it to auto, the code will be run on the GPU if possible.
- :param _init_setup_model: Whether or not to build the network at the creation of the instance
- """
-
- policy_aliases: Dict[str, Type[BasePolicy]] = {
- "MlpPolicy": MlpPolicy,
- "CnnPolicy": CnnPolicy,
- "MultiInputPolicy": MultiInputPolicy,
- }
-
- def __init__(
- self,
- policy: Union[str, Type[SACPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Schedule] = 3e-4,
- buffer_size: int = 1_000_000, # 1e6
- learning_starts: int = 100,
- batch_size: int = 256,
- tau: float = 0.005,
- gamma: float = 0.99,
- train_freq: Union[int, Tuple[int, str]] = 1,
- gradient_steps: int = 1,
- action_noise: Optional[ActionNoise] = None,
- replay_buffer_class: Optional[ReplayBuffer] = None,
- replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
- optimize_memory_usage: bool = False,
- ent_coef: Union[str, float] = "auto",
- target_update_interval: int = 1,
- target_entropy: Union[str, float] = "auto",
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- use_sde_at_warmup: bool = False,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = "auto",
- _init_setup_model: bool = True,
- ):
-
- super().__init__(
- policy,
- env,
- learning_rate,
- buffer_size,
- learning_starts,
- batch_size,
- tau,
- gamma,
- train_freq,
- gradient_steps,
- action_noise,
- replay_buffer_class=replay_buffer_class,
- replay_buffer_kwargs=replay_buffer_kwargs,
- policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log,
- verbose=verbose,
- device=device,
- create_eval_env=create_eval_env,
- seed=seed,
- use_sde=use_sde,
- sde_sample_freq=sde_sample_freq,
- use_sde_at_warmup=use_sde_at_warmup,
- optimize_memory_usage=optimize_memory_usage,
- supported_action_spaces=(gym.spaces.Box),
- support_multi_env=True,
- )
-
- self.target_entropy = target_entropy
- self.log_ent_coef = None # type: Optional[th.Tensor]
- # Entropy coefficient / Entropy temperature
- # Inverse of the reward scale
- self.ent_coef = ent_coef
- self.target_update_interval = target_update_interval
- self.ent_coef_optimizer = None
-
- if _init_setup_model:
- self._setup_model()
-
- def _setup_model(self) -> None:
- super()._setup_model()
- self._create_aliases()
- # Target entropy is used when learning the entropy coefficient
- if self.target_entropy == "auto":
- # automatically set target entropy if needed
- self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
- else:
- # Force conversion
- # this will also throw an error for unexpected string
- self.target_entropy = float(self.target_entropy)
-
- # The entropy coefficient or entropy can be learned automatically
- # see Automating Entropy Adjustment for Maximum Entropy RL section
- # of https://arxiv.org/abs/1812.05905
- if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
- # Default initial value of ent_coef when learned
- init_value = 1.0
- if "_" in self.ent_coef:
- init_value = float(self.ent_coef.split("_")[1])
- assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
-
- # Note: we optimize the log of the entropy coeff which is slightly different from the paper
- # as discussed in https://github.com/rail-berkeley/softlearning/issues/37
- self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
- self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
- else:
- # Force conversion to float
- # this will throw an error if a malformed string (different from 'auto')
- # is passed
- self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
-
- def _create_aliases(self) -> None:
- self.actor = self.policy.actor
- self.critic = self.policy.critic
- self.critic_target = self.policy.critic_target
-
- def train(self, gradient_steps: int, batch_size: int = 64) -> None:
- # Switch to train mode (this affects batch norm / dropout)
- self.policy.set_training_mode(True)
- # Update optimizers learning rate
- optimizers = [self.actor.optimizer, self.critic.optimizer]
- if self.ent_coef_optimizer is not None:
- optimizers += [self.ent_coef_optimizer]
-
- # Update learning rate according to lr schedule
- self._update_learning_rate(optimizers)
-
- ent_coef_losses, ent_coefs = [], []
- actor_losses, critic_losses = [], []
-
- for gradient_step in range(gradient_steps):
- # Sample replay buffer
- replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
-
- # We need to sample because `log_std` may have changed between two gradient steps
- if self.use_sde:
- self.actor.reset_noise()
-
- # Action by the current actor for the sampled state
- actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
- log_prob = log_prob.reshape(-1, 1)
-
- ent_coef_loss = None
- if self.ent_coef_optimizer is not None:
- # Important: detach the variable from the graph
- # so we don't change it with other losses
- # see https://github.com/rail-berkeley/softlearning/issues/60
- ent_coef = th.exp(self.log_ent_coef.detach())
- ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
- ent_coef_losses.append(ent_coef_loss.item())
- else:
- ent_coef = self.ent_coef_tensor
-
- ent_coefs.append(ent_coef.item())
-
- # Optimize entropy coefficient, also called
- # entropy temperature or alpha in the paper
- if ent_coef_loss is not None:
- self.ent_coef_optimizer.zero_grad()
- ent_coef_loss.backward()
- self.ent_coef_optimizer.step()
-
- with th.no_grad():
- # Select action according to policy
- next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
- # Compute the next Q values: min over all critics targets
- next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
- next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
- # add entropy term
- next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
- # td error + entropy term
- target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
-
- # Get current Q-values estimates for each critic network
- # using action from the replay buffer
- current_q_values = self.critic(replay_data.observations, replay_data.actions)
-
- # Compute critic loss
- critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
- critic_losses.append(critic_loss.item())
-
- # Optimize the critic
- self.critic.optimizer.zero_grad()
- critic_loss.backward()
- self.critic.optimizer.step()
-
- # Compute actor loss
- # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
- # Mean over all critic networks
- q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
- min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
- actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
- actor_losses.append(actor_loss.item())
-
- # Optimize the actor
- self.actor.optimizer.zero_grad()
- actor_loss.backward()
- self.actor.optimizer.step()
-
- # Update target networks
- if gradient_step % self.target_update_interval == 0:
- polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
-
- self._n_updates += gradient_steps
-
- self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
- self.logger.record("train/ent_coef", np.mean(ent_coefs))
- self.logger.record("train/actor_loss", np.mean(actor_losses))
- self.logger.record("train/critic_loss", np.mean(critic_losses))
- if len(ent_coef_losses) > 0:
- self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
-
- def learn(
- self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "SAC",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True,
- ) -> OffPolicyAlgorithm:
-
- return super().learn(
- total_timesteps=total_timesteps,
- callback=callback,
- log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps,
- )
-
- def _excluded_save_params(self) -> List[str]:
- return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
-
- def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
- state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
- if self.ent_coef_optimizer is not None:
- saved_pytorch_variables = ["log_ent_coef"]
- state_dicts.append("ent_coef_optimizer")
- else:
- saved_pytorch_variables = ["ent_coef_tensor"]
- return state_dicts, saved_pytorch_variables
diff --git a/test.py b/test.py
index 5e975aa..52e27fd 100755
--- a/test.py
+++ b/test.py
@@ -10,16 +10,16 @@ from stable_baselines3 import SAC, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
-from sb3_trl.trl_pg import TRL_PG
+from metastable_baselines.trl_pg import TRL_PG
import columbus
#root_path = os.getcwd()
root_path = '.'
-def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, showRes=True, saveModel=True, n_eval_episodes=0):
+def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name)
- use_sde = False
+ use_sde = True
ppo = PPO(
"MlpPolicy",
env,
@@ -54,8 +54,8 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, showRes=True,
print('TRL_PG:')
testModel(trl_pg, timesteps, showRes,
saveModel, n_eval_episodes)
- #print('PPO:')
- #testModel(ppo, timesteps, showRes,
+ # print('PPO:')
+ # testModel(ppo, timesteps, showRes,
# saveModel, n_eval_episodes)
@@ -100,7 +100,7 @@ def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=
if __name__ == '__main__':
- main('LunarLanderContinuous-v2')
- #main('ColumbusJustState-v0')
- #main('ColumbusStateWithBarriers-v0')
- #main('ColumbusEasierObstacles-v0')
+ # main('LunarLanderContinuous-v2')
+ # main('ColumbusJustState-v0')
+ # main('ColumbusStateWithBarriers-v0')
+ main('ColumbusEasierObstacles-v0')