diff --git a/test/utils.py b/test/utils.py index 01e33fe..427622d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -31,11 +31,12 @@ def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[T terminations = [] truncations = [] obs, _ = env.reset(seed=seed) + env.action_space.seed(seed) verify_observations(obs, env.observation_space, "reset()") iterations = iterations or (env.spec.max_episode_steps or 1) - # number of samples(multiple environment steps) + # number of samples (multiple environment steps) for i in range(iterations): observations.append(obs)