diff --git a/agent/dataset/sequence.py b/agent/dataset/sequence.py index 28cb886..08349a3 100644 --- a/agent/dataset/sequence.py +++ b/agent/dataset/sequence.py @@ -91,7 +91,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset): actions = self.actions[start:end] states = torch.stack( [ - states[min(num_before_start - t, 0)] + states[max(num_before_start - t, 0)] for t in reversed(range(self.cond_steps)) ] ) # more recent is at the end @@ -100,7 +100,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset): images = self.images[(start - num_before_start) : end] images = torch.stack( [ - images[min(num_before_start - t, 0)] + images[max(num_before_start - t, 0)] for t in reversed(range(self.img_cond_steps)) ] )