Squashing Bugs
This commit is contained in:
parent
89bc80ed3a
commit
21de8f418b
@ -60,11 +60,12 @@ def cast_to_enum(inp, Class):
|
||||
|
||||
|
||||
def cast_to_kernel(inp):
|
||||
if isinstance(inp, function):
|
||||
if callable(inp):
|
||||
return inp
|
||||
else:
|
||||
func, *pars = inp.split('_')
|
||||
return Avaible_Kernel_Funcs[func](*pars)
|
||||
pars = [float(par) for par in pars]
|
||||
return Avaible_Kernel_Funcs[func].get_func()(*pars)
|
||||
|
||||
|
||||
class PCA_Distribution(SB3_Distribution):
|
||||
@ -84,7 +85,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
|
||||
self.action_dim = action_dim
|
||||
self.kernel_func = cast_to_kernel(kernel_func)
|
||||
self.par_strength = cast_to_enum(Par_Strength, par_strength)
|
||||
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
||||
self.init_std = init_std
|
||||
self.window = window
|
||||
self.epsilon = epsilon
|
||||
@ -138,11 +139,10 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return eta.detach()
|
||||
|
||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||
cut = traj[:self.window]
|
||||
if traj.shape[-2] < self.window:
|
||||
missing = self.window - traj.shape[-2]
|
||||
return F.pad(input=cut, pad=(missing, 0), value=value)
|
||||
return cut
|
||||
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||
return traj[:, :self.window, :]
|
||||
|
||||
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
||||
traj = self._pad_and_cut_trajectory(trajectory)
|
||||
@ -188,6 +188,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
|
||||
# This way conditioning of the GP can be done in O(dim(A)) time.
|
||||
if not self.is_contextual():
|
||||
# TODO: fix, this does not work
|
||||
# safe inplace
|
||||
self.conditioner[-1, -
|
||||
1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)
|
||||
|
Loading…
Reference in New Issue
Block a user