set deterministic=True when sampling in diffusion evaluation

This commit is contained in:
allenzren 2024-09-26 01:15:10 -04:00
parent 4962bbce38
commit dd14c5887c
2 changed files with 2 additions and 2 deletions

View File

@ -50,7 +50,7 @@ class EvalDiffusionAgent(EvalAgent):
.float() .float()
.to(self.device) .to(self.device)
} }
samples = self.model(cond=cond) samples = self.model(cond=cond, deterministic=True)
output_venv = ( output_venv = (
samples.trajectories.cpu().numpy() samples.trajectories.cpu().numpy()
) # n_env x horizon x act ) # n_env x horizon x act

View File

@ -53,7 +53,7 @@ class EvalImgDiffusionAgent(EvalAgent):
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device) key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
for key in self.obs_dims for key in self.obs_dims
} # batch each type of obs and put into dict } # batch each type of obs and put into dict
samples = self.model(cond=cond) samples = self.model(cond=cond, deterministic=True)
output_venv = ( output_venv = (
samples.trajectories.cpu().numpy() samples.trajectories.cpu().numpy()
) # n_env x horizon x act ) # n_env x horizon x act