Experimental Support for PCA

This commit is contained in:
Dominik Moritz Roth 2023-05-21 14:27:09 +02:00
parent 76ea3a6326
commit 2efa6c18fb
2 changed files with 25 additions and 3 deletions

View File

@ -213,7 +213,10 @@ class GaussianRolloutCollectorAuxclass():
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
if 'use_pca' in self.policy and self.policy['use_pca']:
actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories())
else:
actions, values, log_probs = self.policy(obs_tensor)
dist = self.policy.get_distribution(obs_tensor).distribution
mean, chol = get_mean_and_chol(dist)
actions = actions.cpu().numpy()
@ -274,3 +277,7 @@ class GaussianRolloutCollectorAuxclass():
callback.on_rollout_end()
return True
def get_past_trajectories(self):
# TODO: Respect Episode Boundaries
return self.actions

View File

@ -42,6 +42,8 @@ from metastable_projections.projections.w2_projection_layer import WassersteinPr
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
from ..misc.distTools import get_mean_and_chol
from priorConditionedAnnealing.pca import PCA_Distribution
class ActorCriticPolicy(BasePolicy):
"""
@ -195,6 +197,7 @@ class ActorCriticPolicy(BasePolicy):
def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
TODO: Support for SDE under PCA
:param n_envs:
"""
@ -251,6 +254,10 @@ class ActorCriticPolicy(BasePolicy):
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
self.log_std_init)
)
elif isinstance(self.action_dist, PCA_Distribution):
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi
)
else:
raise NotImplementedError(
f"Unsupported distribution '{self.action_dist}'.")
@ -276,7 +283,7 @@ class ActorCriticPolicy(BasePolicy):
self.optimizer = self.optimizer_class(
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
@ -290,7 +297,11 @@ class ActorCriticPolicy(BasePolicy):
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
if self.use_pca:
assert trajectory, 'Past trajetcory has to be provided when using PCA.'
actions = distribution.get_actions(deterministic=deterministic, trajectory=trajectory)
else:
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
return actions, values, log_prob
@ -341,6 +352,10 @@ class ActorCriticPolicy(BasePolicy):
chol = self.chol_net(latent_pi)
self.chol = chol
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
elif isinstance(self.action_dist, PCA_Distribution):
chol = self.chol_net(latent_pi)
self.chol = chol
return self.action_dist.proba_distribution(mean_actions, self.chol)
else:
raise ValueError("Invalid action distribution")