diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index e90f8ad..3bdf504 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -89,6 +89,7 @@ class PCA_Distribution(SB3_Distribution): def __init__( self, action_dim: int, + n_envs: int=1, par_strength: Par_Strength = Par_Strength.CONT_DIAG, kernel_func=kernel.rbf(), init_std: float = 1, @@ -96,6 +97,7 @@ class PCA_Distribution(SB3_Distribution): window: int = 64, epsilon: float = 1e-6, skip_conditioning: bool = False, + temporal_gradient_emission: bool = False, Base_Noise=noise.White_Noise, ): super().__init__() @@ -108,8 +110,9 @@ class PCA_Distribution(SB3_Distribution): self.window = window self.epsilon = epsilon self.skip_conditioning = skip_conditioning + self.temporal_gradient_emission = temporal_gradient_emission - self.base_noise = cast_to_Noise(Base_Noise, (1, action_dim)) + self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim)) # Premature optimization is the root of all evil self._build_conditioner() @@ -128,7 +131,16 @@ class PCA_Distribution(SB3_Distribution): return self def log_prob(self, actions: th.Tensor) -> th.Tensor: - return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device))) + return self._log_prob(actions, self.distribution) + + def conditioned_log_prob(self, actions: th.Tensor, trajectory: th.Tensor = None) -> th.Tensor: + pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu() + rho_mean, rho_std = self._conditioning_engine(trajectory, pi_mean, pi_std) + dist = Normal(rho_mean, rho_std) + return self._log_prob(dist) + + def _log_prob(self, actions: th.Tensor, dist: Normal): + return sum_independent_dims(dist.log_prob(actions.to(dist.mean.device))) def entropy(self) -> th.Tensor: return sum_independent_dims(self.distribution.entropy()) @@ -164,20 +176,26 @@ class PCA_Distribution(SB3_Distribution): return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG] def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None): - with th.no_grad(): - if epsilon == None: - epsilon = self.base_noise(pi_mean.shape) + # Ugly function to ensure, that the gradients flow as intended for each modus operandi + if not self.temporal_gradient_emission or self.skip_conditioning: + with th.no_grad(): + return self._get_emitting_rigged(pi_mean, pi_std, rho_mean, rho_std, epsilon=epsilon).detach() + return self._get_emitting_rigged(pi_mean.detach(), pi_std.detach(), rho_mean, rho_std, epsilon=epsilon) - if self.skip_conditioning: - return epsilon.detach() + def _get_emitting_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None): + if epsilon == None: + epsilon = self.base_noise(pi_mean.shape) - Delta = rho_mean - pi_mean - Pi_mu = 1 / pi_std - Pi_sigma = rho_std / pi_std + if self.skip_conditioning: + return epsilon.detach() - eta = Pi_mu * Delta + Pi_sigma * epsilon + Delta = rho_mean - pi_mean + Pi_mu = 1 / pi_std + Pi_sigma = rho_std / pi_std - return eta.detach() + eta = Pi_mu * Delta + Pi_sigma * epsilon + + return eta def _pad_and_cut_trajectory(self, traj, value=0): if traj.shape[-2] < self.window: @@ -194,18 +212,15 @@ class PCA_Distribution(SB3_Distribution): return pi_mean, pi_std traj = self._pad_and_cut_trajectory(trajectory) - - # Numpy is fun - y_np = np.append(np.swapaxes(traj, -1, -2), np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1) + y = th.cat((traj.transpose(-1, -2), pi_mean.unsqueeze(-1).unsqueeze(0).repeat(traj.shape[0], 1, traj.shape[-2])), dim=1) with th.no_grad(): conditioners = th.Tensor(self._adapt_conditioner(pi_std)) - y = th.Tensor(y_np) S = th.cholesky_solve(self.Sig12.expand(conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1) - rho_mean = th.einsum('bai,bai->ba', S, y) rho_std = self.Sig22 - (S @ self.Sig12) + rho_mean = th.einsum('bai,bai->ba', S, y) return rho_mean, rho_std @@ -254,13 +269,13 @@ class PCA_Distribution(SB3_Distribution): return self.distribution.mean def actions_from_params( - self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False + self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None ) -> th.Tensor: self.proba_distribution(mean, std) - return self.get_actions(deterministic=deterministic) + return self.get_actions(deterministic=deterministic, trajectory=trajectory) - def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor): - actions = self.actions_from_params(mean, std) + def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor, trajectory: th.Tensor = None): + actions = self.actions_from_params(mean, std, trajectory=trajectory) log_prob = self.log_prob(actions) return actions, log_prob