Allow skipping conditioning and minor additions
This commit is contained in:
parent
5f2d27efce
commit
d04f245e9b
@ -79,7 +79,8 @@ class PCA_Distribution(SB3_Distribution):
|
||||
init_std: int = 1,
|
||||
window: int = 64,
|
||||
epsilon: float = 1e-6,
|
||||
Base_Noise=noise.White_Noise
|
||||
skip_conditioning: bool = False,
|
||||
Base_Noise=noise.White_Noise,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -89,12 +90,13 @@ class PCA_Distribution(SB3_Distribution):
|
||||
self.init_std = init_std
|
||||
self.window = window
|
||||
self.epsilon = epsilon
|
||||
self.skip_conditioning = skip_conditioning
|
||||
|
||||
if Base_Noise.__class__ != noise.White_Noise:
|
||||
self.base_noise = Base_Noise((1, action_dim))
|
||||
|
||||
if not isinstance(self.base_noise, noise.White_Noise):
|
||||
print('[!] Non-White Noise was not yet tested!')
|
||||
|
||||
self.base_noise = Base_Noise((1, )+action_dim)
|
||||
|
||||
# Premature optimization is the root of all evil
|
||||
self._build_conditioner()
|
||||
# *Optimizes it anyways*
|
||||
@ -118,9 +120,10 @@ class PCA_Distribution(SB3_Distribution):
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor:
|
||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||
rho_std *= f_sigma
|
||||
eta = self._get_rigged(pi_mean, pi_std,
|
||||
rho_mean, rho_std,
|
||||
epsilon)
|
||||
@ -137,6 +140,9 @@ class PCA_Distribution(SB3_Distribution):
|
||||
if epsilon == None:
|
||||
epsilon = self.base_noise(pi_mean.shape)
|
||||
|
||||
if self.skip_conditioning:
|
||||
return epsilon.detach()
|
||||
|
||||
Delta = rho_mean - pi_mean
|
||||
Pi_mu = 1 / pi_std
|
||||
Pi_sigma = rho_std / pi_std
|
||||
|
Loading…
Reference in New Issue
Block a user