diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 40f6d1c..02623de 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -76,7 +76,8 @@ class PCA_Distribution(SB3_Distribution): action_dim: int, par_strength: Par_Strength = Par_Strength.CONT_DIAG, kernel_func=rbf(), - init_std: int = 1, + init_std: float = 1, + cond_noise: float = 0, window: int = 64, epsilon: float = 1e-6, skip_conditioning: bool = False, @@ -88,6 +89,7 @@ class PCA_Distribution(SB3_Distribution): self.kernel_func = cast_to_kernel(kernel_func) self.par_strength = cast_to_enum(par_strength, Par_Strength) self.init_std = init_std + self.cond_noise = cond_noise self.window = window self.epsilon = epsilon self.skip_conditioning = skip_conditioning @@ -155,7 +157,7 @@ class PCA_Distribution(SB3_Distribution): if traj.shape[-2] < self.window: missing = self.window - traj.shape[-2] return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value) - return traj[:, :self.window, :] + return traj[:, -self.window:, :] def _conditioning_engine(self, trajectory, pi_mean, pi_std): traj = self._pad_and_cut_trajectory(trajectory) @@ -181,7 +183,8 @@ class PCA_Distribution(SB3_Distribution): Z = np.linspace(0, w, w+1).reshape(-1, 1) X = np.array([w]).reshape(-1, 1) - Sig11 = self.kernel_func(Z, Z) + Sig11 = self.kernel_func( + Z, Z) + np.diag(np.hstack((np.repeat(self.cond_noise**2, w), 0))) self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1) self.Sig22 = th.Tensor(self.kernel_func( X, X)).squeeze(-1).squeeze(-1)