set deterministic=True when sampling in diffusion evaluation

This commit is contained in:
allenzren 2024-09-26 01:15:10 -04:00
parent 4962bbce38
commit dd14c5887c
2 changed files with 2 additions and 2 deletions

View File

@ -50,7 +50,7 @@ class EvalDiffusionAgent(EvalAgent):
.float()
.to(self.device)
}
samples = self.model(cond=cond)
samples = self.model(cond=cond, deterministic=True)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act

View File

@ -53,7 +53,7 @@ class EvalImgDiffusionAgent(EvalAgent):
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
for key in self.obs_dims
} # batch each type of obs and put into dict
samples = self.model(cond=cond)
samples = self.model(cond=cond, deterministic=True)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act