Allow skipping conditioning and minor additions

This commit is contained in:
Dominik Moritz Roth 2023-05-04 12:18:07 +02:00
parent 5f2d27efce
commit d04f245e9b

View File

@ -79,7 +79,8 @@ class PCA_Distribution(SB3_Distribution):
init_std: int = 1,
window: int = 64,
epsilon: float = 1e-6,
Base_Noise=noise.White_Noise
skip_conditioning: bool = False,
Base_Noise=noise.White_Noise,
):
super().__init__()
@ -89,12 +90,13 @@ class PCA_Distribution(SB3_Distribution):
self.init_std = init_std
self.window = window
self.epsilon = epsilon
self.skip_conditioning = skip_conditioning
if Base_Noise.__class__ != noise.White_Noise:
self.base_noise = Base_Noise((1, action_dim))
if not isinstance(self.base_noise, noise.White_Noise):
print('[!] Non-White Noise was not yet tested!')
self.base_noise = Base_Noise((1, )+action_dim)
# Premature optimization is the root of all evil
self._build_conditioner()
# *Optimizes it anyways*
@ -118,9 +120,10 @@ class PCA_Distribution(SB3_Distribution):
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor:
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
rho_std *= f_sigma
eta = self._get_rigged(pi_mean, pi_std,
rho_mean, rho_std,
epsilon)
@ -137,6 +140,9 @@ class PCA_Distribution(SB3_Distribution):
if epsilon == None:
epsilon = self.base_noise(pi_mean.shape)
if self.skip_conditioning:
return epsilon.detach()
Delta = rho_mean - pi_mean
Pi_mu = 1 / pi_std
Pi_sigma = rho_std / pi_std