Fixing bugs, allowing noisy conditioning
This commit is contained in:
parent
cdea897a66
commit
0383ac39ed
@ -76,7 +76,8 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
action_dim: int,
|
action_dim: int,
|
||||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||||
kernel_func=rbf(),
|
kernel_func=rbf(),
|
||||||
init_std: int = 1,
|
init_std: float = 1,
|
||||||
|
cond_noise: float = 0,
|
||||||
window: int = 64,
|
window: int = 64,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
skip_conditioning: bool = False,
|
skip_conditioning: bool = False,
|
||||||
@ -88,6 +89,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self.kernel_func = cast_to_kernel(kernel_func)
|
self.kernel_func = cast_to_kernel(kernel_func)
|
||||||
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
|
self.cond_noise = cond_noise
|
||||||
self.window = window
|
self.window = window
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.skip_conditioning = skip_conditioning
|
self.skip_conditioning = skip_conditioning
|
||||||
@ -155,7 +157,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
if traj.shape[-2] < self.window:
|
if traj.shape[-2] < self.window:
|
||||||
missing = self.window - traj.shape[-2]
|
missing = self.window - traj.shape[-2]
|
||||||
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||||
return traj[:, :self.window, :]
|
return traj[:, -self.window:, :]
|
||||||
|
|
||||||
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
||||||
traj = self._pad_and_cut_trajectory(trajectory)
|
traj = self._pad_and_cut_trajectory(trajectory)
|
||||||
@ -181,7 +183,8 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
Z = np.linspace(0, w, w+1).reshape(-1, 1)
|
Z = np.linspace(0, w, w+1).reshape(-1, 1)
|
||||||
X = np.array([w]).reshape(-1, 1)
|
X = np.array([w]).reshape(-1, 1)
|
||||||
|
|
||||||
Sig11 = self.kernel_func(Z, Z)
|
Sig11 = self.kernel_func(
|
||||||
|
Z, Z) + np.diag(np.hstack((np.repeat(self.cond_noise**2, w), 0)))
|
||||||
self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1)
|
self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1)
|
||||||
self.Sig22 = th.Tensor(self.kernel_func(
|
self.Sig22 = th.Tensor(self.kernel_func(
|
||||||
X, X)).squeeze(-1).squeeze(-1)
|
X, X)).squeeze(-1).squeeze(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user