diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 951eca4..282d69c 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -389,7 +389,7 @@ class CholNet(nn.Module): self._full_params_len = self._flat_chol_len elif self.par_type == ParametrizationType.EIGEN: self._full_params_len = self.action_dim * 2 - elif self.par_type == ParametrizationType.EIGEN_BIJECT: + elif self.par_type == ParametrizationType.EIGEN_RAW: self._full_params_len = self.action_dim * 2 self._givens_rotator = givens.Rotation(action_dim)