Squashing Bugs
This commit is contained in:
parent
89bc80ed3a
commit
21de8f418b
@ -60,11 +60,12 @@ def cast_to_enum(inp, Class):
|
|||||||
|
|
||||||
|
|
||||||
def cast_to_kernel(inp):
|
def cast_to_kernel(inp):
|
||||||
if isinstance(inp, function):
|
if callable(inp):
|
||||||
return inp
|
return inp
|
||||||
else:
|
else:
|
||||||
func, *pars = inp.split('_')
|
func, *pars = inp.split('_')
|
||||||
return Avaible_Kernel_Funcs[func](*pars)
|
pars = [float(par) for par in pars]
|
||||||
|
return Avaible_Kernel_Funcs[func].get_func()(*pars)
|
||||||
|
|
||||||
|
|
||||||
class PCA_Distribution(SB3_Distribution):
|
class PCA_Distribution(SB3_Distribution):
|
||||||
@ -84,7 +85,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
|
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.kernel_func = cast_to_kernel(kernel_func)
|
self.kernel_func = cast_to_kernel(kernel_func)
|
||||||
self.par_strength = cast_to_enum(Par_Strength, par_strength)
|
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.window = window
|
self.window = window
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
@ -138,11 +139,10 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return eta.detach()
|
return eta.detach()
|
||||||
|
|
||||||
def _pad_and_cut_trajectory(self, traj, value=0):
|
def _pad_and_cut_trajectory(self, traj, value=0):
|
||||||
cut = traj[:self.window]
|
|
||||||
if traj.shape[-2] < self.window:
|
if traj.shape[-2] < self.window:
|
||||||
missing = self.window - traj.shape[-2]
|
missing = self.window - traj.shape[-2]
|
||||||
return F.pad(input=cut, pad=(missing, 0), value=value)
|
return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value)
|
||||||
return cut
|
return traj[:, :self.window, :]
|
||||||
|
|
||||||
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
def _conditioning_engine(self, trajectory, pi_mean, pi_std):
|
||||||
traj = self._pad_and_cut_trajectory(trajectory)
|
traj = self._pad_and_cut_trajectory(trajectory)
|
||||||
@ -188,6 +188,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
|
# https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
|
||||||
# This way conditioning of the GP can be done in O(dim(A)) time.
|
# This way conditioning of the GP can be done in O(dim(A)) time.
|
||||||
if not self.is_contextual():
|
if not self.is_contextual():
|
||||||
|
# TODO: fix, this does not work
|
||||||
# safe inplace
|
# safe inplace
|
||||||
self.conditioner[-1, -
|
self.conditioner[-1, -
|
||||||
1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)
|
1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)
|
||||||
|
Loading…
Reference in New Issue
Block a user