From a1df15f7bc996093de33c0565904abdd8867c14b Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 21 May 2023 16:32:27 +0200 Subject: [PATCH] Fixed PCA bug --- metastable_baselines/misc/rollout_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index dcbdee6..58d4d42 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -213,7 +213,7 @@ class GaussianRolloutCollectorAuxclass(): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - if 'use_pca' in self.policy and self.policy['use_pca']: + if self.policy['use_pca']: actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories()) else: actions, values, log_probs = self.policy(obs_tensor)