diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 22831d4..06e6807 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -22,13 +22,6 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution from ..misc.tensor_ops import fill_triangular from ..misc.tanhBijector import TanhBijector -# TODO: Integrate and Test what I currently have before adding more complexity -# TODO: Support Squashed Dists (tanh) -# TODO: Contextual Cov -# TODO: - Hybrid -# TODO: Contextual SDE (Scalar + Diag + Full) -# TODO: (SqrtInducedCov (Scalar + Diag + Full)) - class Strength(Enum): NONE = 0 @@ -220,17 +213,13 @@ class UniversalGaussianDistribution(SB3_Distribution): return mean_actions, chol def _sqrt_to_chol(self, cov_sqrt): - vec = False - nobatch = False - if len(cov_sqrt.shape) <= 2: - vec = True - if len(cov_sqrt.shape) == 1: - nobatch = True + vec = self.cov_strength != Strength.FULL + batch_dims = len(cov_sqrt.shape) - 2 + 1*vec if vec: cov_sqrt = th.diag_embed(cov_sqrt) - if nobatch: + if batch_dims == 0: cov = th.mm(cov_sqrt.mT, cov_sqrt) cov += th.eye(cov.shape[-1])*(self.epsilon) else: @@ -533,8 +522,12 @@ class CholNet(nn.Module): # S[i,j] e (0, pi) where i = 2..n, j = 2..i # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements - batch = (len(sphe_chol.shape) == 3) - batch_size = sphe_chol.shape[0] + vec = self.cov_strength != Strength.FULL + batch_dims = len(sphe_chol.shape) - 2 + 1*vec + batch = batch_dims != 0 + batch_shape = sphe_chol.shape[:batch_dims] + batch_shape_scalar = batch_shape + (1,) + S = sphe_chol n = sphe_chol.shape[-1] L = th.zeros_like(sphe_chol) @@ -542,13 +535,13 @@ class CholNet(nn.Module): #t = 1 t = th.Tensor([1])[0] if batch: - t = t.expand((batch_size, 1)) + t = t.expand(batch_shape_scalar) #s = '' for j in range(i+1): #maybe_cos = 1 maybe_cos = th.Tensor([1])[0] if batch: - maybe_cos = maybe_cos.expand((batch_size, 1)) + maybe_cos = maybe_cos.expand(batch_shape_scalar) #s_maybe_cos = '' if i != j and j < n-1 and i < n: if batch: diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 273aa79..4f66ee2 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -101,7 +101,7 @@ class ActorCriticPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, dist_kwargs: Optional[Dict[str, Any]] = None, - sqrt_induced_gaussian=False, + sqrt_induced_gaussian: bool = False, ): if optimizer_kwargs is None: