diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index 39ac441..dadf7a7 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -15,7 +15,7 @@ class Colored_Noise(): self.reset(random_state=random_state) def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: - assert shape == self.shape + assert shape == self.knonw_shape sample = self.samples[:, self.index] self.index = (self.index+1) % self.num_samples return sample diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 25592a8..a1dfaa9 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -145,6 +145,7 @@ class PCA_Distribution(SB3_Distribution): return self.sample(traj=trajectory) def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor: + assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed' pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu() rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) rho_std *= f_sigma @@ -157,6 +158,8 @@ class PCA_Distribution(SB3_Distribution): return actions def is_contextual(self): + return True # TODO: Remove, when bug for non-contextual is fixed + # Always returning True will merely waste cpu cycles return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG] def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None): @@ -177,6 +180,10 @@ class PCA_Distribution(SB3_Distribution): def _pad_and_cut_trajectory(self, traj, value=0): if traj.shape[-2] < self.window: + if traj.shape[-2] == 0: + shape = list(traj.shape) + shape[-2] = 1 + traj = th.ones(shape)*value missing = self.window - traj.shape[-2] return F.pad(input=traj, pad=(0, 0, missing, 0, 0, 0), value=value) return traj[:, -self.window:, :] @@ -186,16 +193,15 @@ class PCA_Distribution(SB3_Distribution): return pi_mean, pi_std traj = self._pad_and_cut_trajectory(trajectory) + # Numpy is fun - y_np = np.append(np.swapaxes(traj, -1, -2), - np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1) + y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1) with th.no_grad(): conditioners = th.Tensor(self._adapt_conditioner(pi_std)) y = th.Tensor(y_np) - S = th.cholesky_solve(self.Sig12.expand( - conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1) + S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1) rho_mean = th.einsum('bai,bai->ba', S, y) rho_std = self.Sig22 - (S @ self.Sig12)