Bug Fix for Full Cov
This commit is contained in:
parent
00dbc9bdd8
commit
f37d3215a6
@ -383,6 +383,15 @@ class CholNet(nn.Module):
|
|||||||
|
|
||||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||||
|
|
||||||
|
if self.par_type == ParametrizationType.CHOL:
|
||||||
|
self._full_params_len = self._flat_chol_len
|
||||||
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||||
|
self._full_params_len = self._flat_chol_len
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN:
|
||||||
|
self._full_params_len = self.action_dim * 2
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
|
||||||
|
self._full_params_len = self.action_dim * 2
|
||||||
|
|
||||||
self._givens_rotator = givens.Rotation(action_dim)
|
self._givens_rotator = givens.Rotation(action_dim)
|
||||||
self._givens_ident = th.eye(action_dim)
|
self._givens_ident = th.eye(action_dim)
|
||||||
|
|
||||||
@ -491,18 +500,6 @@ class CholNet(nn.Module):
|
|||||||
return chol
|
return chol
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
@property
|
|
||||||
def _full_params_len(self):
|
|
||||||
if self.par_type == ParametrizationType.CHOL:
|
|
||||||
return self._flat_chol_len
|
|
||||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
|
||||||
return self._flat_chol_len
|
|
||||||
elif self.par_type == ParametrizationType.EIGEN:
|
|
||||||
return self.action_dim * 2
|
|
||||||
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
|
|
||||||
return self.action_dim * 2
|
|
||||||
raise Exception()
|
|
||||||
|
|
||||||
def _parameterize_full(self, params):
|
def _parameterize_full(self, params):
|
||||||
if self.par_type == ParametrizationType.CHOL:
|
if self.par_type == ParametrizationType.CHOL:
|
||||||
return self._chol_from_flat(params)
|
return self._chol_from_flat(params)
|
||||||
|
Loading…
Reference in New Issue
Block a user