get_past_trajectories will now return a th.Tensor

This commit is contained in:
Dominik Moritz Roth 2023-05-21 17:29:03 +02:00
parent ea7634ca42
commit d4003e1a68

View File

@ -278,6 +278,6 @@ class GaussianRolloutCollectorAuxclass():
return True
def get_past_trajectories(self):
def get_past_trajectories(self) -> th.Tensor:
# TODO: Respect Episode Boundaries
return self.rollout_buffer.actions
return th.Tensor(self.rollout_buffer.actions)