diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index 58d4d42..2b68ffe 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -213,7 +213,7 @@ class GaussianRolloutCollectorAuxclass(): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - if self.policy['use_pca']: + if self.policy.use_pca: actions, values, log_probs = self.policy(obs_tensor, trajectory=self.get_past_trajectories()) else: actions, values, log_probs = self.policy(obs_tensor)