Bug Fix for Full Cov

This commit is contained in:
Dominik Moritz Roth 2022-09-23 23:06:19 +02:00
parent 00dbc9bdd8
commit f37d3215a6

View File

@ -383,6 +383,15 @@ class CholNet(nn.Module):
self._flat_chol_len = action_dim * (action_dim + 1) // 2 self._flat_chol_len = action_dim * (action_dim + 1) // 2
if self.par_type == ParametrizationType.CHOL:
self._full_params_len = self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
self._full_params_len = self._flat_chol_len
elif self.par_type == ParametrizationType.EIGEN:
self._full_params_len = self.action_dim * 2
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
self._full_params_len = self.action_dim * 2
self._givens_rotator = givens.Rotation(action_dim) self._givens_rotator = givens.Rotation(action_dim)
self._givens_ident = th.eye(action_dim) self._givens_ident = th.eye(action_dim)
@ -491,18 +500,6 @@ class CholNet(nn.Module):
return chol return chol
raise Exception() raise Exception()
@property
def _full_params_len(self):
if self.par_type == ParametrizationType.CHOL:
return self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._flat_chol_len
elif self.par_type == ParametrizationType.EIGEN:
return self.action_dim * 2
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
return self.action_dim * 2
raise Exception()
def _parameterize_full(self, params): def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL: if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params) return self._chol_from_flat(params)