fix observation history indexing in the dataset
This commit is contained in:
parent
1aaa6c2302
commit
ef5b14f820
@ -91,7 +91,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||
actions = self.actions[start:end]
|
||||
states = torch.stack(
|
||||
[
|
||||
states[min(num_before_start - t, 0)]
|
||||
states[max(num_before_start - t, 0)]
|
||||
for t in reversed(range(self.cond_steps))
|
||||
]
|
||||
) # more recent is at the end
|
||||
@ -100,7 +100,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||
images = self.images[(start - num_before_start) : end]
|
||||
images = torch.stack(
|
||||
[
|
||||
images[min(num_before_start - t, 0)]
|
||||
images[max(num_before_start - t, 0)]
|
||||
for t in reversed(range(self.img_cond_steps))
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user