diff --git a/fast_td3/fast_td3_utils.py b/fast_td3/fast_td3_utils.py index 6923440..d5c2a1d 100644 --- a/fast_td3/fast_td3_utils.py +++ b/fast_td3/fast_td3_utils.py @@ -388,7 +388,7 @@ class SimpleReplayBuffer(nn.Module): out["critic_observations"] = critic_observations out["next"]["critic_observations"] = next_critic_observations - if self.ptr >= self.buffer_size: + if self.n_steps > 1 and self.ptr >= self.buffer_size: # Roll back the truncation flags introduced for safe sampling self.truncations[:, current_pos - 1] = curr_truncations return out