set deterministic=True
when sampling in diffusion evaluation
This commit is contained in:
parent
4962bbce38
commit
dd14c5887c
@ -50,7 +50,7 @@ class EvalDiffusionAgent(EvalAgent):
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = self.model(cond=cond)
|
||||
samples = self.model(cond=cond, deterministic=True)
|
||||
output_venv = (
|
||||
samples.trajectories.cpu().numpy()
|
||||
) # n_env x horizon x act
|
||||
|
@ -53,7 +53,7 @@ class EvalImgDiffusionAgent(EvalAgent):
|
||||
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
|
||||
for key in self.obs_dims
|
||||
} # batch each type of obs and put into dict
|
||||
samples = self.model(cond=cond)
|
||||
samples = self.model(cond=cond, deterministic=True)
|
||||
output_venv = (
|
||||
samples.trajectories.cpu().numpy()
|
||||
) # n_env x horizon x act
|
||||
|
Loading…
Reference in New Issue
Block a user