diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index dd0a775..40f6d1c 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -79,7 +79,8 @@ class PCA_Distribution(SB3_Distribution): init_std: int = 1, window: int = 64, epsilon: float = 1e-6, - Base_Noise=noise.White_Noise + skip_conditioning: bool = False, + Base_Noise=noise.White_Noise, ): super().__init__() @@ -89,12 +90,13 @@ class PCA_Distribution(SB3_Distribution): self.init_std = init_std self.window = window self.epsilon = epsilon + self.skip_conditioning = skip_conditioning - if Base_Noise.__class__ != noise.White_Noise: + self.base_noise = Base_Noise((1, action_dim)) + + if not isinstance(self.base_noise, noise.White_Noise): print('[!] Non-White Noise was not yet tested!') - self.base_noise = Base_Noise((1, )+action_dim) - # Premature optimization is the root of all evil self._build_conditioner() # *Optimizes it anyways* @@ -118,9 +120,10 @@ class PCA_Distribution(SB3_Distribution): def entropy(self) -> th.Tensor: return sum_independent_dims(self.distribution.entropy()) - def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor: + def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor: pi_mean, pi_std = self.distribution.mean, self.distribution.scale rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) + rho_std *= f_sigma eta = self._get_rigged(pi_mean, pi_std, rho_mean, rho_std, epsilon) @@ -137,6 +140,9 @@ class PCA_Distribution(SB3_Distribution): if epsilon == None: epsilon = self.base_noise(pi_mean.shape) + if self.skip_conditioning: + return epsilon.detach() + Delta = rho_mean - pi_mean Pi_mu = 1 / pi_std Pi_sigma = rho_std / pi_std