fix observation history indexing in the dataset
This commit is contained in:
parent
1aaa6c2302
commit
ef5b14f820
@ -91,7 +91,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|||||||
actions = self.actions[start:end]
|
actions = self.actions[start:end]
|
||||||
states = torch.stack(
|
states = torch.stack(
|
||||||
[
|
[
|
||||||
states[min(num_before_start - t, 0)]
|
states[max(num_before_start - t, 0)]
|
||||||
for t in reversed(range(self.cond_steps))
|
for t in reversed(range(self.cond_steps))
|
||||||
]
|
]
|
||||||
) # more recent is at the end
|
) # more recent is at the end
|
||||||
@ -100,7 +100,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|||||||
images = self.images[(start - num_before_start) : end]
|
images = self.images[(start - num_before_start) : end]
|
||||||
images = torch.stack(
|
images = torch.stack(
|
||||||
[
|
[
|
||||||
images[min(num_before_start - t, 0)]
|
images[max(num_before_start - t, 0)]
|
||||||
for t in reversed(range(self.img_cond_steps))
|
for t in reversed(range(self.img_cond_steps))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user