Bug Fix for Full Cov
This commit is contained in:
parent
00dbc9bdd8
commit
f37d3215a6
@ -383,6 +383,15 @@ class CholNet(nn.Module):
|
||||
|
||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||
|
||||
if self.par_type == ParametrizationType.CHOL:
|
||||
self._full_params_len = self._flat_chol_len
|
||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||
self._full_params_len = self._flat_chol_len
|
||||
elif self.par_type == ParametrizationType.EIGEN:
|
||||
self._full_params_len = self.action_dim * 2
|
||||
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
|
||||
self._full_params_len = self.action_dim * 2
|
||||
|
||||
self._givens_rotator = givens.Rotation(action_dim)
|
||||
self._givens_ident = th.eye(action_dim)
|
||||
|
||||
@ -491,18 +500,6 @@ class CholNet(nn.Module):
|
||||
return chol
|
||||
raise Exception()
|
||||
|
||||
@property
|
||||
def _full_params_len(self):
|
||||
if self.par_type == ParametrizationType.CHOL:
|
||||
return self._flat_chol_len
|
||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||
return self._flat_chol_len
|
||||
elif self.par_type == ParametrizationType.EIGEN:
|
||||
return self.action_dim * 2
|
||||
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
|
||||
return self.action_dim * 2
|
||||
raise Exception()
|
||||
|
||||
def _parameterize_full(self, params):
|
||||
if self.par_type == ParametrizationType.CHOL:
|
||||
return self._chol_from_flat(params)
|
||||
|
Loading…
Reference in New Issue
Block a user