From f37d3215a6aa0c280939726e2e51c7ffb7468cda Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 23 Sep 2022 23:06:19 +0200 Subject: [PATCH] Bug Fix for Full Cov --- .../distributions/distributions.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index f5acfa0..951eca4 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -383,6 +383,15 @@ class CholNet(nn.Module): self._flat_chol_len = action_dim * (action_dim + 1) // 2 + if self.par_type == ParametrizationType.CHOL: + self._full_params_len = self._flat_chol_len + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + self._full_params_len = self._flat_chol_len + elif self.par_type == ParametrizationType.EIGEN: + self._full_params_len = self.action_dim * 2 + elif self.par_type == ParametrizationType.EIGEN_BIJECT: + self._full_params_len = self.action_dim * 2 + self._givens_rotator = givens.Rotation(action_dim) self._givens_ident = th.eye(action_dim) @@ -491,18 +500,6 @@ class CholNet(nn.Module): return chol raise Exception() - @property - def _full_params_len(self): - if self.par_type == ParametrizationType.CHOL: - return self._flat_chol_len - elif self.par_type == ParametrizationType.SPHERICAL_CHOL: - return self._flat_chol_len - elif self.par_type == ParametrizationType.EIGEN: - return self.action_dim * 2 - elif self.par_type == ParametrizationType.EIGEN_BIJECT: - return self.action_dim * 2 - raise Exception() - def _parameterize_full(self, params): if self.par_type == ParametrizationType.CHOL: return self._chol_from_flat(params)