Experimental Support for PCA
This commit is contained in:
parent
76ea3a6326
commit
2efa6c18fb
@ -213,7 +213,10 @@ class GaussianRolloutCollectorAuxclass():
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor or to TensorDict
|
||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
if 'use_pca' in self.policy and self.policy['use_pca']:
|
||||
actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories())
|
||||
else:
|
||||
actions, values, log_probs = self.policy(obs_tensor)
|
||||
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||
mean, chol = get_mean_and_chol(dist)
|
||||
actions = actions.cpu().numpy()
|
||||
@ -274,3 +277,7 @@ class GaussianRolloutCollectorAuxclass():
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def get_past_trajectories(self):
|
||||
# TODO: Respect Episode Boundaries
|
||||
return self.actions
|
||||
|
@ -42,6 +42,8 @@ from metastable_projections.projections.w2_projection_layer import WassersteinPr
|
||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||
from ..misc.distTools import get_mean_and_chol
|
||||
|
||||
from priorConditionedAnnealing.pca import PCA_Distribution
|
||||
|
||||
|
||||
class ActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
@ -195,6 +197,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix.
|
||||
TODO: Support for SDE under PCA
|
||||
|
||||
:param n_envs:
|
||||
"""
|
||||
@ -251,6 +254,10 @@ class ActorCriticPolicy(BasePolicy):
|
||||
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
|
||||
self.log_std_init)
|
||||
)
|
||||
elif isinstance(self.action_dist, PCA_Distribution):
|
||||
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||
latent_dim=latent_dim_pi
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported distribution '{self.action_dist}'.")
|
||||
@ -276,7 +283,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
self.optimizer = self.optimizer_class(
|
||||
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
@ -290,7 +297,11 @@ class ActorCriticPolicy(BasePolicy):
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
if self.use_pca:
|
||||
assert trajectory, 'Past trajetcory has to be provided when using PCA.'
|
||||
actions = distribution.get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||
else:
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
@ -341,6 +352,10 @@ class ActorCriticPolicy(BasePolicy):
|
||||
chol = self.chol_net(latent_pi)
|
||||
self.chol = chol
|
||||
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
||||
elif isinstance(self.action_dist, PCA_Distribution):
|
||||
chol = self.chol_net(latent_pi)
|
||||
self.chol = chol
|
||||
return self.action_dist.proba_distribution(mean_actions, self.chol)
|
||||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user