From d4003e1a680cf2cbbd94f65c473c7c749e6419d4 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 21 May 2023 17:29:03 +0200 Subject: [PATCH] get_past_trajectories will now return a th.Tensor --- metastable_baselines/misc/rollout_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index 85eed07..b96e628 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -278,6 +278,6 @@ class GaussianRolloutCollectorAuxclass(): return True - def get_past_trajectories(self): + def get_past_trajectories(self) -> th.Tensor: # TODO: Respect Episode Boundaries - return self.rollout_buffer.actions + return th.Tensor(self.rollout_buffer.actions)