From dd14c5887c222ea1abf55187fcd60bf0291a4e6c Mon Sep 17 00:00:00 2001 From: allenzren Date: Thu, 26 Sep 2024 01:15:10 -0400 Subject: [PATCH] set `deterministic=True` when sampling in diffusion evaluation --- agent/eval/eval_diffusion_agent.py | 2 +- agent/eval/eval_diffusion_img_agent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/eval/eval_diffusion_agent.py b/agent/eval/eval_diffusion_agent.py index b3f998f..9daa70f 100644 --- a/agent/eval/eval_diffusion_agent.py +++ b/agent/eval/eval_diffusion_agent.py @@ -50,7 +50,7 @@ class EvalDiffusionAgent(EvalAgent): .float() .to(self.device) } - samples = self.model(cond=cond) + samples = self.model(cond=cond, deterministic=True) output_venv = ( samples.trajectories.cpu().numpy() ) # n_env x horizon x act diff --git a/agent/eval/eval_diffusion_img_agent.py b/agent/eval/eval_diffusion_img_agent.py index 2582c36..e8eacbe 100644 --- a/agent/eval/eval_diffusion_img_agent.py +++ b/agent/eval/eval_diffusion_img_agent.py @@ -53,7 +53,7 @@ class EvalImgDiffusionAgent(EvalAgent): key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device) for key in self.obs_dims } # batch each type of obs and put into dict - samples = self.model(cond=cond) + samples = self.model(cond=cond, deterministic=True) output_venv = ( samples.trajectories.cpu().numpy() ) # n_env x horizon x act