set deterministic=True
when sampling in diffusion evaluation
This commit is contained in:
parent
4962bbce38
commit
dd14c5887c
@ -50,7 +50,7 @@ class EvalDiffusionAgent(EvalAgent):
|
|||||||
.float()
|
.float()
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
}
|
}
|
||||||
samples = self.model(cond=cond)
|
samples = self.model(cond=cond, deterministic=True)
|
||||||
output_venv = (
|
output_venv = (
|
||||||
samples.trajectories.cpu().numpy()
|
samples.trajectories.cpu().numpy()
|
||||||
) # n_env x horizon x act
|
) # n_env x horizon x act
|
||||||
|
@ -53,7 +53,7 @@ class EvalImgDiffusionAgent(EvalAgent):
|
|||||||
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
|
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
|
||||||
for key in self.obs_dims
|
for key in self.obs_dims
|
||||||
} # batch each type of obs and put into dict
|
} # batch each type of obs and put into dict
|
||||||
samples = self.model(cond=cond)
|
samples = self.model(cond=cond, deterministic=True)
|
||||||
output_venv = (
|
output_venv = (
|
||||||
samples.trajectories.cpu().numpy()
|
samples.trajectories.cpu().numpy()
|
||||||
) # n_env x horizon x act
|
) # n_env x horizon x act
|
||||||
|
Loading…
Reference in New Issue
Block a user