Fixing bugs, allowing noisy conditioning
This commit is contained in:
parent
cdea897a66
commit
0383ac39ed
@ -76,7 +76,8 @@ class PCA_Distribution(SB3_Distribution):
|
||||
action_dim: int,
|
||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||
kernel_func=rbf(),
|
||||
init_std: int = 1,
|
||||
init_std: float = 1,
|
||||
cond_noise: float = 0,
|
||||
window: int = 64,
|
||||
epsilon: float = 1e-6,
|
||||
skip_conditioning: bool = False,
|
||||
@ -88,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
self.kernel_func = cast_to_kernel(kernel_func)
|
||||
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
||||
self.init_std = init_std
|
||||
self.cond_noise = cond_noise
|
||||
self.window = window
|
||||
self.epsilon = epsilon
|
||||
self.skip_conditioning = skip_conditioning
|
||||
@ -155,7 +157,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
if traj.shape[-2] < self.window:
|
||||
missing = self.window - traj.shape[-2]
|
||||
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||
return traj[:, :self.window, :]
|
||||
return traj[:, -self.window:, :]
|
||||
|
||||
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
||||
traj = self._pad_and_cut_trajectory(trajectory)
|
||||
@ -181,7 +183,8 @@ class PCA_Distribution(SB3_Distribution):
|
||||
Z = np.linspace(0, w, w+1).reshape(-1, 1)
|
||||
X = np.array([w]).reshape(-1, 1)
|
||||
|
||||
Sig11 = self.kernel_func(Z, Z)
|
||||
Sig11 = self.kernel_func(
|
||||
Z, Z) + np.diag(np.hstack((np.repeat(self.cond_noise**2, w), 0)))
|
||||
self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1)
|
||||
self.Sig22 = th.Tensor(self.kernel_func(
|
||||
X, X)).squeeze(-1).squeeze(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user