diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index db13ef9..25592a8 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -114,7 +114,6 @@ class PCA_Distribution(SB3_Distribution): # Premature optimization is the root of all evil self._build_conditioner() # *Optimizes it anyways* - print('[i] PCA-Distribution initialized') def proba_distribution_net(self, latent_dim: int): mu_net = nn.Linear(latent_dim, self.action_dim) @@ -129,7 +128,7 @@ class PCA_Distribution(SB3_Distribution): return self def log_prob(self, actions: th.Tensor) -> th.Tensor: - return sum_independent_dims(self.distribution.log_prob(actions)) + return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device))) def entropy(self) -> th.Tensor: return sum_independent_dims(self.distribution.entropy()) @@ -146,7 +145,7 @@ class PCA_Distribution(SB3_Distribution): return self.sample(traj=trajectory) def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor: - pi_mean, pi_std = self.distribution.mean, self.distribution.scale + pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu() rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) rho_std *= f_sigma eta = self._get_rigged(pi_mean, pi_std, @@ -187,8 +186,9 @@ class PCA_Distribution(SB3_Distribution): return pi_mean, pi_std traj = self._pad_and_cut_trajectory(trajectory) + # Numpy is fun y_np = np.append(np.swapaxes(traj, -1, -2), - np.expand_dims(pi_mean, -1), -1) + np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1) with th.no_grad(): conditioners = th.Tensor(self._adapt_conditioner(pi_std)) @@ -300,7 +300,7 @@ class StdNet(nn.Module): return self._ensure_positive_func(self.param) elif self.par_strength == Par_Strength.CONT_SCALAR: cont = self.net(x) - diag_chol = th.ones(self.action_dim) * cont * self.std_init + diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init return self._ensure_positive_func(diag_chol) elif self.par_strength == Par_Strength.CONT_HYBRID: cont = self.net(x)