Fixed PCA bug

This commit is contained in:
Dominik Moritz Roth 2023-05-21 16:32:27 +02:00
parent ea4b9851d8
commit a1df15f7bc

View File

@ -213,7 +213,7 @@ class GaussianRolloutCollectorAuxclass():
with th.no_grad(): with th.no_grad():
# Convert to pytorch tensor or to TensorDict # Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device) obs_tensor = obs_as_tensor(self._last_obs, self.device)
if 'use_pca' in self.policy and self.policy['use_pca']: if self.policy['use_pca']:
actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories()) actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories())
else: else:
actions, values, log_probs = self.policy(obs_tensor) actions, values, log_probs = self.policy(obs_tensor)