get_past_trajectories will now return a th.Tensor
This commit is contained in:
parent
ea7634ca42
commit
d4003e1a68
@ -278,6 +278,6 @@ class GaussianRolloutCollectorAuxclass():
|
||||
|
||||
return True
|
||||
|
||||
def get_past_trajectories(self):
|
||||
def get_past_trajectories(self) -> th.Tensor:
|
||||
# TODO: Respect Episode Boundaries
|
||||
return self.rollout_buffer.actions
|
||||
return th.Tensor(self.rollout_buffer.actions)
|
||||
|
Loading…
Reference in New Issue
Block a user