diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index f5acfa0..951eca4 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -383,6 +383,15 @@ class CholNet(nn.Module): self._flat_chol_len = action_dim * (action_dim + 1) // 2 + if self.par_type == ParametrizationType.CHOL: + self._full_params_len = self._flat_chol_len + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + self._full_params_len = self._flat_chol_len + elif self.par_type == ParametrizationType.EIGEN: + self._full_params_len = self.action_dim * 2 + elif self.par_type == ParametrizationType.EIGEN_BIJECT: + self._full_params_len = self.action_dim * 2 + self._givens_rotator = givens.Rotation(action_dim) self._givens_ident = th.eye(action_dim) @@ -491,18 +500,6 @@ class CholNet(nn.Module): return chol raise Exception() - @property - def _full_params_len(self): - if self.par_type == ParametrizationType.CHOL: - return self._flat_chol_len - elif self.par_type == ParametrizationType.SPHERICAL_CHOL: - return self._flat_chol_len - elif self.par_type == ParametrizationType.EIGEN: - return self.action_dim * 2 - elif self.par_type == ParametrizationType.EIGEN_BIJECT: - return self.action_dim * 2 - raise Exception() - def _parameterize_full(self, params): if self.par_type == ParametrizationType.CHOL: return self._chol_from_flat(params)