Tiny Bugfix

This commit is contained in:
Dominik Moritz Roth 2023-05-21 16:40:13 +02:00
parent a1df15f7bc
commit 43418a5e53

View File

@ -213,7 +213,7 @@ class GaussianRolloutCollectorAuxclass():
with th.no_grad(): with th.no_grad():
# Convert to pytorch tensor or to TensorDict # Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device) obs_tensor = obs_as_tensor(self._last_obs, self.device)
if self.policy['use_pca']: if self.policy.use_pca:
actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories()) actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories())
else: else:
actions, values, log_probs = self.policy(obs_tensor) actions, values, log_probs = self.policy(obs_tensor)