Fixing bugs, allowing noisy conditioning

This commit is contained in:
Dominik Moritz Roth 2023-05-04 15:39:57 +02:00
parent cdea897a66
commit 0383ac39ed

View File

@ -76,7 +76,8 @@ class PCA_Distribution(SB3_Distribution):
action_dim: int, action_dim: int,
par_strength: Par_Strength = Par_Strength.CONT_DIAG, par_strength: Par_Strength = Par_Strength.CONT_DIAG,
kernel_func=rbf(), kernel_func=rbf(),
init_std: int = 1, init_std: float = 1,
cond_noise: float = 0,
window: int = 64, window: int = 64,
epsilon: float = 1e-6, epsilon: float = 1e-6,
skip_conditioning: bool = False, skip_conditioning: bool = False,
@ -88,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
self.kernel_func = cast_to_kernel(kernel_func) self.kernel_func = cast_to_kernel(kernel_func)
self.par_strength = cast_to_enum(par_strength, Par_Strength) self.par_strength = cast_to_enum(par_strength, Par_Strength)
self.init_std = init_std self.init_std = init_std
self.cond_noise = cond_noise
self.window = window self.window = window
self.epsilon = epsilon self.epsilon = epsilon
self.skip_conditioning = skip_conditioning self.skip_conditioning = skip_conditioning
@ -155,7 +157,7 @@ class PCA_Distribution(SB3_Distribution):
if traj.shape[-2] < self.window: if traj.shape[-2] < self.window:
missing = self.window - traj.shape[-2] missing = self.window - traj.shape[-2]
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value) return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
return traj[:, :self.window, :] return traj[:, -self.window:, :]
def _conditioning_engine(self, trajectory, pi_mean, pi_std): def _conditioning_engine(self, trajectory, pi_mean, pi_std):
traj = self._pad_and_cut_trajectory(trajectory) traj = self._pad_and_cut_trajectory(trajectory)
@ -181,7 +183,8 @@ class PCA_Distribution(SB3_Distribution):
Z = np.linspace(0, w, w+1).reshape(-1, 1) Z = np.linspace(0, w, w+1).reshape(-1, 1)
X = np.array([w]).reshape(-1, 1) X = np.array([w]).reshape(-1, 1)
Sig11 = self.kernel_func(Z, Z) Sig11 = self.kernel_func(
Z, Z) + np.diag(np.hstack((np.repeat(self.cond_noise**2, w), 0)))
self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1) self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1)
self.Sig22 = th.Tensor(self.kernel_func( self.Sig22 = th.Tensor(self.kernel_func(
X, X)).squeeze(-1).squeeze(-1) X, X)).squeeze(-1).squeeze(-1)