Squashing Bugs

This commit is contained in:
Dominik Moritz Roth 2023-05-03 17:02:22 +02:00
parent 89bc80ed3a
commit 21de8f418b

View File

@ -60,11 +60,12 @@ def cast_to_enum(inp, Class):
def cast_to_kernel(inp): def cast_to_kernel(inp):
if isinstance(inp, function): if callable(inp):
return inp return inp
else: else:
func, *pars = inp.split('_') func, *pars = inp.split('_')
return Avaible_Kernel_Funcs[func](*pars) pars = [float(par) for par in pars]
return Avaible_Kernel_Funcs[func].get_func()(*pars)
class PCA_Distribution(SB3_Distribution): class PCA_Distribution(SB3_Distribution):
@ -84,7 +85,7 @@ class PCA_Distribution(SB3_Distribution):
self.action_dim = action_dim self.action_dim = action_dim
self.kernel_func = cast_to_kernel(kernel_func) self.kernel_func = cast_to_kernel(kernel_func)
self.par_strength = cast_to_enum(Par_Strength, par_strength) self.par_strength = cast_to_enum(par_strength, Par_Strength)
self.init_std = init_std self.init_std = init_std
self.window = window self.window = window
self.epsilon = epsilon self.epsilon = epsilon
@ -138,11 +139,10 @@ class PCA_Distribution(SB3_Distribution):
return eta.detach() return eta.detach()
def _pad_and_cut_trajectory(self, traj, value=0): def _pad_and_cut_trajectory(self, traj, value=0):
cut = traj[:self.window]
if traj.shape[-2] < self.window: if traj.shape[-2] < self.window:
missing = self.window - traj.shape[-2] missing = self.window - traj.shape[-2]
return F.pad(input=cut, pad=(missing, 0), value=value) return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
return cut return traj[:, :self.window, :]
def _conditioning_engine(self, trajectory, pi_mean, pi_std): def _conditioning_engine(self, trajectory, pi_mean, pi_std):
traj = self._pad_and_cut_trajectory(trajectory) traj = self._pad_and_cut_trajectory(trajectory)
@ -188,6 +188,7 @@ class PCA_Distribution(SB3_Distribution):
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png # https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
# This way conditioning of the GP can be done in O(dim(A)) time. # This way conditioning of the GP can be done in O(dim(A)) time.
if not self.is_contextual(): if not self.is_contextual():
# TODO: fix, this does not work
# safe inplace # safe inplace
self.conditioner[-1, - self.conditioner[-1, -
1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm) 1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)