Experimental Support for PCA
This commit is contained in:
parent
76ea3a6326
commit
2efa6c18fb
@ -213,7 +213,10 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Convert to pytorch tensor or to TensorDict
|
# Convert to pytorch tensor or to TensorDict
|
||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
actions, values, log_probs = self.policy(obs_tensor)
|
if 'use_pca' in self.policy and self.policy['use_pca']:
|
||||||
|
actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories())
|
||||||
|
else:
|
||||||
|
actions, values, log_probs = self.policy(obs_tensor)
|
||||||
dist = self.policy.get_distribution(obs_tensor).distribution
|
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||||
mean, chol = get_mean_and_chol(dist)
|
mean, chol = get_mean_and_chol(dist)
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
@ -274,3 +277,7 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
callback.on_rollout_end()
|
callback.on_rollout_end()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_past_trajectories(self):
|
||||||
|
# TODO: Respect Episode Boundaries
|
||||||
|
return self.actions
|
||||||
|
@ -42,6 +42,8 @@ from metastable_projections.projections.w2_projection_layer import WassersteinPr
|
|||||||
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
from ..distributions import UniversalGaussianDistribution, make_proba_distribution
|
||||||
from ..misc.distTools import get_mean_and_chol
|
from ..misc.distTools import get_mean_and_chol
|
||||||
|
|
||||||
|
from priorConditionedAnnealing.pca import PCA_Distribution
|
||||||
|
|
||||||
|
|
||||||
class ActorCriticPolicy(BasePolicy):
|
class ActorCriticPolicy(BasePolicy):
|
||||||
"""
|
"""
|
||||||
@ -195,6 +197,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
def reset_noise(self, n_envs: int = 1) -> None:
|
def reset_noise(self, n_envs: int = 1) -> None:
|
||||||
"""
|
"""
|
||||||
Sample new weights for the exploration matrix.
|
Sample new weights for the exploration matrix.
|
||||||
|
TODO: Support for SDE under PCA
|
||||||
|
|
||||||
:param n_envs:
|
:param n_envs:
|
||||||
"""
|
"""
|
||||||
@ -251,6 +254,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
|
latent_dim=latent_dim_pi, latent_sde_dim=self.latent_dim_sde or latent_dim_pi, std_init=math.exp(
|
||||||
self.log_std_init)
|
self.log_std_init)
|
||||||
)
|
)
|
||||||
|
elif isinstance(self.action_dist, PCA_Distribution):
|
||||||
|
self.action_net, self.chol_net = self.action_dist.proba_distribution_net(
|
||||||
|
latent_dim=latent_dim_pi
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupported distribution '{self.action_dist}'.")
|
f"Unsupported distribution '{self.action_dist}'.")
|
||||||
@ -276,7 +283,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
self.optimizer = self.optimizer_class(
|
self.optimizer = self.optimizer_class(
|
||||||
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||||
|
|
||||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||||
"""
|
"""
|
||||||
Forward pass in all the networks (actor and critic)
|
Forward pass in all the networks (actor and critic)
|
||||||
|
|
||||||
@ -290,7 +297,11 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
# Evaluate the values for the given observations
|
# Evaluate the values for the given observations
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
distribution = self._get_action_dist_from_latent(latent_pi)
|
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||||
actions = distribution.get_actions(deterministic=deterministic)
|
if self.use_pca:
|
||||||
|
assert trajectory, 'Past trajetcory has to be provided when using PCA.'
|
||||||
|
actions = distribution.get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||||
|
else:
|
||||||
|
actions = distribution.get_actions(deterministic=deterministic)
|
||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
return actions, values, log_prob
|
return actions, values, log_prob
|
||||||
|
|
||||||
@ -341,6 +352,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
chol = self.chol_net(latent_pi)
|
chol = self.chol_net(latent_pi)
|
||||||
self.chol = chol
|
self.chol = chol
|
||||||
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
return self.action_dist.proba_distribution(mean_actions, chol, latent_pi)
|
||||||
|
elif isinstance(self.action_dist, PCA_Distribution):
|
||||||
|
chol = self.chol_net(latent_pi)
|
||||||
|
self.chol = chol
|
||||||
|
return self.action_dist.proba_distribution(mean_actions, self.chol)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid action distribution")
|
raise ValueError("Invalid action distribution")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user