diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index e46c228..8a8234d 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -60,11 +60,12 @@ def cast_to_enum(inp, Class): def cast_to_kernel(inp): - if isinstance(inp, function): + if callable(inp): return inp else: func, *pars = inp.split('_') - return Avaible_Kernel_Funcs[func](*pars) + pars = [float(par) for par in pars] + return Avaible_Kernel_Funcs[func].get_func()(*pars) class PCA_Distribution(SB3_Distribution): @@ -84,7 +85,7 @@ class PCA_Distribution(SB3_Distribution): self.action_dim = action_dim self.kernel_func = cast_to_kernel(kernel_func) - self.par_strength = cast_to_enum(Par_Strength, par_strength) + self.par_strength = cast_to_enum(par_strength, Par_Strength) self.init_std = init_std self.window = window self.epsilon = epsilon @@ -138,11 +139,10 @@ class PCA_Distribution(SB3_Distribution): return eta.detach() def _pad_and_cut_trajectory(self, traj, value=0): - cut = traj[:self.window] if traj.shape[-2] < self.window: missing = self.window - traj.shape[-2] - return F.pad(input=cut, pad=(missing, 0), value=value) - return cut + return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value) + return traj[:, :self.window, :] def _conditioning_engine(self, trajectory, pi_mean, pi_std): traj = self._pad_and_cut_trajectory(trajectory) @@ -188,6 +188,7 @@ class PCA_Distribution(SB3_Distribution): # https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png # This way conditioning of the GP can be done in O(dim(A)) time. if not self.is_contextual(): + # TODO: fix, this does not work # safe inplace self.conditioner[-1, - 1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)