diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index 846d60a..dcbdee6 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -213,7 +213,10 @@ class GaussianRolloutCollectorAuxclass(): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - actions, values, log_probs = self.policy(obs_tensor) + if 'use_pca' in self.policy and self.policy['use_pca']: + actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories()) + else: + actions, values, log_probs = self.policy(obs_tensor) dist = self.policy.get_distribution(obs_tensor).distribution mean, chol = get_mean_and_chol(dist) actions = actions.cpu().numpy() @@ -274,3 +277,7 @@ class GaussianRolloutCollectorAuxclass(): callback.on_rollout_end() return True + + def get_past_trajectories(self): + # TODO: Respect Episode Boundaries + return self.actions diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index e7a363a..8bcb2c0 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -42,6 +42,8 @@ from metastable_projections.projections.w2_projection_layer import WassersteinPr from ..distributions import UniversalGaussianDistribution, make_proba_distribution from ..misc.distTools import get_mean_and_chol +from priorConditionedAnnealing.pca import PCA_Distribution + class ActorCriticPolicy(BasePolicy): """ @@ -195,6 +197,7 @@ class ActorCriticPolicy(BasePolicy): def reset_noise(self, n_envs: int = 1) -> None: """ Sample new weights for the exploration matrix. + TODO: Support for SDE under PCA :param n_envs: """ @@ -251,6 +254,10 @@ class ActorCriticPolicy(BasePolicy): latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp( self.log_std_init) ) + elif isinstance(self.action_dist, PCA_Distribution): + self.action_net, self.chol_net = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi + ) else: raise NotImplementedError( f"Unsupported distribution '{self.action_dist}'.") @@ -276,7 +283,7 @@ class ActorCriticPolicy(BasePolicy): self.optimizer = self.optimizer_class( self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) @@ -290,7 +297,11 @@ class ActorCriticPolicy(BasePolicy): # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) - actions = distribution.get_actions(deterministic=deterministic) + if self.use_pca: + assert trajectory, 'Past trajetcory has to be provided when using PCA.' + actions = distribution.get_actions(deterministic=deterministic, trajectory=trajectory) + else: + actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) return actions, values, log_prob @@ -341,6 +352,10 @@ class ActorCriticPolicy(BasePolicy): chol = self.chol_net(latent_pi) self.chol = chol return self.action_dist.proba_distribution(mean_actions, chol, latent_pi) + elif isinstance(self.action_dist, PCA_Distribution): + chol = self.chol_net(latent_pi) + self.chol = chol + return self.action_dist.proba_distribution(mean_actions, self.chol) else: raise ValueError("Invalid action distribution")