Allow skipping conditioning and minor additions
This commit is contained in:
parent
5f2d27efce
commit
d04f245e9b
@ -79,7 +79,8 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
init_std: int = 1,
|
init_std: int = 1,
|
||||||
window: int = 64,
|
window: int = 64,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
Base_Noise=noise.White_Noise
|
skip_conditioning: bool = False,
|
||||||
|
Base_Noise=noise.White_Noise,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -89,12 +90,13 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.window = window
|
self.window = window
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
self.skip_conditioning = skip_conditioning
|
||||||
|
|
||||||
if Base_Noise.__class__ != noise.White_Noise:
|
self.base_noise = Base_Noise((1, action_dim))
|
||||||
|
|
||||||
|
if not isinstance(self.base_noise, noise.White_Noise):
|
||||||
print('[!] Non-White Noise was not yet tested!')
|
print('[!] Non-White Noise was not yet tested!')
|
||||||
|
|
||||||
self.base_noise = Base_Noise((1, )+action_dim)
|
|
||||||
|
|
||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
@ -118,9 +120,10 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.entropy())
|
return sum_independent_dims(self.distribution.entropy())
|
||||||
|
|
||||||
def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor:
|
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||||
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
||||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||||
|
rho_std *= f_sigma
|
||||||
eta = self._get_rigged(pi_mean, pi_std,
|
eta = self._get_rigged(pi_mean, pi_std,
|
||||||
rho_mean, rho_std,
|
rho_mean, rho_std,
|
||||||
epsilon)
|
epsilon)
|
||||||
@ -137,6 +140,9 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
if epsilon == None:
|
if epsilon == None:
|
||||||
epsilon = self.base_noise(pi_mean.shape)
|
epsilon = self.base_noise(pi_mean.shape)
|
||||||
|
|
||||||
|
if self.skip_conditioning:
|
||||||
|
return epsilon.detach()
|
||||||
|
|
||||||
Delta = rho_mean - pi_mean
|
Delta = rho_mean - pi_mean
|
||||||
Pi_mu = 1 / pi_std
|
Pi_mu = 1 / pi_std
|
||||||
Pi_sigma = rho_std / pi_std
|
Pi_sigma = rho_std / pi_std
|
||||||
|
Loading…
Reference in New Issue
Block a user